diff --git a/danding_bot/plugins/danding_points/README.md b/danding_bot/plugins/danding_points/README.md index 6e4843c..e11a0a0 100644 --- a/danding_bot/plugins/danding_points/README.md +++ b/danding_bot/plugins/danding_points/README.md @@ -10,7 +10,6 @@ danding_points/ ├── __init__.py # 插件元数据 & points_api 单例导出 ├── config.py # 配置类 -├── database.py # SQLite 数据库操作 └── api.py # PointsAPI 核心 ``` @@ -20,10 +19,11 @@ danding_points/ | 配置项 | 类型 | 默认值 | 说明 | |--------|------|--------|------| -| `DANDING_POINTS_DB_FILE` | str | `data/danding_points/points.db` | 数据库文件路径 | -| `DANDING_POINTS_MAX_BALANCE` | int | `0` | 用户积分余额上限,`0` = 无限制 | -| `DANDING_POINTS_MAX_PER_OPERATION` | int | `0` | 单次操作积分上限,`0` = 无限制 | -| `DANDING_POINTS_LOG_RETENTION_DAYS` | int | `365` | 流水日志保留天数 | +| `DANDING_POINTS_API_HOST` | str | `https://api.danding.vip/bot/points` | xapi 积分 API 地址 | +| `DANDING_BOT_USER` | str | `1424473282` | xapi Bot 鉴权用户 | +| `DANDING_BOT_TOKEN` | str | 空 | xapi Bot 鉴权 Token,未设置时读取 `DANDING_API_TOKEN` / `BOT_TOKEN` | + +积分余额上限与单次操作上限由 xapi `BOT_POINTS_MAX_BALANCE` / `BOT_POINTS_MAX_PER_OPERATION` 控制,nonebot 本地不再持有阈值。 ## API 接口 @@ -170,36 +170,10 @@ async def admin_set(user_id: str, amount: int): | 插件 | source 值 | |------|-----------| -| onmyoji_gacha | `"onmyoji_gacha"` | -| danding_api | `"danding_api"` | -| shop | `"shop"` | -| sign_in | `"sign_in"` | +| onmyoji_gacha 签到 | `"gacha_sign"` | +| group_horse_racing | `"horse_race"` | +| 管理调整 | `"admin"` | ## 数据库 -使用 SQLite,数据文件位于 `data/danding_points/points.db`,无需额外配置。 - -### 表结构 - -**user_points** — 用户积分账户 - -| 字段 | 类型 | 说明 | -|------|------|------| -| user_id | TEXT PK | 用户 ID | -| points | INTEGER | 当前余额,>= 0 | -| total_earned | INTEGER | 累计获得 | -| total_spent | INTEGER | 累计消费 | -| created_at | TEXT | 创建时间 | -| updated_at | TEXT | 最后更新时间 | - -**point_transactions** — 积分变动流水 - -| 字段 | 类型 | 说明 | -|------|------|------| -| id | INTEGER PK | 自增 ID | -| user_id | TEXT | 用户 ID | -| amount | INTEGER | 变动数额(消费为负) | -| balance_after | INTEGER | 变动后余额 | -| source | TEXT | 来源标识 | -| reason | TEXT | 变动原因 | -| created_at | TEXT | 创建时间 | +本插件不再写入本地 SQLite。积分账户与流水由 xapi MySQL `bot_user_points` / `bot_point_transactions` 承载,nonebot 只通过 `/bot/points/*` HTTP API 读写。 diff --git a/danding_bot/plugins/danding_points/api.py b/danding_bot/plugins/danding_points/api.py index d510cf5..e9d63de 100644 --- a/danding_bot/plugins/danding_points/api.py +++ b/danding_bot/plugins/danding_points/api.py @@ -1,312 +1,181 @@ -import asyncio -import logging -import threading -from datetime import datetime -from typing import Tuple, List, Dict, Any -from .config import Config -from .database import PointsDatabase - -logger = logging.getLogger(__name__) +import asyncio +import aiohttp +import logging +from typing import Tuple, List, Dict, Any, Optional +from .config import Config + +logger = logging.getLogger(__name__) class PointsAPI: - """Points system API for managing user points.""" - - def __init__(self, config: Config): - self.config = config - self.db = PointsDatabase(config) - self._lock = threading.Lock() - - async def get_balance(self, user_id: str) -> int: - """Get user's current points balance.""" - return await asyncio.to_thread(self.db.get_user_balance, user_id) + """Points system API for managing user points.""" + + def __init__(self, config: Config): + self.config = config + + def _url(self, path: str) -> str: + """拼接 /bot/points 端点地址。""" + + return f"{self.config.POINTS_API_HOST}/{path.lstrip('/')}" + + def _auth(self) -> Dict[str, str]: + """生成 xapi Bot 鉴权参数。""" + + return { + "user": self.config.BOT_USER, + "token": self.config.BOT_TOKEN, + } + + async def _request( + self, + method: str, + path: str, + *, + payload: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: + """调用 xapi /bot/points,并只向上层暴露 data。""" + + request_url = self._url(path) + timeout = aiohttp.ClientTimeout(total=10) + try: + async with aiohttp.ClientSession() as session: + if method == "GET": + request_params = {**self._auth(), **(params or {})} + async with session.get(request_url, params=request_params, timeout=timeout) as resp: + return await self._parse_response(resp, path) + request_payload = {**self._auth(), **(payload or {})} + async with session.post(request_url, json=request_payload, timeout=timeout) as resp: + return await self._parse_response(resp, path) + except aiohttp.ClientError as exc: + logger.error("points api request failed path=%s error=%s", path, exc) + return None + except asyncio.TimeoutError as exc: + logger.error("points api request timeout path=%s error=%s", path, exc) + return None + + async def _parse_response(self, resp: aiohttp.ClientResponse, path: str) -> Optional[Dict[str, Any]]: + """解析 xapi 统一响应,失败时返回 None 维持旧 API 失败语义。""" + + if resp.status != 200: + logger.error("points api bad status path=%s status=%s", path, resp.status) + return None + body = await resp.json() + if body.get("code") != 200: + logger.error("points api fail path=%s code=%s message=%s", path, body.get("code"), body.get("message")) + return None + data = body.get("data") + return data if isinstance(data, dict) else None + + async def get_balance(self, user_id: str) -> int: + """Get user's current points balance.""" + data = await self._request("GET", "balance", params={"user_id": user_id}) + if data is None: + return 0 + return int(data.get("balance", 0) or 0) async def add_points( self, user_id: str, amount: int, source: str, reason: str = None ) -> Tuple[bool, int]: - """Add points to user account. - - Returns: (success, new_balance) - """ - # Parameter validation - if not isinstance(amount, int) or amount <= 0: - return False, 0 - if not user_id or not source: - return False, 0 - - # Operation limit validation - if self.config.POINTS_MAX_PER_OPERATION > 0: - if amount > self.config.POINTS_MAX_PER_OPERATION: - return False, 0 - - def _add(): - with self._lock: - conn = self.db.get_connection() - try: - cursor = conn.cursor() - # Ensure user exists - self.db.ensure_user_exists(user_id, conn) - - # Get current balance - cursor.execute( - "SELECT points FROM user_points WHERE user_id = ?", - (user_id,), - ) - row = cursor.fetchone() - current_balance = row["points"] if row else 0 - - # Check balance limit - new_balance = current_balance + amount - if self.config.POINTS_MAX_BALANCE > 0: - if new_balance > self.config.POINTS_MAX_BALANCE: - conn.rollback() - return False, current_balance - - # Update balance and total_earned - now = datetime.now().isoformat() - cursor.execute( - """ - UPDATE user_points - SET points = ?, total_earned = total_earned + ?, updated_at = ? - WHERE user_id = ? - """, - (new_balance, amount, now, user_id), - ) - - # Write transaction log - cursor.execute( - """ - INSERT INTO point_transactions - (user_id, amount, balance_after, source, reason, created_at) - VALUES (?, ?, ?, ?, ?, ?) - """, - (user_id, amount, new_balance, source, reason, now), - ) - - conn.commit() - return True, new_balance - except Exception as e: - conn.rollback() - logger.error(f"add_points failed for {user_id}: {e}") - return False, 0 - finally: - conn.close() - - return await asyncio.to_thread(_add) + """Add points to user account. + + Returns: (success, new_balance) + """ + # 保留原 PointsAPI 的入参失败语义;限额校验由 xapi 承担。 + if not isinstance(amount, int) or amount <= 0: + return False, 0 + if not user_id or not source: + return False, 0 + + data = await self._request( + "POST", + "add", + payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason}, + ) + return self._change_result(data) async def spend_points( self, user_id: str, amount: int, source: str, reason: str = None ) -> Tuple[bool, int]: - """Spend points from user account. - - Returns: (success, new_balance) - """ - # Parameter validation - if not isinstance(amount, int) or amount <= 0: - return False, 0 - if not user_id or not source: - return False, 0 - - # Operation limit validation - if self.config.POINTS_MAX_PER_OPERATION > 0: - if amount > self.config.POINTS_MAX_PER_OPERATION: - return False, 0 - - def _spend(): - with self._lock: - conn = self.db.get_connection() - try: - cursor = conn.cursor() - # Ensure user exists - self.db.ensure_user_exists(user_id, conn) - - # Get current balance - cursor.execute( - "SELECT points FROM user_points WHERE user_id = ?", - (user_id,), - ) - row = cursor.fetchone() - current_balance = row["points"] if row else 0 - - # Check sufficient balance - if current_balance < amount: - conn.rollback() - return False, current_balance - - # Update balance and total_spent - new_balance = current_balance - amount - now = datetime.now().isoformat() - cursor.execute( - """ - UPDATE user_points - SET points = ?, total_spent = total_spent + ?, updated_at = ? - WHERE user_id = ? - """, - (new_balance, amount, now, user_id), - ) - - # Write transaction log (amount as negative) - cursor.execute( - """ - INSERT INTO point_transactions - (user_id, amount, balance_after, source, reason, created_at) - VALUES (?, ?, ?, ?, ?, ?) - """, - (user_id, -amount, new_balance, source, reason, now), - ) - - conn.commit() - return True, new_balance - except Exception as e: - conn.rollback() - logger.error(f"spend_points failed for {user_id}: {e}") - return False, 0 - finally: - conn.close() - - return await asyncio.to_thread(_spend) + """Spend points from user account. + + Returns: (success, new_balance) + """ + # 保留原 PointsAPI 的入参失败语义;余额与限额校验由 xapi 承担。 + if not isinstance(amount, int) or amount <= 0: + return False, 0 + if not user_id or not source: + return False, 0 + + data = await self._request( + "POST", + "spend", + payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason}, + ) + return self._change_result(data) async def set_points( self, user_id: str, amount: int, source: str, reason: str = None ) -> Tuple[bool, int]: - """Set user's points to exact amount. - - Returns: (success, new_balance) - """ - # Parameter validation - if not isinstance(amount, int) or amount < 0: - return False, 0 - if not user_id or not source: - return False, 0 - - def _set(): - with self._lock: - conn = self.db.get_connection() - try: - cursor = conn.cursor() - # Ensure user exists - self.db.ensure_user_exists(user_id, conn) - - # Get current balance - cursor.execute( - "SELECT points, total_earned FROM user_points WHERE user_id = ?", - (user_id,), - ) - row = cursor.fetchone() - current_balance = row["points"] if row else 0 - current_earned = row["total_earned"] if row else 0 - - # If new value equals old value, return without writing - if current_balance == amount: - conn.rollback() - return True, amount - - # Calculate difference for total_earned (only positive diff) - diff = amount - current_balance - earned_diff = max(0, diff) - - # Update balance and total_earned - now = datetime.now().isoformat() - cursor.execute( - """ - UPDATE user_points - SET points = ?, total_earned = total_earned + ?, updated_at = ? - WHERE user_id = ? - """, - (amount, earned_diff, now, user_id), - ) - - # Write transaction log - cursor.execute( - """ - INSERT INTO point_transactions - (user_id, amount, balance_after, source, reason, created_at) - VALUES (?, ?, ?, ?, ?, ?) - """, - (user_id, diff, amount, source, reason, now), - ) - - conn.commit() - return True, amount - except Exception as e: - conn.rollback() - logger.error(f"set_points failed for {user_id}: {e}") - return False, 0 - finally: - conn.close() - - return await asyncio.to_thread(_set) + """Set user's points to exact amount. + + Returns: (success, new_balance) + """ + # set 仍保持原 PointsAPI 行为:只校验非负,不做余额上限判断。 + if not isinstance(amount, int) or amount < 0: + return False, 0 + if not user_id or not source: + return False, 0 + + data = await self._request( + "POST", + "set", + payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason}, + ) + return self._change_result(data) async def get_transactions( self, user_id: str, limit: int = 20, offset: int = 0 ) -> List[Dict[str, Any]]: - """Get transaction history for a user. - - Returns: List of transaction dicts - """ - # Normalize parameters - limit = max(1, min(100, limit)) - offset = max(0, offset) - - def _get(): - conn = self.db.get_connection() - try: - cursor = conn.cursor() - cursor.execute( - """ - SELECT id, user_id, amount, balance_after, source, reason, created_at - FROM point_transactions - WHERE user_id = ? - ORDER BY id DESC - LIMIT ? OFFSET ? - """, - (user_id, limit, offset), - ) - rows = cursor.fetchall() - return [dict(row) for row in rows] - except Exception as e: - logger.error(f"get_transactions failed for {user_id}: {e}") - return [] - finally: - conn.close() - - return await asyncio.to_thread(_get) + """Get transaction history for a user. + + Returns: List of transaction dicts + """ + limit = max(1, min(100, limit)) + offset = max(0, offset) + data = await self._request( + "GET", + "transactions", + params={"user_id": user_id, "limit": limit, "offset": offset}, + ) + if data is None: + return [] + items = data.get("items", []) + return items if isinstance(items, list) else [] async def get_ranking( self, limit: int = 10, order_by: str = "points" ) -> List[Dict[str, Any]]: - """Get points ranking. - - Returns: List of ranking dicts with rank field - """ - # Normalize parameters - limit = max(1, min(100, limit)) - if order_by not in ("points", "total_earned"): - order_by = "points" - - def _get(): - conn = self.db.get_connection() - try: - cursor = conn.cursor() - order_column = "points" if order_by == "points" else "total_earned" - query = f""" - SELECT - RANK() OVER (ORDER BY {order_column} DESC) as rank, - user_id, - points, - total_earned, - total_spent - FROM user_points - ORDER BY {order_column} DESC, user_id ASC - LIMIT ? - """ - cursor.execute(query, (limit,)) - rows = cursor.fetchall() - return [dict(row) for row in rows] - except Exception as e: - logger.error(f"get_ranking failed: {e}") - return [] - finally: - conn.close() - - return await asyncio.to_thread(_get) + """Get points ranking. + + Returns: List of ranking dicts with rank field + """ + limit = max(1, min(100, limit)) + if order_by not in ("points", "total_earned"): + order_by = "points" + data = await self._request( + "GET", + "ranking", + params={"limit": limit, "order_by": order_by}, + ) + if data is None: + return [] + items = data.get("items", []) + return items if isinstance(items, list) else [] + + def _change_result(self, data: Optional[Dict[str, Any]]) -> Tuple[bool, int]: + """解析 add/spend/set 响应并维持旧失败返回值。""" + + if data is None: + return False, 0 + return bool(data.get("success")), int(data.get("balance", 0) or 0) diff --git a/danding_bot/plugins/danding_points/config.py b/danding_bot/plugins/danding_points/config.py index 4851aad..8e2cc76 100644 --- a/danding_bot/plugins/danding_points/config.py +++ b/danding_bot/plugins/danding_points/config.py @@ -1,6 +1,5 @@ from pydantic import field_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from pathlib import Path import os @@ -11,22 +10,24 @@ class Config(BaseSettings): extra="ignore", ) - # 数据库配置 - POINTS_DB_FILE: str = os.getenv("DANDING_POINTS_DB_FILE", "data/danding_points/points.db") - POINTS_MAX_BALANCE: int = int(os.getenv("DANDING_POINTS_MAX_BALANCE", "0")) - POINTS_MAX_PER_OPERATION: int = int(os.getenv("DANDING_POINTS_MAX_PER_OPERATION", "0")) - POINTS_LOG_RETENTION_DAYS: int = int(os.getenv("DANDING_POINTS_LOG_RETENTION_DAYS", "365")) + # xapi /bot/points 运行时 API 配置 + POINTS_API_HOST: str = os.getenv("DANDING_POINTS_API_HOST", "https://api.danding.vip/bot/points") + BOT_USER: str = os.getenv("DANDING_BOT_USER", "1424473282") + BOT_TOKEN: str = os.getenv( + "DANDING_BOT_TOKEN", + os.getenv("DANDING_API_TOKEN", os.getenv("BOT_TOKEN", "")), + ) - @field_validator("POINTS_MAX_BALANCE", "POINTS_MAX_PER_OPERATION", "POINTS_LOG_RETENTION_DAYS") + @field_validator("POINTS_API_HOST") @classmethod - def validate_non_negative(cls, v): - if v < 0: - raise ValueError("Value must be non-negative") - return v + def validate_api_host(cls, value): + if not value: + raise ValueError("POINTS_API_HOST cannot be empty") + return value.rstrip("/") - @field_validator("POINTS_DB_FILE") + @field_validator("BOT_USER") @classmethod - def validate_db_path(cls, v): - if not v: - raise ValueError("Database file path cannot be empty") - return v + def validate_bot_user(cls, value): + if not value: + raise ValueError("BOT_USER cannot be empty") + return value diff --git a/danding_bot/plugins/danding_points/database.py b/danding_bot/plugins/danding_points/database.py deleted file mode 100644 index 6c14093..0000000 --- a/danding_bot/plugins/danding_points/database.py +++ /dev/null @@ -1,106 +0,0 @@ -import sqlite3 -import os -from datetime import datetime -from typing import Optional, List, Dict, Any -from .config import Config - - -class PointsDatabase: - """SQLite database handler for points system.""" - - def __init__(self, config: Config): - self.config = config - self.db_path = config.POINTS_DB_FILE - self._ensure_db_dir() - self._init_db() - - def _ensure_db_dir(self): - """Create database directory if it doesn't exist.""" - db_dir = os.path.dirname(self.db_path) - if db_dir: - os.makedirs(db_dir, exist_ok=True) - - def _init_db(self): - """Initialize database tables.""" - conn = sqlite3.connect(self.db_path, timeout=5.0) - cursor = conn.cursor() - - # Create user_points table - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS user_points ( - user_id TEXT PRIMARY KEY, - points INTEGER NOT NULL DEFAULT 0 CHECK(points >= 0), - total_earned INTEGER NOT NULL DEFAULT 0, - total_spent INTEGER NOT NULL DEFAULT 0, - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL - ) - """ - ) - - # Create point_transactions table - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS point_transactions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - amount INTEGER NOT NULL, - balance_after INTEGER NOT NULL, - source TEXT NOT NULL, - reason TEXT, - created_at TEXT NOT NULL - ) - """ - ) - - # Create indexes - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_transactions_user_id ON point_transactions(user_id)" - ) - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_transactions_source ON point_transactions(source)" - ) - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_transactions_created_at ON point_transactions(created_at)" - ) - - conn.commit() - conn.close() - - def get_connection(self) -> sqlite3.Connection: - """Get a database connection.""" - conn = sqlite3.connect(self.db_path, timeout=5.0) - conn.row_factory = sqlite3.Row - return conn - - def get_user_balance(self, user_id: str) -> int: - """Get user's current points balance.""" - conn = self.get_connection() - try: - cursor = conn.cursor() - cursor.execute("SELECT points FROM user_points WHERE user_id = ?", (user_id,)) - row = cursor.fetchone() - return row["points"] if row else 0 - finally: - conn.close() - - def ensure_user_exists(self, user_id: str, conn=None) -> None: - """Create user account if it doesn't exist. Reuses provided conn if given.""" - should_close = False - if conn is None: - conn = self.get_connection() - should_close = True - cursor = conn.cursor() - now = datetime.now().isoformat() - cursor.execute( - """ - INSERT OR IGNORE INTO user_points - (user_id, points, total_earned, total_spent, created_at, updated_at) - VALUES (?, 0, 0, 0, ?, ?) - """, - (user_id, now, now), - ) - if should_close: - conn.commit() - conn.close() diff --git a/danding_bot/plugins/group_horse_racing/commands/shared.py b/danding_bot/plugins/group_horse_racing/commands/shared.py index e3adef1..f7c707f 100644 --- a/danding_bot/plugins/group_horse_racing/commands/shared.py +++ b/danding_bot/plugins/group_horse_racing/commands/shared.py @@ -1,5 +1,7 @@ -import logging -import asyncio +import logging +import asyncio +from datetime import datetime +from uuid import uuid4 from nonebot.adapters.onebot.v11 import Bot, Event, GroupMessageEvent, Message, MessageSegment, PrivateMessageEvent @@ -27,8 +29,8 @@ async def _get_user_name(bot: Bot, scope: str, user_id: str) -> str: group_id = int(scope.split("_", 1)[1]) info = await bot.get_group_member_info(group_id=group_id, user_id=int(user_id)) return info.get("card") or info.get("nickname") or user_id - except Exception: - pass + except Exception as exc: + logger.debug("获取赛马用户昵称失败 scope=%s user_id=%s error=%s", scope, user_id, exc) return user_id @@ -142,13 +144,14 @@ async def _is_admin_or_owner(bot: Bot, event: Event) -> bool: user_id=int(event.get_user_id()), ) return member_info.get("role", "") in ("admin", "owner") - except Exception: - return False + except Exception as exc: + logger.debug("检查赛马管理员权限失败 user_id=%s error=%s", getattr(event, "user_id", ""), exc) + return False -def _build_point_changes(room: Room, odds: dict[str, float]) -> tuple[dict[str, int], dict[str, str]]: - point_changes: dict[str, int] = {} +def _build_point_changes(room: Room, odds: dict[str, float]) -> tuple[dict[str, int], dict[str, str]]: + point_changes: dict[str, int] = {} for horse in room.horses.values(): point_changes[horse.owner_id] = point_changes.get(horse.owner_id, 0) + config.PARTICIPANT_REWARD @@ -168,8 +171,23 @@ def _build_point_changes(room: Room, odds: dict[str, float]) -> tuple[dict[str, point_summaries = { user_id: _describe_points_delta(delta) for user_id, delta in point_changes.items() - } - return point_changes, point_summaries + } + return point_changes, point_summaries + + +def _build_participants_snapshot(room: Room) -> list[str]: + """生成赛果归档所需的参赛马名快照。""" + + return [horse.name for horse in _get_horses_in_order(room)] + + +def _build_bet_distribution(room: Room) -> dict[str, int]: + """按马名汇总下注分布,供 xapi 原样归档。""" + + distribution = {horse.name: 0 for horse in _get_horses_in_order(room)} + for bet in room.bets: + distribution[bet.horse_name] = distribution.get(bet.horse_name, 0) + bet.amount + return distribution async def _send_to_scope(bot: Bot, scope: str, message: str, message_type: str = "race_update", critical: bool = False): @@ -243,10 +261,10 @@ async def settle_race(room: Room) -> tuple[RaceResult, dict[str, float]] | None: for bet in room.bets: user_ids.add(bet.user_id) - # Record pre-balances - pre_balances: dict[str, int] = {} - for uid in user_ids: - pre_balances[uid] = points_service.get_balance(uid) + # Record pre-balances + pre_balances: dict[str, int] = {} + for uid in user_ids: + pre_balances[uid] = await points_service.get_balance(uid) # 1. Reward all participants for horse in room.horses.values(): @@ -271,10 +289,10 @@ async def settle_race(room: Room) -> tuple[RaceResult, dict[str, float]] | None: except Exception as e: logger.warning(f"payout_winnings failed for {bet.user_id}: {e}") - # Record post-balances and compute deltas - post_balances: dict[str, int] = {} - for uid in user_ids: - post_balances[uid] = points_service.get_balance(uid) + # Record post-balances and compute deltas + post_balances: dict[str, int] = {} + for uid in user_ids: + post_balances[uid] = await points_service.get_balance(uid) point_changes: dict[str, int] = {} for uid in user_ids: @@ -285,13 +303,20 @@ async def settle_race(room: Room) -> tuple[RaceResult, dict[str, float]] | None: # Build human-readable summaries _, point_change_summaries = _build_point_changes(room, odds) - result = RaceResult( - champion_name=room.champion_name, - champion_owner=champion.owner_id, - point_changes=point_changes, - point_change_summaries=point_change_summaries, - ) - return result, odds + result = RaceResult( + race_id=str(uuid4()), + scope=room.scope, + champion_name=room.champion_name, + champion_owner=champion.owner_id, + participants=_build_participants_snapshot(room), + bet_distribution=_build_bet_distribution(room), + duration_ticks=room.tick_count, + completed_at=datetime.now(), + point_changes=point_changes, + point_change_summaries=point_change_summaries, + odds_snapshot=odds, + ) + return result, odds async def run_race_with_settlement(bot: Bot, room: Room, scope: str): @@ -343,12 +368,15 @@ async def run_race_with_settlement(bot: Bot, room: Room, scope: str): if result: result_lines.extend(await _format_point_change_lines(room, result.point_changes, result.point_change_summaries, name_map)) - await message_service.recall_previous_of_type(bot, scope, "race_update") - await _send_to_scope(bot, scope, "\n".join(result_lines), "race_result") - - race_engine.stop_race(scope) - room_store.delete_room(scope) - message_service.clear_pending_recalls(scope) + await message_service.recall_previous_of_type(bot, scope, "race_update") + await _send_to_scope(bot, scope, "\n".join(result_lines), "race_result") + + if result: + await room_store.save_race_result(result) + + race_engine.stop_race(scope) + room_store.delete_room(scope) + message_service.clear_pending_recalls(scope) # Import and re-export access functions from access.py (canonical source) diff --git a/danding_bot/plugins/group_horse_racing/config.py b/danding_bot/plugins/group_horse_racing/config.py index 697d54c..aca28e3 100644 --- a/danding_bot/plugins/group_horse_racing/config.py +++ b/danding_bot/plugins/group_horse_racing/config.py @@ -1,6 +1,7 @@ -from pydantic import Field, field_validator -from pydantic_settings import BaseSettings, SettingsConfigDict -import json +from pydantic import Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict +import json +import os class Config(BaseSettings): @@ -43,8 +44,34 @@ class Config(BaseSettings): } ) - # 数据库配置 - RACE_DB_FILE: str = "data/group_horse_racing/race.db" + # 数据库配置 + RACE_DB_FILE: str = "data/group_horse_racing/race.db" + + # xapi /bot/race 运行时 API 配置 + RACE_API_HOST: str = os.getenv("DANDING_RACE_API_HOST", "https://api.danding.vip/bot/race") + BOT_USER: str = os.getenv("DANDING_BOT_USER", "1424473282") + BOT_TOKEN: str = os.getenv( + "DANDING_BOT_TOKEN", + os.getenv("DANDING_API_TOKEN", os.getenv("BOT_TOKEN", "")), + ) + + @field_validator("RACE_API_HOST") + @classmethod + def validate_race_api_host(cls, value): + """规范化 xapi 赛马运行时 API 地址。""" + + if not value: + raise ValueError("RACE_API_HOST cannot be empty") + return value.rstrip("/") + + @field_validator("BOT_USER") + @classmethod + def validate_bot_user(cls, value): + """Bot 鉴权用户不能为空。""" + + if not value: + raise ValueError("BOT_USER cannot be empty") + return value @field_validator("TESTERS", "TEST_GROUPS", "ALLOWED_GROUPS", mode="before") @classmethod diff --git a/danding_bot/plugins/group_horse_racing/models.py b/danding_bot/plugins/group_horse_racing/models.py index b929fd5..d8125d6 100644 --- a/danding_bot/plugins/group_horse_racing/models.py +++ b/danding_bot/plugins/group_horse_racing/models.py @@ -54,6 +54,7 @@ class RaceResult: race_id: str = "" scope: str = "" participants: list[str] = field(default_factory=list) - bet_distribution: dict[str, int] = field(default_factory=dict) - duration_ticks: int = 0 - completed_at: datetime = field(default_factory=datetime.now) + bet_distribution: dict[str, int] = field(default_factory=dict) + duration_ticks: int = 0 + completed_at: datetime = field(default_factory=datetime.now) + odds_snapshot: dict[str, float] = field(default_factory=dict) diff --git a/danding_bot/plugins/group_horse_racing/room_store.py b/danding_bot/plugins/group_horse_racing/room_store.py index 8a42a23..11928b7 100644 --- a/danding_bot/plugins/group_horse_racing/room_store.py +++ b/danding_bot/plugins/group_horse_racing/room_store.py @@ -1,15 +1,19 @@ -import asyncio -import aiosqlite -import json -from datetime import datetime -from pathlib import Path -from typing import Optional - -from .models import Room, RoomState, RaceResult -from .config import Config - - -class RoomStore: +import asyncio +import aiosqlite +import aiohttp +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Any, Optional + +from .models import Room, RoomState, RaceResult +from .config import Config + +logger = logging.getLogger(__name__) + + +class RoomStore: def __init__(self, config: Config): self.config = config self.rooms: dict[str, Room] = {} @@ -37,46 +41,7 @@ class RoomStore: ) """) - await db.execute(""" - CREATE TABLE IF NOT EXISTS race_history ( - race_id TEXT PRIMARY KEY, - scope TEXT NOT NULL, - champion_name TEXT NOT NULL, - champion_owner TEXT NOT NULL, - participants TEXT NOT NULL, - bet_distribution TEXT NOT NULL, - duration_ticks INTEGER NOT NULL, - completed_at TEXT NOT NULL, - point_changes TEXT DEFAULT '{}', - point_change_summaries TEXT DEFAULT '{}', - odds_snapshot TEXT DEFAULT '{}' - ) - """) - - await db.execute(""" - CREATE TABLE IF NOT EXISTS user_horse_names ( - user_id TEXT PRIMARY KEY, - horse_name TEXT NOT NULL - ) - """) - - # Add missing columns if they don't exist (for existing databases) - try: - await db.execute("SELECT point_changes FROM race_history LIMIT 1") - except aiosqlite.OperationalError: - await db.execute("ALTER TABLE race_history ADD COLUMN point_changes TEXT DEFAULT '{}'") - - try: - await db.execute("SELECT point_change_summaries FROM race_history LIMIT 1") - except aiosqlite.OperationalError: - await db.execute("ALTER TABLE race_history ADD COLUMN point_change_summaries TEXT DEFAULT '{}'") - - try: - await db.execute("SELECT odds_snapshot FROM race_history LIMIT 1") - except aiosqlite.OperationalError: - await db.execute("ALTER TABLE race_history ADD COLUMN odds_snapshot TEXT DEFAULT '{}'") - - await db.commit() + await db.commit() self._initialized = True @@ -95,12 +60,69 @@ class RoomStore: async def close(self): """Close database connection on shutdown.""" - if self._db is not None: - await self._db.close() - self._db = None - - async def load_rooms(self): - """Restore active rooms from DB snapshots on startup.""" + if self._db is not None: + await self._db.close() + self._db = None + + def _url(self, path: str) -> str: + """拼接 /bot/race 端点地址。""" + + return f"{self.config.RACE_API_HOST}/{path.lstrip('/')}" + + def _auth(self) -> dict[str, str]: + """生成 xapi Bot 鉴权参数。""" + + return { + "user": self.config.BOT_USER, + "token": self.config.BOT_TOKEN, + } + + async def _request( + self, + method: str, + path: str, + *, + payload: Optional[dict[str, Any]] = None, + params: Optional[dict[str, Any]] = None, + ) -> Optional[dict[str, Any]]: + """调用 xapi /bot/race,并只向上层暴露 data。""" + + request_url = self._url(path) + timeout = aiohttp.ClientTimeout(total=10) + try: + async with aiohttp.ClientSession() as session: + if method == "GET": + request_params = {**self._auth(), **(params or {})} + async with session.get(request_url, params=request_params, timeout=timeout) as resp: + return await self._parse_response(resp, path) + request_payload = {**self._auth(), **(payload or {})} + if method == "PUT": + async with session.put(request_url, json=request_payload, timeout=timeout) as resp: + return await self._parse_response(resp, path) + async with session.post(request_url, json=request_payload, timeout=timeout) as resp: + return await self._parse_response(resp, path) + except aiohttp.ClientError as exc: + logger.error("race api request failed path=%s error=%s", path, exc) + return None + except asyncio.TimeoutError as exc: + logger.error("race api request timeout path=%s error=%s", path, exc) + return None + + async def _parse_response(self, resp: aiohttp.ClientResponse, path: str) -> Optional[dict[str, Any]]: + """解析 xapi 统一响应,失败时返回 None 维持旧调用方失败语义。""" + + if resp.status != 200: + logger.error("race api bad status path=%s status=%s", path, resp.status) + return None + body = await resp.json() + if body.get("code") != 200: + logger.error("race api fail path=%s code=%s message=%s", path, body.get("code"), body.get("message")) + return None + data = body.get("data") + return data if isinstance(data, dict) else None + + async def load_rooms(self): + """Restore active rooms from DB snapshots on startup.""" await self.ensure_initialized() db = await self._get_db() cursor = await db.execute( @@ -167,24 +189,23 @@ class RoomStore: if scope in self.rooms: del self.rooms[scope] - async def get_last_horse_name(self, user_id: str) -> Optional[str]: - await self.ensure_initialized() - db = await self._get_db() - cursor = await db.execute( - "SELECT horse_name FROM user_horse_names WHERE user_id = ?", - (user_id,) - ) - row = await cursor.fetchone() - return row[0] if row else None - - async def set_last_horse_name(self, user_id: str, horse_name: str): - await self.ensure_initialized() - db = await self._get_db() - await db.execute( - "INSERT OR REPLACE INTO user_horse_names (user_id, horse_name) VALUES (?, ?)", - (user_id, horse_name), - ) - await db.commit() + async def get_last_horse_name(self, user_id: str) -> Optional[str]: + """从 xapi 读取用户最后使用马名。""" + + data = await self._request("GET", "horse-name", params={"user_id": user_id}) + if data is None: + return None + horse_name = data.get("horse_name") + return str(horse_name) if horse_name else None + + async def set_last_horse_name(self, user_id: str, horse_name: str): + """将用户最后使用马名写入 xapi。""" + + await self._request( + "PUT", + "horse-name", + payload={"user_id": user_id, "horse_name": horse_name}, + ) async def _save_snapshot(self, room: Room): """Save room snapshot to database.""" @@ -226,31 +247,28 @@ class RoomStore: )) await db.commit() - async def save_race_result(self, result: RaceResult): - """Save race result to history.""" - await self.ensure_initialized() - - db = await self._get_db() - await db.execute(""" - INSERT INTO race_history - (race_id, scope, champion_name, champion_owner, participants, - bet_distribution, duration_ticks, completed_at, - point_changes, point_change_summaries, odds_snapshot) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - result.race_id, - result.scope, - result.champion_name, - result.champion_owner, - json.dumps(result.participants), - json.dumps(result.bet_distribution), - result.duration_ticks, - result.completed_at.isoformat(), - json.dumps(getattr(result, 'point_changes', {})), - json.dumps(getattr(result, 'point_change_summaries', {})), - json.dumps(getattr(result, 'odds_snapshot', {})), - )) - await db.commit() + async def save_race_result(self, result: RaceResult): + """将完整赛果写入 xapi。""" + + data = await self._request( + "POST", + "history", + payload={ + "race_id": result.race_id, + "scope": result.scope, + "champion_name": result.champion_name, + "champion_owner": result.champion_owner, + "participants": result.participants, + "bet_distribution": result.bet_distribution, + "duration_ticks": result.duration_ticks, + "completed_at": result.completed_at.isoformat(), + "point_changes": result.point_changes, + "point_change_summaries": result.point_change_summaries, + "odds_snapshot": result.odds_snapshot, + }, + ) + if data is None: + raise RuntimeError(f"赛马赛果写入 xapi 失败: race_id={result.race_id}") # Module-level singleton instance diff --git a/danding_bot/plugins/onmyoji_gacha/__init__.py b/danding_bot/plugins/onmyoji_gacha/__init__.py index 7d4cd04..f57b5bb 100644 --- a/danding_bot/plugins/onmyoji_gacha/__init__.py +++ b/danding_bot/plugins/onmyoji_gacha/__init__.py @@ -1,7 +1,7 @@ -import os -import logging -import random -from nonebot import on_command, on_startswith +import os +import logging +import random +from nonebot import get_driver, on_command, on_startswith from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, MessageEvent, Message from nonebot.adapters.onebot.v11.message import MessageSegment from nonebot.typing import T_State @@ -29,12 +29,22 @@ ACHIEVEMENT_COMMANDS = config.ACHIEVEMENT_COMMANDS INTRO_COMMANDS = config.INTRO_COMMANDS DAILY_LIMIT = config.DAILY_LIMIT -gacha_system = GachaSystem() -logger = logging.getLogger(__name__) +gacha_system = GachaSystem() +logger = logging.getLogger(__name__) SIGN_IN_MIN_POINTS = 1 SIGN_IN_MAX_POINTS = 100 SIGN_IN_SOURCE = "gacha_sign" -SIGN_IN_REASON = "抽卡签到" +SIGN_IN_REASON = "抽卡签到" + + +@get_driver().on_startup +async def load_gacha_shikigami_data() -> None: + """启动时从 xapi 拉取式神基础数据缓存。""" + + try: + await gacha_system.data_manager.refresh_shikigami_data() + except Exception: + logger.exception("启动拉取抽卡式神缓存失败") # 检查是否允许使用功能的规则 def check_permission() -> Rule: @@ -52,14 +62,11 @@ def check_permission() -> Rule: return Rule(_checker) -async def try_handle_daily_sign_in(matcher, user_id: str, user_name: str) -> None: - """处理抽卡成功后的每日签到,不影响主流程""" - try: - if gacha_system.data_manager.has_signed_in_today(user_id): - return - - points = random.randint(SIGN_IN_MIN_POINTS, SIGN_IN_MAX_POINTS) - success, new_balance = await points_api.add_points( +async def try_handle_daily_sign_in(matcher, user_id: str, user_name: str) -> None: + """处理抽卡成功后的每日签到,不影响主流程""" + try: + points = random.randint(SIGN_IN_MIN_POINTS, SIGN_IN_MAX_POINTS) + success, new_balance = await points_api.add_points( user_id, points, SIGN_IN_SOURCE, @@ -69,13 +76,23 @@ async def try_handle_daily_sign_in(matcher, user_id: str, user_name: str) -> Non logger.error("抽卡签到积分发放失败 user_id=%s points=%s", user_id, points) return - if not gacha_system.data_manager.record_sign_in(user_id, points): - logger.warning("抽卡签到落库冲突,积分已发放但签到记录重复 user_id=%s", user_id) - return - - await matcher.send(format_sign_in_message(user_id, user_name, points, new_balance)) - except Exception: - logger.exception("处理抽卡签到失败 user_id=%s", user_id) + if not await gacha_system.data_manager.record_sign_in(user_id, points): + logger.warning("抽卡签到落库冲突,积分已发放但签到记录重复 user_id=%s", user_id) + return + + await matcher.send(format_sign_in_message(user_id, user_name, points, new_balance)) + except Exception: + logger.exception("处理抽卡签到失败 user_id=%s", user_id) + + +async def claim_achievement_after_reward(user_id: str, achievement_id: str, reward_success: bool) -> None: + """成就奖励发放成功后标记 xapi reward_claimed。""" + + if not reward_success: + return + claimed = await gacha_system.data_manager.claim_achievement_reward(user_id, achievement_id) + if not claimed: + logger.warning("成就奖励已发放但 claim 标记失败 user_id=%s achievement_id=%s", user_id, achievement_id) # 注册抽卡命令,添加权限检查规则 gacha_matcher = on_command("抽卡", aliases=set(GACHA_COMMANDS), priority=10, rule=check_permission()) @@ -86,7 +103,7 @@ async def handle_gacha(bot: Bot, event: MessageEvent, state: T_State): user_name = event.sender.card if isinstance(event, GroupMessageEvent) else event.sender.nickname # 执行抽卡 - result = gacha_system.draw(user_id) + result = await gacha_system.draw(user_id) if not result["success"]: await gacha_matcher.finish(format_user_mention(user_id, user_name) + " ❌ " + result["message"]) @@ -127,8 +144,9 @@ async def handle_gacha(bot: Bot, event: MessageEvent, state: T_State): has_manual_rewards = False for achievement_id in unlocked_achievements: - # 尝试自动发放成就奖励 - auto_success, reward_msg = await process_achievement_reward(user_id, achievement_id) + # 尝试自动发放成就奖励 + auto_success, reward_msg = await process_achievement_reward(user_id, achievement_id) + await claim_achievement_after_reward(user_id, achievement_id, auto_success) # 检查是否是重复奖励 if "_repeat_" in achievement_id: @@ -161,7 +179,7 @@ async def handle_gacha(bot: Bot, event: MessageEvent, state: T_State): msg.append("💰 未自动发放的奖励请联系管理员\n") # 添加成就进度提示 - achievement_data = gacha_system.get_user_achievements(user_id) + achievement_data = await gacha_system.get_user_achievements(user_id) if achievement_data["success"]: progress = achievement_data["progress"] consecutive_days = progress.get("consecutive_days", 0) @@ -226,13 +244,13 @@ async def handle_gacha(bot: Bot, event: MessageEvent, state: T_State): await try_handle_daily_sign_in(gacha_matcher, user_id, user_name) return -async def notify_admin(bot: Bot, message: str): - """通知管理员""" - admin_id = 2185330092 - try: - await bot.send_private_msg(user_id=admin_id, message=message) - except Exception as e: - pass # 忽略通知失败的错误 +async def notify_admin(bot: Bot, message: str): + """通知管理员""" + admin_id = 2185330092 + try: + await bot.send_private_msg(user_id=admin_id, message=message) + except Exception as e: + logger.debug("通知管理员失败: %s", e) # 注册查询命令,添加权限检查规则 stats_matcher = on_command("我的抽卡", aliases=set(STATS_COMMANDS), priority=5, rule=check_permission()) @@ -252,7 +270,7 @@ async def handle_stats(bot: Bot, event: MessageEvent, state: T_State): user_name = event.sender.card if isinstance(event, GroupMessageEvent) else event.sender.nickname # 获取用户统计 - stats = gacha_system.get_user_stats(user_id) + stats = await gacha_system.get_user_stats(user_id) if not stats["success"]: await stats_matcher.finish(format_user_mention(user_id, user_name) + " " + stats["message"]) @@ -294,7 +312,7 @@ async def handle_triple_gacha(bot: Bot, event: MessageEvent, state: T_State): user_name = event.sender.card or event.sender.nickname or "未知用户" # 执行三连抽 - result = gacha_system.triple_draw(user_id) + result = await gacha_system.triple_draw(user_id) if not result["success"]: await triple_gacha_matcher.finish(f"❌ {result['message']}") @@ -338,8 +356,9 @@ async def handle_triple_gacha(bot: Bot, event: MessageEvent, state: T_State): has_manual_rewards = False for achievement_id in unlocked_achievements: - # 尝试自动发放成就奖励 - auto_success, reward_msg = await process_achievement_reward(user_id, achievement_id) + # 尝试自动发放成就奖励 + auto_success, reward_msg = await process_achievement_reward(user_id, achievement_id) + await claim_achievement_after_reward(user_id, achievement_id, auto_success) # 检查是否是重复奖励 if "_repeat_" in achievement_id: @@ -372,7 +391,7 @@ async def handle_triple_gacha(bot: Bot, event: MessageEvent, state: T_State): msg.append("💰 未自动发放的奖励请联系管理员\n") # 添加成就进度提示 - achievement_data = gacha_system.get_user_achievements(user_id) + achievement_data = await gacha_system.get_user_achievements(user_id) if achievement_data["success"]: progress = achievement_data["progress"] consecutive_days = progress.get("consecutive_days", 0) @@ -442,7 +461,7 @@ async def handle_achievement(bot: Bot, event: MessageEvent, state: T_State): user_name = event.sender.card or event.sender.nickname or "未知用户" # 获取用户成就信息 - result = gacha_system.get_user_achievements(user_id) + result = await gacha_system.get_user_achievements(user_id) if not result["success"]: await achievement_matcher.finish(f"❌ {result['message']}") @@ -561,7 +580,7 @@ async def handle_query(bot: Bot, event: MessageEvent, state: T_State): target_user_name = event.sender.card if isinstance(event, GroupMessageEvent) else event.sender.nickname # 获取用户统计 - stats = gacha_system.get_user_stats(target_user_id) + stats = await gacha_system.get_user_stats(target_user_id) # 构建响应消息 msg = Message() @@ -623,7 +642,7 @@ rank_matcher = on_startswith(("抽卡排行","抽卡榜"), priority=1, rule=chec @rank_matcher.handle() async def handle_rank(bot: Bot, event: MessageEvent, state: T_State): # 获取排行榜数据 - rank_data = gacha_system.get_rank_list() + rank_data = await gacha_system.get_rank_list() if not rank_data: await rank_matcher.finish("暂无抽卡排行榜数据") @@ -668,7 +687,7 @@ async def handle_rank(bot: Bot, event: MessageEvent, state: T_State): @daily_stats_matcher.handle() async def handle_daily_stats(bot: Bot, event: MessageEvent, state: T_State): """处理今日抽卡统计命令""" - result = gacha_system.get_daily_stats() + result = await gacha_system.get_daily_stats() if not result["success"]: await daily_stats_matcher.finish(f"❌ {result['message']}") diff --git a/danding_bot/plugins/onmyoji_gacha/config.py b/danding_bot/plugins/onmyoji_gacha/config.py index da487f0..a01fade 100644 --- a/danding_bot/plugins/onmyoji_gacha/config.py +++ b/danding_bot/plugins/onmyoji_gacha/config.py @@ -1,7 +1,7 @@ -from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import model_validator -import os -import logging +from pydantic import field_validator, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict +import os +import logging logger = logging.getLogger("onmyoji_gacha") @@ -115,13 +115,24 @@ class Config(BaseSettings): WEB_ADMIN_TOKEN: str = os.getenv("WEB_ADMIN_TOKEN", "onmyoji_admin_token_2024") WEB_ADMIN_PORT: int = int(os.getenv("WEB_ADMIN_PORT", "8080")) - # 蛋定服务器对接配置 - DD_API_HOST: str = "https://api.danding.vip/DD/" - BOT_TOKEN: str = os.getenv("ONMYOJI_BOT_TOKEN", os.getenv("BOT_TOKEN", "")) # 必须设置 - BOT_USER_ID: str = "1424473282" + # 蛋定服务器对接配置 + DD_API_HOST: str = "https://api.danding.vip/DD/" + GACHA_API_HOST: str = os.getenv("DANDING_GACHA_API_HOST", "https://api.danding.vip/bot/gacha") + BOT_TOKEN: str = os.getenv( + "DANDING_BOT_TOKEN", + os.getenv("ONMYOJI_BOT_TOKEN", os.getenv("DANDING_API_TOKEN", os.getenv("BOT_TOKEN", ""))), + ) + BOT_USER_ID: str = os.getenv("DANDING_BOT_USER", "1424473282") # 时区 - TIMEZONE: str = "Asia/Shanghai" + TIMEZONE: str = "Asia/Shanghai" + + @field_validator("GACHA_API_HOST") + @classmethod + def validate_gacha_api_host(cls, value): + if not value: + raise ValueError("GACHA_API_HOST cannot be empty") + return value.rstrip("/") @model_validator(mode="after") def _warn_default_token(self): diff --git a/danding_bot/plugins/onmyoji_gacha/data_manager.py b/danding_bot/plugins/onmyoji_gacha/data_manager.py index cfea234..8ddda01 100644 --- a/danding_bot/plugins/onmyoji_gacha/data_manager.py +++ b/danding_bot/plugins/onmyoji_gacha/data_manager.py @@ -1,615 +1,291 @@ -""" -阴阳师抽卡插件 - 数据管理模块 - -管理抽卡数据持久化,包括: -- SQLite数据库操作 -- 用户抽卡记录管理 -- 每日签到记录 -- 统计查询 - -TODO(代码评审 2026-05-03): 本模块承担了数据文件IO + 缓存 + 业务规则三重职责, -后续应拆分为: data_io(纯文件读写) / data_cache(内存缓存层) / data_rules(业务规则校验)。 -当前拆分风险较大(影响面广),暂维持现状。 - -TODO(第二轮评审 2026-05-03): 补充建议拆分方案: -- achievement_manager.py: 成就定义加载 + 进度计算 + 奖励发放 (~150行) -- record_manager.py: 记录归档 + 统计查询 + 每日数据 (~100行) -- data_manager.py: 核心用户数据IO + 缓存管理 (~359行) -拆分为独立PR,不阻塞当前修复。 -""" - -import os -import json -import sqlite3 -import datetime -from typing import Dict, List, Any, Optional -import logging -from pathlib import Path - -from .config import Config - -# 创建Config实例 -config = Config() - -class DataManager: - """抽卡数据管理器,封装所有数据库操作""" - def __init__(self): - # 确保目录存在 - os.makedirs(os.path.dirname(config.DB_FILE), exist_ok=True) - - # 初始化数据库 - self._init_db() - - # 加载式神数据 - self.shikigami_data = self._load_shikigami_data() - - def _init_db(self): - """初始化数据库""" - with sqlite3.connect(config.DB_FILE) as conn: - cursor = conn.cursor() - - # 创建式神表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS shikigami ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - rarity TEXT NOT NULL, - image_path TEXT NOT NULL - ) - """) - - # 创建每日抽卡记录表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS daily_draws ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - date TEXT NOT NULL, - user_id TEXT NOT NULL, - rarity TEXT NOT NULL, - shikigami_id INTEGER NOT NULL, - timestamp TEXT NOT NULL, - FOREIGN KEY (shikigami_id) REFERENCES shikigami(id) - ) - """) - - # 创建用户统计表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS user_stats ( - user_id TEXT PRIMARY KEY, - total_draws INTEGER DEFAULT 0, - R_count INTEGER DEFAULT 0, - SR_count INTEGER DEFAULT 0, - SSR_count INTEGER DEFAULT 0, - SP_count INTEGER DEFAULT 0 - ) - """) - - # 创建抽卡历史表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS draw_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - date TEXT NOT NULL, - rarity TEXT NOT NULL, - shikigami_id INTEGER NOT NULL, - FOREIGN KEY (user_id) REFERENCES user_stats(user_id), - FOREIGN KEY (shikigami_id) REFERENCES shikigami(id) - ) - """) - - # 创建成就表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS achievements ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - achievement_id TEXT NOT NULL, - unlocked_date TEXT NOT NULL, - reward_claimed INTEGER DEFAULT 0, - UNIQUE(user_id, achievement_id) - ) - """) - - # 创建用户成就进度表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS user_achievement_progress ( - user_id TEXT PRIMARY KEY, - consecutive_days INTEGER DEFAULT 0, - last_draw_date TEXT DEFAULT '', - no_ssr_streak INTEGER DEFAULT 0, - total_consecutive_days INTEGER DEFAULT 0 - ) - """) - - self._init_sign_in_table(cursor) - - conn.commit() - - def _init_sign_in_table(self, cursor: sqlite3.Cursor) -> None: # OK - """创建每日签到表""" - cursor.execute(""" - CREATE TABLE IF NOT EXISTS daily_sign_in ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - sign_date TEXT NOT NULL, - points_awarded INTEGER NOT NULL, - created_at TEXT NOT NULL, - UNIQUE(user_id, sign_date) - ) - """) - - def update_achievement_progress(self, user_id: str, rarity: str) -> List[str]: # type: ignore[return] - """更新用户成就进度,返回新解锁的成就列表""" - today = self.get_today_date() - unlocked_achievements = [] - - with sqlite3.connect(config.DB_FILE) as conn: - cursor = conn.cursor() - - # 获取或创建用户成就进度 - cursor.execute( - "SELECT * FROM user_achievement_progress WHERE user_id = ?", - (user_id,) - ) - progress = cursor.fetchone() - - if not progress: - cursor.execute( - "INSERT INTO user_achievement_progress (user_id, last_draw_date) VALUES (?, ?)", - (user_id, today) - ) - consecutive_days = 1 - no_ssr_streak = 1 if rarity not in ["SSR", "SP"] else 0 - total_consecutive_days = 1 - else: - last_draw_date = progress[2] - consecutive_days = progress[1] - no_ssr_streak = progress[3] - total_consecutive_days = progress[4] - - # 更新连续抽卡天数 - if last_draw_date != today: - # 检查是否是连续的一天 - last_date = datetime.datetime.strptime(last_draw_date, "%Y-%m-%d") - current_date = datetime.datetime.strptime(today, "%Y-%m-%d") - days_diff = (current_date - last_date).days - - if days_diff == 1: - consecutive_days += 1 - total_consecutive_days += 1 - elif days_diff > 1: - consecutive_days = 1 - total_consecutive_days += 1 - # days_diff == 0 表示今天已经抽过卡了,不更新连续天数 - - # 更新无SSR连击数 - if rarity in ["SSR", "SP"]: - no_ssr_streak = 0 - else: - no_ssr_streak += 1 - - # 更新进度 - cursor.execute(""" - INSERT OR REPLACE INTO user_achievement_progress - (user_id, consecutive_days, last_draw_date, no_ssr_streak, total_consecutive_days) - VALUES (?, ?, ?, ?, ?) - """, (user_id, consecutive_days, today, no_ssr_streak, total_consecutive_days)) - - # 检查是否解锁新成就 - for achievement_id, achievement_config in config.ACHIEVEMENTS.items(): - # 对于可重复获得的成就(勤勤恳恳系列),需要特殊处理 - if achievement_config.get("repeatable", False) and achievement_config["type"] == "consecutive_days": - # 检查连续抽卡成就的升级逻辑 - if consecutive_days >= achievement_config["threshold"]: - # 检查是否已经解锁过这个等级 - cursor.execute( - "SELECT id FROM achievements WHERE user_id = ? AND achievement_id = ?", - (user_id, achievement_id) - ) - if not cursor.fetchone(): - # 解锁新等级的成就 - cursor.execute(""" - INSERT INTO achievements (user_id, achievement_id, unlocked_date) - VALUES (?, ?, ?) - """, (user_id, achievement_id, today)) - unlocked_achievements.append(achievement_id) - - # 如果是最高等级(Ⅴ),检查是否需要给重复奖励 - elif achievement_config["level"] == 5 and consecutive_days >= 150: - # 每30天给一次重复奖励 - days_over_150 = consecutive_days - 150 - if days_over_150 > 0 and days_over_150 % 30 == 0: - # 检查这个重复奖励是否已经给过 - repeat_id = f"{achievement_id}_repeat_{days_over_150//30}" - cursor.execute( - "SELECT id FROM achievements WHERE user_id = ? AND achievement_id = ?", - (user_id, repeat_id) - ) - if not cursor.fetchone(): - cursor.execute(""" - INSERT INTO achievements (user_id, achievement_id, unlocked_date) - VALUES (?, ?, ?) - """, (user_id, repeat_id, today)) - unlocked_achievements.append(achievement_id) - else: - # 非重复成就的原有逻辑 - # 检查是否已经解锁 - cursor.execute( - "SELECT id FROM achievements WHERE user_id = ? AND achievement_id = ?", - (user_id, achievement_id) - ) - if cursor.fetchone(): - continue - - # 检查成就条件 - unlocked = False - if achievement_config["type"] == "consecutive_days": - if consecutive_days >= achievement_config["threshold"]: - unlocked = True - elif achievement_config["type"] == "no_ssr_streak": - if no_ssr_streak >= achievement_config["threshold"]: - unlocked = True - - if unlocked: - cursor.execute(""" - INSERT INTO achievements (user_id, achievement_id, unlocked_date) - VALUES (?, ?, ?) - """, (user_id, achievement_id, today)) - unlocked_achievements.append(achievement_id) - - conn.commit() - - return unlocked_achievements - - def get_user_achievements(self, user_id: str) -> Dict[str, Any]: - """获取用户成就信息""" - with sqlite3.connect(config.DB_FILE) as conn: - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - - # 获取已解锁的成就 - cursor.execute( - "SELECT achievement_id, unlocked_date, reward_claimed FROM achievements WHERE user_id = ?", - (user_id,) - ) - unlocked = {row["achievement_id"]: { - "unlocked_date": row["unlocked_date"], - "reward_claimed": bool(row["reward_claimed"]) - } for row in cursor.fetchall()} - - # 获取进度 - cursor.execute( - "SELECT * FROM user_achievement_progress WHERE user_id = ?", - (user_id,) - ) - progress_row = cursor.fetchone() - - if not progress_row: - progress = { - "consecutive_days": 0, - "no_ssr_streak": 0, - "total_consecutive_days": 0 - } - else: - progress = { - "consecutive_days": progress_row["consecutive_days"], - "no_ssr_streak": progress_row["no_ssr_streak"], - "total_consecutive_days": progress_row["total_consecutive_days"] - } - - return { - "unlocked": unlocked, - "progress": progress - } - - def claim_achievement_reward(self, user_id: str, achievement_id: str) -> bool: - """领取成就奖励""" - with sqlite3.connect(config.DB_FILE) as conn: - cursor = conn.cursor() - - cursor.execute(""" - UPDATE achievements - SET reward_claimed = 1 - WHERE user_id = ? AND achievement_id = ? AND reward_claimed = 0 - """, (user_id, achievement_id)) - - conn.commit() - return cursor.rowcount > 0 - - - def _load_shikigami_data(self) -> Dict[str, List[Dict[str, str]]]: - """加载式神数据到数据库""" - result = {"R": [], "SR": [], "SSR": [], "SP": []} - rarity_dirs = { - "R": "r", - "SR": "sr", - "SSR": "ssr", - "SP": "sp" - } - - with sqlite3.connect(config.DB_FILE) as conn: - cursor = conn.cursor() - - # 清空现有式神数据 - cursor.execute("DELETE FROM shikigami") - - for rarity, dir_name in rarity_dirs.items(): - dir_path = os.path.join(config.SHIKIGAMI_IMG_DIR, dir_name) - if os.path.exists(dir_path): - for file_name in os.listdir(dir_path): - if file_name.endswith(('.png', '.jpg', '.jpeg')): - name = os.path.splitext(file_name)[0] - image_path = os.path.join(dir_path, file_name) - - # 插入式神数据 - cursor.execute( - "INSERT INTO shikigami (name, rarity, image_path) VALUES (?, ?, ?)", - (name, rarity, image_path) - ) - - result[rarity].append({ - "name": name, - "image_url": image_path - }) - - conn.commit() - - return result - - def get_today_date(self) -> str: - """获取当前日期字符串""" - return datetime.datetime.now().strftime("%Y-%m-%d") - - def has_signed_in_today(self, user_id: str) -> bool: - """检查用户今天是否已签到""" - today = self.get_today_date() - with sqlite3.connect(config.DB_FILE) as conn: - cursor = conn.cursor() - cursor.execute( - "SELECT 1 FROM daily_sign_in WHERE user_id = ? AND sign_date = ? LIMIT 1", - (user_id, today), - ) - return cursor.fetchone() is not None - - def record_sign_in(self, user_id: str, points_awarded: int) -> bool: - """记录每日签到,重复签到返回False""" - today = self.get_today_date() - created_at = datetime.datetime.now().isoformat() - try: - with sqlite3.connect(config.DB_FILE) as conn: - cursor = conn.cursor() - cursor.execute(""" - INSERT INTO daily_sign_in (user_id, sign_date, points_awarded, created_at) - VALUES (?, ?, ?, ?) - """, (user_id, today, points_awarded, created_at)) - conn.commit() - return True - except sqlite3.IntegrityError: - return False - - def get_current_time(self) -> str: - """获取当前时间字符串""" - return datetime.datetime.now().strftime("%H:%M:%S") - - def get_daily_draws(self) -> Dict[str, Dict[str, List[Dict[str, str]]]]: - """获取每日抽卡记录""" - result = {} - today = self.get_today_date() - - with sqlite3.connect(config.DB_FILE) as conn: - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - - # 先查询今日的抽卡记录 - cursor.execute(""" - SELECT date, user_id, rarity, shikigami_id, timestamp - FROM daily_draws - WHERE date = ? - ORDER BY timestamp - """, (today,)) - - rows = cursor.fetchall() - - # 获取所有涉及的式神ID - shikigami_ids = list(set(row["shikigami_id"] for row in rows)) - - # 查询式神信息 - shikigami_info = {} - if shikigami_ids: - placeholders = ','.join('?' * len(shikigami_ids)) - cursor.execute(f""" - SELECT id, name, rarity - FROM shikigami - WHERE id IN ({placeholders}) - """, shikigami_ids) - - for shikigami_row in cursor.fetchall(): - shikigami_info[shikigami_row["id"]] = { - "name": shikigami_row["name"], - "rarity": shikigami_row["rarity"] - } - - # 构建结果 - for row in rows: - date = row["date"] - user_id = row["user_id"] - shikigami_id = row["shikigami_id"] - - if date not in result: - result[date] = {} - - if user_id not in result[date]: - result[date][user_id] = [] - - # 如果找不到式神信息,使用daily_draws表中的稀有度和默认名称 - if shikigami_id in shikigami_info: - name = shikigami_info[shikigami_id]["name"] - rarity = shikigami_info[shikigami_id]["rarity"] - else: - name = f"式神{shikigami_id}" - rarity = row["rarity"] - - result[date][user_id].append({ - "rarity": rarity, - "name": name, - "timestamp": row["timestamp"] - }) - - return result - - def save_daily_draws(self, data: Dict[str, Dict[str, List[Dict[str, str]]]]): - """保存每日抽卡记录""" - # SQLite实现中此方法为空,因为记录时直接插入数据库 - pass - - def get_user_stats(self) -> Dict[str, Dict[str, Any]]: - """获取用户统计数据""" - result = {} - - with sqlite3.connect(config.DB_FILE) as conn: - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - - # 获取基础统计 - cursor.execute("SELECT * FROM user_stats") - user_stats = cursor.fetchall() - - for stat in user_stats: - user_id = stat["user_id"] - result[user_id] = { - "total_draws": stat["total_draws"], - "R_count": stat["R_count"], - "SR_count": stat["SR_count"], - "SSR_count": stat["SSR_count"], - "SP_count": stat["SP_count"], - "draw_history": [] - } - - # 获取抽卡历史 - cursor.execute(""" - SELECT draw_history.date, draw_history.rarity, shikigami.name - FROM draw_history - JOIN shikigami ON draw_history.shikigami_id = shikigami.id - WHERE draw_history.user_id = ? - ORDER BY draw_history.date DESC - LIMIT 100 - """, (user_id,)) - - history = cursor.fetchall() - result[user_id]["draw_history"] = [ - { - "date": row["date"], - "rarity": row["rarity"], - "name": row["name"] - } for row in history - ] - - return result - - def save_user_stats(self, data: Dict[str, Dict[str, Any]]): - """保存用户统计数据""" - # SQLite实现中此方法为空,因为统计时直接更新数据库 - pass - - def check_daily_limit(self, user_id: str) -> bool: - """检查用户是否达到每日抽卡限制""" - today = self.get_today_date() - - with sqlite3.connect(config.DB_FILE) as conn: - cursor = conn.cursor() - - cursor.execute(""" - SELECT COUNT(*) - FROM daily_draws - WHERE date = ? AND user_id = ? - """, (today, user_id)) - - count = cursor.fetchone()[0] - - return count < config.DAILY_LIMIT - - def get_draws_left(self, user_id: str) -> int: - """获取用户今日剩余抽卡次数""" - today = self.get_today_date() - - with sqlite3.connect(config.DB_FILE) as conn: - cursor = conn.cursor() - - cursor.execute(""" - SELECT COUNT(*) - FROM daily_draws - WHERE date = ? AND user_id = ? - """, (today, user_id)) - - count = cursor.fetchone()[0] - - return max(0, config.DAILY_LIMIT - count) - - def record_draw(self, user_id: str, rarity: str, shikigami_name: str) -> List[str]: - """记录一次抽卡,返回新解锁的成就列表""" - today = self.get_today_date() - current_time = self.get_current_time() - - with sqlite3.connect(config.DB_FILE) as conn: - cursor = conn.cursor() - - # 获取式神ID - cursor.execute( - "SELECT id FROM shikigami WHERE name = ? AND rarity = ?", - (shikigami_name, rarity) - ) - shikigami_id = cursor.fetchone() - - if not shikigami_id: - logging.error(f"找不到式神: {shikigami_name} ({rarity})") - return [] - - shikigami_id = shikigami_id[0] - - # 记录每日抽卡 - cursor.execute(""" - INSERT INTO daily_draws (date, user_id, rarity, shikigami_id, timestamp) - VALUES (?, ?, ?, ?, ?) - """, (today, user_id, rarity, shikigami_id, current_time)) - - # 更新用户统计 - cursor.execute(""" - INSERT OR IGNORE INTO user_stats (user_id) VALUES (?) - """, (user_id,)) - - cursor.execute(""" - UPDATE user_stats - SET total_draws = total_draws + 1, - R_count = R_count + ?, - SR_count = SR_count + ?, - SSR_count = SSR_count + ?, - SP_count = SP_count + ? - WHERE user_id = ? - """, ( - 1 if rarity == "R" else 0, - 1 if rarity == "SR" else 0, - 1 if rarity == "SSR" else 0, - 1 if rarity == "SP" else 0, - user_id - )) - - # 添加抽卡历史 - cursor.execute(""" - INSERT INTO draw_history (user_id, date, rarity, shikigami_id) - VALUES (?, ?, ?, ?) - """, (user_id, today, rarity, shikigami_id)) - - # 保持历史记录不超过100条 - cursor.execute(""" - DELETE FROM draw_history - WHERE user_id = ? AND id NOT IN ( - SELECT id FROM draw_history - WHERE user_id = ? - ORDER BY date DESC - LIMIT 100 - ) - """, (user_id, user_id)) - - conn.commit() - - # 更新成就进度 - unlocked_achievements = self.update_achievement_progress(user_id, rarity) - return unlocked_achievements \ No newline at end of file +""" +阴阳师抽卡插件 - xapi 数据管理模块。 + +本模块只负责调用 xapi /bot/gacha 运行时 API。抽卡概率、奖励发放和 QQ 消息编排 +仍由 nonebot 插件本地负责。 +""" + +from __future__ import annotations + +import asyncio +import datetime +import logging +from typing import Any, Dict, List, Optional + +import aiohttp + +from .config import Config + +logger = logging.getLogger(__name__) +config = Config() + + +class DataManager: + """抽卡数据管理器,封装 /bot/gacha HTTP 调用。""" + + def __init__(self): + self.shikigami_data: Dict[str, List[Dict[str, Any]]] = {"R": [], "SR": [], "SSR": [], "SP": []} + + def _url(self, path: str) -> str: + """拼接 /bot/gacha 端点地址。""" + + return f"{config.GACHA_API_HOST}/{path.lstrip('/')}" + + def _auth(self) -> Dict[str, str]: + """生成 xapi Bot 鉴权参数。""" + + return { + "user": config.BOT_USER_ID, + "token": config.BOT_TOKEN, + } + + async def _request( + self, + method: str, + path: str, + *, + payload: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: + """调用 xapi /bot/gacha,并只向上层暴露 data。""" + + request_url = self._url(path) + timeout = aiohttp.ClientTimeout(total=10) + try: + async with aiohttp.ClientSession() as session: + if method == "GET": + request_params = {**self._auth(), **(params or {})} + async with session.get(request_url, params=request_params, timeout=timeout) as resp: + return await self._parse_response(resp, path) + request_payload = {**self._auth(), **(payload or {})} + async with session.post(request_url, json=request_payload, timeout=timeout) as resp: + return await self._parse_response(resp, path) + except aiohttp.ClientError as exc: + logger.error("gacha api request failed path=%s error=%s", path, exc) + return None + except asyncio.TimeoutError as exc: + logger.error("gacha api request timeout path=%s error=%s", path, exc) + return None + + async def _parse_response(self, resp: aiohttp.ClientResponse, path: str) -> Optional[Dict[str, Any]]: + """解析 xapi 统一响应,失败时返回 None 维持旧调用方失败语义。""" + + if resp.status != 200: + logger.error("gacha api bad status path=%s status=%s", path, resp.status) + return None + body = await resp.json() + if body.get("code") != 200: + logger.error("gacha api fail path=%s code=%s message=%s", path, body.get("code"), body.get("message")) + return None + data = body.get("data") + return data if isinstance(data, dict) else None + + async def refresh_shikigami_data(self) -> Dict[str, List[Dict[str, Any]]]: + """从 xapi 拉取式神基础数据并按稀有度缓存。""" + + data = await self._request("GET", "shikigami") + items = data.get("items", []) if data else [] + grouped: Dict[str, List[Dict[str, Any]]] = {"R": [], "SR": [], "SSR": [], "SP": []} + for item in items: + rarity = item.get("rarity") + if rarity not in grouped: + continue + image_path = item.get("image_path") or item.get("image_url") or "" + grouped[rarity].append( + { + "id": item.get("id"), + "name": item.get("name"), + "rarity": rarity, + "image_path": image_path, + "image_url": image_path, + } + ) + self.shikigami_data = grouped + return self.shikigami_data + + async def ensure_shikigami_data(self) -> Dict[str, List[Dict[str, Any]]]: + """确保式神缓存已加载。""" + + if not any(self.shikigami_data.values()): + await self.refresh_shikigami_data() + return self.shikigami_data + + def get_today_date(self) -> str: + """获取当前日期字符串。""" + + return datetime.datetime.now().strftime("%Y-%m-%d") + + def get_current_time(self) -> str: + """获取当前时间字符串。""" + + return datetime.datetime.now().strftime("%H:%M:%S") + + def _find_shikigami(self, rarity: str, shikigami_name: str) -> Optional[Dict[str, Any]]: + """从本地缓存查找 xapi 托管式神。""" + + for item in self.shikigami_data.get(rarity, []): + if item.get("name") == shikigami_name: + return item + return None + + async def get_draws_left(self, user_id: str) -> int: + """获取用户今日剩余抽卡次数。""" + + data = await self._request("GET", "draws-left", params={"user_id": user_id}) + if data is None: + return 0 + return int(data.get("draws_left", 0) or 0) + + async def check_daily_limit(self, user_id: str) -> bool: + """检查用户是否还有抽卡次数。""" + + return await self.get_draws_left(user_id) > 0 + + async def record_draw_result(self, user_id: str, rarity: str, shikigami: Dict[str, Any]) -> Dict[str, Any]: + """写入一次抽卡并返回 xapi 原始业务结果。""" + + data = await self._request( + "POST", + "draw", + payload={ + "user_id": user_id, + "shikigami_id": int(shikigami["id"]), + "rarity": rarity, + "name": shikigami["name"], + }, + ) + if data is None: + return {"success": False, "message": "抽卡记录写入失败"} + return data + + async def record_triple_draw_result(self, user_id: str, draws: List[Dict[str, Any]]) -> Dict[str, Any]: + """写入三连抽并返回 xapi 原始业务结果。""" + + payload_draws = [ + { + "shikigami_id": int(item["id"]), + "rarity": item["rarity"], + "name": item["name"], + } + for item in draws + ] + data = await self._request("POST", "draw/triple", payload={"user_id": user_id, "draws": payload_draws}) + if data is None: + return {"success": False, "message": "三连抽记录写入失败"} + return data + + async def record_draw(self, user_id: str, rarity: str, shikigami_name: str) -> List[str]: + """记录一次抽卡,返回新解锁的成就列表。""" + + await self.ensure_shikigami_data() + shikigami = self._find_shikigami(rarity, shikigami_name) + if not shikigami: + logger.error("找不到式神: %s (%s)", shikigami_name, rarity) + return [] + result = await self.record_draw_result(user_id, rarity, shikigami) + if not result.get("success"): + logger.error("抽卡记录写入失败 user_id=%s message=%s", user_id, result.get("message")) + return [] + return result.get("unlocked_achievements", []) + + async def record_sign_in(self, user_id: str, points_awarded: int) -> bool: + """记录每日签到,重复签到返回 False。""" + + data = await self._request( + "POST", + "sign-in", + payload={"user_id": user_id, "points_awarded": points_awarded}, + ) + if data is None: + return False + return bool(data.get("success")) and not bool(data.get("signed_already")) + + async def get_user_stats(self, user_id: str) -> Dict[str, Any]: + """获取用户抽卡统计。""" + + data = await self._request("GET", "user-stats", params={"user_id": user_id}) + return data or {"success": False, "message": "您还没有抽卡记录哦!"} + + async def get_daily_stats(self, date: Optional[str] = None) -> Dict[str, Any]: + """获取指定日期抽卡统计。""" + + params = {"date": date} if date else {} + data = await self._request("GET", "daily-stats", params=params) + return data or {"success": False, "message": "今日还没有人抽卡哦!"} + + async def get_rank(self, limit: int = 10) -> List[Dict[str, Any]]: + """获取抽卡排行榜。""" + + data = await self._request("GET", "rank", params={"limit": max(1, min(100, limit))}) + if data is None: + return [] + items = data.get("items", []) + return items if isinstance(items, list) else [] + + async def get_user_achievements(self, user_id: str) -> Dict[str, Any]: + """获取用户成就信息。""" + + data = await self._request("GET", f"achievements/{user_id}") + if data is None: + return { + "unlocked": {}, + "progress": { + "consecutive_days": 0, + "no_ssr_streak": 0, + "total_consecutive_days": 0, + }, + } + return { + "unlocked": data.get("achievements", {}), + "progress": data.get("progress", {}), + } + + async def claim_achievement_reward(self, user_id: str, achievement_id: str) -> bool: + """标记成就奖励已领取。""" + + data = await self._request("POST", f"achievements/{user_id}/claim", payload={"achievement_id": achievement_id}) + return bool(data and data.get("success")) + + async def get_daily_records(self, date: Optional[str] = None) -> Dict[str, Any]: + """获取每日详细抽卡记录。""" + + params = {"date": date} if date else {} + data = await self._request("GET", "records/daily", params=params) + return data or {"success": False, "date": date or self.get_today_date(), "records": [], "total_count": 0} + + async def get_daily_draws(self, date: Optional[str] = None) -> Dict[str, Dict[str, List[Dict[str, str]]]]: + """按旧结构返回每日抽卡记录。""" + + data = await self.get_daily_records(date) + result: Dict[str, Dict[str, List[Dict[str, str]]]] = {} + if not data.get("success"): + return result + target_date = data.get("date") or date or self.get_today_date() + result[target_date] = {} + for record in data.get("records", []): + user_id = record.get("user_id") + if not user_id: + continue + result[target_date].setdefault(user_id, []).append( + { + "rarity": record.get("rarity", ""), + "name": record.get("shikigami_name", ""), + "timestamp": record.get("draw_time", ""), + } + ) + return result + + async def has_signed_in_today(self, user_id: str) -> bool: + """保留旧方法名;当前无独立查询端点,签到去重由 xapi sign-in 写接口处理。""" + + return False + + def save_daily_draws(self, data: Dict[str, Dict[str, List[Dict[str, str]]]]) -> None: + """兼容旧空方法,运行时不写本地文件。""" + + return None + + def save_user_stats(self, data: Dict[str, Dict[str, Any]]) -> None: + """兼容旧空方法,运行时不写本地文件。""" + + return None diff --git a/danding_bot/plugins/onmyoji_gacha/gacha.py b/danding_bot/plugins/onmyoji_gacha/gacha.py index 912fa96..a99d709 100644 --- a/danding_bot/plugins/onmyoji_gacha/gacha.py +++ b/danding_bot/plugins/onmyoji_gacha/gacha.py @@ -24,44 +24,47 @@ class GachaSystem: def __init__(self): self.data_manager = data_manager - def draw(self, user_id: str) -> Dict[str, Any]: - """执行一次抽卡""" - # 检查抽卡限制 - if not self.data_manager.check_daily_limit(user_id): - draws_left = self.data_manager.get_draws_left(user_id) - return { - "success": False, - "message": f"您今日的抽卡次数已用完,每日限制{config.DAILY_LIMIT}次,明天再来吧!" - } + async def draw(self, user_id: str) -> Dict[str, Any]: + """执行一次抽卡""" + # 检查抽卡限制 + if not await self.data_manager.check_daily_limit(user_id): + draws_left = await self.data_manager.get_draws_left(user_id) + return { + "success": False, + "message": f"您今日的抽卡次数已用完,每日限制{config.DAILY_LIMIT}次,明天再来吧!" + } # 抽取稀有度(传递用户ID) rarity = self._draw_rarity(user_id) # 从该稀有度中抽取式神 - shikigami_data = self.data_manager.shikigami_data.get(rarity, []) - if not shikigami_data: - return { - "success": False, - "message": f"系统错误:{rarity}稀有度下没有可用式神" - } + await self.data_manager.ensure_shikigami_data() + shikigami_data = self.data_manager.shikigami_data.get(rarity, []) + if not shikigami_data: + return { + "success": False, + "message": f"系统错误:{rarity}稀有度下没有可用式神" + } # 随机选择式神 shikigami = random.choice(shikigami_data) - # 记录抽卡 - unlocked_achievements = self.data_manager.record_draw(user_id, rarity, shikigami["name"]) - - # 剩余次数 - draws_left = self.data_manager.get_draws_left(user_id) + # xapi 只负责写入抽卡数据;奖励副作用仍由 nonebot handler 编排。 + record_result = await self.data_manager.record_draw_result(user_id, rarity, shikigami) + if not record_result.get("success"): + return { + "success": False, + "message": record_result.get("message", "抽卡记录写入失败") + } return { - "success": True, - "rarity": rarity, - "name": shikigami["name"], - "image_url": shikigami["image_url"], - "draws_left": draws_left, - "unlocked_achievements": unlocked_achievements - } + "success": True, + "rarity": record_result.get("rarity", rarity), + "name": record_result.get("name", shikigami["name"]), + "image_url": record_result.get("image_url") or record_result.get("image_path") or shikigami["image_url"], + "draws_left": record_result.get("draws_left", 0), + "unlocked_achievements": record_result.get("unlocked_achievements", []) + } def _draw_rarity(self, user_id: str = None) -> str: """按概率抽取稀有度""" @@ -82,145 +85,39 @@ class GachaSystem: # 默认返回R,理论上不会执行到这里 return "R" - def get_user_stats(self, user_id: str) -> Dict: - """获取用户抽卡统计""" - user_stats = self.data_manager.get_user_stats() - - if user_id not in user_stats: - return { - "success": False, - "message": "您还没有抽卡记录哦!" - } - - stats = user_stats[user_id] - return { - "success": True, - "total_draws": stats["total_draws"], - "R_count": stats["R_count"], - "SR_count": stats["SR_count"], - "SSR_count": stats["SSR_count"], - "SP_count": stats["SP_count"], - "recent_draws": stats["draw_history"][-5:] if stats["draw_history"] else [] - } + async def get_user_stats(self, user_id: str) -> Dict: + """获取用户抽卡统计""" + return await self.data_manager.get_user_stats(user_id) def get_probability_text(self) -> str: """获取概率展示文本""" probs = config.RARITY_PROBABILITY return f"--- 系统概率 ---\nR: {probs['R']}% | SR: {probs['SR']}% | SSR: {probs['SSR']}% | SP: {probs['SP']}%" - def get_rank_list(self) -> List[Tuple[str, Dict[str, int]]]: - """获取抽卡排行榜数据""" - user_stats = self.data_manager.get_user_stats() - - # 过滤有SSR/SP记录的用户 - ranked_users = [ - (user_id, stats) - for user_id, stats in user_stats.items() - if stats.get("SSR_count", 0) > 0 or stats.get("SP_count", 0) > 0 - ] - - # 按SSR+SP总数降序排序 - ranked_users.sort( - key=lambda x: (x[1].get("SSR_count", 0) + x[1].get("SP_count", 0)), - reverse=True - ) - - return ranked_users - - def get_daily_stats(self) -> Dict: - """获取今日抽卡统计""" - daily_draws = self.data_manager.get_daily_draws() - today = self.data_manager.get_today_date() - - if not daily_draws or today not in daily_draws: - return { - "success": False, - "message": "今日还没有人抽卡哦!" - } - - today_stats = daily_draws[today] - total_stats = { - "total_users": len(today_stats), - "total_draws": 0, - "R_count": 0, - "SR_count": 0, - "SSR_count": 0, - "SP_count": 0, - "user_stats": [] - } - - # 统计每个用户的抽卡情况 - for user_id, draws in today_stats.items(): - user_stats = { - "user_id": user_id, - "total_draws": len(draws), - "R_count": sum(1 for d in draws if d["rarity"] == "R"), - "SR_count": sum(1 for d in draws if d["rarity"] == "SR"), - "SSR_count": sum(1 for d in draws if d["rarity"] == "SSR"), - "SP_count": sum(1 for d in draws if d["rarity"] == "SP") - } - - # 更新总统计 - total_stats["total_draws"] += user_stats["total_draws"] - total_stats["R_count"] += user_stats["R_count"] - total_stats["SR_count"] += user_stats["SR_count"] - total_stats["SSR_count"] += user_stats["SSR_count"] - total_stats["SP_count"] += user_stats["SP_count"] - - # 只记录抽到SSR或SP的用户 - if user_stats["SSR_count"] > 0 or user_stats["SP_count"] > 0: - total_stats["user_stats"].append(user_stats) - - # 按SSR+SP数量排序用户统计 - total_stats["user_stats"].sort( - key=lambda x: (x["SSR_count"] + x["SP_count"]), - reverse=True - ) - - # 构建稀有度统计 - rarity_stats = { - "R": total_stats["R_count"], - "SR": total_stats["SR_count"], - "SSR": total_stats["SSR_count"], - "SP": total_stats["SP_count"] - } - - # 构建排行榜数据 - top_users = [] - for user_stat in total_stats["user_stats"]: - top_users.append({ - "user_id": user_stat["user_id"], - "ssr_count": user_stat["SSR_count"] + user_stat["SP_count"] - }) - - final_stats = { - "total_users": total_stats["total_users"], - "total_draws": total_stats["total_draws"], - "rarity_stats": rarity_stats, - "top_users": top_users - } - - return { - "success": True, - "date": today, - "stats": final_stats - } - - def triple_draw(self, user_id: str) -> Dict: - """执行三连抽""" - # 检查是否有足够的抽卡次数 - draws_left = self.data_manager.get_draws_left(user_id) - if draws_left < 3: - return { - "success": False, + async def get_rank_list(self) -> List[Tuple[str, Dict[str, int]]]: + """获取抽卡排行榜数据""" + items = await self.data_manager.get_rank(limit=10) + return [(item["user_id"], item) for item in items] + + async def get_daily_stats(self) -> Dict: + """获取今日抽卡统计""" + return await self.data_manager.get_daily_stats() + + async def triple_draw(self, user_id: str) -> Dict: + """执行三连抽""" + # 检查是否有足够的抽卡次数 + draws_left = await self.data_manager.get_draws_left(user_id) + if draws_left < 3: + return { + "success": False, "message": f"抽卡次数不足,您今日还剩{draws_left}次抽卡机会,三连抽需要3次机会" } results = [] - all_unlocked_achievements = [] - - # 执行三次抽卡 - for i in range(3): + await self.data_manager.ensure_shikigami_data() + + # 执行三次本地概率抽取,统一提交 xapi 三连写入端点。 + for i in range(3): # 抽取稀有度(传递用户ID) rarity = self._draw_rarity(user_id) @@ -235,29 +132,30 @@ class GachaSystem: # 随机选择式神 shikigami = random.choice(shikigami_data) - # 记录抽卡 - unlocked_achievements = self.data_manager.record_draw(user_id, rarity, shikigami["name"]) - all_unlocked_achievements.extend(unlocked_achievements) - - results.append({ - "rarity": rarity, - "name": shikigami["name"], - "image_url": shikigami["image_url"] - }) - - # 剩余次数 - draws_left = self.data_manager.get_draws_left(user_id) - - return { - "success": True, - "results": results, - "draws_left": draws_left, - "unlocked_achievements": list(set(all_unlocked_achievements)) # 去重 - } - - def get_user_achievements(self, user_id: str) -> Dict: - """获取用户成就信息""" - achievement_data = self.data_manager.get_user_achievements(user_id) + results.append({ + "id": shikigami["id"], + "rarity": rarity, + "name": shikigami["name"], + "image_url": shikigami["image_url"] + }) + + record_result = await self.data_manager.record_triple_draw_result(user_id, results) + if not record_result.get("success"): + return { + "success": False, + "message": record_result.get("message", "三连抽记录写入失败") + } + + return { + "success": True, + "results": record_result.get("results", results), + "draws_left": record_result.get("draws_left", 0), + "unlocked_achievements": record_result.get("unlocked_achievements", []) + } + + async def get_user_achievements(self, user_id: str) -> Dict: + """获取用户成就信息""" + achievement_data = await self.data_manager.get_user_achievements(user_id) if not achievement_data["unlocked"] and all(v == 0 for v in achievement_data["progress"].values()): return { @@ -271,48 +169,6 @@ class GachaSystem: "progress": achievement_data["progress"] } - def get_daily_detailed_records(self, date: Optional[str] = None) -> Dict: - """获取每日详细抽卡记录""" - if not date: - date = self.data_manager.get_today_date() - - daily_draws = self.data_manager.get_daily_draws() - - if not daily_draws or date not in daily_draws: - return { - "success": False, - "message": f"{date} 没有抽卡记录" - } - - records = [] - for user_id, draws in daily_draws[date].items(): - for draw in draws: - # 检查这次抽卡是否解锁了成就 - unlocked_achievements = [] - draw_time = draw.get("timestamp", "未知时间") - - # 获取用户成就信息 - achievement_data = self.data_manager.get_user_achievements(user_id) - if achievement_data["unlocked"]: - # 检查是否有在抽卡时间之后解锁的成就 - for achievement_id, achievement_info in achievement_data["unlocked"].items(): - if achievement_info["unlocked_date"] == f"{date} {draw_time}": - unlocked_achievements.append(achievement_id) - - records.append({ - "user_id": user_id, - "draw_time": draw_time, - "shikigami_name": draw["name"], - "rarity": draw["rarity"], - "unlocked_achievements": unlocked_achievements - }) - - # 按时间排序 - records.sort(key=lambda x: x["draw_time"]) - - return { - "success": True, - "date": date, - "records": records, - "total_count": len(records) - } \ No newline at end of file + async def get_daily_detailed_records(self, date: Optional[str] = None) -> Dict: + """获取每日详细抽卡记录""" + return await self.data_manager.get_daily_records(date) diff --git a/danding_bot/plugins/onmyoji_gacha/web_api.py b/danding_bot/plugins/onmyoji_gacha/web_api.py index 242fd60..a3cf3f5 100644 --- a/danding_bot/plugins/onmyoji_gacha/web_api.py +++ b/danding_bot/plugins/onmyoji_gacha/web_api.py @@ -82,9 +82,9 @@ async def admin_page(request: Request): # API 端点 @router.get("/api/stats/daily", response_model=DailyStatsResponse, dependencies=[Depends(verify_admin_token)]) -async def get_daily_stats(): - """获取今日抽卡统计""" - result = gacha_system.get_daily_stats() +async def get_daily_stats(): + """获取今日抽卡统计""" + result = await gacha_system.get_daily_stats() if not result["success"]: return result @@ -95,9 +95,9 @@ async def get_daily_stats(): } @router.get("/api/stats/user/{user_id}", response_model=UserStatsResponse, dependencies=[Depends(verify_admin_token)]) -async def get_user_stats(user_id: str): - """获取用户抽卡统计""" - result = gacha_system.get_user_stats(user_id) +async def get_user_stats(user_id: str): + """获取用户抽卡统计""" + result = await gacha_system.get_user_stats(user_id) if not result["success"]: return { "success": False, @@ -122,9 +122,9 @@ async def get_user_stats(user_id: str): } @router.get("/api/stats/rank", response_model=RankListResponse, dependencies=[Depends(verify_admin_token)]) -async def get_rank_list(): - """获取抽卡排行榜""" - rank_data = gacha_system.get_rank_list() +async def get_rank_list(): + """获取抽卡排行榜""" + rank_data = await gacha_system.get_rank_list() # 转换数据格式 formatted_data = [] @@ -145,9 +145,9 @@ async def get_rank_list(): } @router.get("/api/achievements/{user_id}", response_model=AchievementResponse, dependencies=[Depends(verify_admin_token)]) -async def get_user_achievements(user_id: str): - """获取用户成就信息""" - result = gacha_system.get_user_achievements(user_id) +async def get_user_achievements(user_id: str): + """获取用户成就信息""" + result = await gacha_system.get_user_achievements(user_id) if not result["success"]: return { "success": False, @@ -164,9 +164,9 @@ async def get_user_achievements(user_id: str): } @router.get("/api/records/daily", response_model=DailyDetailedRecordsResponse, dependencies=[Depends(verify_admin_token)]) -async def get_daily_detailed_records(date: Optional[str] = None): - """获取每日详细抽卡记录""" - result = gacha_system.get_daily_detailed_records(date) +async def get_daily_detailed_records(date: Optional[str] = None): + """获取每日详细抽卡记录""" + result = await gacha_system.get_daily_detailed_records(date) if not result["success"]: return { "success": False, diff --git a/tests/test_danding_points_http_api.py b/tests/test_danding_points_http_api.py new file mode 100644 index 0000000..8c0544d --- /dev/null +++ b/tests/test_danding_points_http_api.py @@ -0,0 +1,217 @@ +"""danding_points HTTP PointsAPI 测试。""" + +from __future__ import annotations + +import aiohttp +import importlib.util +import pytest +import sys +import types +from pathlib import Path + + +def load_points_modules(): + """直接加载 danding_points 子模块,避免测试环境缺 nonebot 时执行插件元数据。""" + + plugin_dir = Path(__file__).resolve().parents[1] / "danding_bot" / "plugins" / "danding_points" + package_name = "_danding_points_under_test" + package = types.ModuleType(package_name) + package.__path__ = [str(plugin_dir)] + sys.modules[package_name] = package + + for module_name in ("config", "api"): + full_name = f"{package_name}.{module_name}" + spec = importlib.util.spec_from_file_location(full_name, plugin_dir / f"{module_name}.py") + module = importlib.util.module_from_spec(spec) + sys.modules[full_name] = module + assert spec and spec.loader + spec.loader.exec_module(module) + + return sys.modules[f"{package_name}.api"], sys.modules[f"{package_name}.config"] + + +api_module, config_module = load_points_modules() +PointsAPI = api_module.PointsAPI +Config = config_module.Config + + +class FakeResponse: + """模拟 aiohttp 响应上下文。""" + + def __init__(self, payload, status=200): + self.payload = payload + self.status = status + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def json(self): + return self.payload + + +class FakeSession: + """记录请求参数的 aiohttp ClientSession 替身。""" + + def __init__(self, responses, calls): + self.responses = responses + self.calls = calls + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def get(self, url, params=None, timeout=None): + self.calls.append({"method": "GET", "url": url, "params": params, "timeout": timeout}) + return FakeResponse(self.responses.pop(0)) + + def post(self, url, json=None, timeout=None): + self.calls.append({"method": "POST", "url": url, "json": json, "timeout": timeout}) + return FakeResponse(self.responses.pop(0)) + + +def make_points_api() -> PointsAPI: + return PointsAPI( + Config( + POINTS_API_HOST="http://xapi.test/bot/points/", + BOT_USER="robot", + BOT_TOKEN="secret", + ) + ) + + +@pytest.fixture +def fake_aiohttp(monkeypatch): + calls = [] + responses = [] + monkeypatch.setattr(api_module.aiohttp, "ClientSession", lambda: FakeSession(responses, calls)) + return responses, calls + + +@pytest.mark.asyncio +async def test_get_balance_sends_auth_in_query(fake_aiohttp): + responses, calls = fake_aiohttp + responses.append({"code": 200, "message": "", "data": {"balance": 88}}) + points = make_points_api() + + balance = await points.get_balance("10001") + + assert balance == 88 + assert not hasattr(points, "db") + assert calls[0]["method"] == "GET" + assert calls[0]["url"] == "http://xapi.test/bot/points/balance" + assert calls[0]["params"] == {"user": "robot", "token": "secret", "user_id": "10001"} + + +@pytest.mark.asyncio +async def test_add_spend_set_send_auth_in_post_body(fake_aiohttp): + responses, calls = fake_aiohttp + responses.extend( + [ + {"code": 200, "message": "", "data": {"success": True, "balance": 10}}, + {"code": 200, "message": "", "data": {"success": False, "balance": 7}}, + {"code": 200, "message": "", "data": {"success": True, "balance": 99}}, + ] + ) + points = make_points_api() + + add_result = await points.add_points("10001", 10, "gacha_sign", "签到") + spend_result = await points.spend_points("10001", 5, "horse_race", "下注") + set_result = await points.set_points("10001", 99, "admin", "调整") + + assert add_result == (True, 10) + assert spend_result == (False, 7) + assert set_result == (True, 99) + assert [call["method"] for call in calls] == ["POST", "POST", "POST"] + assert calls[0]["json"] == { + "user": "robot", + "token": "secret", + "user_id": "10001", + "amount": 10, + "source": "gacha_sign", + "reason": "签到", + } + assert calls[1]["url"].endswith("/spend") + assert calls[2]["url"].endswith("/set") + + +@pytest.mark.asyncio +async def test_transactions_and_ranking_return_items_unchanged(fake_aiohttp): + responses, calls = fake_aiohttp + tx_item = { + "id": 1, + "user_id": "10001", + "amount": 10, + "balance_after": 10, + "source": "gacha_sign", + "reason": "签到", + "created_at": "2026-06-20T12:00:00", + } + ranking_item = { + "rank": 1, + "user_id": "10001", + "points": 10, + "total_earned": 10, + "total_spent": 0, + } + responses.extend( + [ + {"code": 200, "message": "", "data": {"items": [tx_item]}}, + {"code": 200, "message": "", "data": {"items": [ranking_item]}}, + ] + ) + points = make_points_api() + + transactions = await points.get_transactions("10001", limit=5, offset=2) + ranking = await points.get_ranking(limit=3, order_by="unknown") + + assert transactions == [tx_item] + assert ranking == [ranking_item] + assert calls[0]["params"] == { + "user": "robot", + "token": "secret", + "user_id": "10001", + "limit": 5, + "offset": 2, + } + assert calls[1]["params"] == { + "user": "robot", + "token": "secret", + "limit": 3, + "order_by": "points", + } + + +@pytest.mark.asyncio +async def test_network_error_keeps_old_failure_returns(monkeypatch): + class FailingSession: + async def __aenter__(self): + raise aiohttp.ClientError("offline") + + async def __aexit__(self, exc_type, exc, tb): + return None + + monkeypatch.setattr(api_module.aiohttp, "ClientSession", lambda: FailingSession()) + points = make_points_api() + + assert await points.get_balance("10001") == 0 + assert await points.add_points("10001", 1, "gacha_sign") == (False, 0) + assert await points.spend_points("10001", 1, "horse_race") == (False, 0) + assert await points.set_points("10001", 1, "admin") == (False, 0) + assert await points.get_transactions("10001") == [] + assert await points.get_ranking() == [] + + +@pytest.mark.asyncio +async def test_invalid_change_request_does_not_call_http(fake_aiohttp): + _responses, calls = fake_aiohttp + points = make_points_api() + + assert await points.add_points("10001", 0, "gacha_sign") == (False, 0) + assert await points.spend_points("", 1, "horse_race") == (False, 0) + assert await points.set_points("10001", -1, "admin") == (False, 0) + assert calls == [] diff --git a/tests/test_group_horse_racing_runtime_api.py b/tests/test_group_horse_racing_runtime_api.py new file mode 100644 index 0000000..d564c2b --- /dev/null +++ b/tests/test_group_horse_racing_runtime_api.py @@ -0,0 +1,474 @@ +"""group_horse_racing 运行时 API 改造测试。""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from datetime import datetime +from pathlib import Path + +import aiohttp +import aiosqlite +import pytest + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +PLUGIN_DIR = PROJECT_ROOT / "danding_bot" / "plugins" / "group_horse_racing" + + +def _load_module(full_name: str, path: Path): + spec = importlib.util.spec_from_file_location(full_name, path) + module = importlib.util.module_from_spec(spec) + sys.modules[full_name] = module + assert spec and spec.loader + spec.loader.exec_module(module) + return module + + +def _install_nonebot_stubs() -> None: + """安装 shared.py 导入所需的最小 NoneBot 类型桩。""" + + nonebot = sys.modules.setdefault("nonebot", types.ModuleType("nonebot")) + adapters = sys.modules.setdefault("nonebot.adapters", types.ModuleType("nonebot.adapters")) + onebot = sys.modules.setdefault("nonebot.adapters.onebot", types.ModuleType("nonebot.adapters.onebot")) + v11 = types.ModuleType("nonebot.adapters.onebot.v11") + + class Bot: + async def get_group_member_info(self, **kwargs): + return {} + + class Event: + def get_user_id(self): + return str(getattr(self, "user_id", "")) + + class GroupMessageEvent(Event): + pass + + class PrivateMessageEvent(Event): + pass + + class Message(list): + pass + + class MessageSegment: + @staticmethod + def image(data): + return {"type": "image", "data": data} + + v11.Bot = Bot + v11.Event = Event + v11.GroupMessageEvent = GroupMessageEvent + v11.PrivateMessageEvent = PrivateMessageEvent + v11.Message = Message + v11.MessageSegment = MessageSegment + sys.modules["nonebot.adapters.onebot.v11"] = v11 + nonebot.adapters = adapters + adapters.onebot = onebot + onebot.v11 = v11 + + +def _install_qqpush_stubs() -> None: + qqpush_config = types.ModuleType("danding_bot.plugins.danding_qqpush.config") + qqpush_image = types.ModuleType("danding_bot.plugins.danding_qqpush.image_render") + + class QqPushConfig: + FontPaths = [] + + class ImageRenderer: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def render_to_base64(self, body, title=""): + return f"base64://{title}:{body}" + + qqpush_config.Config = QqPushConfig + qqpush_image.ImageRenderer = ImageRenderer + sys.modules["danding_bot.plugins.danding_qqpush.config"] = qqpush_config + sys.modules["danding_bot.plugins.danding_qqpush.image_render"] = qqpush_image + + +def _install_points_stub() -> None: + points_package = types.ModuleType("danding_bot.plugins.danding_points") + + class PointsApi: + async def add_points(self, user_id, amount, source, reason=None): + return True, 0 + + async def spend_points(self, user_id, amount, source, reason=None): + return True, 0 + + async def set_points(self, user_id, amount, source, reason=None): + return True, amount + + async def get_balance(self, user_id): + return 0 + + points_package.points_api = PointsApi() + sys.modules["danding_bot.plugins.danding_points"] = points_package + + +def load_room_store_modules(): + """直接加载 room_store 相关子模块,避免执行插件入口。""" + + package_name = "_horse_racing_room_store_under_test" + package = types.ModuleType(package_name) + package.__path__ = [str(PLUGIN_DIR)] + sys.modules[package_name] = package + + config_module = _load_module(f"{package_name}.config", PLUGIN_DIR / "config.py") + models_module = _load_module(f"{package_name}.models", PLUGIN_DIR / "models.py") + room_store_module = _load_module(f"{package_name}.room_store", PLUGIN_DIR / "room_store.py") + return config_module, models_module, room_store_module + + +def load_shared_modules(): + """加载 shared.py 及其依赖,同时隔离 NoneBot 和外部插件依赖。""" + + _install_nonebot_stubs() + _install_qqpush_stubs() + _install_points_stub() + + package_name = "_horse_racing_shared_under_test" + package = types.ModuleType(package_name) + package.__path__ = [str(PLUGIN_DIR)] + sys.modules[package_name] = package + + config_module = _load_module(f"{package_name}.config", PLUGIN_DIR / "config.py") + models_module = _load_module(f"{package_name}.models", PLUGIN_DIR / "models.py") + package.plugin_config = config_module.Config(RACE_RENDER_AS_IMAGE=False, RACE_TICK_INTERVAL=0, RACE_DISTANCE=1) + _load_module(f"{package_name}.room_store", PLUGIN_DIR / "room_store.py") + _load_module(f"{package_name}.points_service", PLUGIN_DIR / "points_service.py") + _load_module(f"{package_name}.race_engine", PLUGIN_DIR / "race_engine.py") + _load_module(f"{package_name}.message_service", PLUGIN_DIR / "message_service.py") + + commands_package_name = f"{package_name}.commands" + commands_package = types.ModuleType(commands_package_name) + commands_package.__path__ = [str(PLUGIN_DIR / "commands")] + sys.modules[commands_package_name] = commands_package + access_module = types.ModuleType(f"{commands_package_name}.access") + access_module.get_event_id = lambda event: str(getattr(event, "user_id", "")) + access_module.get_scope = lambda event: f"group_{getattr(event, 'group_id', '')}" + async def _check_access(bot, event): + return True + access_module.check_access = _check_access + sys.modules[f"{commands_package_name}.access"] = access_module + + shared_module = _load_module(f"{commands_package_name}.shared", PLUGIN_DIR / "commands" / "shared.py") + return config_module, models_module, shared_module + + +config_module, models_module, room_store_module = load_room_store_modules() +Config = config_module.Config +RoomStore = room_store_module.RoomStore +RaceResult = models_module.RaceResult + + +class FakeResponse: + """模拟 aiohttp 响应上下文。""" + + def __init__(self, payload, status=200): + self.payload = payload + self.status = status + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def json(self): + return self.payload + + +class FakeSession: + """记录赛马 HTTP 请求参数的 aiohttp ClientSession 替身。""" + + def __init__(self, responses, calls): + self.responses = responses + self.calls = calls + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def get(self, url, params=None, timeout=None): + self.calls.append({"method": "GET", "url": url, "params": params, "timeout": timeout}) + return FakeResponse(self.responses.pop(0)) + + def post(self, url, json=None, timeout=None): + self.calls.append({"method": "POST", "url": url, "json": json, "timeout": timeout}) + return FakeResponse(self.responses.pop(0)) + + def put(self, url, json=None, timeout=None): + self.calls.append({"method": "PUT", "url": url, "json": json, "timeout": timeout}) + return FakeResponse(self.responses.pop(0)) + + +def success_data(data): + return {"code": 200, "message": "", "data": data} + + +@pytest.fixture +def fake_aiohttp(monkeypatch): + calls = [] + responses = [] + monkeypatch.setattr(room_store_module.aiohttp, "ClientSession", lambda: FakeSession(responses, calls)) + return responses, calls + + +def make_store(tmp_path: Path) -> RoomStore: + return RoomStore( + Config( + RACE_DB_FILE=str(tmp_path / "race.db"), + RACE_API_HOST="http://xapi.test/bot/race/", + BOT_USER="robot", + BOT_TOKEN="secret", + ) + ) + + +@pytest.mark.asyncio +async def test_room_store_race_history_and_horse_names_use_xapi(fake_aiohttp, tmp_path): + responses, calls = fake_aiohttp + responses.extend( + [ + success_data({"user_id": "10001", "horse_name": "赤焰"}), + success_data({"success": True}), + success_data({"race_id": "race-001"}), + ] + ) + store = make_store(tmp_path) + result = RaceResult( + race_id="race-001", + scope="group_1000", + champion_name="赤焰", + champion_owner="10001", + participants=["赤焰", "青岚"], + bet_distribution={"赤焰": 30, "青岚": 10}, + duration_ticks=8, + completed_at=datetime(2026, 6, 20, 10, 0, 0), + point_changes={"10001": 170}, + point_change_summaries={"10001": "大赚特赚"}, + odds_snapshot={"赤焰": 1.5}, + ) + + horse_name = await store.get_last_horse_name("10001") + await store.set_last_horse_name("10001", "青岚") + await store.save_race_result(result) + + assert horse_name == "赤焰" + assert calls[0]["method"] == "GET" + assert calls[0]["url"] == "http://xapi.test/bot/race/horse-name" + assert calls[0]["params"] == {"user": "robot", "token": "secret", "user_id": "10001"} + assert calls[1]["method"] == "PUT" + assert calls[1]["json"] == {"user": "robot", "token": "secret", "user_id": "10001", "horse_name": "青岚"} + assert calls[2]["method"] == "POST" + assert calls[2]["url"] == "http://xapi.test/bot/race/history" + assert calls[2]["json"]["race_id"] == "race-001" + assert calls[2]["json"]["participants"] == ["赤焰", "青岚"] + assert calls[2]["json"]["odds_snapshot"] == {"赤焰": 1.5} + + +@pytest.mark.asyncio +async def test_room_snapshots_still_use_local_sqlite_without_race_http(fake_aiohttp, tmp_path): + _responses, calls = fake_aiohttp + store = make_store(tmp_path) + + room = await store.create_room("group_1000") + loaded = store.get_room("group_1000") + + assert loaded is room + assert calls == [] + async with aiosqlite.connect(store.db_path) as db: + cursor = await db.execute("SELECT scope, state FROM room_snapshots") + rows = await cursor.fetchall() + await store.close() + assert rows == [("group_1000", "waiting")] + + +@pytest.mark.asyncio +async def test_race_api_network_error_keeps_old_failure_shapes(monkeypatch, tmp_path): + class FailingSession: + async def __aenter__(self): + raise aiohttp.ClientError("offline") + + async def __aexit__(self, exc_type, exc, tb): + return None + + monkeypatch.setattr(room_store_module.aiohttp, "ClientSession", lambda: FailingSession()) + store = make_store(tmp_path) + + assert await store.get_last_horse_name("10001") is None + await store.set_last_horse_name("10001", "赤焰") + + +@pytest.mark.asyncio +async def test_save_race_result_raises_when_xapi_rejects(fake_aiohttp, tmp_path): + responses, _calls = fake_aiohttp + responses.append({"code": 500, "message": "写入失败", "data": None}) + store = make_store(tmp_path) + result = RaceResult( + race_id="race-rejected", + scope="group_1000", + champion_name="赤焰", + champion_owner="10001", + participants=["赤焰"], + bet_distribution={"赤焰": 0}, + duration_ticks=1, + completed_at=datetime(2026, 6, 20, 10, 0, 0), + point_changes={"10001": 170}, + point_change_summaries={"10001": "大赚特赚"}, + odds_snapshot={"赤焰": 1.2}, + ) + + with pytest.raises(RuntimeError, match="赛马赛果写入 xapi 失败"): + await store.save_race_result(result) + + +@pytest.mark.asyncio +async def test_settle_race_awaits_balances_and_builds_complete_result(): + _config_module, shared_models, shared = load_shared_modules() + Room = shared_models.Room + Horse = shared_models.Horse + HorseState = shared_models.HorseState + Bet = shared_models.Bet + + class FakePointsService: + def __init__(self): + self.balances = {"10001": 100, "10002": 50, "20001": 200} + + async def get_balance(self, user_id: str) -> int: + return self.balances[user_id] + + async def reward_participant(self, user_id: str): + self.balances[user_id] += shared.config.PARTICIPANT_REWARD + return True, self.balances[user_id] + + async def reward_champion(self, user_id: str): + self.balances[user_id] += shared.config.CHAMPION_REWARD + return True, self.balances[user_id] + + async def payout_winnings(self, user_id: str, amount: int, odds: float): + self.balances[user_id] += max(1, round(amount * odds)) + return True, self.balances[user_id] + + shared.points_service = FakePointsService() + room = Room(scope="group_1000") + room.horses = { + "赤焰": Horse(owner_id="10001", name="赤焰", index=1, state=HorseState.RACING), + "青岚": Horse(owner_id="10002", name="青岚", index=2, state=HorseState.RACING), + } + room.bets = [Bet(user_id="20001", horse_name="赤焰", amount=30)] + room.champion_name = "赤焰" + room.tick_count = 7 + + settlement = await shared.settle_race(room) + + assert settlement is not None + result, odds = settlement + assert result.race_id + assert result.scope == "group_1000" + assert result.participants == ["赤焰", "青岚"] + assert result.bet_distribution == {"赤焰": 30, "青岚": 0} + assert result.duration_ticks == 7 + assert result.odds_snapshot == odds + assert result.point_changes == {"10001": 170, "10002": 20, "20001": 36} + assert all(isinstance(value, int) for value in result.point_changes.values()) + + +@pytest.mark.asyncio +async def test_run_race_with_settlement_saves_result_before_deleting_room(): + _config_module, shared_models, shared = load_shared_modules() + Room = shared_models.Room + Horse = shared_models.Horse + HorseState = shared_models.HorseState + Bet = shared_models.Bet + + class FakeRaceEngine: + def __init__(self): + self.stopped = [] + + def tick(self, room): + room.tick_count = 3 + return [room.horses["赤焰"]] + + def format_progress(self, room): + return "progress" + + def determine_champion(self, horses): + return horses[0] + + def stop_race(self, scope): + self.stopped.append(scope) + + class FakePointsService: + def __init__(self): + self.balances = {"10001": 100, "10002": 50, "20001": 200} + + async def get_balance(self, user_id: str) -> int: + return self.balances[user_id] + + async def reward_participant(self, user_id: str): + self.balances[user_id] += shared.config.PARTICIPANT_REWARD + return True, self.balances[user_id] + + async def reward_champion(self, user_id: str): + self.balances[user_id] += shared.config.CHAMPION_REWARD + return True, self.balances[user_id] + + async def payout_winnings(self, user_id: str, amount: int, odds: float): + self.balances[user_id] += max(1, round(amount * odds)) + return True, self.balances[user_id] + + class FakeMessageService: + def __init__(self): + self.sent = [] + self.cleared = [] + + async def send_with_recall(self, bot, scope, message_type, message): + self.sent.append((scope, message_type, str(message))) + return "msg" + + async def recall_previous_of_type(self, bot, scope, message_type): + return None + + def clear_pending_recalls(self, scope): + self.cleared.append(scope) + + class FakeRoomStore: + def __init__(self): + self.saved = [] + self.deleted = [] + + async def save_race_result(self, result): + self.saved.append(result) + + def delete_room(self, scope): + self.deleted.append(scope) + + shared.race_engine = FakeRaceEngine() + shared.points_service = FakePointsService() + shared.message_service = FakeMessageService() + shared.room_store = FakeRoomStore() + shared.config.RACE_TICK_INTERVAL = 0 + + room = Room(scope="group_1000") + room.horses = { + "赤焰": Horse(owner_id="10001", name="赤焰", index=1, state=HorseState.RACING), + "青岚": Horse(owner_id="10002", name="青岚", index=2, state=HorseState.RACING), + } + room.bets = [Bet(user_id="20001", horse_name="赤焰", amount=30)] + + await shared.run_race_with_settlement(object(), room, "group_1000") + + assert len(shared.room_store.saved) == 1 + saved = shared.room_store.saved[0] + assert saved.scope == "group_1000" + assert saved.duration_ticks == 3 + assert saved.odds_snapshot == {"赤焰": 1.2, "青岚": 1.2} + assert shared.room_store.deleted == ["group_1000"] + assert shared.race_engine.stopped == ["group_1000"] + assert shared.message_service.cleared == ["group_1000"] diff --git a/tests/test_onmyoji_gacha_http_api.py b/tests/test_onmyoji_gacha_http_api.py new file mode 100644 index 0000000..326d57a --- /dev/null +++ b/tests/test_onmyoji_gacha_http_api.py @@ -0,0 +1,247 @@ +"""onmyoji_gacha HTTP DataManager 测试。""" + +from __future__ import annotations + +import aiohttp +import importlib.util +import inspect +import pytest +import sys +import types +from pathlib import Path + + +def load_gacha_modules(): + """直接加载 onmyoji_gacha 子模块,避免测试环境执行 nonebot 插件入口。""" + + plugin_dir = Path(__file__).resolve().parents[1] / "danding_bot" / "plugins" / "onmyoji_gacha" + package_name = "_onmyoji_gacha_under_test" + package = types.ModuleType(package_name) + package.__path__ = [str(plugin_dir)] + sys.modules[package_name] = package + + for module_name in ("config", "data_manager", "gacha"): + full_name = f"{package_name}.{module_name}" + spec = importlib.util.spec_from_file_location(full_name, plugin_dir / f"{module_name}.py") + module = importlib.util.module_from_spec(spec) + sys.modules[full_name] = module + assert spec and spec.loader + spec.loader.exec_module(module) + + return ( + sys.modules[f"{package_name}.config"], + sys.modules[f"{package_name}.data_manager"], + sys.modules[f"{package_name}.gacha"], + ) + + +config_module, data_manager_module, gacha_module = load_gacha_modules() +Config = config_module.Config +DataManager = data_manager_module.DataManager +GachaSystem = gacha_module.GachaSystem + + +class FakeResponse: + """模拟 aiohttp 响应上下文。""" + + def __init__(self, payload, status=200): + self.payload = payload + self.status = status + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def json(self): + return self.payload + + +class FakeSession: + """记录请求参数的 aiohttp ClientSession 替身。""" + + def __init__(self, responses, calls): + self.responses = responses + self.calls = calls + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def get(self, url, params=None, timeout=None): + self.calls.append({"method": "GET", "url": url, "params": params, "timeout": timeout}) + return FakeResponse(self.responses.pop(0)) + + def post(self, url, json=None, timeout=None): + self.calls.append({"method": "POST", "url": url, "json": json, "timeout": timeout}) + return FakeResponse(self.responses.pop(0)) + + +@pytest.fixture +def fake_aiohttp(monkeypatch): + calls = [] + responses = [] + monkeypatch.setattr(data_manager_module.aiohttp, "ClientSession", lambda: FakeSession(responses, calls)) + test_config = Config( + GACHA_API_HOST="http://xapi.test/bot/gacha/", + BOT_USER_ID="robot", + BOT_TOKEN="secret", + ) + monkeypatch.setattr(data_manager_module, "config", test_config) + monkeypatch.setattr(gacha_module, "config", test_config) + return responses, calls + + +def success_data(data): + return {"code": 200, "message": "", "data": data} + + +@pytest.mark.asyncio +async def test_shikigami_cache_and_draw_send_auth_to_xapi(fake_aiohttp): + responses, calls = fake_aiohttp + responses.extend( + [ + success_data( + { + "items": [ + {"id": 1, "name": "灯笼鬼", "rarity": "R", "image_path": "/r/灯笼鬼.png"}, + {"id": 3, "name": "茨木童子", "rarity": "SSR", "image_path": "/ssr/茨木童子.png"}, + ] + } + ), + success_data( + { + "success": True, + "rarity": "SSR", + "name": "茨木童子", + "image_path": "/ssr/茨木童子.png", + "image_url": "/ssr/茨木童子.png", + "draws_left": 2, + "unlocked_achievements": ["no_ssr_60"], + } + ), + ] + ) + manager = DataManager() + + grouped = await manager.refresh_shikigami_data() + draw_result = await manager.record_draw_result("10001", "SSR", grouped["SSR"][0]) + + assert grouped["SSR"][0]["image_url"] == "/ssr/茨木童子.png" + assert draw_result["unlocked_achievements"] == ["no_ssr_60"] + assert calls[0]["method"] == "GET" + assert calls[0]["url"] == "http://xapi.test/bot/gacha/shikigami" + assert calls[0]["params"] == {"user": "robot", "token": "secret"} + assert calls[1]["method"] == "POST" + assert calls[1]["json"] == { + "user": "robot", + "token": "secret", + "user_id": "10001", + "shikigami_id": 3, + "rarity": "SSR", + "name": "茨木童子", + } + + +@pytest.mark.asyncio +async def test_gacha_system_draw_keeps_probability_local_and_uses_xapi(fake_aiohttp, monkeypatch): + responses, calls = fake_aiohttp + responses.extend( + [ + success_data({"draws_left": 3, "daily_limit": 3}), + success_data({"items": [{"id": 3, "name": "茨木童子", "rarity": "SSR", "image_path": "/ssr/茨木童子.png"}]}), + success_data( + { + "success": True, + "rarity": "SSR", + "name": "茨木童子", + "image_path": "/ssr/茨木童子.png", + "image_url": "/ssr/茨木童子.png", + "draws_left": 2, + "unlocked_achievements": [], + } + ), + ] + ) + manager = DataManager() + monkeypatch.setattr(gacha_module, "data_manager", manager) + system = GachaSystem() + monkeypatch.setattr(system, "_draw_rarity", lambda user_id=None: "SSR") + + result = await system.draw("10001") + + assert result == { + "success": True, + "rarity": "SSR", + "name": "茨木童子", + "image_url": "/ssr/茨木童子.png", + "draws_left": 2, + "unlocked_achievements": [], + } + assert [call["url"].rsplit("/", 1)[-1] for call in calls] == ["draws-left", "shikigami", "draw"] + + +@pytest.mark.asyncio +async def test_triple_sign_in_claim_and_query_shapes(fake_aiohttp): + responses, calls = fake_aiohttp + responses.extend( + [ + success_data({"success": True, "results": [], "draws_left": 0, "unlocked_achievements": ["no_ssr_60"]}), + success_data({"success": True, "signed_already": True}), + success_data({"success": True, "reward_type": "天卡"}), + success_data({"success": True, "total_draws": 1, "R_count": 1, "SR_count": 0, "SSR_count": 0, "SP_count": 0, "recent_draws": []}), + success_data({"success": True, "date": "2026-06-20", "stats": {"total_users": 1}}), + success_data({"items": [{"user_id": "10001", "total_draws": 1, "R_count": 1, "SR_count": 0, "SSR_count": 0, "SP_count": 0, "ssr_sp_total": 0}]}), + success_data({"achievements": {"no_ssr_60": {"unlocked_date": "2026-06-20", "reward_claimed": False}}, "progress": {"no_ssr_streak": 60}}), + success_data({"success": True, "date": "2026-06-20", "records": [], "total_count": 0}), + ] + ) + manager = DataManager() + draws = [ + {"id": 1, "name": "灯笼鬼", "rarity": "R", "image_url": "/r/灯笼鬼.png"}, + {"id": 2, "name": "雪女", "rarity": "SR", "image_url": "/sr/雪女.png"}, + {"id": 3, "name": "茨木童子", "rarity": "SSR", "image_url": "/ssr/茨木童子.png"}, + ] + + triple = await manager.record_triple_draw_result("10001", draws) + signed = await manager.record_sign_in("10001", 20) + claimed = await manager.claim_achievement_reward("10001", "no_ssr_60") + stats = await manager.get_user_stats("10001") + daily = await manager.get_daily_stats("2026-06-20") + rank = await manager.get_rank() + achievements = await manager.get_user_achievements("10001") + records = await manager.get_daily_records("2026-06-20") + + assert triple["unlocked_achievements"] == ["no_ssr_60"] + assert signed is False + assert claimed is True + assert stats["success"] is True + assert daily["date"] == "2026-06-20" + assert rank[0]["user_id"] == "10001" + assert "unlocked" in achievements and "progress" in achievements + assert records["total_count"] == 0 + assert calls[0]["json"]["draws"][0] == {"shikigami_id": 1, "rarity": "R", "name": "灯笼鬼"} + assert calls[1]["json"]["points_awarded"] == 20 + assert calls[2]["json"]["achievement_id"] == "no_ssr_60" + assert calls[3]["params"]["user_id"] == "10001" + + +@pytest.mark.asyncio +async def test_network_error_keeps_failure_shapes_and_no_sqlite(monkeypatch): + class FailingSession: + async def __aenter__(self): + raise aiohttp.ClientError("offline") + + async def __aexit__(self, exc_type, exc, tb): + return None + + monkeypatch.setattr(data_manager_module.aiohttp, "ClientSession", lambda: FailingSession()) + manager = DataManager() + + assert await manager.get_draws_left("10001") == 0 + assert await manager.record_sign_in("10001", 20) is False + assert await manager.get_rank() == [] + assert "sqlite3" not in inspect.getsource(data_manager_module)