Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 88 additions & 21 deletions rdagent/scenarios/data_science/proposal/exp_gen/trace_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,16 @@ def __init__(self, max_trace_num: int, temperature: float = 1.0, *args, **kwargs
super().__init__(max_trace_num, temperature)
# Read c_puct from settings if available, otherwise fall back to default 1.0
self.c_puct = getattr(DS_RD_SETTING, "scheduler_c_puct", 1.0) or 1.0
self.c_uct = getattr(DS_RD_SETTING, "scheduler_c_uct", 1.0) or 1.0
# Statistics keyed by leaf node index
self.node_visit_count: dict[int, int] = {}
self.node_value_sum: dict[int, float] = {}
self.node_prior: dict[int, float] = {}

self.root_id = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it the class attribute

self.node_visit_count[self.root_id] = 1
self.node_value_sum[self.root_id] = 0.0

# Global counter to stabilize U term
self.global_visit_count: int = 0
# Last observed commit index for batch feedback observation
Expand All @@ -349,37 +355,61 @@ def _get_q(self, node_id: int) -> float:
return 0.0
return value_sum / visits

def _get_u(self, node_id: int) -> float:
prior = self.node_prior.get(node_id, 0.0)
# def _get_u(self, node_id: int) -> float:
# prior = self.node_prior.get(node_id, 0.0)
# visits = self.node_visit_count.get(node_id, 0)
# # Avoid div-by-zero; encourage exploration when visits are small
# return self.c_puct * prior * math.sqrt(max(1, self.global_visit_count)) / (1 + visits)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a new get_parents function

def _get_u_uct(self, node_id: int, trace: DSTrace) -> float:
parents = trace.get_parents(node_id)

#last_parent_id = parents[-2] if len(parents) > 1 else 0
if len(parents) < 2:
last_parent_id = self.root_id
else:
last_parent_id = parents[-2]

parent_visits = self.node_visit_count.get(last_parent_id, 0)
visits = self.node_visit_count.get(node_id, 0)
# Avoid div-by-zero; encourage exploration when visits are small
return self.c_puct * prior * math.sqrt(max(1, self.global_visit_count)) / (1 + visits)
N = max(1, parent_visits)
n = max(1, visits)
return self.c_uct * math.sqrt(math.log(N) / n)

def select(self, trace: DSTrace) -> tuple[int, ...] | None:
# Step 1: keep same policy to reach target number of parallel traces
# TODO: expanding from the virtual root node is implemented in a rule-based way.
if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < self.max_trace_num:
return trace.NEW_ROOT



# if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < 1:
# return trace.NEW_ROOT

# Step 2: consider only available leaves (not being expanded)
available_leaves = list(set(range(len(trace.hist))))
if not available_leaves:
return None

# Step 3: compute priors (P) from potentials via softmax
potentials = [self.calculate_potential(trace, leaf) for leaf in available_leaves]
if any(p < 0 for p in potentials):
raise ValueError("Potential function returned a negative value.")
priors = self._softmax_probabilities(potentials)
for leaf, p in zip(available_leaves, priors):
self.node_prior[leaf] = p
candidates = list(available_leaves) # copy
candidates_with_root = candidates + [self.root_id]


# # Step 3: compute priors (P) from potentials via softmax
# potentials = [self.calculate_potential(trace, leaf) for leaf in available_leaves]
# if any(p < 0 for p in potentials):
# raise ValueError("Potential function returned a negative value.")
# priors = self._softmax_probabilities(potentials)
# for leaf, p in zip(available_leaves, priors):
# self.node_prior[leaf] = p

# Step 4: score each leaf using PUCT-like rule: Q + U
best_leaf = None
best_score = -float("inf")
for leaf in available_leaves:
for leaf in candidates_with_root:
q = self._get_q(leaf)
u = self._get_u(leaf)
#u = self._get_u(leaf)
#u = self._get_u_uct(leaf,trace)
u = self._get_u_uct(leaf, trace) if leaf != self.root_id else self.c_uct * math.sqrt(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this if else should be placed into the _get_u_uct function.

math.log(max(1, self.node_visit_count.get(self.root_id, 1))) / max(1, self.node_visit_count.get(self.root_id, 1))
)
score = q + u
if score > best_score:
best_score = score
Expand All @@ -388,10 +418,41 @@ def select(self, trace: DSTrace) -> tuple[int, ...] | None:
if best_leaf is None:
return None

if best_leaf == self.root_id:
capacity = trace.sub_trace_count + self.uncommited_rec_status.get(trace.NEW_ROOT, 0)
if capacity >= self.max_trace_num:
# capacity full: pick next best real leaf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified

second_best = None
second_score = -float("inf")
for node in candidates:
q = self._get_q(node)
u = self._get_u_uct(node, trace)
score = q + u
if score > second_score:
second_score = score
second_best = node
if second_best is None:
return None
# optimistic visit update for chosen leaf (optional)
# self.node_visit_count[second_best] = self.node_visit_count.get(second_best, 0) + 1
return (second_best,)
else:
# choose to expand from virtual root
# optimistic visit update for root if desired:
# self.node_visit_count[self.root_id] += 1
return trace.NEW_ROOT

# # Step 5: optimistic visit update on selection; value update deferred to observe_feedback
self.global_visit_count += 1
#self.global_visit_count += 1

return (best_leaf,)

def sigmoid(self, x):
return 1 / (1 + math.exp(-x))

def scaled_tanh(self, x):
# tanh -> (-1,1), then scale to (0,1)
return (math.tanh(x) + 1.0) / 2.0

def observe_feedback(self, trace: DSTrace, new_idx: int) -> None:
"""
Expand All @@ -406,13 +467,19 @@ def observe_feedback(self, trace: DSTrace, new_idx: int) -> None:
re, fb = trace.hist[new_idx]
if DS_RD_SETTING.enable_score_reward:
bigger_is_better = get_metric_direction(trace.scen.competition)
if getattr(fb, "decision", False):
reward = math.tanh(re.result.loc["ensemble"].iloc[0].round(3)) * (1 if bigger_is_better else -1)
if re.result is not None:
if bigger_is_better:
reward = self.scaled_tanh(re.result.loc["ensemble"].iloc[0])
else:
reward = 1- self.scaled_tanh(re.result.loc["ensemble"].iloc[0])
else:
reward = -1 if bigger_is_better else 1
reward = 0 if bigger_is_better else 1
else:
reward = 1.0 if getattr(fb, "decision", False) else 0.0

id_list = trace.get_parents(new_idx)
id_list = [self.root_id] + id_list

for id in id_list:
self.node_value_sum[id] = self.node_value_sum.get(id, 0.0) + float(reward)
self.node_visit_count[id] = self.node_visit_count.get(id, 0) + 1
Expand Down
Loading