Skip to content

Commit f468595

Browse files
you-n-gxuangu-fang
andauthored
fix: scheduler next selection parallel disorder (#1028)
* fix: improve scheduler API (suggest_sel) and add timer.remain_time() * chore: exclude .venv from auto-black and auto-isort tasks * refactor: wrap RoundRobinScheduler commit and selection in retry loop * set search_type="ancestors" for experiment_and_feedback_list_after_init * refactor: merge sync_dag_parent_and_hist and hist.append into one call * fix uncommited rec bug * lint --------- Co-authored-by: xuangu-fang <[email protected]>
1 parent 0059a6a commit f468595

File tree

6 files changed

+42
-45
lines changed

6 files changed

+42
-45
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ pre-commit:
119119

120120
# Auto lint with black.
121121
auto-black:
122-
$(PIPRUN) python -m black . --extend-exclude test/scripts --extend-exclude git_ignore_folder -l 120
122+
$(PIPRUN) python -m black . --extend-exclude test/scripts --extend-exclude git_ignore_folder --extend-exclude .venv -l 120
123123

124124
# Auto lint with isort.
125125
auto-isort:
126-
$(PIPRUN) python -m isort . -s git_ignore_folder -s test/scripts
126+
$(PIPRUN) python -m isort . -s git_ignore_folder -s test/scripts -s .venv
127127

128128
# Auto lint with toml-sort.
129129
auto-toml-sort:

rdagent/log/utils/folder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_first_session_file_after_duration(log_folder: str | Path, duration: str
2424
session_obj: LoopBase = pickle.load(f)
2525
timer = session_obj.timer
2626
all_duration = timer.all_duration
27-
remain_time_duration = timer.remain_time_duration
27+
remain_time_duration = timer.remain_time()
2828
if all_duration is None or remain_time_duration is None:
2929
msg = "Timer is not configured"
3030
raise ValueError(msg)

rdagent/scenarios/data_science/loop.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,19 +210,14 @@ def record(self, prev_out: dict[str, Any]):
210210
# set the local selection to the trace as global selection, then set the DAG parent for the trace
211211
if exp.local_selection is not None:
212212
self.trace.set_current_selection(exp.local_selection)
213-
self.trace.sync_dag_parent_and_hist()
214-
215-
self.trace.hist.append((exp, prev_out["feedback"]))
216-
213+
self.trace.sync_dag_parent_and_hist((exp, prev_out["feedback"]))
217214
else:
218215
exp: DSExperiment = prev_out["direct_exp_gen"] if isinstance(e, CoderError) else prev_out["coding"]
219216

220217
# set the local selection to the trace as global selection, then set the DAG parent for the trace
221218
if exp.local_selection is not None:
222219
self.trace.set_current_selection(exp.local_selection)
223-
self.trace.sync_dag_parent_and_hist()
224-
225-
self.trace.hist.append(
220+
self.trace.sync_dag_parent_and_hist(
226221
(
227222
exp,
228223
ExperimentFeedback.from_exception(e),

rdagent/scenarios/data_science/proposal/exp_gen/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from rdagent.app.data_science.conf import DS_RD_SETTING
55
from rdagent.core.evolving_framework import KnowledgeBase
6+
from rdagent.core.experiment import Experiment
67
from rdagent.core.proposal import ExperimentFeedback, Hypothesis, Trace
78
from rdagent.scenarios.data_science.experiment.experiment import COMPONENT, DSExperiment
89
from rdagent.scenarios.data_science.scen import DataScienceScen
@@ -93,6 +94,7 @@ def get_leaves(self) -> list[int, ...]:
9394

9495
def sync_dag_parent_and_hist(
9596
self,
97+
exp_and_fb: tuple[Experiment, ExperimentFeedback],
9698
) -> None:
9799
"""
98100
Adding corresponding parent index to the dag_parent when the hist is going to be changed.
@@ -111,6 +113,7 @@ def sync_dag_parent_and_hist(
111113
current_node_idx = len(self.hist) - 1
112114

113115
self.dag_parent.append((current_node_idx,))
116+
self.hist.append(exp_and_fb)
114117

115118
def retrieve_search_list(
116119
self,
@@ -171,7 +174,7 @@ def has_component(
171174
def experiment_and_feedback_list_after_init(
172175
self,
173176
return_type: Literal["sota", "failed", "all"],
174-
search_type: Literal["all", "ancestors"] = "all",
177+
search_type: Literal["all", "ancestors"] = "ancestors",
175178
selection: tuple[int, ...] | None = None,
176179
max_retrieve_num: int | None = None,
177180
) -> list[tuple[DSExperiment, ExperimentFeedback]]:

rdagent/scenarios/data_science/proposal/exp_gen/parallel.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def __init__(self, *args, **kwargs):
3636
# The underlying generator for creating a single experiment
3737
self.exp_gen = DataScienceRDLoop.default_exp_gen(self.scen)
3838
self.merge_exp_gen = ExpGen2Hypothesis(self.scen)
39-
self.trace_scheduler: TraceScheduler = RoundRobinScheduler()
40-
self.max_trace_num = DS_RD_SETTING.max_trace_num
39+
self.trace_scheduler: TraceScheduler = RoundRobinScheduler(DS_RD_SETTING.max_trace_num)
4140

4241
def gen(self, trace: "DSTrace") -> "Experiment":
4342
raise NotImplementedError(
@@ -67,15 +66,9 @@ async def async_gen(self, trace: DSTrace, loop: LoopBase) -> DSExperiment:
6766
else:
6867
# set the knowledge base option back to False for the other traces
6968
DS_RD_SETTING.enable_knowledge_base = False
70-
# step 1: select the parant trace to expand
71-
# Policy: if we have fewer traces than our target, start a new one.
72-
if trace.sub_trace_count < self.max_trace_num:
73-
local_selection = trace.NEW_ROOT
74-
else:
75-
# Otherwise, use the scheduler to pick an existing trace to expand.
76-
local_selection = await self.trace_scheduler.select_trace(trace)
7769

7870
if loop.get_unfinished_loop_cnt(loop.loop_idx) < RD_AGENT_SETTINGS.get_max_parallel():
71+
local_selection = await self.trace_scheduler.next(trace)
7972

8073
# set the local selection as the global current selection for the trace
8174
trace.set_current_selection(local_selection)

rdagent/scenarios/data_science/proposal/exp_gen/trace_scheduler.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
from abc import ABC, abstractmethod
5+
from collections import defaultdict
56
from typing import TYPE_CHECKING
67

78
if TYPE_CHECKING:
@@ -15,11 +16,14 @@ class TraceScheduler(ABC):
1516
"""
1617

1718
@abstractmethod
18-
async def select_trace(self, trace: DSTrace) -> tuple[int, ...]:
19+
async def next(self, trace: DSTrace) -> tuple[int, ...]:
1920
"""
2021
Selects the next trace to expand.
2122
22-
This method must be async to allow for safe concurrent access.
23+
For proposing selections, we have to follow the rules
24+
- Suggest selection: suggest a selection that is suitable for the current trace.
25+
- Suggested should be garenteed to be recorded at last!!!
26+
- If no suitable selection is found, the function should async wait!!!!
2327
2428
Args:
2529
trace: The DSTrace object containing the full experiment history.
@@ -39,31 +43,33 @@ class RoundRobinScheduler(TraceScheduler):
3943
NOTE: we don't need to use asyncio.Lock here as the kickoff_loop ensures the ExpGen is always sequential, instead of parallel.
4044
"""
4145

42-
def __init__(self):
46+
def __init__(self, max_trace_num: int):
47+
self.max_trace_num = max_trace_num
4348
self._last_selected_leaf_id = -1
49+
self.rec_commit_idx = 0 # the node before rec_idx is already committed.
50+
self.uncommited_rec_status = defaultdict(int) # the uncommited record status
4451

45-
async def select_trace(self, trace: DSTrace) -> tuple[int, ...]:
52+
async def next(self, trace: DSTrace) -> tuple[int, ...]:
4653
"""
4754
Atomically selects the next leaf node from the trace in order.
4855
"""
49-
50-
leaves = trace.get_leaves()
51-
if not leaves:
52-
# This is the very first experiment in a new tree.
53-
return trace.NEW_ROOT
54-
55-
# Find the index of the last selected leaf in the current list of leaves
56-
try:
57-
current_position = leaves.index(self._last_selected_leaf_id)
58-
# Move to the next position, wrapping around if necessary
59-
next_position = (current_position + 1) % len(leaves)
60-
except ValueError:
61-
# This can happen if the last selected leaf is no longer a leaf
62-
# (it has been expanded) or if this is the first selection.
63-
# In either case, start from the beginning.
64-
next_position = 0
65-
66-
selected_leaf = leaves[next_position]
67-
self._last_selected_leaf_id = selected_leaf
68-
69-
return (selected_leaf,)
56+
while True:
57+
# step 0: Commit the pending selections
58+
for i in range(self.rec_commit_idx, len(trace.dag_parent)):
59+
for p in trace.dag_parent[i]:
60+
self.uncommited_rec_status[p] -= 1
61+
self.rec_commit_idx = len(trace.hist)
62+
63+
# step 1: select the parant trace to expand
64+
# Policy: if we have fewer traces than our target, start a new one.
65+
if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < self.max_trace_num:
66+
self.uncommited_rec_status[trace.NEW_ROOT] += 1
67+
return trace.NEW_ROOT
68+
69+
# Step2: suggest a selection to a not expanding leave
70+
leaves = trace.get_leaves()
71+
for leaf in leaves:
72+
if self.uncommited_rec_status[leaf] == 0:
73+
self.uncommited_rec_status[leaf] += 1
74+
return (leaf,)
75+
await asyncio.sleep(1)

0 commit comments

Comments
 (0)