Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions rdagent/utils/workflow/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ async def _run_step(self, li: int, force_subproc: bool = False) -> None:
# NOTE: each step are aware are of current loop index
# It is very important to set it before calling the step function!
self.loop_prev_out[li][self.LOOP_IDX_KEY] = li

try:
# Call function with current loop's output, await if coroutine or use ProcessPoolExecutor for sync if required
if force_subproc:
Expand All @@ -236,9 +237,6 @@ async def _run_step(self, li: int, force_subproc: bool = False) -> None:
result = func(self.loop_prev_out[li])
# Store result in the nested dictionary
self.loop_prev_out[li][name] = result

# Save snapshot after completing the step
self.dump(self.session_folder / f"{li}" / f"{si}_{name}")
except Exception as e:
if isinstance(e, self.skip_loop_error):
logger.warning(f"Skip loop {li} due to {e}")
Expand All @@ -256,6 +254,8 @@ async def _run_step(self, li: int, force_subproc: bool = False) -> None:
else:
raise # re-raise unhandled exceptions
finally:
# No matter the execution succeed or not, we have to finish the following steps

# Record the trace
end = datetime.now(timezone.utc)
self.loop_trace[li].append(LoopTrace(start, end, step_idx=si))
Expand All @@ -279,6 +279,12 @@ async def _run_step(self, li: int, force_subproc: bool = False) -> None:
step_index=next_step,
step_name=self.steps[next_step],
)

# Save snapshot after completing the step;
# 1) It has to be after the step_idx is updated, so loading the snapshot will be on the right step.
# 2) Only save it when the step forward, withdraw does not worth saving.
self.dump(self.session_folder / f"{li}" / f"{si}_{name}")

self._check_exit_conditions_on_step(loop_id=li, step_id=si)
else:
logger.warning(f"Step forward {si} of loop {li} is skipped.")
Expand Down