Source code for crab.core.models.task

# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, Callable, Literal
from uuid import uuid4

import networkx as nx
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    field_validator,
    model_serializer,
)

from .action import Action, ClosedAction
from .evaluator import Evaluator


[docs] class Task(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) id: str description: str evaluator: nx.DiGraph | Evaluator setup: list[ClosedAction] | ClosedAction = [] teardown: list[ClosedAction] | ClosedAction = [] extra_action: list[Action] = []
[docs] @field_validator("evaluator") @classmethod def change_evaluator_to_graph(cls, evaluator: nx.DiGraph | Evaluator) -> str: if isinstance(evaluator, Evaluator): graph = nx.DiGraph() graph.add_node(evaluator) return graph return evaluator
[docs] @field_validator("setup", "teardown") @classmethod def to_list(cls, action: Action | list[Action]) -> list[Action]: if isinstance(action, Action): return [action] return action
[docs] class SubTask(BaseModel): id: str description: str attribute_dict: dict[str, list[str] | str] output_type: str output_generator: Callable[[Any], str] | Literal["manual"] | None = None evaluator_generator: Callable[[Any], nx.DiGraph] | None = None setup: list[ClosedAction] | ClosedAction = [] teardown: list[ClosedAction] | ClosedAction = [] extra_action: list[Action] = [] def __hash__(self) -> int: return hash(self.id)
[docs] @field_validator("attribute_dict") @classmethod def expand_attribute_type( cls, attribute_dict: dict[str, list[str] | str], ) -> dict[str, list[str]]: attribute_dict = attribute_dict.copy() for key in attribute_dict: if isinstance(attribute_dict[key], str): attribute_dict[key] = [attribute_dict[key]] return attribute_dict
[docs] class SubTaskInstance(BaseModel): task: SubTask attribute: dict[str, Any] output: str | None = None id: str = Field(default_factory=uuid4) def __hash__(self) -> int: return hash(self.id)
[docs] @model_serializer def dump_model(self) -> dict[str, Any]: return { "task": self.task.id, "attribute": self.attribute, "output": self.output, }
[docs] class GeneratedTask(BaseModel): description: str tasks: list[SubTaskInstance] adjlist: str id: str = Field(default_factory=uuid4)