Skip to content

Commit 22428a4

Browse files
Hoder-zyfyou-n-g
andauthored
fix: refine prompts and add additional package info (#1179)
* refine prompts and add additional package info * refine prompts to be specific for GBDT models * minor refine prompts * use include to replace duplicate info * refine prompts * refactor: import DSTrace from base and remove exp_gen __init__ * lint --------- Co-authored-by: Young <[email protected]>
1 parent ee2a029 commit 22428a4

File tree

11 files changed

+143
-43
lines changed

11 files changed

+143
-43
lines changed

rdagent/components/coder/data_science/pipeline/prompts.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ pipeline_coder:
1313
{{ runtime_environment }}
1414
1515
{% if package_info is not none %}
16-
To help you write the runnable code, the user has provided the package information which contains the package names and versions.
17-
You should be careful about the package versions, as the code will be executed in the environment with the specified version and the api might be different from the latest version.
18-
The user might provide the packages the environment doesn't have, you should avoid using any of them.
16+
- To help you write the runnable code, the user has provided the package information which contains the package names and versions.
17+
- You should be careful about the package versions, as the code will be executed in the environment with the specified version and the api might be different from the latest version.
18+
- While the environment is fixed, you should not limit yourself to only the provided packages - feel free to explore other libraries that might better suit the task. However, prioritize using the available packages first, and only suggest alternatives when they would provide significant improvements or are more appropriate for the specific problem.
1919
## Package Information
2020
{{ package_info }}
2121
{% endif %}

rdagent/scenarios/data_science/dev/feedback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from rdagent.log.utils import dict_get_with_warning
1414
from rdagent.oai.llm_utils import APIBackend
1515
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
16-
from rdagent.scenarios.data_science.proposal.exp_gen import DSTrace
16+
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSTrace
1717
from rdagent.scenarios.data_science.proposal.exp_gen.idea_pool import DSIdea
1818
from rdagent.utils import convert2bool
1919
from rdagent.utils.agent.tpl import T

rdagent/scenarios/data_science/loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
from rdagent.scenarios.data_science.dev.feedback import DSExperiment2Feedback
3131
from rdagent.scenarios.data_science.dev.runner import DSCoSTEERRunner
3232
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
33-
from rdagent.scenarios.data_science.proposal.exp_gen import DSTrace
34-
from rdagent.scenarios.data_science.proposal.exp_gen.base import DataScienceScen
33+
from rdagent.scenarios.data_science.proposal.exp_gen.base import (
34+
DataScienceScen,
35+
DSTrace,
36+
)
3537
from rdagent.scenarios.data_science.proposal.exp_gen.idea_pool import DSKnowledgeBase
3638
from rdagent.scenarios.data_science.proposal.exp_gen.proposal import DSProposalV2ExpGen
3739
from rdagent.utils.workflow.misc import wait_retry
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSTrace
2-
3-
__all__ = ["DSTrace"]

rdagent/scenarios/data_science/proposal/exp_gen/package_info.py

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,71 @@
11
import sys
22
from importlib.metadata import distributions
33

4+
# Kaggle competition packages - based on usage frequency
5+
PYTHON_BASE_PACKAGES = ["catboost", "lightgbm", "numpy", "optuna", "pandas", "scikit-learn", "scipy", "shap", "xgboost"]
6+
7+
PYTHON_ADVANCED_PACKAGES = [
8+
"accelerate",
9+
"albumentations",
10+
"bayesian-optimization",
11+
"category_encoders",
12+
"datasets",
13+
"featuretools",
14+
"imbalanced-learn",
15+
"nltk",
16+
"opencv-python",
17+
"pillow",
18+
"polars",
19+
"sentence-transformers",
20+
"spacy",
21+
"tensorflow",
22+
"timm",
23+
"tokenizers",
24+
"torch",
25+
"torchvision",
26+
"transformers",
27+
]
28+
29+
30+
def get_all_excepted_packages():
31+
"""Get flattened list of all packages"""
32+
all_packages = PYTHON_BASE_PACKAGES + PYTHON_ADVANCED_PACKAGES
33+
return sorted(set(all_packages))
34+
35+
36+
def get_available_recommended_packages_prompt():
37+
"""Generate prompt template for dynamically detected available packages"""
38+
installed_packages = get_installed_packages()
39+
40+
# Check which packages are actually installed
41+
base_available = [pkg for pkg in PYTHON_BASE_PACKAGES if pkg.lower() in installed_packages]
42+
advanced_available = [pkg for pkg in PYTHON_ADVANCED_PACKAGES if pkg.lower() in installed_packages]
43+
44+
# Build prompt
45+
prompt_parts = ["# Available packages in environment:\n"]
46+
47+
if base_available:
48+
prompt_parts.append("## [Basic Libraries] (general tools for data science tasks):")
49+
prompt_parts.append(f"- {', '.join(base_available)}")
50+
prompt_parts.append("")
51+
52+
if advanced_available:
53+
prompt_parts.append("## [Advanced Tools] (specialized for specific domains):")
54+
prompt_parts.append(f"- {', '.join(advanced_available)}")
55+
prompt_parts.append("")
56+
57+
prompt_parts.append(
58+
"You should choose appropriate tool combinations based on the specific context and current situation. Feel free to use any other packages you think are necessary to achieve the best performance."
59+
)
60+
61+
return "\n".join(prompt_parts).strip()
62+
63+
64+
def print_available_packages_prompt():
65+
"""Print the available packages prompt to stdout for external consumption"""
66+
prompt = get_available_recommended_packages_prompt()
67+
print(prompt)
68+
469

570
def get_installed_packages():
671
return {dist.metadata["Name"].lower(): dist.version for dist in distributions()}
@@ -26,24 +91,7 @@ def get_python_packages():
2691
# Example: `python package_info.py pandas torch scikit-learn`
2792
# If no extra arguments are provided we fall back to the original default list
2893
# to keep full backward-compatibility.
29-
packages_list = [ # default packages
30-
"transformers",
31-
"accelerate",
32-
"torch",
33-
"tensorflow",
34-
"pandas",
35-
"numpy",
36-
"scikit-learn",
37-
"scipy",
38-
"xgboost",
39-
"sklearn",
40-
"lightgbm",
41-
"vtk",
42-
"opencv-python",
43-
"keras",
44-
"matplotlib",
45-
"pydicom",
46-
]
94+
packages_list = get_all_excepted_packages()
4795
if len(sys.argv) > 1:
4896
packages_list = list(set(packages_list) | set(sys.argv[1:]))
4997

@@ -61,4 +109,8 @@ def get_python_packages():
61109

62110

63111
if __name__ == "__main__":
64-
get_python_packages()
112+
# Check if we should print available packages prompt
113+
if len(sys.argv) > 1 and sys.argv[1] == "--packages-prompt":
114+
print_available_packages_prompt()
115+
else:
116+
get_python_packages()

rdagent/scenarios/data_science/proposal/exp_gen/prompts_v2.yaml

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ hypothesis_critique:
250250
- **Metric Impact**: Will this meaningfully improve the competition's evaluation metric?
251251
- **Historical Context**: Has similar approaches been tried? Key learnings from past attempts?
252252
- **Innovation vs History Balance**: Distinguish between implementation failures (worth retrying with improvements) vs fundamental approach failures (multiple attempts failed due to core unsuitability - should avoid)
253+
- **Tool Selection Appropriateness**: Are the suggested tools/packages well-suited for the problem? Consider both modern capabilities and traditional reliability
253254
254255
### 3. Improvement Direction
255256
- **Clarity Issues**: If vague, identify specific methods or strategies that address the core problem
@@ -268,11 +269,13 @@ hypothesis_critique:
268269
**Good Critiques:**
269270
- "The hypothesis lacks specificity about which ensemble method to use. Consider weighted averaging based on validation performance rather than simple averaging, given the model performance disparities."
270271
- "This hypothesis proposes LSTM for tabular data. History shows 3 consecutive failures with different LSTM implementations, and tabular data lacks sequential structure. Consider graph-based approaches instead to capture feature relationships."
272+
- "The hypothesis jumps to LightGBM without establishing a baseline. Consider starting with XGBoost to ensure a working solution, then explore LightGBM for potential improvements if the baseline performs adequately."
271273
272274
**Poor Critiques:**
273275
- "Set max_depth=10, learning_rate=0.05, and use 500 trees." (too specific)
274276
- "This might not work." (too vague)
275277
- "LSTM is innovative, let's try again with different hyperparameters." (ignores fundamental mismatch)
278+
- "Use the latest deep learning model because it's new." (ignores problem-solution fit)
276279
277280
{% if critique_output_format is not none %}
278281
## Output Format
@@ -320,6 +323,12 @@ hypothesis_rewrite:
320323
{% endif %}
321324
322325
## Guidelines for Writing Rewritten Hypotheses
326+
327+
### Available Tools Consideration
328+
- When rewriting, consider if the hypothesis leverages appropriate tools from the available packages
329+
- Balance innovation with practical tool selection - prefer modern packages when they offer clear advantages
330+
- Ensure tool choices align with the problem requirements and constraints
331+
- Be pragmatic: use whatever works best for the task - whether it's a cutting-edge transformer or traditional logistic regression
323332
324333
1. **Critique-Informed Specificity**:
325334
- Address technical gaps identified in the critique and replace vague terms with specific algorithms, methods, or parameters.
@@ -379,6 +388,10 @@ hypothesis_rewrite:
379388
{{ time_status }}
380389
{% endif %}
381390
391+
{% if packages_prompt is not none %}
392+
{{ packages_prompt }}
393+
{% endif %}
394+
382395
383396
task_gen:
384397
system: |-
@@ -429,12 +442,8 @@ task_gen:
429442
- Ensure validation metrics and processes are consistent across all parts of the pipeline. Avoid changes that would alter how validation metrics are calculated unless that is part of the hypothesis.
430443
8. **Submission File (`submission.csv`)**: Generate `submission.csv` in the **exact format** required (column names, order, data types), as detailed in the '====== Submission Format ======' section of the Competition Scenario Description (DO NOT read the sample_submission.csv file directly in the code). This is a critical step.
431444
9. **Preferred Packages Notes**:
432-
- You can choose the most proper packages for the task to best achieve the hypothesis.
433-
- When facing a choice between two packages which both can achieve the same goal, you should choose the one which is more commonly used and less likely to cause bugs in coding. Especially those you are not familiar with.
434-
- For GBDT models, prefer XGBoost or RandomForest over LightGBM unless the SOTA or hypothesis dictates otherwise. Prefer not using GPU for GBDT models unless the SOTA or hypothesis dictates otherwise.
435-
- For neural networks, prefer PyTorch or PyTorch based library (over TensorFlow) unless the SOTA or hypothesis dictates otherwise.
436-
- For neural networks, prefer fine-tuning pre-trained models over training from scratch.
437-
445+
{% include "scenarios.data_science.share:guidelines.package_selection" %}
446+
438447
## Package Declaration
439448
At the end of your design, **you MUST** provide a key `packages` in the final JSON output.
440449
It should be an **array of PyPI package names** (strings) that you expect to `import` in the forthcoming implementation.

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
DSExperimentPlan,
2828
RD_Agent_TIMER_wrapper,
2929
)
30-
from rdagent.scenarios.data_science.proposal.exp_gen.utils import get_packages
30+
from rdagent.scenarios.data_science.proposal.exp_gen.utils import (
31+
get_available_packages_prompt,
32+
get_packages,
33+
)
3134
from rdagent.utils.agent.tpl import T
3235
from rdagent.utils.repo.diff import generate_diff_from_dict
3336
from rdagent.utils.workflow import wait_retry
@@ -588,16 +591,22 @@ def hypothesis_gen(
588591
enable_idea_pool: bool,
589592
inject_diverse: bool = False,
590593
exp_gen_plan: Optional[Dict] = None,
594+
packages_prompt: str = "",
591595
) -> Dict:
592596
problem_formatted_str = ""
593597
for i, (problem_name, problem_dict) in enumerate(problems.items()):
594598
problem_formatted_str += f"## {i+1}. {problem_name}\n"
595-
problem_formatted_str += f"{problem_dict['problem']}\n"
599+
problem_formatted_str += f"Statement: {problem_dict['problem']}\n"
600+
problem_formatted_str += f"Reason: {problem_dict['reason']}\n"
596601
if "idea" in problem_dict:
597602
idea_formatted_str = DSIdea(problem_dict["idea"]).to_formatted_str()
598603
problem_formatted_str += f"Sampled Idea by user: \n{idea_formatted_str}\n"
599604
problem_formatted_str += "\n\n"
600605

606+
# add available packages prompt
607+
if packages_prompt:
608+
problem_formatted_str += f"\n{packages_prompt}\n"
609+
601610
sys_prompt = T(".prompts_v2:hypothesis_gen.system").r(
602611
hypothesis_output_format=(
603612
T(".prompts_v2:output_format.hypothesis").r(pipeline=pipeline, enable_idea_pool=enable_idea_pool)
@@ -731,6 +740,7 @@ def hypothesis_rewrite(
731740
scenario_desc: str,
732741
sota_exp_desc: str,
733742
exp_feedback_list_desc: str,
743+
packages_prompt: str = "",
734744
) -> Dict:
735745
"""
736746
Generate improved hypotheses based on critique feedback for each original hypothesis.
@@ -769,6 +779,7 @@ def hypothesis_rewrite(
769779
sota_exp_desc=sota_exp_desc,
770780
hypothesis_critique_pairs=hypothesis_critique_pairs,
771781
time_status=time_status,
782+
packages_prompt=packages_prompt,
772783
)
773784

774785
response = APIBackend().build_messages_and_create_chat_completion(
@@ -1056,6 +1067,9 @@ def gen(
10561067
else:
10571068
inject_diverse = False
10581069

1070+
# add available packages prompt
1071+
packages_prompt = get_available_packages_prompt()
1072+
10591073
# Step 1: Identify problems
10601074
all_problems = self.identify_problem(
10611075
current_sub_trace=trace.get_parent_exps(),
@@ -1087,6 +1101,7 @@ def gen(
10871101
enable_idea_pool=DS_RD_SETTING.enable_knowledge_base,
10881102
inject_diverse=inject_diverse,
10891103
exp_gen_plan=plan.get("exp_gen") if plan else None,
1104+
packages_prompt=packages_prompt,
10901105
)
10911106
if not pipeline:
10921107
sota_exp_model_file_count = len(
@@ -1130,6 +1145,7 @@ def gen(
11301145
scenario_desc=scenario_desc,
11311146
sota_exp_desc=sota_exp_desc,
11321147
exp_feedback_list_desc=exp_feedback_list_desc,
1148+
packages_prompt=packages_prompt,
11331149
)
11341150
logger.info(f"Successfully completed hypothesis critique and rewrite process")
11351151
except Exception as e:

rdagent/scenarios/data_science/proposal/exp_gen/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,16 @@ def get_packages(pkgs: list[str] | None = None) -> str:
103103
pkg_args = " ".join(pkgs) if pkgs else ""
104104
stdout = implementation.execute(env=env, entry=f"python {fname} {pkg_args}")
105105
return stdout
106+
107+
108+
def get_available_packages_prompt() -> str:
109+
"""Generate prompt template for dynamically detected available packages."""
110+
# Use the same approach as get_packages but call the packages prompt functionality
111+
112+
env = get_ds_env()
113+
implementation = FBWorkspace()
114+
fname = "package_info.py"
115+
implementation.inject_files(**{fname: (Path(__file__).absolute().resolve().parent / "package_info.py").read_text()})
116+
117+
stdout = implementation.execute(env=env, entry=f"python {fname} --packages-prompt")
118+
return stdout.strip()

rdagent/scenarios/data_science/scen/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from rdagent.log import rdagent_logger as logger
1111
from rdagent.oai.llm_utils import APIBackend
1212
from rdagent.scenarios.data_science.debug.data import create_debug_data
13+
from rdagent.scenarios.data_science.proposal.exp_gen.utils import (
14+
get_available_packages_prompt,
15+
)
1316
from rdagent.scenarios.data_science.scen.utils import describe_data_folder_v2
1417
from rdagent.scenarios.kaggle.kaggle_crawler import (
1518
crawl_descriptions,
@@ -209,6 +212,7 @@ def get_scenario_all_desc(self, eda_output=None) -> str:
209212
f"{self.recommend_debug_timeout() / 60 : .2f} minutes" if DS_RD_SETTING.sample_data_by_LLM else None
210213
),
211214
runtime_environment=self.get_runtime_environment(),
215+
available_packages_prompt=get_available_packages_prompt(),
212216
)
213217

214218
def get_runtime_environment(self) -> str:

rdagent/scenarios/data_science/scen/prompts.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ scenario_description: |-
5454
{{ runtime_environment }}
5555
{% endif %}
5656
57+
{% if available_packages_prompt is not none %}
58+
====== Available Packages ======
59+
{{ available_packages_prompt }}
60+
{% endif %}
61+
5762
competition_description_template:
5863
system: |-
5964
You are a data science assistant that extracts structured information from unstructured text.

0 commit comments

Comments
 (0)