from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict from typing import Union import os import json class Config(BaseSettings): model_config = SettingsConfigDict( extra="ignore", ) # 测试模式配置 TEST_MODE: bool = os.getenv("GROUP_HORSE_RACING_TEST_MODE", "False").lower() == "true" TESTERS: set[int] = Field(default_factory=set) TEST_GROUPS: set[int] = Field(default_factory=set) ALLOWED_GROUPS: set[int] = Field(default_factory=set) # 奖励配置 PARTICIPANT_REWARD: int = int(os.getenv("GROUP_HORSE_RACING_PARTICIPANT_REWARD", "50")) CHAMPION_REWARD: int = int(os.getenv("GROUP_HORSE_RACING_CHAMPION_REWARD", "200")) MIN_BET: int = int(os.getenv("GROUP_HORSE_RACING_MIN_BET", "10")) MIN_ODDS: float = float(os.getenv("GROUP_HORSE_RACING_MIN_ODDS", "1.2")) RACE_DISTANCE: int = int(os.getenv("GROUP_HORSE_RACING_RACE_DISTANCE", "100")) RACE_TICK_INTERVAL: int = int(os.getenv("GROUP_HORSE_RACING_RACE_TICK_INTERVAL", "5")) # 消息撤回配置 MESSAGE_RECALL: dict[str, int] = Field( default_factory=lambda: { "race_update": 30, "registration": 180, "bet_confirm": 180, "cancel_confirm": 60, "error": 60, "race_result": 0, "leaderboard": 0, "help": 0, "odds_display": 0, } ) # 数据库配置 RACE_DB_FILE: str = os.getenv("GROUP_HORSE_RACING_RACE_DB_FILE", "data/group_horse_racing/race.db") def __init__(self, **data): super().__init__(**data) # 从环境变量解析 TESTERS testers_env = os.getenv("GROUP_HORSE_RACING_TESTERS", "") if testers_env: self.TESTERS = self._parse_id_set(testers_env) # 从环境变量解析 TEST_GROUPS test_groups_env = os.getenv("GROUP_HORSE_RACING_TEST_GROUPS", "") if test_groups_env: self.TEST_GROUPS = self._parse_id_set(test_groups_env) # 从环境变量解析 ALLOWED_GROUPS allowed_groups_env = os.getenv("GROUP_HORSE_RACING_ALLOWED_GROUPS", "") if allowed_groups_env: self.ALLOWED_GROUPS = self._parse_id_set(allowed_groups_env) @staticmethod def _parse_id_set(v: str) -> set[int]: """Parse ID sets from various formats.""" # Handle JSON string format like "[1424473282]" try: parsed = json.loads(v) if isinstance(parsed, list): return set(int(x) for x in parsed) except (json.JSONDecodeError, ValueError, TypeError): pass # Handle comma-separated format try: return set(int(x.strip()) for x in v.split(",") if x.strip()) except ValueError: pass return set() @field_validator("TESTERS", "TEST_GROUPS", "ALLOWED_GROUPS", mode="before") @classmethod def parse_id_sets(cls, v): """Parse ID sets from various formats.""" if isinstance(v, set): return v if isinstance(v, str): return cls._parse_id_set(v) if isinstance(v, (list, tuple)): return set(int(x) for x in v) return v if isinstance(v, set) else set()