gabriel / musehub public
proposal_dag.py python
379 lines 14.4 KB
Raw
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor ⚠ breaking 1 day ago
1 """Proposal dependency DAG engine — issue #37 Phase 2.
2
3 Provides:
4 - ``build_dag`` — construct an adjacency graph from dependency edge rows
5 - ``topological_sort`` — Kahn's algorithm; raises ``CycleError`` on cycle
6 - ``detect_cycle`` — bool cycle test (no sort, just detection)
7 - ``blocked_by`` — proposal numbers with unmerged dependencies
8 - ``blocks`` — proposal numbers that depend on a given proposal
9 - ``DagNode`` — lightweight per-proposal view of DAG position
10 - ``ProposalDag`` — full repo-scoped DAG built from DB edge rows
11
12 All pure functions except ``load_dag`` and ``load_dag_for_proposals``, which
13 query the DB. The pure core is independently testable with no session required.
14 """
15
16 from __future__ import annotations
17
18 import logging
19 from collections import defaultdict, deque
20 from dataclasses import dataclass, field
21 from typing import Iterable
22
23 from sqlalchemy import select
24 from sqlalchemy.ext.asyncio import AsyncSession
25
26 from musehub.db.musehub_social_models import MusehubProposal, MusehubProposalDependency
27
28 logger = logging.getLogger(__name__)
29
30
31 # ─────────────────────────────────────────────────────────────────────────────
32 # Errors
33 # ─────────────────────────────────────────────────────────────────────────────
34
35
36 class CycleError(ValueError):
37 """Raised when the dependency graph contains a cycle.
38
39 ``cycle_ids`` contains the proposal_ids that form (or are part of) the cycle.
40 The exact cycle membership is approximated via the set of nodes that Kahn's
41 algorithm could not drain — all such nodes are part of at least one cycle.
42 """
43
44 def __init__(self, cycle_ids: set[str]) -> None:
45 self.cycle_ids = cycle_ids
46 super().__init__(
47 f"Dependency cycle detected among proposals: {sorted(cycle_ids)}"
48 )
49
50
51 # ─────────────────────────────────────────────────────────────────────────────
52 # Core data structures
53 # ─────────────────────────────────────────────────────────────────────────────
54
55
56 @dataclass
57 class ProposalDag:
58 """Adjacency representation of a proposal dependency graph.
59
60 Nodes are proposal_ids (str). Edges encode *dependency* direction:
61 ``depends_on[A]`` = set of proposal_ids that A *must wait for*
62 ``required_by[B]`` = set of proposal_ids that *must wait for* B
63
64 ``merged_ids`` is the set of proposals already in MERGED state — used to
65 compute which blocking edges are still live.
66 ``number_by_id`` maps proposal_id → proposal_number for display.
67 """
68
69 # depends_on[A] → {B, C} means A cannot merge until B and C are merged
70 depends_on: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
71 # required_by[B] → {A, D} means B is blocking A and D
72 required_by: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
73 # Full set of node IDs seen in any edge
74 nodes: set[str] = field(default_factory=set)
75 # proposal_ids already in MERGED state
76 merged_ids: set[str] = field(default_factory=set)
77 # proposal_id → proposal_number (for display in blocked_by / blocks lists)
78 number_by_id: dict[str, int] = field(default_factory=dict)
79
80
81 def build_dag(
82 edges: Iterable[tuple[str, str]],
83 *,
84 merged_ids: Iterable[str] = (),
85 number_by_id: dict[str, int] | None = None,
86 ) -> ProposalDag:
87 """Build a ``ProposalDag`` from raw (dependent, dependency) edge pairs.
88
89 Args:
90 edges: Iterable of ``(dependent_id, dependency_id)`` tuples.
91 ``dependent_id`` is the proposal that *must wait*.
92 ``dependency_id`` is the proposal it *must wait for*.
93 merged_ids: Proposal IDs already in MERGED state — live blocking
94 edges are edges where the dependency is NOT merged.
95 number_by_id: Optional map from proposal_id → proposal_number.
96
97 Returns:
98 Populated ``ProposalDag``.
99 """
100 dag = ProposalDag(
101 depends_on=defaultdict(set),
102 required_by=defaultdict(set),
103 merged_ids=set(merged_ids),
104 number_by_id=dict(number_by_id or {}),
105 )
106 for dependent_id, dependency_id in edges:
107 dag.nodes.add(dependent_id)
108 dag.nodes.add(dependency_id)
109 dag.depends_on[dependent_id].add(dependency_id)
110 dag.required_by[dependency_id].add(dependent_id)
111 return dag
112
113
114 def topological_sort(dag: ProposalDag) -> list[str]:
115 """Return a valid merge order via Kahn's algorithm.
116
117 Only considers *live* edges — edges where the dependency is NOT yet merged.
118 Merged proposals are treated as already-satisfied prerequisites, so they do
119 not appear in the output list.
120
121 Raises:
122 CycleError: If any cycle is detected among the unmerged nodes.
123
124 Returns:
125 List of proposal_ids in a valid merge order (unmerged proposals only).
126 Proposals with no unmerged dependencies appear first.
127 """
128 # Work only with unmerged nodes that appear in at least one edge
129 unmerged_nodes = dag.nodes - dag.merged_ids
130
131 # in-degree counting over *live* edges only
132 in_degree: dict[str, int] = {n: 0 for n in unmerged_nodes}
133 for dependent in unmerged_nodes:
134 live_deps = dag.depends_on[dependent] - dag.merged_ids
135 in_degree[dependent] = len(live_deps)
136
137 # Seed the queue with zero-in-degree nodes
138 queue: deque[str] = deque(n for n, d in in_degree.items() if d == 0)
139 order: list[str] = []
140
141 while queue:
142 node = queue.popleft()
143 order.append(node)
144 for dependent in dag.required_by.get(node, set()):
145 if dependent in dag.merged_ids:
146 continue
147 in_degree[dependent] -= 1
148 if in_degree[dependent] == 0:
149 queue.append(dependent)
150
151 drained = set(order)
152 cycle_members = unmerged_nodes - drained
153 if cycle_members:
154 raise CycleError(cycle_members)
155
156 return order
157
158
159 def detect_cycle(dag: ProposalDag) -> bool:
160 """Return ``True`` if the DAG contains a cycle among unmerged nodes."""
161 try:
162 topological_sort(dag)
163 return False
164 except CycleError:
165 return True
166
167
168 def blocked_by_numbers(dag: ProposalDag, proposal_id: str) -> list[int]:
169 """Return proposal_numbers of unmerged hard dependencies for ``proposal_id``.
170
171 These are the proposals that must reach MERGED state before ``proposal_id``
172 can be merged. Returns an empty list when all dependencies are satisfied.
173 """
174 live_deps = dag.depends_on.get(proposal_id, set()) - dag.merged_ids
175 return sorted(dag.number_by_id[dep] for dep in live_deps if dep in dag.number_by_id)
176
177
178 def blocks_numbers(dag: ProposalDag, proposal_id: str) -> list[int]:
179 """Return proposal_numbers of unmerged proposals that are waiting on ``proposal_id``.
180
181 These are proposals that list ``proposal_id`` as a hard dependency and have
182 not yet been merged themselves.
183 """
184 waiters = dag.required_by.get(proposal_id, set()) - dag.merged_ids
185 return sorted(dag.number_by_id[w] for w in waiters if w in dag.number_by_id)
186
187
188 def is_blocked(dag: ProposalDag, proposal_id: str) -> bool:
189 """Return ``True`` if ``proposal_id`` has at least one unmerged dependency."""
190 return bool(dag.depends_on.get(proposal_id, set()) - dag.merged_ids)
191
192
193 # ─────────────────────────────────────────────────────────────────────────────
194 # DB helpers
195 # ─────────────────────────────────────────────────────────────────────────────
196
197
198 async def load_dag(session: AsyncSession, repo_id: str) -> ProposalDag:
199 """Load the full dependency DAG for a repo from the DB.
200
201 Issues two queries:
202 1. All dependency edges for proposals in this repo.
203 2. proposal_id → (proposal_number, state) for all proposals in this repo.
204
205 The returned DAG includes all proposals, merged or not.
206 """
207 # Query 1 — dependency edges
208 edge_rows = list(
209 (
210 await session.execute(
211 select(
212 MusehubProposalDependency.dependent_proposal_id,
213 MusehubProposalDependency.dependency_proposal_id,
214 ).join(
215 MusehubProposal,
216 MusehubProposal.proposal_id == MusehubProposalDependency.dependent_proposal_id,
217 ).where(MusehubProposal.repo_id == repo_id)
218 )
219 ).all()
220 )
221
222 # Query 2 — proposal numbers and merged states
223 proposal_rows = list(
224 (
225 await session.execute(
226 select(
227 MusehubProposal.proposal_id,
228 MusehubProposal.proposal_number,
229 MusehubProposal.state,
230 ).where(MusehubProposal.repo_id == repo_id)
231 )
232 ).all()
233 )
234
235 number_by_id = {pid: num for pid, num, _ in proposal_rows}
236 merged_ids = {pid for pid, _, state in proposal_rows if state == "merged"}
237
238 return build_dag(
239 edges=[(dep_id, dependency_id) for dep_id, dependency_id in edge_rows],
240 merged_ids=merged_ids,
241 number_by_id=number_by_id,
242 )
243
244
245 async def load_dag_for_proposals(
246 session: AsyncSession,
247 proposal_ids: list[str],
248 ) -> ProposalDag:
249 """Load a partial DAG covering ``proposal_ids`` and their transitive neighbours.
250
251 Used by the batch enrichment path: loads only the edges and proposals
252 relevant to the current page, plus any proposals they depend on or block.
253
254 Issues two queries — one for edges, one for proposal metadata.
255 """
256 if not proposal_ids:
257 return ProposalDag()
258
259 # Query 1 — all edges where either end is in our proposal set
260 edge_rows = list(
261 (
262 await session.execute(
263 select(
264 MusehubProposalDependency.dependent_proposal_id,
265 MusehubProposalDependency.dependency_proposal_id,
266 ).where(
267 MusehubProposalDependency.dependent_proposal_id.in_(proposal_ids)
268 | MusehubProposalDependency.dependency_proposal_id.in_(proposal_ids)
269 )
270 )
271 ).all()
272 )
273
274 # Collect the full set of proposal IDs we need metadata for
275 all_ids = set(proposal_ids)
276 for dep_id, dependency_id in edge_rows:
277 all_ids.add(dep_id)
278 all_ids.add(dependency_id)
279
280 # Query 2 — proposal numbers and states for all relevant proposals
281 proposal_rows = list(
282 (
283 await session.execute(
284 select(
285 MusehubProposal.proposal_id,
286 MusehubProposal.proposal_number,
287 MusehubProposal.state,
288 ).where(MusehubProposal.proposal_id.in_(list(all_ids)))
289 )
290 ).all()
291 )
292
293 number_by_id = {pid: num for pid, num, _ in proposal_rows}
294 merged_ids = {pid for pid, _, state in proposal_rows if state == "merged"}
295
296 return build_dag(
297 edges=[(dep_id, dependency_id) for dep_id, dependency_id in edge_rows],
298 merged_ids=merged_ids,
299 number_by_id=number_by_id,
300 )
301
302
303 async def create_dependency_edges(
304 session: AsyncSession,
305 dependent_proposal_id: str,
306 dependency_proposal_ids: list[str],
307 ) -> None:
308 """Persist dependency edges for a newly created proposal.
309
310 Validates that none of the dependency proposal IDs would introduce a cycle
311 *before* writing any rows. Raises ``CycleError`` if a cycle is detected.
312
313 Args:
314 session: Async DB session (must be inside a transaction).
315 dependent_proposal_id: The proposal being created.
316 dependency_proposal_ids: Proposal IDs it must wait for.
317 """
318 if not dependency_proposal_ids:
319 return
320
321 from musehub.core.genesis import _genesis_hash
322
323 # Verify all dependency IDs exist
324 existing = list(
325 (
326 await session.execute(
327 select(MusehubProposal.proposal_id).where(
328 MusehubProposal.proposal_id.in_(dependency_proposal_ids)
329 )
330 )
331 ).scalars()
332 )
333 missing = set(dependency_proposal_ids) - set(existing)
334 if missing:
335 raise ValueError(f"depends_on references unknown proposal IDs: {sorted(missing)}")
336
337 # Build tentative DAG including the new edges
338 # We need to load existing edges for the relevant proposals
339 existing_edges = list(
340 (
341 await session.execute(
342 select(
343 MusehubProposalDependency.dependent_proposal_id,
344 MusehubProposalDependency.dependency_proposal_id,
345 ).where(
346 MusehubProposalDependency.dependent_proposal_id.in_(
347 dependency_proposal_ids + [dependent_proposal_id]
348 )
349 | MusehubProposalDependency.dependency_proposal_id.in_(
350 dependency_proposal_ids + [dependent_proposal_id]
351 )
352 )
353 )
354 ).all()
355 )
356
357 all_edges = list(existing_edges) + [
358 (dependent_proposal_id, dep_id) for dep_id in dependency_proposal_ids
359 ]
360 tentative_dag = build_dag(edges=all_edges)
361
362 if detect_cycle(tentative_dag):
363 raise CycleError({dependent_proposal_id} | set(dependency_proposal_ids))
364
365 # Safe to persist
366 for dep_id in dependency_proposal_ids:
367 dep_edge_id = _genesis_hash(dependent_proposal_id, dep_id)
368 edge = MusehubProposalDependency(
369 dep_id=dep_edge_id,
370 dependent_proposal_id=dependent_proposal_id,
371 dependency_proposal_id=dep_id,
372 )
373 session.add(edge)
374
375 logger.info(
376 "✅ Created %d dependency edge(s) for proposal %s",
377 len(dependency_proposal_ids),
378 dependent_proposal_id,
379 )
File History 1 commit
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor 1 day ago