-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add virtual nodes to the Monte Carlo Tree #1300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -332,10 +332,16 @@ def __init__(self, max_trace_num: int, temperature: float = 1.0, *args, **kwargs | |
| super().__init__(max_trace_num, temperature) | ||
| # Read c_puct from settings if available, otherwise fall back to default 1.0 | ||
| self.c_puct = getattr(DS_RD_SETTING, "scheduler_c_puct", 1.0) or 1.0 | ||
| self.c_uct = getattr(DS_RD_SETTING, "scheduler_c_uct", 1.0) or 1.0 | ||
| # Statistics keyed by leaf node index | ||
| self.node_visit_count: dict[int, int] = {} | ||
| self.node_value_sum: dict[int, float] = {} | ||
| self.node_prior: dict[int, float] = {} | ||
|
|
||
| self.root_id = -1 | ||
| self.node_visit_count[self.root_id] = 1 | ||
| self.node_value_sum[self.root_id] = 0.0 | ||
|
|
||
| # Global counter to stabilize U term | ||
| self.global_visit_count: int = 0 | ||
| # Last observed commit index for batch feedback observation | ||
|
|
@@ -349,37 +355,61 @@ def _get_q(self, node_id: int) -> float: | |
| return 0.0 | ||
| return value_sum / visits | ||
|
|
||
| def _get_u(self, node_id: int) -> float: | ||
| prior = self.node_prior.get(node_id, 0.0) | ||
| # def _get_u(self, node_id: int) -> float: | ||
| # prior = self.node_prior.get(node_id, 0.0) | ||
| # visits = self.node_visit_count.get(node_id, 0) | ||
| # # Avoid div-by-zero; encourage exploration when visits are small | ||
| # return self.c_puct * prior * math.sqrt(max(1, self.global_visit_count)) / (1 + visits) | ||
|
|
||
|
||
| def _get_u_uct(self, node_id: int, trace: DSTrace) -> float: | ||
| parents = trace.get_parents(node_id) | ||
|
|
||
| #last_parent_id = parents[-2] if len(parents) > 1 else 0 | ||
| if len(parents) < 2: | ||
| last_parent_id = self.root_id | ||
| else: | ||
| last_parent_id = parents[-2] | ||
|
|
||
| parent_visits = self.node_visit_count.get(last_parent_id, 0) | ||
| visits = self.node_visit_count.get(node_id, 0) | ||
| # Avoid div-by-zero; encourage exploration when visits are small | ||
| return self.c_puct * prior * math.sqrt(max(1, self.global_visit_count)) / (1 + visits) | ||
| N = max(1, parent_visits) | ||
| n = max(1, visits) | ||
| return self.c_uct * math.sqrt(math.log(N) / n) | ||
|
|
||
| def select(self, trace: DSTrace) -> tuple[int, ...] | None: | ||
| # Step 1: keep same policy to reach target number of parallel traces | ||
| # TODO: expanding from the virtual root node is implemented in a rule-based way. | ||
| if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < self.max_trace_num: | ||
| return trace.NEW_ROOT | ||
|
|
||
|
|
||
|
|
||
| # if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < 1: | ||
| # return trace.NEW_ROOT | ||
|
|
||
| # Step 2: consider only available leaves (not being expanded) | ||
| available_leaves = list(set(range(len(trace.hist)))) | ||
| if not available_leaves: | ||
| return None | ||
|
|
||
| # Step 3: compute priors (P) from potentials via softmax | ||
| potentials = [self.calculate_potential(trace, leaf) for leaf in available_leaves] | ||
| if any(p < 0 for p in potentials): | ||
| raise ValueError("Potential function returned a negative value.") | ||
| priors = self._softmax_probabilities(potentials) | ||
| for leaf, p in zip(available_leaves, priors): | ||
| self.node_prior[leaf] = p | ||
| candidates = list(available_leaves) # copy | ||
| candidates_with_root = candidates + [self.root_id] | ||
|
|
||
|
|
||
| # # Step 3: compute priors (P) from potentials via softmax | ||
| # potentials = [self.calculate_potential(trace, leaf) for leaf in available_leaves] | ||
| # if any(p < 0 for p in potentials): | ||
| # raise ValueError("Potential function returned a negative value.") | ||
| # priors = self._softmax_probabilities(potentials) | ||
| # for leaf, p in zip(available_leaves, priors): | ||
| # self.node_prior[leaf] = p | ||
|
|
||
| # Step 4: score each leaf using PUCT-like rule: Q + U | ||
| best_leaf = None | ||
| best_score = -float("inf") | ||
| for leaf in available_leaves: | ||
| for leaf in candidates_with_root: | ||
| q = self._get_q(leaf) | ||
| u = self._get_u(leaf) | ||
| #u = self._get_u(leaf) | ||
| #u = self._get_u_uct(leaf,trace) | ||
| u = self._get_u_uct(leaf, trace) if leaf != self.root_id else self.c_uct * math.sqrt( | ||
|
||
| math.log(max(1, self.node_visit_count.get(self.root_id, 1))) / max(1, self.node_visit_count.get(self.root_id, 1)) | ||
| ) | ||
| score = q + u | ||
| if score > best_score: | ||
| best_score = score | ||
|
|
@@ -388,10 +418,41 @@ def select(self, trace: DSTrace) -> tuple[int, ...] | None: | |
| if best_leaf is None: | ||
| return None | ||
|
|
||
| if best_leaf == self.root_id: | ||
| capacity = trace.sub_trace_count + self.uncommited_rec_status.get(trace.NEW_ROOT, 0) | ||
| if capacity >= self.max_trace_num: | ||
| # capacity full: pick next best real leaf | ||
|
||
| second_best = None | ||
| second_score = -float("inf") | ||
| for node in candidates: | ||
| q = self._get_q(node) | ||
| u = self._get_u_uct(node, trace) | ||
| score = q + u | ||
| if score > second_score: | ||
| second_score = score | ||
| second_best = node | ||
| if second_best is None: | ||
| return None | ||
| # optimistic visit update for chosen leaf (optional) | ||
| # self.node_visit_count[second_best] = self.node_visit_count.get(second_best, 0) + 1 | ||
| return (second_best,) | ||
| else: | ||
| # choose to expand from virtual root | ||
| # optimistic visit update for root if desired: | ||
| # self.node_visit_count[self.root_id] += 1 | ||
| return trace.NEW_ROOT | ||
|
|
||
| # # Step 5: optimistic visit update on selection; value update deferred to observe_feedback | ||
| self.global_visit_count += 1 | ||
| #self.global_visit_count += 1 | ||
|
|
||
| return (best_leaf,) | ||
|
|
||
| def sigmoid(self, x): | ||
| return 1 / (1 + math.exp(-x)) | ||
|
|
||
| def scaled_tanh(self, x): | ||
| # tanh -> (-1,1), then scale to (0,1) | ||
| return (math.tanh(x) + 1.0) / 2.0 | ||
|
|
||
| def observe_feedback(self, trace: DSTrace, new_idx: int) -> None: | ||
| """ | ||
|
|
@@ -406,13 +467,19 @@ def observe_feedback(self, trace: DSTrace, new_idx: int) -> None: | |
| re, fb = trace.hist[new_idx] | ||
| if DS_RD_SETTING.enable_score_reward: | ||
| bigger_is_better = get_metric_direction(trace.scen.competition) | ||
| if getattr(fb, "decision", False): | ||
| reward = math.tanh(re.result.loc["ensemble"].iloc[0].round(3)) * (1 if bigger_is_better else -1) | ||
| if re.result is not None: | ||
| if bigger_is_better: | ||
| reward = self.scaled_tanh(re.result.loc["ensemble"].iloc[0]) | ||
| else: | ||
| reward = 1- self.scaled_tanh(re.result.loc["ensemble"].iloc[0]) | ||
| else: | ||
| reward = -1 if bigger_is_better else 1 | ||
| reward = 0 if bigger_is_better else 1 | ||
| else: | ||
| reward = 1.0 if getattr(fb, "decision", False) else 0.0 | ||
|
|
||
| id_list = trace.get_parents(new_idx) | ||
| id_list = [self.root_id] + id_list | ||
|
|
||
| for id in id_list: | ||
| self.node_value_sum[id] = self.node_value_sum.get(id, 0.0) + float(reward) | ||
| self.node_visit_count[id] = self.node_visit_count.get(id, 0) + 1 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it the class attribute