Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions rdagent/scenarios/data_science/proposal/exp_gen/draft/draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,31 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.supports_response_schema = APIBackend().supports_response_schema()

# Packages requested by the LLM during the draft stage. Cached for reuse.
self._requested_pkgs: list[str] | None = None

# ---------------------------------------------------------------------
# New: ask the LLM which third-party packages it plans to use so we can
# query their versions dynamically.
# ---------------------------------------------------------------------

def _package_query(self, scenario_desc: str) -> list[str]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this function to rdagent/scenarios/data_science/proposal/exp_gen/utils.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now moved the functionality of this function to rdagent/scenarios/data_science/proposal/exp_gen/proposal.py, achieving it through one interaction with the Backend API. Therefore, I deem it unnecessary to implement it separately.

"""Ask the LLM for the list of packages it intends to import.

Returns a unique, lowercase list of package names.
"""
sys_prompt = T(".prompts_draft:pkg_query.system").r()
user_prompt = T(".prompts_draft:pkg_query.user").r(scenario_desc=scenario_desc)

response = APIBackend().build_messages_and_create_chat_completion(
system_prompt=sys_prompt,
user_prompt=user_prompt,
json_mode=True,
json_target_type=Dict[str, List[str]],
)
pkg_list = list({pkg.lower() for pkg in json.loads(response)["packages"]}) # deduplicate & normalize
return pkg_list

def tag_gen(self, scenario_desc: str) -> str:
sys_prompt = T(".prompts_draft:tag_gen.system").r(tag_desc=T(".prompts_draft:description.tag_description").r())
user_prompt = T(".prompts_draft:tag_gen.user").r(
Expand All @@ -144,12 +169,25 @@ def tag_gen(self, scenario_desc: str) -> str:
return json.loads(response)["tag"].lower()

def knowledge_gen(self) -> str:
runtime_environment = self.scen.get_runtime_environment()
"""Generate general knowledge section with tailored runtime information."""

# Step 1: Ask for required packages once per draft session.
if self._requested_pkgs is None:
scenario_desc = self.scen.get_competition_full_desc()
self._requested_pkgs = self._package_query(scenario_desc)

# Cache the package list inside the scenario for later stages.
setattr(self.scen, "required_packages", self._requested_pkgs)

# Step 2: Query versions via updated get_runtime_environment.
runtime_environment = self.scen.get_runtime_environment(self._requested_pkgs)

# Step 3: Render the knowledge template with enriched environment info.
general_knowledge = T(".prompts_draft:knowledge.general").r(
runtime_environment=runtime_environment,
component_desc=T(".prompts_draft:description.component_description").r(),
)
return f"{general_knowledge}"
return general_knowledge

def hypothesis_gen(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ knowledge:
- Keep the ensembling method **simple, reproducible**.
- Ensemble logic must not bypass earlier validation steps.

pkg_query:
system: |-
You are an experienced data scientist. Based on the competition description, list **all third-party Python packages** (PyPI names) you plan to import in the forthcoming implementation. Output MUST be a JSON object with a single key "packages" whose value is an array of package names.
user: |-
# Scenario Description
{{ scenario_desc }}

hypothesis_draft:
system: |-
{% include "scenarios.data_science.share:scen.role" %}
Expand Down
16 changes: 14 additions & 2 deletions rdagent/scenarios/data_science/scen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,27 @@ def get_scenario_all_desc(self, eda_output=None) -> str:
debug_time_limit=f"{DS_RD_SETTING.debug_timeout / 60 / 60 : .2f} hours",
)

def get_runtime_environment(self) -> str:
def get_runtime_environment(self, pkgs: list[str] | None = None) -> str:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this function callable from docker

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function wasn’t docker-callable before my changes.
Do we really need Docker support here, or can we just keep it simple for now?

# TODO: add it into base class. Environment should(i.e. `DSDockerConf`) should be part of the scenario class.
"""Return runtime environment information.

If *pkgs* is provided, only versions for those packages will be queried; otherwise
it falls back to cached *self.required_packages* (if set in Draft stage) or the
default list defined inside ``runtime_info.py`` for backward-compatibility.
"""
# Reuse package list cached during Draft stage when available.
if pkgs is None and hasattr(self, "required_packages"):
pkgs = getattr(self, "required_packages") # type: ignore[arg-type]

env = get_ds_env()
implementation = FBWorkspace()
fname = "runtime_info.py"
implementation.inject_files(
**{fname: (Path(__file__).absolute().resolve().parent / "runtime_info.py").read_text()}
)
stdout = implementation.execute(env=env, entry=f"python {fname}")

pkg_args = " ".join(pkgs) if pkgs else ""
stdout = implementation.execute(env=env, entry=f"python {fname} {pkg_args}")
return stdout

def _get_data_folder_description(self) -> str:
Expand Down
54 changes: 36 additions & 18 deletions rdagent/scenarios/data_science/scen/runtime_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,42 @@ def get_gpu_info():

if __name__ == "__main__":
print_runtime_info()
filtered_packages = [
"transformers",
"accelerate",
"torch",
"tensorflow",
"pandas",
"numpy",
"scikit-learn",
"scipy",
"xgboost",
"sklearn",
"lightgbm",
"vtk",
"opencv-python",
"keras",
"matplotlib",
"pydicom",
]
# Allow the caller to pass a custom package list via command-line arguments.
# Example: `python runtime_info.py pandas torch scikit-learn`
# If no extra arguments are provided we fall back to the original default list
# to keep full backward-compatibility.
filtered_packages = (
sys.argv[1:]
if len(sys.argv) > 1
else [
"transformers",
"accelerate",
"torch",
"tensorflow",
"pandas",
"numpy",
"scikit-learn",
"scipy",
"xgboost",
"sklearn",
"lightgbm",
"vtk",
"opencv-python",
"keras",
"matplotlib",
"pydicom",
]
)

installed_packages = get_installed_packages()

print_filtered_packages(installed_packages, filtered_packages)

# Report packages that are requested by the LLM but are not installed.
missing_pkgs = [pkg for pkg in filtered_packages if pkg.lower() not in installed_packages]
if missing_pkgs:
print("\n=== Missing Packages ===")
for pkg in missing_pkgs:
print(pkg)

get_gpu_info()