diff --git a/rdagent/scenarios/data_science/proposal/exp_gen/trace_scheduler.py b/rdagent/scenarios/data_science/proposal/exp_gen/trace_scheduler.py index f079c2f6a..b32bada4a 100644 --- a/rdagent/scenarios/data_science/proposal/exp_gen/trace_scheduler.py +++ b/rdagent/scenarios/data_science/proposal/exp_gen/trace_scheduler.py @@ -328,16 +328,18 @@ class MCTSScheduler(ProbabilisticScheduler): - Keep NEW_ROOT policy and uncommitted status handling identical to base classes. """ + ROOT_ID = -1 + def __init__(self, max_trace_num: int, temperature: float = 1.0, *args, **kwargs): super().__init__(max_trace_num, temperature) - # Read c_puct from settings if available, otherwise fall back to default 1.0 - self.c_puct = getattr(DS_RD_SETTING, "scheduler_c_puct", 1.0) or 1.0 + self.c_uct = getattr(DS_RD_SETTING, "scheduler_c_uct", 1.0) or 1.0 # Statistics keyed by leaf node index self.node_visit_count: dict[int, int] = {} self.node_value_sum: dict[int, float] = {} - self.node_prior: dict[int, float] = {} - # Global counter to stabilize U term - self.global_visit_count: int = 0 + + self.node_visit_count[self.ROOT_ID] = 1 + self.node_value_sum[self.ROOT_ID] = 0.0 + # Last observed commit index for batch feedback observation self.last_observed_commit_idx: int = 0 @@ -349,49 +351,81 @@ def _get_q(self, node_id: int) -> float: return 0.0 return value_sum / visits - def _get_u(self, node_id: int) -> float: - prior = self.node_prior.get(node_id, 0.0) + def _get_parents(self, node_id: int, trace: DSTrace) -> list[int]: + """ + Due to the MCTS algorithm will have a virtual root node, which does not exist in the trace data structure. + """ + if node_id == self.ROOT_ID: + parents = [] + else: + parents = trace.get_parents(node_id) + parents_with_root = [self.ROOT_ID] + parents + return parents_with_root + + def _get_all_nodes(self, trace: DSTrace) -> list[int]: + """ + Due to the MCTS algorithm will have a virtual root node, which does not exist in the trace data structure. + """ + return [self.ROOT_ID] + list(range(len(trace.hist))) + + def _get_u_uct(self, node_id: int, trace: DSTrace) -> float: + parents = self._get_parents(node_id, trace) + + if node_id == self.ROOT_ID: + last_parent_id = self.ROOT_ID + else: + last_parent_id = parents[-2] + + parent_visits = self.node_visit_count.get(last_parent_id, 0) visits = self.node_visit_count.get(node_id, 0) - # Avoid div-by-zero; encourage exploration when visits are small - return self.c_puct * prior * math.sqrt(max(1, self.global_visit_count)) / (1 + visits) + N = max(1, parent_visits) + n = max(1, visits) + return self.c_uct * math.sqrt(math.log(N) / n) - def select(self, trace: DSTrace) -> tuple[int, ...] | None: + def _select(self, trace: DSTrace) -> tuple[int, ...] | None: # Step 1: keep same policy to reach target number of parallel traces - # TODO: expanding from the virtual root node is implemented in a rule-based way. - if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < self.max_trace_num: - return trace.NEW_ROOT - - # Step 2: consider only available leaves (not being expanded) available_leaves = list(set(range(len(trace.hist)))) - if not available_leaves: - return None - # Step 3: compute priors (P) from potentials via softmax - potentials = [self.calculate_potential(trace, leaf) for leaf in available_leaves] - if any(p < 0 for p in potentials): - raise ValueError("Potential function returned a negative value.") - priors = self._softmax_probabilities(potentials) - for leaf, p in zip(available_leaves, priors): - self.node_prior[leaf] = p + candidates = list(available_leaves) # copy + candidates_with_root = candidates + [self.ROOT_ID] + + candidates_with_root = self._get_all_nodes(trace) # Step 4: score each leaf using PUCT-like rule: Q + U - best_leaf = None - best_score = -float("inf") - for leaf in available_leaves: - q = self._get_q(leaf) - u = self._get_u(leaf) - score = q + u - if score > best_score: - best_score = score - best_leaf = leaf - - if best_leaf is None: + score_id_pairs = [(self._get_q(nid) + self._get_u_uct(nid, trace), nid) for nid in candidates_with_root] + score_id_pairs.sort(reverse=True) + + if len(score_id_pairs) == 0: return None - # # Step 5: optimistic visit update on selection; value update deferred to observe_feedback - self.global_visit_count += 1 + best_node, _ = score_id_pairs[0] + + if best_node == self.ROOT_ID and len(score_id_pairs) > 1: + # Motivation: we sometimes want to limit the expansion of the root node. + # capacity full: pick next best real leaf + capacity = trace.sub_trace_count + self.uncommited_rec_status.get(trace.NEW_ROOT, 0) + if capacity >= self.max_trace_num: + second_best, _ = score_id_pairs[1] + return (second_best,) - return (best_leaf,) + return (best_node,) + + def select(self, trace: DSTrace) -> tuple[int, ...] | None: + """ + In MCTS, we have a virtual root node, expanding from the virutal root node will return (-1,). + But in the trace DAG, expanding a new node from root node should return (trace.NEW_ROOT,). + """ + base_nodes = self._select(trace) + if base_nodes == (self.ROOT_ID,): + return trace.NEW_ROOT + return base_nodes + + def sigmoid(self, x): + return 1 / (1 + math.exp(-x)) + + def scaled_tanh(self, x): + # tanh -> (-1,1), then scale to (0,1) + return (math.tanh(x) + 1.0) / 2.0 def observe_feedback(self, trace: DSTrace, new_idx: int) -> None: """ @@ -406,13 +440,18 @@ def observe_feedback(self, trace: DSTrace, new_idx: int) -> None: re, fb = trace.hist[new_idx] if DS_RD_SETTING.enable_score_reward: bigger_is_better = get_metric_direction(trace.scen.competition) - if getattr(fb, "decision", False): - reward = math.tanh(re.result.loc["ensemble"].iloc[0].round(3)) * (1 if bigger_is_better else -1) + if re.result is not None: + if bigger_is_better: + reward = self.scaled_tanh(re.result.loc["ensemble"].iloc[0]) + else: + reward = 1 - self.scaled_tanh(re.result.loc["ensemble"].iloc[0]) else: - reward = -1 if bigger_is_better else 1 + reward = 0 if bigger_is_better else 1 else: reward = 1.0 if getattr(fb, "decision", False) else 0.0 - id_list = trace.get_parents(new_idx) + + id_list = self._get_parents(new_idx, trace) + for id in id_list: self.node_value_sum[id] = self.node_value_sum.get(id, 0.0) + float(reward) self.node_visit_count[id] = self.node_visit_count.get(id, 0) + 1 @@ -424,8 +463,6 @@ def reset(self) -> None: super().reset() self.node_visit_count.clear() self.node_value_sum.clear() - self.node_prior.clear() - self.global_visit_count = 0 self.last_observed_commit_idx = 0 def process_uncommitted_nodes(self, trace: DSTrace) -> None: