Skip to content
Merged
44 changes: 34 additions & 10 deletions rdagent/oai/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, cast
from typing import Any, Callable, List, Optional, Tuple, Type, Union, cast

import pytz
from pydantic import BaseModel, TypeAdapter
Expand All @@ -36,13 +36,14 @@
class JSONParser:
"""JSON parser supporting multiple strategies"""

def __init__(self) -> None:
def __init__(self, add_json_in_prompt: bool = False) -> None:
self.strategies: List[Callable[[str], str]] = [
self._direct_parse,
self._extract_from_code_block,
self._fix_python_syntax,
self._extract_with_fix_combined,
]
self.add_json_in_prompt = add_json_in_prompt

def parse(self, content: str) -> str:
"""Parse JSON content, automatically trying multiple strategies"""
Expand All @@ -55,7 +56,16 @@ def parse(self, content: str) -> str:
continue

# All strategies failed
raise json.JSONDecodeError("Failed to parse JSON after all attempts", original_content, 0)
if not self.add_json_in_prompt:
error = json.JSONDecodeError(
"Failed to parse JSON after all attempts, maybe because 'messages' must contain the word 'json' in some form",
original_content,
0,
)
error.message = "Failed to parse JSON after all attempts, maybe because 'messages' must contain the word 'json' in some form" # type: ignore[attr-defined]
raise error
else:
raise json.JSONDecodeError("Failed to parse JSON after all attempts", original_content, 0)

def _direct_parse(self, content: str) -> str:
"""Strategy 1: Direct parsing (including handling extra data)"""
Expand Down Expand Up @@ -528,12 +538,16 @@ def _create_chat_completion_auto_continue(
seed: Optional[int] = None,
json_target_type: Optional[str] = None,
add_json_in_prompt: bool = False,
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
**kwargs: Any,
) -> str:
"""
Call the chat completion function and automatically continue the conversation if the finish_reason is length.
"""

if response_format is None and json_mode:
response_format = {"type": "json_object"}

# 0) return directly if cache is hit
if seed is None and LLM_SETTINGS.use_auto_chat_cache_seed_gen:
seed = LLM_CACHE_SEED_GEN.get_next_seed()
Expand All @@ -555,11 +569,11 @@ def _create_chat_completion_auto_continue(
# Loop to get a full response
try_n = 6
for _ in range(try_n): # for some long code, 3 times may not enough for reasoning models
if json_mode and add_json_in_prompt:
if response_format == {"type": "json_object"} and add_json_in_prompt:
self._add_json_in_prompt(new_messages)
response, finish_reason = self._create_chat_completion_inner_function(
messages=new_messages,
json_mode=json_mode,
response_format=response_format,
**kwargs,
)
all_response += response
Expand All @@ -571,21 +585,31 @@ def _create_chat_completion_auto_continue(

# 2) refine the response and return
if LLM_SETTINGS.reasoning_think_rm:
# Strategy 1: Try to match complete <think>...</think> pattern
match = re.search(r"<think>(.*?)</think>(.*)", all_response, re.DOTALL)
_, all_response = match.groups() if match else ("", all_response)
if match:
_, all_response = match.groups()
else:
# Strategy 2: If no complete match, try to match only </think>
match = re.search(r"</think>(.*)", all_response, re.DOTALL)
if match:
all_response = match.group(1)
# If no match at all, keep original content

# 3) format checking
if json_mode or json_target_type:
parser = JSONParser()
if response_format == {"type": "json_object"} or json_target_type:
parser = JSONParser(add_json_in_prompt=add_json_in_prompt)
all_response = parser.parse(all_response)
if json_target_type:
# deepseek will enter this branch
TypeAdapter(json_target_type).validate_json(all_response)

if (response_format := kwargs.get("response_format")) is not None:
if response_format is not None:
if not isinstance(response_format, dict) and issubclass(response_format, BaseModel):
# It may raise TypeError if initialization fails
response_format(**json.loads(all_response))
elif response_format == {"type": "json_object"}:
logger.info(f"Using OpenAI response format: {response_format}")
else:
logger.warning(f"Unknown response_format: {response_format}, skipping validation.")
if self.dump_chat_cache:
Expand Down Expand Up @@ -642,7 +666,7 @@ def _create_embedding_inner_function( # type: ignore[no-untyped-def]
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915
self,
messages: list[dict[str, Any]],
json_mode: bool = False,
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
*args,
**kwargs,
) -> tuple[str, str | None]:
Expand Down
20 changes: 11 additions & 9 deletions rdagent/oai/backend/deprec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import uuid
from copy import deepcopy
from pathlib import Path
from typing import Any, Optional, cast
from typing import Any, Optional, Type, Union, cast

import numpy as np
import openai
import tiktoken
from openai.types.chat import ChatCompletion
from pydantic import BaseModel

from rdagent.core.utils import LLM_CACHE_SEED_GEN, SingletonBaseClass, import_class
from rdagent.log import LogColors
Expand Down Expand Up @@ -294,7 +295,7 @@ def _create_embedding_inner_function( # type: ignore[no-untyped-def]
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915
self,
messages: list[dict[str, Any]],
json_mode: bool = False,
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
add_json_in_prompt: bool = False,
*args,
**kwargs,
Expand Down Expand Up @@ -414,13 +415,14 @@ def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # no
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
if json_mode:
if add_json_in_prompt:
for message in messages[::-1]:
message["content"] = message["content"] + "\nPlease respond in json format."
if message["role"] == LLM_SETTINGS.system_prompt_role:
# NOTE: assumption: systemprompt is always the first message
break

# FIX what if the model does not support response_schema
if response_format == {"type": "json_object"} and add_json_in_prompt:
for message in messages[::-1]:
message["content"] = message["content"] + "\nPlease respond in json format."
if message["role"] == LLM_SETTINGS.system_prompt_role:
# NOTE: assumption: systemprompt is always the first message
break
call_kwargs["response_format"] = {"type": "json_object"}
response = self.chat_client.chat.completions.create(**call_kwargs)

Expand Down
14 changes: 8 additions & 6 deletions rdagent/oai/backend/litellm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copyreg
from typing import Any, Literal, cast
from typing import Any, Literal, Optional, Type, Union, cast

import numpy as np
from litellm import (
Expand All @@ -11,6 +11,7 @@
supports_response_schema,
token_counter,
)
from pydantic import BaseModel

from rdagent.log import LogColors
from rdagent.log import rdagent_logger as logger
Expand Down Expand Up @@ -86,23 +87,24 @@ def _create_embedding_inner_function(
def _create_chat_completion_inner_function( # type: ignore[no-untyped-def] # noqa: C901, PLR0912, PLR0915
self,
messages: list[dict[str, Any]],
json_mode: bool = False,
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
*args,
**kwargs,
) -> tuple[str, str | None]:
"""
Call the chat completion function
"""
if json_mode and supports_response_schema(model=LITELLM_SETTINGS.chat_model):
kwargs["response_format"] = {"type": "json_object"}

elif not supports_response_schema(model=LITELLM_SETTINGS.chat_model) and "response_format" in kwargs:
if response_format and not supports_response_schema(model=LITELLM_SETTINGS.chat_model):
# Deepseek will enter this branch
logger.warning(
f"{LogColors.RED}Model {LITELLM_SETTINGS.chat_model} does not support response schema, ignoring response_format argument.{LogColors.END}",
tag="llm_messages",
)
kwargs.pop("response_format")
response_format = None

if response_format:
kwargs["response_format"] = response_format

if LITELLM_SETTINGS.log_llm_chat_content:
logger.info(self._build_log_messages(messages), tag="llm_messages")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,5 +344,5 @@ output_format:
Design a specific and detailed Pipeline task based on the given hypothesis. The output should be detailed enough to directly implement the corresponding code.
The output should follow JSON format. The schema is as follows:
{
"description": "A precise and comprehensive description of the main workflow script (`main.py`)",
"description": "A detailed, step-by-step implementation guide for `main.py` that synthesizes planned modifications and code structure into a comprehensive coding plan. Must be formatted in Markdown with level-3 headings (###) organizing logical sections, key decision points, and implementation steps. Should provide sufficient detail covering implementation flow, algorithms, data handling, and key logic points for unambiguous developer execution.",
}
30 changes: 22 additions & 8 deletions rdagent/scenarios/data_science/proposal/exp_gen/prompts_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -355,23 +355,36 @@ task_gen:
If you are confident in a specific value based on strong evidence, prior experiments, or clear rationale, specify the value clearly.
{% include "scenarios.data_science.share:spec.hyperparameter" %}


{% if task_output_format is not none %}
## [Partial Response Format 1] Task Output Format:

# Output Format

{% if not workflow_check %}

{{ task_output_format }}

{% else %}

There are two steps in the task. But you should adhere to the final output format.

## [Partial Response Format 1]
### Step1: **Task Output Format** :
{{ task_output_format }}

{% if workflow_check %}
# Step 2: Workflow Update
### Step 2: **Workflow Update** :
Since components have dependencies, your second task is to update the workflow to reflect the changes made to the target component. Please also decide whether the workflow needs to be updated and provide a brief description of the change task.
{{ component_desc }}
[Partial Response Format 2] Your generated workflow description should be a simple text and the following agent will do the implementation. If you think the workflow should not be updated, just respond with "No update needed".
{% endif %}

Your final output should strictly adhere to the following JSON format.
## [Partial Response Format 2] Your generated workflow description should be a simple text and the following agent will do the implementation. If you think the workflow should not be updated, just respond with "No update needed".

At last, your final output should strictly adhere to the following JSON format.
{
"task_design": ---The dict corresponding to task output format---,
{% if workflow_check %}"workflow_update": ---A string corresponding to workflow description--- {% endif %}
"task_design": a dict which strictly adheres to the **Task Output Format** in Step 1,
"workflow_update": "A string which is a precise and comprehensive description of the Workflow Update, or 'No update needed' if no changes are required."
}
{% endif %}
{% endif %}

user: |-
# Competition Scenario Description
Expand Down Expand Up @@ -489,3 +502,4 @@ output_format:
}



42 changes: 26 additions & 16 deletions rdagent/scenarios/data_science/proposal/exp_gen/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,11 +729,11 @@ def task_gen(
else:
component_info = get_component(hypotheses[0].component)
data_folder_info = self.scen.processed_data_folder_description
workflow_check = not pipeline and hypotheses[0].component != "Workflow"
sys_prompt = T(".prompts_v2:task_gen.system").r(
task_output_format=component_info["task_output_format"] if not self.supports_response_schema else None,
# task_output_format=component_info["task_output_format"],
component_desc=component_desc,
workflow_check=not pipeline and hypotheses[0].component != "Workflow",
workflow_check=workflow_check,
)
user_prompt = T(".prompts_v2:task_gen.user").r(
scenario_desc=scenario_desc,
Expand All @@ -743,37 +743,47 @@ def task_gen(
failed_exp_and_feedback_list_desc=failed_exp_feedback_list_desc,
eda_improvement=fb_to_sota_exp.eda_improvement if fb_to_sota_exp else None,
)

response = APIBackend().build_messages_and_create_chat_completion(
user_prompt=user_prompt,
system_prompt=sys_prompt,
response_format=CodingSketch if self.supports_response_schema else {"type": "json_object"},
json_target_type=Dict[str, str | Dict[str, str]] if not self.supports_response_schema else None,
)

task_dict = json.loads(response)
task_design = (
task_dict.get("task_design", {}) if not self.supports_response_schema else task_dict.get("sketch", {})
)
logger.info(f"Task design:\n{task_design}")

# 1) explain the response and get main task_description
not_found_str = f"{component_info['target_name']} description not provided"
if self.supports_response_schema:
# task_dict: {"sketch": str, ...}
task_desc = task_dict.get("sketch", not_found_str)
else:
if workflow_check:
# task_dict: {"task_design": ...., "workflow_update": ....}
task_desc = task_dict.get("task_design", {}).get("description", not_found_str)
else:
# task_dict: {"description": ....}
task_desc = task_dict.get("description", not_found_str)
# task_desc: str, a description of the task

# 2) create the main task
logger.info(f"Task design:\n{task_desc}")
task_name = hypotheses[0].component
description = (
task_design
if isinstance(task_design, str)
else task_design.get("description", f"{component_info['target_name']} description not provided")
)
task_class = component_info["task_class"]
task = task_class(
name=task_name,
description=description,
description=task_desc,
)
new_workflow_desc = task_dict.get("workflow_update", "No update needed")
exp = DSExperiment(pending_tasks_list=[[task]], hypothesis=hypotheses[0])
# exp.experiment_workspace.inject_code_from_folder(sota_exp.experiment_workspace.workspace_path)
if sota_exp is not None:
exp.experiment_workspace.inject_code_from_file_dict(sota_exp.experiment_workspace)
if not pipeline and new_workflow_desc != "No update needed":

# 3) create the workflow update task
if workflow_check:
workflow_task = WorkflowTask(
name="Workflow",
description=new_workflow_desc,
description=task_dict.get("workflow_update", "No update needed"),
)
exp.pending_tasks_list.append([workflow_task])
return exp
Expand Down
Loading