feat(bot): use runtime api for bot data

This commit is contained in:
2026-06-20 18:20:40 +08:00
parent f67f3ca1d6
commit 8d26c46323
16 changed files with 1803 additions and 1491 deletions

View File

@@ -10,7 +10,6 @@
danding_points/
├── __init__.py # 插件元数据 & points_api 单例导出
├── config.py # 配置类
├── database.py # SQLite 数据库操作
└── api.py # PointsAPI 核心
```
@@ -20,10 +19,11 @@ danding_points/
| 配置项 | 类型 | 默认值 | 说明 |
|--------|------|--------|------|
| `DANDING_POINTS_DB_FILE` | str | `data/danding_points/points.db` | 数据库文件路径 |
| `DANDING_POINTS_MAX_BALANCE` | int | `0` | 用户积分余额上限,`0` = 无限制 |
| `DANDING_POINTS_MAX_PER_OPERATION` | int | `0` | 单次操作积分上限,`0` = 无限制 |
| `DANDING_POINTS_LOG_RETENTION_DAYS` | int | `365` | 流水日志保留天数 |
| `DANDING_POINTS_API_HOST` | str | `https://api.danding.vip/bot/points` | xapi 积分 API 地址 |
| `DANDING_BOT_USER` | str | `1424473282` | xapi Bot 鉴权用户 |
| `DANDING_BOT_TOKEN` | str | 空 | xapi Bot 鉴权 Token未设置时读取 `DANDING_API_TOKEN` / `BOT_TOKEN` |
积分余额上限与单次操作上限由 xapi `BOT_POINTS_MAX_BALANCE` / `BOT_POINTS_MAX_PER_OPERATION` 控制nonebot 本地不再持有阈值。
## API 接口
@@ -170,36 +170,10 @@ async def admin_set(user_id: str, amount: int):
| 插件 | source 值 |
|------|-----------|
| onmyoji_gacha | `"onmyoji_gacha"` |
| danding_api | `"danding_api"` |
| shop | `"shop"` |
| sign_in | `"sign_in"` |
| onmyoji_gacha 签到 | `"gacha_sign"` |
| group_horse_racing | `"horse_race"` |
| 管理调整 | `"admin"` |
## 数据库
使用 SQLite数据文件位于 `data/danding_points/points.db`,无需额外配置
### 表结构
**user_points** — 用户积分账户
| 字段 | 类型 | 说明 |
|------|------|------|
| user_id | TEXT PK | 用户 ID |
| points | INTEGER | 当前余额,>= 0 |
| total_earned | INTEGER | 累计获得 |
| total_spent | INTEGER | 累计消费 |
| created_at | TEXT | 创建时间 |
| updated_at | TEXT | 最后更新时间 |
**point_transactions** — 积分变动流水
| 字段 | 类型 | 说明 |
|------|------|------|
| id | INTEGER PK | 自增 ID |
| user_id | TEXT | 用户 ID |
| amount | INTEGER | 变动数额(消费为负) |
| balance_after | INTEGER | 变动后余额 |
| source | TEXT | 来源标识 |
| reason | TEXT | 变动原因 |
| created_at | TEXT | 创建时间 |
本插件不再写入本地 SQLite。积分账户与流水由 xapi MySQL `bot_user_points` / `bot_point_transactions` 承载nonebot 只通过 `/bot/points/*` HTTP API 读写

View File

@@ -1,312 +1,181 @@
import asyncio
import logging
import threading
from datetime import datetime
from typing import Tuple, List, Dict, Any
from .config import Config
from .database import PointsDatabase
logger = logging.getLogger(__name__)
import asyncio
import aiohttp
import logging
from typing import Tuple, List, Dict, Any, Optional
from .config import Config
logger = logging.getLogger(__name__)
class PointsAPI:
"""Points system API for managing user points."""
def __init__(self, config: Config):
self.config = config
self.db = PointsDatabase(config)
self._lock = threading.Lock()
async def get_balance(self, user_id: str) -> int:
"""Get user's current points balance."""
return await asyncio.to_thread(self.db.get_user_balance, user_id)
"""Points system API for managing user points."""
def __init__(self, config: Config):
self.config = config
def _url(self, path: str) -> str:
"""拼接 /bot/points 端点地址。"""
return f"{self.config.POINTS_API_HOST}/{path.lstrip('/')}"
def _auth(self) -> Dict[str, str]:
"""生成 xapi Bot 鉴权参数。"""
return {
"user": self.config.BOT_USER,
"token": self.config.BOT_TOKEN,
}
async def _request(
self,
method: str,
path: str,
*,
payload: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""调用 xapi /bot/points并只向上层暴露 data。"""
request_url = self._url(path)
timeout = aiohttp.ClientTimeout(total=10)
try:
async with aiohttp.ClientSession() as session:
if method == "GET":
request_params = {**self._auth(), **(params or {})}
async with session.get(request_url, params=request_params, timeout=timeout) as resp:
return await self._parse_response(resp, path)
request_payload = {**self._auth(), **(payload or {})}
async with session.post(request_url, json=request_payload, timeout=timeout) as resp:
return await self._parse_response(resp, path)
except aiohttp.ClientError as exc:
logger.error("points api request failed path=%s error=%s", path, exc)
return None
except asyncio.TimeoutError as exc:
logger.error("points api request timeout path=%s error=%s", path, exc)
return None
async def _parse_response(self, resp: aiohttp.ClientResponse, path: str) -> Optional[Dict[str, Any]]:
"""解析 xapi 统一响应,失败时返回 None 维持旧 API 失败语义。"""
if resp.status != 200:
logger.error("points api bad status path=%s status=%s", path, resp.status)
return None
body = await resp.json()
if body.get("code") != 200:
logger.error("points api fail path=%s code=%s message=%s", path, body.get("code"), body.get("message"))
return None
data = body.get("data")
return data if isinstance(data, dict) else None
async def get_balance(self, user_id: str) -> int:
"""Get user's current points balance."""
data = await self._request("GET", "balance", params={"user_id": user_id})
if data is None:
return 0
return int(data.get("balance", 0) or 0)
async def add_points(
self, user_id: str, amount: int, source: str, reason: str = None
) -> Tuple[bool, int]:
"""Add points to user account.
Returns: (success, new_balance)
"""
# Parameter validation
if not isinstance(amount, int) or amount <= 0:
return False, 0
if not user_id or not source:
return False, 0
# Operation limit validation
if self.config.POINTS_MAX_PER_OPERATION > 0:
if amount > self.config.POINTS_MAX_PER_OPERATION:
return False, 0
def _add():
with self._lock:
conn = self.db.get_connection()
try:
cursor = conn.cursor()
# Ensure user exists
self.db.ensure_user_exists(user_id, conn)
# Get current balance
cursor.execute(
"SELECT points FROM user_points WHERE user_id = ?",
(user_id,),
)
row = cursor.fetchone()
current_balance = row["points"] if row else 0
# Check balance limit
new_balance = current_balance + amount
if self.config.POINTS_MAX_BALANCE > 0:
if new_balance > self.config.POINTS_MAX_BALANCE:
conn.rollback()
return False, current_balance
# Update balance and total_earned
now = datetime.now().isoformat()
cursor.execute(
"""
UPDATE user_points
SET points = ?, total_earned = total_earned + ?, updated_at = ?
WHERE user_id = ?
""",
(new_balance, amount, now, user_id),
)
# Write transaction log
cursor.execute(
"""
INSERT INTO point_transactions
(user_id, amount, balance_after, source, reason, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(user_id, amount, new_balance, source, reason, now),
)
conn.commit()
return True, new_balance
except Exception as e:
conn.rollback()
logger.error(f"add_points failed for {user_id}: {e}")
return False, 0
finally:
conn.close()
return await asyncio.to_thread(_add)
"""Add points to user account.
Returns: (success, new_balance)
"""
# 保留原 PointsAPI 的入参失败语义;限额校验由 xapi 承担。
if not isinstance(amount, int) or amount <= 0:
return False, 0
if not user_id or not source:
return False, 0
data = await self._request(
"POST",
"add",
payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason},
)
return self._change_result(data)
async def spend_points(
self, user_id: str, amount: int, source: str, reason: str = None
) -> Tuple[bool, int]:
"""Spend points from user account.
Returns: (success, new_balance)
"""
# Parameter validation
if not isinstance(amount, int) or amount <= 0:
return False, 0
if not user_id or not source:
return False, 0
# Operation limit validation
if self.config.POINTS_MAX_PER_OPERATION > 0:
if amount > self.config.POINTS_MAX_PER_OPERATION:
return False, 0
def _spend():
with self._lock:
conn = self.db.get_connection()
try:
cursor = conn.cursor()
# Ensure user exists
self.db.ensure_user_exists(user_id, conn)
# Get current balance
cursor.execute(
"SELECT points FROM user_points WHERE user_id = ?",
(user_id,),
)
row = cursor.fetchone()
current_balance = row["points"] if row else 0
# Check sufficient balance
if current_balance < amount:
conn.rollback()
return False, current_balance
# Update balance and total_spent
new_balance = current_balance - amount
now = datetime.now().isoformat()
cursor.execute(
"""
UPDATE user_points
SET points = ?, total_spent = total_spent + ?, updated_at = ?
WHERE user_id = ?
""",
(new_balance, amount, now, user_id),
)
# Write transaction log (amount as negative)
cursor.execute(
"""
INSERT INTO point_transactions
(user_id, amount, balance_after, source, reason, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(user_id, -amount, new_balance, source, reason, now),
)
conn.commit()
return True, new_balance
except Exception as e:
conn.rollback()
logger.error(f"spend_points failed for {user_id}: {e}")
return False, 0
finally:
conn.close()
return await asyncio.to_thread(_spend)
"""Spend points from user account.
Returns: (success, new_balance)
"""
# 保留原 PointsAPI 的入参失败语义;余额与限额校验由 xapi 承担。
if not isinstance(amount, int) or amount <= 0:
return False, 0
if not user_id or not source:
return False, 0
data = await self._request(
"POST",
"spend",
payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason},
)
return self._change_result(data)
async def set_points(
self, user_id: str, amount: int, source: str, reason: str = None
) -> Tuple[bool, int]:
"""Set user's points to exact amount.
Returns: (success, new_balance)
"""
# Parameter validation
if not isinstance(amount, int) or amount < 0:
return False, 0
if not user_id or not source:
return False, 0
def _set():
with self._lock:
conn = self.db.get_connection()
try:
cursor = conn.cursor()
# Ensure user exists
self.db.ensure_user_exists(user_id, conn)
# Get current balance
cursor.execute(
"SELECT points, total_earned FROM user_points WHERE user_id = ?",
(user_id,),
)
row = cursor.fetchone()
current_balance = row["points"] if row else 0
current_earned = row["total_earned"] if row else 0
# If new value equals old value, return without writing
if current_balance == amount:
conn.rollback()
return True, amount
# Calculate difference for total_earned (only positive diff)
diff = amount - current_balance
earned_diff = max(0, diff)
# Update balance and total_earned
now = datetime.now().isoformat()
cursor.execute(
"""
UPDATE user_points
SET points = ?, total_earned = total_earned + ?, updated_at = ?
WHERE user_id = ?
""",
(amount, earned_diff, now, user_id),
)
# Write transaction log
cursor.execute(
"""
INSERT INTO point_transactions
(user_id, amount, balance_after, source, reason, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(user_id, diff, amount, source, reason, now),
)
conn.commit()
return True, amount
except Exception as e:
conn.rollback()
logger.error(f"set_points failed for {user_id}: {e}")
return False, 0
finally:
conn.close()
return await asyncio.to_thread(_set)
"""Set user's points to exact amount.
Returns: (success, new_balance)
"""
# set 仍保持原 PointsAPI 行为:只校验非负,不做余额上限判断。
if not isinstance(amount, int) or amount < 0:
return False, 0
if not user_id or not source:
return False, 0
data = await self._request(
"POST",
"set",
payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason},
)
return self._change_result(data)
async def get_transactions(
self, user_id: str, limit: int = 20, offset: int = 0
) -> List[Dict[str, Any]]:
"""Get transaction history for a user.
Returns: List of transaction dicts
"""
# Normalize parameters
limit = max(1, min(100, limit))
offset = max(0, offset)
def _get():
conn = self.db.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
SELECT id, user_id, amount, balance_after, source, reason, created_at
FROM point_transactions
WHERE user_id = ?
ORDER BY id DESC
LIMIT ? OFFSET ?
""",
(user_id, limit, offset),
)
rows = cursor.fetchall()
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"get_transactions failed for {user_id}: {e}")
return []
finally:
conn.close()
return await asyncio.to_thread(_get)
"""Get transaction history for a user.
Returns: List of transaction dicts
"""
limit = max(1, min(100, limit))
offset = max(0, offset)
data = await self._request(
"GET",
"transactions",
params={"user_id": user_id, "limit": limit, "offset": offset},
)
if data is None:
return []
items = data.get("items", [])
return items if isinstance(items, list) else []
async def get_ranking(
self, limit: int = 10, order_by: str = "points"
) -> List[Dict[str, Any]]:
"""Get points ranking.
Returns: List of ranking dicts with rank field
"""
# Normalize parameters
limit = max(1, min(100, limit))
if order_by not in ("points", "total_earned"):
order_by = "points"
def _get():
conn = self.db.get_connection()
try:
cursor = conn.cursor()
order_column = "points" if order_by == "points" else "total_earned"
query = f"""
SELECT
RANK() OVER (ORDER BY {order_column} DESC) as rank,
user_id,
points,
total_earned,
total_spent
FROM user_points
ORDER BY {order_column} DESC, user_id ASC
LIMIT ?
"""
cursor.execute(query, (limit,))
rows = cursor.fetchall()
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"get_ranking failed: {e}")
return []
finally:
conn.close()
return await asyncio.to_thread(_get)
"""Get points ranking.
Returns: List of ranking dicts with rank field
"""
limit = max(1, min(100, limit))
if order_by not in ("points", "total_earned"):
order_by = "points"
data = await self._request(
"GET",
"ranking",
params={"limit": limit, "order_by": order_by},
)
if data is None:
return []
items = data.get("items", [])
return items if isinstance(items, list) else []
def _change_result(self, data: Optional[Dict[str, Any]]) -> Tuple[bool, int]:
"""解析 add/spend/set 响应并维持旧失败返回值。"""
if data is None:
return False, 0
return bool(data.get("success")), int(data.get("balance", 0) or 0)

View File

@@ -1,6 +1,5 @@
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pathlib import Path
import os
@@ -11,22 +10,24 @@ class Config(BaseSettings):
extra="ignore",
)
# 数据库配置
POINTS_DB_FILE: str = os.getenv("DANDING_POINTS_DB_FILE", "data/danding_points/points.db")
POINTS_MAX_BALANCE: int = int(os.getenv("DANDING_POINTS_MAX_BALANCE", "0"))
POINTS_MAX_PER_OPERATION: int = int(os.getenv("DANDING_POINTS_MAX_PER_OPERATION", "0"))
POINTS_LOG_RETENTION_DAYS: int = int(os.getenv("DANDING_POINTS_LOG_RETENTION_DAYS", "365"))
# xapi /bot/points 运行时 API 配置
POINTS_API_HOST: str = os.getenv("DANDING_POINTS_API_HOST", "https://api.danding.vip/bot/points")
BOT_USER: str = os.getenv("DANDING_BOT_USER", "1424473282")
BOT_TOKEN: str = os.getenv(
"DANDING_BOT_TOKEN",
os.getenv("DANDING_API_TOKEN", os.getenv("BOT_TOKEN", "")),
)
@field_validator("POINTS_MAX_BALANCE", "POINTS_MAX_PER_OPERATION", "POINTS_LOG_RETENTION_DAYS")
@field_validator("POINTS_API_HOST")
@classmethod
def validate_non_negative(cls, v):
if v < 0:
raise ValueError("Value must be non-negative")
return v
def validate_api_host(cls, value):
if not value:
raise ValueError("POINTS_API_HOST cannot be empty")
return value.rstrip("/")
@field_validator("POINTS_DB_FILE")
@field_validator("BOT_USER")
@classmethod
def validate_db_path(cls, v):
if not v:
raise ValueError("Database file path cannot be empty")
return v
def validate_bot_user(cls, value):
if not value:
raise ValueError("BOT_USER cannot be empty")
return value

View File

@@ -1,106 +0,0 @@
import sqlite3
import os
from datetime import datetime
from typing import Optional, List, Dict, Any
from .config import Config
class PointsDatabase:
"""SQLite database handler for points system."""
def __init__(self, config: Config):
self.config = config
self.db_path = config.POINTS_DB_FILE
self._ensure_db_dir()
self._init_db()
def _ensure_db_dir(self):
"""Create database directory if it doesn't exist."""
db_dir = os.path.dirname(self.db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
def _init_db(self):
"""Initialize database tables."""
conn = sqlite3.connect(self.db_path, timeout=5.0)
cursor = conn.cursor()
# Create user_points table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_points (
user_id TEXT PRIMARY KEY,
points INTEGER NOT NULL DEFAULT 0 CHECK(points >= 0),
total_earned INTEGER NOT NULL DEFAULT 0,
total_spent INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"""
)
# Create point_transactions table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS point_transactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
amount INTEGER NOT NULL,
balance_after INTEGER NOT NULL,
source TEXT NOT NULL,
reason TEXT,
created_at TEXT NOT NULL
)
"""
)
# Create indexes
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_transactions_user_id ON point_transactions(user_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_transactions_source ON point_transactions(source)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_transactions_created_at ON point_transactions(created_at)"
)
conn.commit()
conn.close()
def get_connection(self) -> sqlite3.Connection:
"""Get a database connection."""
conn = sqlite3.connect(self.db_path, timeout=5.0)
conn.row_factory = sqlite3.Row
return conn
def get_user_balance(self, user_id: str) -> int:
"""Get user's current points balance."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT points FROM user_points WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
return row["points"] if row else 0
finally:
conn.close()
def ensure_user_exists(self, user_id: str, conn=None) -> None:
"""Create user account if it doesn't exist. Reuses provided conn if given."""
should_close = False
if conn is None:
conn = self.get_connection()
should_close = True
cursor = conn.cursor()
now = datetime.now().isoformat()
cursor.execute(
"""
INSERT OR IGNORE INTO user_points
(user_id, points, total_earned, total_spent, created_at, updated_at)
VALUES (?, 0, 0, 0, ?, ?)
""",
(user_id, now, now),
)
if should_close:
conn.commit()
conn.close()

View File

@@ -1,5 +1,7 @@
import logging
import asyncio
import logging
import asyncio
from datetime import datetime
from uuid import uuid4
from nonebot.adapters.onebot.v11 import Bot, Event, GroupMessageEvent, Message, MessageSegment, PrivateMessageEvent
@@ -27,8 +29,8 @@ async def _get_user_name(bot: Bot, scope: str, user_id: str) -> str:
group_id = int(scope.split("_", 1)[1])
info = await bot.get_group_member_info(group_id=group_id, user_id=int(user_id))
return info.get("card") or info.get("nickname") or user_id
except Exception:
pass
except Exception as exc:
logger.debug("获取赛马用户昵称失败 scope=%s user_id=%s error=%s", scope, user_id, exc)
return user_id
@@ -142,13 +144,14 @@ async def _is_admin_or_owner(bot: Bot, event: Event) -> bool:
user_id=int(event.get_user_id()),
)
return member_info.get("role", "") in ("admin", "owner")
except Exception:
return False
except Exception as exc:
logger.debug("检查赛马管理员权限失败 user_id=%s error=%s", getattr(event, "user_id", ""), exc)
return False
def _build_point_changes(room: Room, odds: dict[str, float]) -> tuple[dict[str, int], dict[str, str]]:
point_changes: dict[str, int] = {}
def _build_point_changes(room: Room, odds: dict[str, float]) -> tuple[dict[str, int], dict[str, str]]:
point_changes: dict[str, int] = {}
for horse in room.horses.values():
point_changes[horse.owner_id] = point_changes.get(horse.owner_id, 0) + config.PARTICIPANT_REWARD
@@ -168,8 +171,23 @@ def _build_point_changes(room: Room, odds: dict[str, float]) -> tuple[dict[str,
point_summaries = {
user_id: _describe_points_delta(delta)
for user_id, delta in point_changes.items()
}
return point_changes, point_summaries
}
return point_changes, point_summaries
def _build_participants_snapshot(room: Room) -> list[str]:
"""生成赛果归档所需的参赛马名快照。"""
return [horse.name for horse in _get_horses_in_order(room)]
def _build_bet_distribution(room: Room) -> dict[str, int]:
"""按马名汇总下注分布,供 xapi 原样归档。"""
distribution = {horse.name: 0 for horse in _get_horses_in_order(room)}
for bet in room.bets:
distribution[bet.horse_name] = distribution.get(bet.horse_name, 0) + bet.amount
return distribution
async def _send_to_scope(bot: Bot, scope: str, message: str, message_type: str = "race_update", critical: bool = False):
@@ -243,10 +261,10 @@ async def settle_race(room: Room) -> tuple[RaceResult, dict[str, float]] | None:
for bet in room.bets:
user_ids.add(bet.user_id)
# Record pre-balances
pre_balances: dict[str, int] = {}
for uid in user_ids:
pre_balances[uid] = points_service.get_balance(uid)
# Record pre-balances
pre_balances: dict[str, int] = {}
for uid in user_ids:
pre_balances[uid] = await points_service.get_balance(uid)
# 1. Reward all participants
for horse in room.horses.values():
@@ -271,10 +289,10 @@ async def settle_race(room: Room) -> tuple[RaceResult, dict[str, float]] | None:
except Exception as e:
logger.warning(f"payout_winnings failed for {bet.user_id}: {e}")
# Record post-balances and compute deltas
post_balances: dict[str, int] = {}
for uid in user_ids:
post_balances[uid] = points_service.get_balance(uid)
# Record post-balances and compute deltas
post_balances: dict[str, int] = {}
for uid in user_ids:
post_balances[uid] = await points_service.get_balance(uid)
point_changes: dict[str, int] = {}
for uid in user_ids:
@@ -285,13 +303,20 @@ async def settle_race(room: Room) -> tuple[RaceResult, dict[str, float]] | None:
# Build human-readable summaries
_, point_change_summaries = _build_point_changes(room, odds)
result = RaceResult(
champion_name=room.champion_name,
champion_owner=champion.owner_id,
point_changes=point_changes,
point_change_summaries=point_change_summaries,
)
return result, odds
result = RaceResult(
race_id=str(uuid4()),
scope=room.scope,
champion_name=room.champion_name,
champion_owner=champion.owner_id,
participants=_build_participants_snapshot(room),
bet_distribution=_build_bet_distribution(room),
duration_ticks=room.tick_count,
completed_at=datetime.now(),
point_changes=point_changes,
point_change_summaries=point_change_summaries,
odds_snapshot=odds,
)
return result, odds
async def run_race_with_settlement(bot: Bot, room: Room, scope: str):
@@ -343,12 +368,15 @@ async def run_race_with_settlement(bot: Bot, room: Room, scope: str):
if result:
result_lines.extend(await _format_point_change_lines(room, result.point_changes, result.point_change_summaries, name_map))
await message_service.recall_previous_of_type(bot, scope, "race_update")
await _send_to_scope(bot, scope, "\n".join(result_lines), "race_result")
race_engine.stop_race(scope)
room_store.delete_room(scope)
message_service.clear_pending_recalls(scope)
await message_service.recall_previous_of_type(bot, scope, "race_update")
await _send_to_scope(bot, scope, "\n".join(result_lines), "race_result")
if result:
await room_store.save_race_result(result)
race_engine.stop_race(scope)
room_store.delete_room(scope)
message_service.clear_pending_recalls(scope)
# Import and re-export access functions from access.py (canonical source)

View File

@@ -1,6 +1,7 @@
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
import json
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
import json
import os
class Config(BaseSettings):
@@ -43,8 +44,34 @@ class Config(BaseSettings):
}
)
# 数据库配置
RACE_DB_FILE: str = "data/group_horse_racing/race.db"
# 数据库配置
RACE_DB_FILE: str = "data/group_horse_racing/race.db"
# xapi /bot/race 运行时 API 配置
RACE_API_HOST: str = os.getenv("DANDING_RACE_API_HOST", "https://api.danding.vip/bot/race")
BOT_USER: str = os.getenv("DANDING_BOT_USER", "1424473282")
BOT_TOKEN: str = os.getenv(
"DANDING_BOT_TOKEN",
os.getenv("DANDING_API_TOKEN", os.getenv("BOT_TOKEN", "")),
)
@field_validator("RACE_API_HOST")
@classmethod
def validate_race_api_host(cls, value):
"""规范化 xapi 赛马运行时 API 地址。"""
if not value:
raise ValueError("RACE_API_HOST cannot be empty")
return value.rstrip("/")
@field_validator("BOT_USER")
@classmethod
def validate_bot_user(cls, value):
"""Bot 鉴权用户不能为空。"""
if not value:
raise ValueError("BOT_USER cannot be empty")
return value
@field_validator("TESTERS", "TEST_GROUPS", "ALLOWED_GROUPS", mode="before")
@classmethod

View File

@@ -54,6 +54,7 @@ class RaceResult:
race_id: str = ""
scope: str = ""
participants: list[str] = field(default_factory=list)
bet_distribution: dict[str, int] = field(default_factory=dict)
duration_ticks: int = 0
completed_at: datetime = field(default_factory=datetime.now)
bet_distribution: dict[str, int] = field(default_factory=dict)
duration_ticks: int = 0
completed_at: datetime = field(default_factory=datetime.now)
odds_snapshot: dict[str, float] = field(default_factory=dict)

View File

@@ -1,15 +1,19 @@
import asyncio
import aiosqlite
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
from .models import Room, RoomState, RaceResult
from .config import Config
class RoomStore:
import asyncio
import aiosqlite
import aiohttp
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
from .models import Room, RoomState, RaceResult
from .config import Config
logger = logging.getLogger(__name__)
class RoomStore:
def __init__(self, config: Config):
self.config = config
self.rooms: dict[str, Room] = {}
@@ -37,46 +41,7 @@ class RoomStore:
)
""")
await db.execute("""
CREATE TABLE IF NOT EXISTS race_history (
race_id TEXT PRIMARY KEY,
scope TEXT NOT NULL,
champion_name TEXT NOT NULL,
champion_owner TEXT NOT NULL,
participants TEXT NOT NULL,
bet_distribution TEXT NOT NULL,
duration_ticks INTEGER NOT NULL,
completed_at TEXT NOT NULL,
point_changes TEXT DEFAULT '{}',
point_change_summaries TEXT DEFAULT '{}',
odds_snapshot TEXT DEFAULT '{}'
)
""")
await db.execute("""
CREATE TABLE IF NOT EXISTS user_horse_names (
user_id TEXT PRIMARY KEY,
horse_name TEXT NOT NULL
)
""")
# Add missing columns if they don't exist (for existing databases)
try:
await db.execute("SELECT point_changes FROM race_history LIMIT 1")
except aiosqlite.OperationalError:
await db.execute("ALTER TABLE race_history ADD COLUMN point_changes TEXT DEFAULT '{}'")
try:
await db.execute("SELECT point_change_summaries FROM race_history LIMIT 1")
except aiosqlite.OperationalError:
await db.execute("ALTER TABLE race_history ADD COLUMN point_change_summaries TEXT DEFAULT '{}'")
try:
await db.execute("SELECT odds_snapshot FROM race_history LIMIT 1")
except aiosqlite.OperationalError:
await db.execute("ALTER TABLE race_history ADD COLUMN odds_snapshot TEXT DEFAULT '{}'")
await db.commit()
await db.commit()
self._initialized = True
@@ -95,12 +60,69 @@ class RoomStore:
async def close(self):
"""Close database connection on shutdown."""
if self._db is not None:
await self._db.close()
self._db = None
async def load_rooms(self):
"""Restore active rooms from DB snapshots on startup."""
if self._db is not None:
await self._db.close()
self._db = None
def _url(self, path: str) -> str:
"""拼接 /bot/race 端点地址。"""
return f"{self.config.RACE_API_HOST}/{path.lstrip('/')}"
def _auth(self) -> dict[str, str]:
"""生成 xapi Bot 鉴权参数。"""
return {
"user": self.config.BOT_USER,
"token": self.config.BOT_TOKEN,
}
async def _request(
self,
method: str,
path: str,
*,
payload: Optional[dict[str, Any]] = None,
params: Optional[dict[str, Any]] = None,
) -> Optional[dict[str, Any]]:
"""调用 xapi /bot/race并只向上层暴露 data。"""
request_url = self._url(path)
timeout = aiohttp.ClientTimeout(total=10)
try:
async with aiohttp.ClientSession() as session:
if method == "GET":
request_params = {**self._auth(), **(params or {})}
async with session.get(request_url, params=request_params, timeout=timeout) as resp:
return await self._parse_response(resp, path)
request_payload = {**self._auth(), **(payload or {})}
if method == "PUT":
async with session.put(request_url, json=request_payload, timeout=timeout) as resp:
return await self._parse_response(resp, path)
async with session.post(request_url, json=request_payload, timeout=timeout) as resp:
return await self._parse_response(resp, path)
except aiohttp.ClientError as exc:
logger.error("race api request failed path=%s error=%s", path, exc)
return None
except asyncio.TimeoutError as exc:
logger.error("race api request timeout path=%s error=%s", path, exc)
return None
async def _parse_response(self, resp: aiohttp.ClientResponse, path: str) -> Optional[dict[str, Any]]:
"""解析 xapi 统一响应,失败时返回 None 维持旧调用方失败语义。"""
if resp.status != 200:
logger.error("race api bad status path=%s status=%s", path, resp.status)
return None
body = await resp.json()
if body.get("code") != 200:
logger.error("race api fail path=%s code=%s message=%s", path, body.get("code"), body.get("message"))
return None
data = body.get("data")
return data if isinstance(data, dict) else None
async def load_rooms(self):
"""Restore active rooms from DB snapshots on startup."""
await self.ensure_initialized()
db = await self._get_db()
cursor = await db.execute(
@@ -167,24 +189,23 @@ class RoomStore:
if scope in self.rooms:
del self.rooms[scope]
async def get_last_horse_name(self, user_id: str) -> Optional[str]:
await self.ensure_initialized()
db = await self._get_db()
cursor = await db.execute(
"SELECT horse_name FROM user_horse_names WHERE user_id = ?",
(user_id,)
)
row = await cursor.fetchone()
return row[0] if row else None
async def set_last_horse_name(self, user_id: str, horse_name: str):
await self.ensure_initialized()
db = await self._get_db()
await db.execute(
"INSERT OR REPLACE INTO user_horse_names (user_id, horse_name) VALUES (?, ?)",
(user_id, horse_name),
)
await db.commit()
async def get_last_horse_name(self, user_id: str) -> Optional[str]:
"""从 xapi 读取用户最后使用马名。"""
data = await self._request("GET", "horse-name", params={"user_id": user_id})
if data is None:
return None
horse_name = data.get("horse_name")
return str(horse_name) if horse_name else None
async def set_last_horse_name(self, user_id: str, horse_name: str):
"""将用户最后使用马名写入 xapi。"""
await self._request(
"PUT",
"horse-name",
payload={"user_id": user_id, "horse_name": horse_name},
)
async def _save_snapshot(self, room: Room):
"""Save room snapshot to database."""
@@ -226,31 +247,28 @@ class RoomStore:
))
await db.commit()
async def save_race_result(self, result: RaceResult):
"""Save race result to history."""
await self.ensure_initialized()
db = await self._get_db()
await db.execute("""
INSERT INTO race_history
(race_id, scope, champion_name, champion_owner, participants,
bet_distribution, duration_ticks, completed_at,
point_changes, point_change_summaries, odds_snapshot)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
result.race_id,
result.scope,
result.champion_name,
result.champion_owner,
json.dumps(result.participants),
json.dumps(result.bet_distribution),
result.duration_ticks,
result.completed_at.isoformat(),
json.dumps(getattr(result, 'point_changes', {})),
json.dumps(getattr(result, 'point_change_summaries', {})),
json.dumps(getattr(result, 'odds_snapshot', {})),
))
await db.commit()
async def save_race_result(self, result: RaceResult):
"""将完整赛果写入 xapi。"""
data = await self._request(
"POST",
"history",
payload={
"race_id": result.race_id,
"scope": result.scope,
"champion_name": result.champion_name,
"champion_owner": result.champion_owner,
"participants": result.participants,
"bet_distribution": result.bet_distribution,
"duration_ticks": result.duration_ticks,
"completed_at": result.completed_at.isoformat(),
"point_changes": result.point_changes,
"point_change_summaries": result.point_change_summaries,
"odds_snapshot": result.odds_snapshot,
},
)
if data is None:
raise RuntimeError(f"赛马赛果写入 xapi 失败: race_id={result.race_id}")
# Module-level singleton instance

View File

@@ -1,7 +1,7 @@
import os
import logging
import random
from nonebot import on_command, on_startswith
import os
import logging
import random
from nonebot import get_driver, on_command, on_startswith
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, MessageEvent, Message
from nonebot.adapters.onebot.v11.message import MessageSegment
from nonebot.typing import T_State
@@ -29,12 +29,22 @@ ACHIEVEMENT_COMMANDS = config.ACHIEVEMENT_COMMANDS
INTRO_COMMANDS = config.INTRO_COMMANDS
DAILY_LIMIT = config.DAILY_LIMIT
gacha_system = GachaSystem()
logger = logging.getLogger(__name__)
gacha_system = GachaSystem()
logger = logging.getLogger(__name__)
SIGN_IN_MIN_POINTS = 1
SIGN_IN_MAX_POINTS = 100
SIGN_IN_SOURCE = "gacha_sign"
SIGN_IN_REASON = "抽卡签到"
SIGN_IN_REASON = "抽卡签到"
@get_driver().on_startup
async def load_gacha_shikigami_data() -> None:
"""启动时从 xapi 拉取式神基础数据缓存。"""
try:
await gacha_system.data_manager.refresh_shikigami_data()
except Exception:
logger.exception("启动拉取抽卡式神缓存失败")
# 检查是否允许使用功能的规则
def check_permission() -> Rule:
@@ -52,14 +62,11 @@ def check_permission() -> Rule:
return Rule(_checker)
async def try_handle_daily_sign_in(matcher, user_id: str, user_name: str) -> None:
"""处理抽卡成功后的每日签到,不影响主流程"""
try:
if gacha_system.data_manager.has_signed_in_today(user_id):
return
points = random.randint(SIGN_IN_MIN_POINTS, SIGN_IN_MAX_POINTS)
success, new_balance = await points_api.add_points(
async def try_handle_daily_sign_in(matcher, user_id: str, user_name: str) -> None:
"""处理抽卡成功后的每日签到,不影响主流程"""
try:
points = random.randint(SIGN_IN_MIN_POINTS, SIGN_IN_MAX_POINTS)
success, new_balance = await points_api.add_points(
user_id,
points,
SIGN_IN_SOURCE,
@@ -69,13 +76,23 @@ async def try_handle_daily_sign_in(matcher, user_id: str, user_name: str) -> Non
logger.error("抽卡签到积分发放失败 user_id=%s points=%s", user_id, points)
return
if not gacha_system.data_manager.record_sign_in(user_id, points):
logger.warning("抽卡签到落库冲突,积分已发放但签到记录重复 user_id=%s", user_id)
return
await matcher.send(format_sign_in_message(user_id, user_name, points, new_balance))
except Exception:
logger.exception("处理抽卡签到失败 user_id=%s", user_id)
if not await gacha_system.data_manager.record_sign_in(user_id, points):
logger.warning("抽卡签到落库冲突,积分已发放但签到记录重复 user_id=%s", user_id)
return
await matcher.send(format_sign_in_message(user_id, user_name, points, new_balance))
except Exception:
logger.exception("处理抽卡签到失败 user_id=%s", user_id)
async def claim_achievement_after_reward(user_id: str, achievement_id: str, reward_success: bool) -> None:
"""成就奖励发放成功后标记 xapi reward_claimed。"""
if not reward_success:
return
claimed = await gacha_system.data_manager.claim_achievement_reward(user_id, achievement_id)
if not claimed:
logger.warning("成就奖励已发放但 claim 标记失败 user_id=%s achievement_id=%s", user_id, achievement_id)
# 注册抽卡命令,添加权限检查规则
gacha_matcher = on_command("抽卡", aliases=set(GACHA_COMMANDS), priority=10, rule=check_permission())
@@ -86,7 +103,7 @@ async def handle_gacha(bot: Bot, event: MessageEvent, state: T_State):
user_name = event.sender.card if isinstance(event, GroupMessageEvent) else event.sender.nickname
# 执行抽卡
result = gacha_system.draw(user_id)
result = await gacha_system.draw(user_id)
if not result["success"]:
await gacha_matcher.finish(format_user_mention(user_id, user_name) + "" + result["message"])
@@ -127,8 +144,9 @@ async def handle_gacha(bot: Bot, event: MessageEvent, state: T_State):
has_manual_rewards = False
for achievement_id in unlocked_achievements:
# 尝试自动发放成就奖励
auto_success, reward_msg = await process_achievement_reward(user_id, achievement_id)
# 尝试自动发放成就奖励
auto_success, reward_msg = await process_achievement_reward(user_id, achievement_id)
await claim_achievement_after_reward(user_id, achievement_id, auto_success)
# 检查是否是重复奖励
if "_repeat_" in achievement_id:
@@ -161,7 +179,7 @@ async def handle_gacha(bot: Bot, event: MessageEvent, state: T_State):
msg.append("💰 未自动发放的奖励请联系管理员\n")
# 添加成就进度提示
achievement_data = gacha_system.get_user_achievements(user_id)
achievement_data = await gacha_system.get_user_achievements(user_id)
if achievement_data["success"]:
progress = achievement_data["progress"]
consecutive_days = progress.get("consecutive_days", 0)
@@ -226,13 +244,13 @@ async def handle_gacha(bot: Bot, event: MessageEvent, state: T_State):
await try_handle_daily_sign_in(gacha_matcher, user_id, user_name)
return
async def notify_admin(bot: Bot, message: str):
"""通知管理员"""
admin_id = 2185330092
try:
await bot.send_private_msg(user_id=admin_id, message=message)
except Exception as e:
pass # 忽略通知失败的错误
async def notify_admin(bot: Bot, message: str):
"""通知管理员"""
admin_id = 2185330092
try:
await bot.send_private_msg(user_id=admin_id, message=message)
except Exception as e:
logger.debug("通知管理员失败: %s", e)
# 注册查询命令,添加权限检查规则
stats_matcher = on_command("我的抽卡", aliases=set(STATS_COMMANDS), priority=5, rule=check_permission())
@@ -252,7 +270,7 @@ async def handle_stats(bot: Bot, event: MessageEvent, state: T_State):
user_name = event.sender.card if isinstance(event, GroupMessageEvent) else event.sender.nickname
# 获取用户统计
stats = gacha_system.get_user_stats(user_id)
stats = await gacha_system.get_user_stats(user_id)
if not stats["success"]:
await stats_matcher.finish(format_user_mention(user_id, user_name) + " " + stats["message"])
@@ -294,7 +312,7 @@ async def handle_triple_gacha(bot: Bot, event: MessageEvent, state: T_State):
user_name = event.sender.card or event.sender.nickname or "未知用户"
# 执行三连抽
result = gacha_system.triple_draw(user_id)
result = await gacha_system.triple_draw(user_id)
if not result["success"]:
await triple_gacha_matcher.finish(f"{result['message']}")
@@ -338,8 +356,9 @@ async def handle_triple_gacha(bot: Bot, event: MessageEvent, state: T_State):
has_manual_rewards = False
for achievement_id in unlocked_achievements:
# 尝试自动发放成就奖励
auto_success, reward_msg = await process_achievement_reward(user_id, achievement_id)
# 尝试自动发放成就奖励
auto_success, reward_msg = await process_achievement_reward(user_id, achievement_id)
await claim_achievement_after_reward(user_id, achievement_id, auto_success)
# 检查是否是重复奖励
if "_repeat_" in achievement_id:
@@ -372,7 +391,7 @@ async def handle_triple_gacha(bot: Bot, event: MessageEvent, state: T_State):
msg.append("💰 未自动发放的奖励请联系管理员\n")
# 添加成就进度提示
achievement_data = gacha_system.get_user_achievements(user_id)
achievement_data = await gacha_system.get_user_achievements(user_id)
if achievement_data["success"]:
progress = achievement_data["progress"]
consecutive_days = progress.get("consecutive_days", 0)
@@ -442,7 +461,7 @@ async def handle_achievement(bot: Bot, event: MessageEvent, state: T_State):
user_name = event.sender.card or event.sender.nickname or "未知用户"
# 获取用户成就信息
result = gacha_system.get_user_achievements(user_id)
result = await gacha_system.get_user_achievements(user_id)
if not result["success"]:
await achievement_matcher.finish(f"{result['message']}")
@@ -561,7 +580,7 @@ async def handle_query(bot: Bot, event: MessageEvent, state: T_State):
target_user_name = event.sender.card if isinstance(event, GroupMessageEvent) else event.sender.nickname
# 获取用户统计
stats = gacha_system.get_user_stats(target_user_id)
stats = await gacha_system.get_user_stats(target_user_id)
# 构建响应消息
msg = Message()
@@ -623,7 +642,7 @@ rank_matcher = on_startswith(("抽卡排行","抽卡榜"), priority=1, rule=chec
@rank_matcher.handle()
async def handle_rank(bot: Bot, event: MessageEvent, state: T_State):
# 获取排行榜数据
rank_data = gacha_system.get_rank_list()
rank_data = await gacha_system.get_rank_list()
if not rank_data:
await rank_matcher.finish("暂无抽卡排行榜数据")
@@ -668,7 +687,7 @@ async def handle_rank(bot: Bot, event: MessageEvent, state: T_State):
@daily_stats_matcher.handle()
async def handle_daily_stats(bot: Bot, event: MessageEvent, state: T_State):
"""处理今日抽卡统计命令"""
result = gacha_system.get_daily_stats()
result = await gacha_system.get_daily_stats()
if not result["success"]:
await daily_stats_matcher.finish(f"{result['message']}")

View File

@@ -1,7 +1,7 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import model_validator
import os
import logging
from pydantic import field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
import os
import logging
logger = logging.getLogger("onmyoji_gacha")
@@ -115,13 +115,24 @@ class Config(BaseSettings):
WEB_ADMIN_TOKEN: str = os.getenv("WEB_ADMIN_TOKEN", "onmyoji_admin_token_2024")
WEB_ADMIN_PORT: int = int(os.getenv("WEB_ADMIN_PORT", "8080"))
# 蛋定服务器对接配置
DD_API_HOST: str = "https://api.danding.vip/DD/"
BOT_TOKEN: str = os.getenv("ONMYOJI_BOT_TOKEN", os.getenv("BOT_TOKEN", "")) # 必须设置
BOT_USER_ID: str = "1424473282"
# 蛋定服务器对接配置
DD_API_HOST: str = "https://api.danding.vip/DD/"
GACHA_API_HOST: str = os.getenv("DANDING_GACHA_API_HOST", "https://api.danding.vip/bot/gacha")
BOT_TOKEN: str = os.getenv(
"DANDING_BOT_TOKEN",
os.getenv("ONMYOJI_BOT_TOKEN", os.getenv("DANDING_API_TOKEN", os.getenv("BOT_TOKEN", ""))),
)
BOT_USER_ID: str = os.getenv("DANDING_BOT_USER", "1424473282")
# 时区
TIMEZONE: str = "Asia/Shanghai"
TIMEZONE: str = "Asia/Shanghai"
@field_validator("GACHA_API_HOST")
@classmethod
def validate_gacha_api_host(cls, value):
if not value:
raise ValueError("GACHA_API_HOST cannot be empty")
return value.rstrip("/")
@model_validator(mode="after")
def _warn_default_token(self):

View File

@@ -1,615 +1,291 @@
"""
阴阳师抽卡插件 - 数据管理模块
管理抽卡数据持久化,包括:
- SQLite数据库操作
- 用户抽卡记录管理
- 每日签到记录
- 统计查询
TODO(代码评审 2026-05-03): 本模块承担了数据文件IO + 缓存 + 业务规则三重职责,
后续应拆分为: data_io(纯文件读写) / data_cache(内存缓存层) / data_rules(业务规则校验)。
当前拆分风险较大(影响面广),暂维持现状。
TODO(第二轮评审 2026-05-03): 补充建议拆分方案:
- achievement_manager.py: 成就定义加载 + 进度计算 + 奖励发放 (~150行)
- record_manager.py: 记录归档 + 统计查询 + 每日数据 (~100行)
- data_manager.py: 核心用户数据IO + 缓存管理 (~359行)
拆分为独立PR不阻塞当前修复。
"""
import os
import json
import sqlite3
import datetime
from typing import Dict, List, Any, Optional
import logging
from pathlib import Path
from .config import Config
# 创建Config实例
config = Config()
class DataManager:
"""抽卡数据管理器,封装所有数据库操作"""
def __init__(self):
# 确保目录存在
os.makedirs(os.path.dirname(config.DB_FILE), exist_ok=True)
# 初始化数据库
self._init_db()
# 加载式神数据
self.shikigami_data = self._load_shikigami_data()
def _init_db(self):
"""初始化数据库"""
with sqlite3.connect(config.DB_FILE) as conn:
cursor = conn.cursor()
# 创建式神表
cursor.execute("""
CREATE TABLE IF NOT EXISTS shikigami (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
rarity TEXT NOT NULL,
image_path TEXT NOT NULL
)
""")
# 创建每日抽卡记录表
cursor.execute("""
CREATE TABLE IF NOT EXISTS daily_draws (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TEXT NOT NULL,
user_id TEXT NOT NULL,
rarity TEXT NOT NULL,
shikigami_id INTEGER NOT NULL,
timestamp TEXT NOT NULL,
FOREIGN KEY (shikigami_id) REFERENCES shikigami(id)
)
""")
# 创建用户统计表
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_stats (
user_id TEXT PRIMARY KEY,
total_draws INTEGER DEFAULT 0,
R_count INTEGER DEFAULT 0,
SR_count INTEGER DEFAULT 0,
SSR_count INTEGER DEFAULT 0,
SP_count INTEGER DEFAULT 0
)
""")
# 创建抽卡历史表
cursor.execute("""
CREATE TABLE IF NOT EXISTS draw_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
date TEXT NOT NULL,
rarity TEXT NOT NULL,
shikigami_id INTEGER NOT NULL,
FOREIGN KEY (user_id) REFERENCES user_stats(user_id),
FOREIGN KEY (shikigami_id) REFERENCES shikigami(id)
)
""")
# 创建成就表
cursor.execute("""
CREATE TABLE IF NOT EXISTS achievements (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
achievement_id TEXT NOT NULL,
unlocked_date TEXT NOT NULL,
reward_claimed INTEGER DEFAULT 0,
UNIQUE(user_id, achievement_id)
)
""")
# 创建用户成就进度表
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_achievement_progress (
user_id TEXT PRIMARY KEY,
consecutive_days INTEGER DEFAULT 0,
last_draw_date TEXT DEFAULT '',
no_ssr_streak INTEGER DEFAULT 0,
total_consecutive_days INTEGER DEFAULT 0
)
""")
self._init_sign_in_table(cursor)
conn.commit()
def _init_sign_in_table(self, cursor: sqlite3.Cursor) -> None: # OK
"""创建每日签到表"""
cursor.execute("""
CREATE TABLE IF NOT EXISTS daily_sign_in (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
sign_date TEXT NOT NULL,
points_awarded INTEGER NOT NULL,
created_at TEXT NOT NULL,
UNIQUE(user_id, sign_date)
)
""")
def update_achievement_progress(self, user_id: str, rarity: str) -> List[str]: # type: ignore[return]
"""更新用户成就进度,返回新解锁的成就列表"""
today = self.get_today_date()
unlocked_achievements = []
with sqlite3.connect(config.DB_FILE) as conn:
cursor = conn.cursor()
# 获取或创建用户成就进度
cursor.execute(
"SELECT * FROM user_achievement_progress WHERE user_id = ?",
(user_id,)
)
progress = cursor.fetchone()
if not progress:
cursor.execute(
"INSERT INTO user_achievement_progress (user_id, last_draw_date) VALUES (?, ?)",
(user_id, today)
)
consecutive_days = 1
no_ssr_streak = 1 if rarity not in ["SSR", "SP"] else 0
total_consecutive_days = 1
else:
last_draw_date = progress[2]
consecutive_days = progress[1]
no_ssr_streak = progress[3]
total_consecutive_days = progress[4]
# 更新连续抽卡天数
if last_draw_date != today:
# 检查是否是连续的一天
last_date = datetime.datetime.strptime(last_draw_date, "%Y-%m-%d")
current_date = datetime.datetime.strptime(today, "%Y-%m-%d")
days_diff = (current_date - last_date).days
if days_diff == 1:
consecutive_days += 1
total_consecutive_days += 1
elif days_diff > 1:
consecutive_days = 1
total_consecutive_days += 1
# days_diff == 0 表示今天已经抽过卡了,不更新连续天数
# 更新无SSR连击数
if rarity in ["SSR", "SP"]:
no_ssr_streak = 0
else:
no_ssr_streak += 1
# 更新进度
cursor.execute("""
INSERT OR REPLACE INTO user_achievement_progress
(user_id, consecutive_days, last_draw_date, no_ssr_streak, total_consecutive_days)
VALUES (?, ?, ?, ?, ?)
""", (user_id, consecutive_days, today, no_ssr_streak, total_consecutive_days))
# 检查是否解锁新成就
for achievement_id, achievement_config in config.ACHIEVEMENTS.items():
# 对于可重复获得的成就(勤勤恳恳系列),需要特殊处理
if achievement_config.get("repeatable", False) and achievement_config["type"] == "consecutive_days":
# 检查连续抽卡成就的升级逻辑
if consecutive_days >= achievement_config["threshold"]:
# 检查是否已经解锁过这个等级
cursor.execute(
"SELECT id FROM achievements WHERE user_id = ? AND achievement_id = ?",
(user_id, achievement_id)
)
if not cursor.fetchone():
# 解锁新等级的成就
cursor.execute("""
INSERT INTO achievements (user_id, achievement_id, unlocked_date)
VALUES (?, ?, ?)
""", (user_id, achievement_id, today))
unlocked_achievements.append(achievement_id)
# 如果是最高等级(Ⅴ),检查是否需要给重复奖励
elif achievement_config["level"] == 5 and consecutive_days >= 150:
# 每30天给一次重复奖励
days_over_150 = consecutive_days - 150
if days_over_150 > 0 and days_over_150 % 30 == 0:
# 检查这个重复奖励是否已经给过
repeat_id = f"{achievement_id}_repeat_{days_over_150//30}"
cursor.execute(
"SELECT id FROM achievements WHERE user_id = ? AND achievement_id = ?",
(user_id, repeat_id)
)
if not cursor.fetchone():
cursor.execute("""
INSERT INTO achievements (user_id, achievement_id, unlocked_date)
VALUES (?, ?, ?)
""", (user_id, repeat_id, today))
unlocked_achievements.append(achievement_id)
else:
# 非重复成就的原有逻辑
# 检查是否已经解锁
cursor.execute(
"SELECT id FROM achievements WHERE user_id = ? AND achievement_id = ?",
(user_id, achievement_id)
)
if cursor.fetchone():
continue
# 检查成就条件
unlocked = False
if achievement_config["type"] == "consecutive_days":
if consecutive_days >= achievement_config["threshold"]:
unlocked = True
elif achievement_config["type"] == "no_ssr_streak":
if no_ssr_streak >= achievement_config["threshold"]:
unlocked = True
if unlocked:
cursor.execute("""
INSERT INTO achievements (user_id, achievement_id, unlocked_date)
VALUES (?, ?, ?)
""", (user_id, achievement_id, today))
unlocked_achievements.append(achievement_id)
conn.commit()
return unlocked_achievements
def get_user_achievements(self, user_id: str) -> Dict[str, Any]:
"""获取用户成就信息"""
with sqlite3.connect(config.DB_FILE) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# 获取已解锁的成就
cursor.execute(
"SELECT achievement_id, unlocked_date, reward_claimed FROM achievements WHERE user_id = ?",
(user_id,)
)
unlocked = {row["achievement_id"]: {
"unlocked_date": row["unlocked_date"],
"reward_claimed": bool(row["reward_claimed"])
} for row in cursor.fetchall()}
# 获取进度
cursor.execute(
"SELECT * FROM user_achievement_progress WHERE user_id = ?",
(user_id,)
)
progress_row = cursor.fetchone()
if not progress_row:
progress = {
"consecutive_days": 0,
"no_ssr_streak": 0,
"total_consecutive_days": 0
}
else:
progress = {
"consecutive_days": progress_row["consecutive_days"],
"no_ssr_streak": progress_row["no_ssr_streak"],
"total_consecutive_days": progress_row["total_consecutive_days"]
}
return {
"unlocked": unlocked,
"progress": progress
}
def claim_achievement_reward(self, user_id: str, achievement_id: str) -> bool:
"""领取成就奖励"""
with sqlite3.connect(config.DB_FILE) as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE achievements
SET reward_claimed = 1
WHERE user_id = ? AND achievement_id = ? AND reward_claimed = 0
""", (user_id, achievement_id))
conn.commit()
return cursor.rowcount > 0
def _load_shikigami_data(self) -> Dict[str, List[Dict[str, str]]]:
"""加载式神数据到数据库"""
result = {"R": [], "SR": [], "SSR": [], "SP": []}
rarity_dirs = {
"R": "r",
"SR": "sr",
"SSR": "ssr",
"SP": "sp"
}
with sqlite3.connect(config.DB_FILE) as conn:
cursor = conn.cursor()
# 清空现有式神数据
cursor.execute("DELETE FROM shikigami")
for rarity, dir_name in rarity_dirs.items():
dir_path = os.path.join(config.SHIKIGAMI_IMG_DIR, dir_name)
if os.path.exists(dir_path):
for file_name in os.listdir(dir_path):
if file_name.endswith(('.png', '.jpg', '.jpeg')):
name = os.path.splitext(file_name)[0]
image_path = os.path.join(dir_path, file_name)
# 插入式神数据
cursor.execute(
"INSERT INTO shikigami (name, rarity, image_path) VALUES (?, ?, ?)",
(name, rarity, image_path)
)
result[rarity].append({
"name": name,
"image_url": image_path
})
conn.commit()
return result
def get_today_date(self) -> str:
"""获取当前日期字符串"""
return datetime.datetime.now().strftime("%Y-%m-%d")
def has_signed_in_today(self, user_id: str) -> bool:
"""检查用户今天是否已签到"""
today = self.get_today_date()
with sqlite3.connect(config.DB_FILE) as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT 1 FROM daily_sign_in WHERE user_id = ? AND sign_date = ? LIMIT 1",
(user_id, today),
)
return cursor.fetchone() is not None
def record_sign_in(self, user_id: str, points_awarded: int) -> bool:
"""记录每日签到重复签到返回False"""
today = self.get_today_date()
created_at = datetime.datetime.now().isoformat()
try:
with sqlite3.connect(config.DB_FILE) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO daily_sign_in (user_id, sign_date, points_awarded, created_at)
VALUES (?, ?, ?, ?)
""", (user_id, today, points_awarded, created_at))
conn.commit()
return True
except sqlite3.IntegrityError:
return False
def get_current_time(self) -> str:
"""获取当前时间字符串"""
return datetime.datetime.now().strftime("%H:%M:%S")
def get_daily_draws(self) -> Dict[str, Dict[str, List[Dict[str, str]]]]:
"""获取每日抽卡记录"""
result = {}
today = self.get_today_date()
with sqlite3.connect(config.DB_FILE) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# 先查询今日的抽卡记录
cursor.execute("""
SELECT date, user_id, rarity, shikigami_id, timestamp
FROM daily_draws
WHERE date = ?
ORDER BY timestamp
""", (today,))
rows = cursor.fetchall()
# 获取所有涉及的式神ID
shikigami_ids = list(set(row["shikigami_id"] for row in rows))
# 查询式神信息
shikigami_info = {}
if shikigami_ids:
placeholders = ','.join('?' * len(shikigami_ids))
cursor.execute(f"""
SELECT id, name, rarity
FROM shikigami
WHERE id IN ({placeholders})
""", shikigami_ids)
for shikigami_row in cursor.fetchall():
shikigami_info[shikigami_row["id"]] = {
"name": shikigami_row["name"],
"rarity": shikigami_row["rarity"]
}
# 构建结果
for row in rows:
date = row["date"]
user_id = row["user_id"]
shikigami_id = row["shikigami_id"]
if date not in result:
result[date] = {}
if user_id not in result[date]:
result[date][user_id] = []
# 如果找不到式神信息使用daily_draws表中的稀有度和默认名称
if shikigami_id in shikigami_info:
name = shikigami_info[shikigami_id]["name"]
rarity = shikigami_info[shikigami_id]["rarity"]
else:
name = f"式神{shikigami_id}"
rarity = row["rarity"]
result[date][user_id].append({
"rarity": rarity,
"name": name,
"timestamp": row["timestamp"]
})
return result
def save_daily_draws(self, data: Dict[str, Dict[str, List[Dict[str, str]]]]):
"""保存每日抽卡记录"""
# SQLite实现中此方法为空因为记录时直接插入数据库
pass
def get_user_stats(self) -> Dict[str, Dict[str, Any]]:
"""获取用户统计数据"""
result = {}
with sqlite3.connect(config.DB_FILE) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# 获取基础统计
cursor.execute("SELECT * FROM user_stats")
user_stats = cursor.fetchall()
for stat in user_stats:
user_id = stat["user_id"]
result[user_id] = {
"total_draws": stat["total_draws"],
"R_count": stat["R_count"],
"SR_count": stat["SR_count"],
"SSR_count": stat["SSR_count"],
"SP_count": stat["SP_count"],
"draw_history": []
}
# 获取抽卡历史
cursor.execute("""
SELECT draw_history.date, draw_history.rarity, shikigami.name
FROM draw_history
JOIN shikigami ON draw_history.shikigami_id = shikigami.id
WHERE draw_history.user_id = ?
ORDER BY draw_history.date DESC
LIMIT 100
""", (user_id,))
history = cursor.fetchall()
result[user_id]["draw_history"] = [
{
"date": row["date"],
"rarity": row["rarity"],
"name": row["name"]
} for row in history
]
return result
def save_user_stats(self, data: Dict[str, Dict[str, Any]]):
"""保存用户统计数据"""
# SQLite实现中此方法为空因为统计时直接更新数据库
pass
def check_daily_limit(self, user_id: str) -> bool:
"""检查用户是否达到每日抽卡限制"""
today = self.get_today_date()
with sqlite3.connect(config.DB_FILE) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT COUNT(*)
FROM daily_draws
WHERE date = ? AND user_id = ?
""", (today, user_id))
count = cursor.fetchone()[0]
return count < config.DAILY_LIMIT
def get_draws_left(self, user_id: str) -> int:
"""获取用户今日剩余抽卡次数"""
today = self.get_today_date()
with sqlite3.connect(config.DB_FILE) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT COUNT(*)
FROM daily_draws
WHERE date = ? AND user_id = ?
""", (today, user_id))
count = cursor.fetchone()[0]
return max(0, config.DAILY_LIMIT - count)
def record_draw(self, user_id: str, rarity: str, shikigami_name: str) -> List[str]:
"""记录一次抽卡,返回新解锁的成就列表"""
today = self.get_today_date()
current_time = self.get_current_time()
with sqlite3.connect(config.DB_FILE) as conn:
cursor = conn.cursor()
# 获取式神ID
cursor.execute(
"SELECT id FROM shikigami WHERE name = ? AND rarity = ?",
(shikigami_name, rarity)
)
shikigami_id = cursor.fetchone()
if not shikigami_id:
logging.error(f"找不到式神: {shikigami_name} ({rarity})")
return []
shikigami_id = shikigami_id[0]
# 记录每日抽卡
cursor.execute("""
INSERT INTO daily_draws (date, user_id, rarity, shikigami_id, timestamp)
VALUES (?, ?, ?, ?, ?)
""", (today, user_id, rarity, shikigami_id, current_time))
# 更新用户统计
cursor.execute("""
INSERT OR IGNORE INTO user_stats (user_id) VALUES (?)
""", (user_id,))
cursor.execute("""
UPDATE user_stats
SET total_draws = total_draws + 1,
R_count = R_count + ?,
SR_count = SR_count + ?,
SSR_count = SSR_count + ?,
SP_count = SP_count + ?
WHERE user_id = ?
""", (
1 if rarity == "R" else 0,
1 if rarity == "SR" else 0,
1 if rarity == "SSR" else 0,
1 if rarity == "SP" else 0,
user_id
))
# 添加抽卡历史
cursor.execute("""
INSERT INTO draw_history (user_id, date, rarity, shikigami_id)
VALUES (?, ?, ?, ?)
""", (user_id, today, rarity, shikigami_id))
# 保持历史记录不超过100条
cursor.execute("""
DELETE FROM draw_history
WHERE user_id = ? AND id NOT IN (
SELECT id FROM draw_history
WHERE user_id = ?
ORDER BY date DESC
LIMIT 100
)
""", (user_id, user_id))
conn.commit()
# 更新成就进度
unlocked_achievements = self.update_achievement_progress(user_id, rarity)
return unlocked_achievements
"""
阴阳师抽卡插件 - xapi 数据管理模块
本模块只负责调用 xapi /bot/gacha 运行时 API。抽卡概率、奖励发放和 QQ 消息编排
仍由 nonebot 插件本地负责。
"""
from __future__ import annotations
import asyncio
import datetime
import logging
from typing import Any, Dict, List, Optional
import aiohttp
from .config import Config
logger = logging.getLogger(__name__)
config = Config()
class DataManager:
"""抽卡数据管理器,封装 /bot/gacha HTTP 调用。"""
def __init__(self):
self.shikigami_data: Dict[str, List[Dict[str, Any]]] = {"R": [], "SR": [], "SSR": [], "SP": []}
def _url(self, path: str) -> str:
"""拼接 /bot/gacha 端点地址。"""
return f"{config.GACHA_API_HOST}/{path.lstrip('/')}"
def _auth(self) -> Dict[str, str]:
"""生成 xapi Bot 鉴权参数。"""
return {
"user": config.BOT_USER_ID,
"token": config.BOT_TOKEN,
}
async def _request(
self,
method: str,
path: str,
*,
payload: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""调用 xapi /bot/gacha并只向上层暴露 data。"""
request_url = self._url(path)
timeout = aiohttp.ClientTimeout(total=10)
try:
async with aiohttp.ClientSession() as session:
if method == "GET":
request_params = {**self._auth(), **(params or {})}
async with session.get(request_url, params=request_params, timeout=timeout) as resp:
return await self._parse_response(resp, path)
request_payload = {**self._auth(), **(payload or {})}
async with session.post(request_url, json=request_payload, timeout=timeout) as resp:
return await self._parse_response(resp, path)
except aiohttp.ClientError as exc:
logger.error("gacha api request failed path=%s error=%s", path, exc)
return None
except asyncio.TimeoutError as exc:
logger.error("gacha api request timeout path=%s error=%s", path, exc)
return None
async def _parse_response(self, resp: aiohttp.ClientResponse, path: str) -> Optional[Dict[str, Any]]:
"""解析 xapi 统一响应,失败时返回 None 维持旧调用方失败语义。"""
if resp.status != 200:
logger.error("gacha api bad status path=%s status=%s", path, resp.status)
return None
body = await resp.json()
if body.get("code") != 200:
logger.error("gacha api fail path=%s code=%s message=%s", path, body.get("code"), body.get("message"))
return None
data = body.get("data")
return data if isinstance(data, dict) else None
async def refresh_shikigami_data(self) -> Dict[str, List[Dict[str, Any]]]:
"""从 xapi 拉取式神基础数据并按稀有度缓存。"""
data = await self._request("GET", "shikigami")
items = data.get("items", []) if data else []
grouped: Dict[str, List[Dict[str, Any]]] = {"R": [], "SR": [], "SSR": [], "SP": []}
for item in items:
rarity = item.get("rarity")
if rarity not in grouped:
continue
image_path = item.get("image_path") or item.get("image_url") or ""
grouped[rarity].append(
{
"id": item.get("id"),
"name": item.get("name"),
"rarity": rarity,
"image_path": image_path,
"image_url": image_path,
}
)
self.shikigami_data = grouped
return self.shikigami_data
async def ensure_shikigami_data(self) -> Dict[str, List[Dict[str, Any]]]:
"""确保式神缓存已加载。"""
if not any(self.shikigami_data.values()):
await self.refresh_shikigami_data()
return self.shikigami_data
def get_today_date(self) -> str:
"""获取当前日期字符串。"""
return datetime.datetime.now().strftime("%Y-%m-%d")
def get_current_time(self) -> str:
"""获取当前时间字符串。"""
return datetime.datetime.now().strftime("%H:%M:%S")
def _find_shikigami(self, rarity: str, shikigami_name: str) -> Optional[Dict[str, Any]]:
"""从本地缓存查找 xapi 托管式神。"""
for item in self.shikigami_data.get(rarity, []):
if item.get("name") == shikigami_name:
return item
return None
async def get_draws_left(self, user_id: str) -> int:
"""获取用户今日剩余抽卡次数。"""
data = await self._request("GET", "draws-left", params={"user_id": user_id})
if data is None:
return 0
return int(data.get("draws_left", 0) or 0)
async def check_daily_limit(self, user_id: str) -> bool:
"""检查用户是否还有抽卡次数。"""
return await self.get_draws_left(user_id) > 0
async def record_draw_result(self, user_id: str, rarity: str, shikigami: Dict[str, Any]) -> Dict[str, Any]:
"""写入一次抽卡并返回 xapi 原始业务结果。"""
data = await self._request(
"POST",
"draw",
payload={
"user_id": user_id,
"shikigami_id": int(shikigami["id"]),
"rarity": rarity,
"name": shikigami["name"],
},
)
if data is None:
return {"success": False, "message": "抽卡记录写入失败"}
return data
async def record_triple_draw_result(self, user_id: str, draws: List[Dict[str, Any]]) -> Dict[str, Any]:
"""写入三连抽并返回 xapi 原始业务结果。"""
payload_draws = [
{
"shikigami_id": int(item["id"]),
"rarity": item["rarity"],
"name": item["name"],
}
for item in draws
]
data = await self._request("POST", "draw/triple", payload={"user_id": user_id, "draws": payload_draws})
if data is None:
return {"success": False, "message": "三连抽记录写入失败"}
return data
async def record_draw(self, user_id: str, rarity: str, shikigami_name: str) -> List[str]:
"""记录一次抽卡,返回新解锁的成就列表。"""
await self.ensure_shikigami_data()
shikigami = self._find_shikigami(rarity, shikigami_name)
if not shikigami:
logger.error("找不到式神: %s (%s)", shikigami_name, rarity)
return []
result = await self.record_draw_result(user_id, rarity, shikigami)
if not result.get("success"):
logger.error("抽卡记录写入失败 user_id=%s message=%s", user_id, result.get("message"))
return []
return result.get("unlocked_achievements", [])
async def record_sign_in(self, user_id: str, points_awarded: int) -> bool:
"""记录每日签到,重复签到返回 False。"""
data = await self._request(
"POST",
"sign-in",
payload={"user_id": user_id, "points_awarded": points_awarded},
)
if data is None:
return False
return bool(data.get("success")) and not bool(data.get("signed_already"))
async def get_user_stats(self, user_id: str) -> Dict[str, Any]:
"""获取用户抽卡统计。"""
data = await self._request("GET", "user-stats", params={"user_id": user_id})
return data or {"success": False, "message": "您还没有抽卡记录哦!"}
async def get_daily_stats(self, date: Optional[str] = None) -> Dict[str, Any]:
"""获取指定日期抽卡统计。"""
params = {"date": date} if date else {}
data = await self._request("GET", "daily-stats", params=params)
return data or {"success": False, "message": "今日还没有人抽卡哦!"}
async def get_rank(self, limit: int = 10) -> List[Dict[str, Any]]:
"""获取抽卡排行榜。"""
data = await self._request("GET", "rank", params={"limit": max(1, min(100, limit))})
if data is None:
return []
items = data.get("items", [])
return items if isinstance(items, list) else []
async def get_user_achievements(self, user_id: str) -> Dict[str, Any]:
"""获取用户成就信息。"""
data = await self._request("GET", f"achievements/{user_id}")
if data is None:
return {
"unlocked": {},
"progress": {
"consecutive_days": 0,
"no_ssr_streak": 0,
"total_consecutive_days": 0,
},
}
return {
"unlocked": data.get("achievements", {}),
"progress": data.get("progress", {}),
}
async def claim_achievement_reward(self, user_id: str, achievement_id: str) -> bool:
"""标记成就奖励已领取。"""
data = await self._request("POST", f"achievements/{user_id}/claim", payload={"achievement_id": achievement_id})
return bool(data and data.get("success"))
async def get_daily_records(self, date: Optional[str] = None) -> Dict[str, Any]:
"""获取每日详细抽卡记录。"""
params = {"date": date} if date else {}
data = await self._request("GET", "records/daily", params=params)
return data or {"success": False, "date": date or self.get_today_date(), "records": [], "total_count": 0}
async def get_daily_draws(self, date: Optional[str] = None) -> Dict[str, Dict[str, List[Dict[str, str]]]]:
"""按旧结构返回每日抽卡记录。"""
data = await self.get_daily_records(date)
result: Dict[str, Dict[str, List[Dict[str, str]]]] = {}
if not data.get("success"):
return result
target_date = data.get("date") or date or self.get_today_date()
result[target_date] = {}
for record in data.get("records", []):
user_id = record.get("user_id")
if not user_id:
continue
result[target_date].setdefault(user_id, []).append(
{
"rarity": record.get("rarity", ""),
"name": record.get("shikigami_name", ""),
"timestamp": record.get("draw_time", ""),
}
)
return result
async def has_signed_in_today(self, user_id: str) -> bool:
"""保留旧方法名;当前无独立查询端点,签到去重由 xapi sign-in 写接口处理。"""
return False
def save_daily_draws(self, data: Dict[str, Dict[str, List[Dict[str, str]]]]) -> None:
"""兼容旧空方法,运行时不写本地文件。"""
return None
def save_user_stats(self, data: Dict[str, Dict[str, Any]]) -> None:
"""兼容旧空方法,运行时不写本地文件。"""
return None

View File

@@ -24,44 +24,47 @@ class GachaSystem:
def __init__(self):
self.data_manager = data_manager
def draw(self, user_id: str) -> Dict[str, Any]:
"""执行一次抽卡"""
# 检查抽卡限制
if not self.data_manager.check_daily_limit(user_id):
draws_left = self.data_manager.get_draws_left(user_id)
return {
"success": False,
"message": f"您今日的抽卡次数已用完,每日限制{config.DAILY_LIMIT}次,明天再来吧!"
}
async def draw(self, user_id: str) -> Dict[str, Any]:
"""执行一次抽卡"""
# 检查抽卡限制
if not await self.data_manager.check_daily_limit(user_id):
draws_left = await self.data_manager.get_draws_left(user_id)
return {
"success": False,
"message": f"您今日的抽卡次数已用完,每日限制{config.DAILY_LIMIT}次,明天再来吧!"
}
# 抽取稀有度传递用户ID
rarity = self._draw_rarity(user_id)
# 从该稀有度中抽取式神
shikigami_data = self.data_manager.shikigami_data.get(rarity, [])
if not shikigami_data:
return {
"success": False,
"message": f"系统错误:{rarity}稀有度下没有可用式神"
}
await self.data_manager.ensure_shikigami_data()
shikigami_data = self.data_manager.shikigami_data.get(rarity, [])
if not shikigami_data:
return {
"success": False,
"message": f"系统错误:{rarity}稀有度下没有可用式神"
}
# 随机选择式神
shikigami = random.choice(shikigami_data)
# 记录抽卡
unlocked_achievements = self.data_manager.record_draw(user_id, rarity, shikigami["name"])
# 剩余次数
draws_left = self.data_manager.get_draws_left(user_id)
# xapi 只负责写入抽卡数据;奖励副作用仍由 nonebot handler 编排。
record_result = await self.data_manager.record_draw_result(user_id, rarity, shikigami)
if not record_result.get("success"):
return {
"success": False,
"message": record_result.get("message", "抽卡记录写入失败")
}
return {
"success": True,
"rarity": rarity,
"name": shikigami["name"],
"image_url": shikigami["image_url"],
"draws_left": draws_left,
"unlocked_achievements": unlocked_achievements
}
"success": True,
"rarity": record_result.get("rarity", rarity),
"name": record_result.get("name", shikigami["name"]),
"image_url": record_result.get("image_url") or record_result.get("image_path") or shikigami["image_url"],
"draws_left": record_result.get("draws_left", 0),
"unlocked_achievements": record_result.get("unlocked_achievements", [])
}
def _draw_rarity(self, user_id: str = None) -> str:
"""按概率抽取稀有度"""
@@ -82,145 +85,39 @@ class GachaSystem:
# 默认返回R理论上不会执行到这里
return "R"
def get_user_stats(self, user_id: str) -> Dict:
"""获取用户抽卡统计"""
user_stats = self.data_manager.get_user_stats()
if user_id not in user_stats:
return {
"success": False,
"message": "您还没有抽卡记录哦!"
}
stats = user_stats[user_id]
return {
"success": True,
"total_draws": stats["total_draws"],
"R_count": stats["R_count"],
"SR_count": stats["SR_count"],
"SSR_count": stats["SSR_count"],
"SP_count": stats["SP_count"],
"recent_draws": stats["draw_history"][-5:] if stats["draw_history"] else []
}
async def get_user_stats(self, user_id: str) -> Dict:
"""获取用户抽卡统计"""
return await self.data_manager.get_user_stats(user_id)
def get_probability_text(self) -> str:
"""获取概率展示文本"""
probs = config.RARITY_PROBABILITY
return f"--- 系统概率 ---\nR: {probs['R']}% | SR: {probs['SR']}% | SSR: {probs['SSR']}% | SP: {probs['SP']}%"
def get_rank_list(self) -> List[Tuple[str, Dict[str, int]]]:
"""获取抽卡排行榜数据"""
user_stats = self.data_manager.get_user_stats()
# 过滤有SSR/SP记录的用户
ranked_users = [
(user_id, stats)
for user_id, stats in user_stats.items()
if stats.get("SSR_count", 0) > 0 or stats.get("SP_count", 0) > 0
]
# 按SSR+SP总数降序排序
ranked_users.sort(
key=lambda x: (x[1].get("SSR_count", 0) + x[1].get("SP_count", 0)),
reverse=True
)
return ranked_users
def get_daily_stats(self) -> Dict:
"""获取今日抽卡统计"""
daily_draws = self.data_manager.get_daily_draws()
today = self.data_manager.get_today_date()
if not daily_draws or today not in daily_draws:
return {
"success": False,
"message": "今日还没有人抽卡哦!"
}
today_stats = daily_draws[today]
total_stats = {
"total_users": len(today_stats),
"total_draws": 0,
"R_count": 0,
"SR_count": 0,
"SSR_count": 0,
"SP_count": 0,
"user_stats": []
}
# 统计每个用户的抽卡情况
for user_id, draws in today_stats.items():
user_stats = {
"user_id": user_id,
"total_draws": len(draws),
"R_count": sum(1 for d in draws if d["rarity"] == "R"),
"SR_count": sum(1 for d in draws if d["rarity"] == "SR"),
"SSR_count": sum(1 for d in draws if d["rarity"] == "SSR"),
"SP_count": sum(1 for d in draws if d["rarity"] == "SP")
}
# 更新总统计
total_stats["total_draws"] += user_stats["total_draws"]
total_stats["R_count"] += user_stats["R_count"]
total_stats["SR_count"] += user_stats["SR_count"]
total_stats["SSR_count"] += user_stats["SSR_count"]
total_stats["SP_count"] += user_stats["SP_count"]
# 只记录抽到SSR或SP的用户
if user_stats["SSR_count"] > 0 or user_stats["SP_count"] > 0:
total_stats["user_stats"].append(user_stats)
# 按SSR+SP数量排序用户统计
total_stats["user_stats"].sort(
key=lambda x: (x["SSR_count"] + x["SP_count"]),
reverse=True
)
# 构建稀有度统计
rarity_stats = {
"R": total_stats["R_count"],
"SR": total_stats["SR_count"],
"SSR": total_stats["SSR_count"],
"SP": total_stats["SP_count"]
}
# 构建排行榜数据
top_users = []
for user_stat in total_stats["user_stats"]:
top_users.append({
"user_id": user_stat["user_id"],
"ssr_count": user_stat["SSR_count"] + user_stat["SP_count"]
})
final_stats = {
"total_users": total_stats["total_users"],
"total_draws": total_stats["total_draws"],
"rarity_stats": rarity_stats,
"top_users": top_users
}
return {
"success": True,
"date": today,
"stats": final_stats
}
def triple_draw(self, user_id: str) -> Dict:
"""执行三连抽"""
# 检查是否有足够的抽卡次数
draws_left = self.data_manager.get_draws_left(user_id)
if draws_left < 3:
return {
"success": False,
async def get_rank_list(self) -> List[Tuple[str, Dict[str, int]]]:
"""获取抽卡排行榜数据"""
items = await self.data_manager.get_rank(limit=10)
return [(item["user_id"], item) for item in items]
async def get_daily_stats(self) -> Dict:
"""获取今日抽卡统计"""
return await self.data_manager.get_daily_stats()
async def triple_draw(self, user_id: str) -> Dict:
"""执行三连抽"""
# 检查是否有足够的抽卡次数
draws_left = await self.data_manager.get_draws_left(user_id)
if draws_left < 3:
return {
"success": False,
"message": f"抽卡次数不足,您今日还剩{draws_left}次抽卡机会三连抽需要3次机会"
}
results = []
all_unlocked_achievements = []
# 执行三次抽卡
for i in range(3):
await self.data_manager.ensure_shikigami_data()
# 执行三次本地概率抽取,统一提交 xapi 三连写入端点。
for i in range(3):
# 抽取稀有度传递用户ID
rarity = self._draw_rarity(user_id)
@@ -235,29 +132,30 @@ class GachaSystem:
# 随机选择式神
shikigami = random.choice(shikigami_data)
# 记录抽卡
unlocked_achievements = self.data_manager.record_draw(user_id, rarity, shikigami["name"])
all_unlocked_achievements.extend(unlocked_achievements)
results.append({
"rarity": rarity,
"name": shikigami["name"],
"image_url": shikigami["image_url"]
})
# 剩余次数
draws_left = self.data_manager.get_draws_left(user_id)
return {
"success": True,
"results": results,
"draws_left": draws_left,
"unlocked_achievements": list(set(all_unlocked_achievements)) # 去重
}
def get_user_achievements(self, user_id: str) -> Dict:
"""获取用户成就信息"""
achievement_data = self.data_manager.get_user_achievements(user_id)
results.append({
"id": shikigami["id"],
"rarity": rarity,
"name": shikigami["name"],
"image_url": shikigami["image_url"]
})
record_result = await self.data_manager.record_triple_draw_result(user_id, results)
if not record_result.get("success"):
return {
"success": False,
"message": record_result.get("message", "三连抽记录写入失败")
}
return {
"success": True,
"results": record_result.get("results", results),
"draws_left": record_result.get("draws_left", 0),
"unlocked_achievements": record_result.get("unlocked_achievements", [])
}
async def get_user_achievements(self, user_id: str) -> Dict:
"""获取用户成就信息"""
achievement_data = await self.data_manager.get_user_achievements(user_id)
if not achievement_data["unlocked"] and all(v == 0 for v in achievement_data["progress"].values()):
return {
@@ -271,48 +169,6 @@ class GachaSystem:
"progress": achievement_data["progress"]
}
def get_daily_detailed_records(self, date: Optional[str] = None) -> Dict:
"""获取每日详细抽卡记录"""
if not date:
date = self.data_manager.get_today_date()
daily_draws = self.data_manager.get_daily_draws()
if not daily_draws or date not in daily_draws:
return {
"success": False,
"message": f"{date} 没有抽卡记录"
}
records = []
for user_id, draws in daily_draws[date].items():
for draw in draws:
# 检查这次抽卡是否解锁了成就
unlocked_achievements = []
draw_time = draw.get("timestamp", "未知时间")
# 获取用户成就信息
achievement_data = self.data_manager.get_user_achievements(user_id)
if achievement_data["unlocked"]:
# 检查是否有在抽卡时间之后解锁的成就
for achievement_id, achievement_info in achievement_data["unlocked"].items():
if achievement_info["unlocked_date"] == f"{date} {draw_time}":
unlocked_achievements.append(achievement_id)
records.append({
"user_id": user_id,
"draw_time": draw_time,
"shikigami_name": draw["name"],
"rarity": draw["rarity"],
"unlocked_achievements": unlocked_achievements
})
# 按时间排序
records.sort(key=lambda x: x["draw_time"])
return {
"success": True,
"date": date,
"records": records,
"total_count": len(records)
}
async def get_daily_detailed_records(self, date: Optional[str] = None) -> Dict:
"""获取每日详细抽卡记录"""
return await self.data_manager.get_daily_records(date)

View File

@@ -82,9 +82,9 @@ async def admin_page(request: Request):
# API 端点
@router.get("/api/stats/daily", response_model=DailyStatsResponse, dependencies=[Depends(verify_admin_token)])
async def get_daily_stats():
"""获取今日抽卡统计"""
result = gacha_system.get_daily_stats()
async def get_daily_stats():
"""获取今日抽卡统计"""
result = await gacha_system.get_daily_stats()
if not result["success"]:
return result
@@ -95,9 +95,9 @@ async def get_daily_stats():
}
@router.get("/api/stats/user/{user_id}", response_model=UserStatsResponse, dependencies=[Depends(verify_admin_token)])
async def get_user_stats(user_id: str):
"""获取用户抽卡统计"""
result = gacha_system.get_user_stats(user_id)
async def get_user_stats(user_id: str):
"""获取用户抽卡统计"""
result = await gacha_system.get_user_stats(user_id)
if not result["success"]:
return {
"success": False,
@@ -122,9 +122,9 @@ async def get_user_stats(user_id: str):
}
@router.get("/api/stats/rank", response_model=RankListResponse, dependencies=[Depends(verify_admin_token)])
async def get_rank_list():
"""获取抽卡排行榜"""
rank_data = gacha_system.get_rank_list()
async def get_rank_list():
"""获取抽卡排行榜"""
rank_data = await gacha_system.get_rank_list()
# 转换数据格式
formatted_data = []
@@ -145,9 +145,9 @@ async def get_rank_list():
}
@router.get("/api/achievements/{user_id}", response_model=AchievementResponse, dependencies=[Depends(verify_admin_token)])
async def get_user_achievements(user_id: str):
"""获取用户成就信息"""
result = gacha_system.get_user_achievements(user_id)
async def get_user_achievements(user_id: str):
"""获取用户成就信息"""
result = await gacha_system.get_user_achievements(user_id)
if not result["success"]:
return {
"success": False,
@@ -164,9 +164,9 @@ async def get_user_achievements(user_id: str):
}
@router.get("/api/records/daily", response_model=DailyDetailedRecordsResponse, dependencies=[Depends(verify_admin_token)])
async def get_daily_detailed_records(date: Optional[str] = None):
"""获取每日详细抽卡记录"""
result = gacha_system.get_daily_detailed_records(date)
async def get_daily_detailed_records(date: Optional[str] = None):
"""获取每日详细抽卡记录"""
result = await gacha_system.get_daily_detailed_records(date)
if not result["success"]:
return {
"success": False,

View File

@@ -0,0 +1,217 @@
"""danding_points HTTP PointsAPI 测试。"""
from __future__ import annotations
import aiohttp
import importlib.util
import pytest
import sys
import types
from pathlib import Path
def load_points_modules():
"""直接加载 danding_points 子模块,避免测试环境缺 nonebot 时执行插件元数据。"""
plugin_dir = Path(__file__).resolve().parents[1] / "danding_bot" / "plugins" / "danding_points"
package_name = "_danding_points_under_test"
package = types.ModuleType(package_name)
package.__path__ = [str(plugin_dir)]
sys.modules[package_name] = package
for module_name in ("config", "api"):
full_name = f"{package_name}.{module_name}"
spec = importlib.util.spec_from_file_location(full_name, plugin_dir / f"{module_name}.py")
module = importlib.util.module_from_spec(spec)
sys.modules[full_name] = module
assert spec and spec.loader
spec.loader.exec_module(module)
return sys.modules[f"{package_name}.api"], sys.modules[f"{package_name}.config"]
api_module, config_module = load_points_modules()
PointsAPI = api_module.PointsAPI
Config = config_module.Config
class FakeResponse:
"""模拟 aiohttp 响应上下文。"""
def __init__(self, payload, status=200):
self.payload = payload
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def json(self):
return self.payload
class FakeSession:
"""记录请求参数的 aiohttp ClientSession 替身。"""
def __init__(self, responses, calls):
self.responses = responses
self.calls = calls
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
def get(self, url, params=None, timeout=None):
self.calls.append({"method": "GET", "url": url, "params": params, "timeout": timeout})
return FakeResponse(self.responses.pop(0))
def post(self, url, json=None, timeout=None):
self.calls.append({"method": "POST", "url": url, "json": json, "timeout": timeout})
return FakeResponse(self.responses.pop(0))
def make_points_api() -> PointsAPI:
return PointsAPI(
Config(
POINTS_API_HOST="http://xapi.test/bot/points/",
BOT_USER="robot",
BOT_TOKEN="secret",
)
)
@pytest.fixture
def fake_aiohttp(monkeypatch):
calls = []
responses = []
monkeypatch.setattr(api_module.aiohttp, "ClientSession", lambda: FakeSession(responses, calls))
return responses, calls
@pytest.mark.asyncio
async def test_get_balance_sends_auth_in_query(fake_aiohttp):
responses, calls = fake_aiohttp
responses.append({"code": 200, "message": "", "data": {"balance": 88}})
points = make_points_api()
balance = await points.get_balance("10001")
assert balance == 88
assert not hasattr(points, "db")
assert calls[0]["method"] == "GET"
assert calls[0]["url"] == "http://xapi.test/bot/points/balance"
assert calls[0]["params"] == {"user": "robot", "token": "secret", "user_id": "10001"}
@pytest.mark.asyncio
async def test_add_spend_set_send_auth_in_post_body(fake_aiohttp):
responses, calls = fake_aiohttp
responses.extend(
[
{"code": 200, "message": "", "data": {"success": True, "balance": 10}},
{"code": 200, "message": "", "data": {"success": False, "balance": 7}},
{"code": 200, "message": "", "data": {"success": True, "balance": 99}},
]
)
points = make_points_api()
add_result = await points.add_points("10001", 10, "gacha_sign", "签到")
spend_result = await points.spend_points("10001", 5, "horse_race", "下注")
set_result = await points.set_points("10001", 99, "admin", "调整")
assert add_result == (True, 10)
assert spend_result == (False, 7)
assert set_result == (True, 99)
assert [call["method"] for call in calls] == ["POST", "POST", "POST"]
assert calls[0]["json"] == {
"user": "robot",
"token": "secret",
"user_id": "10001",
"amount": 10,
"source": "gacha_sign",
"reason": "签到",
}
assert calls[1]["url"].endswith("/spend")
assert calls[2]["url"].endswith("/set")
@pytest.mark.asyncio
async def test_transactions_and_ranking_return_items_unchanged(fake_aiohttp):
responses, calls = fake_aiohttp
tx_item = {
"id": 1,
"user_id": "10001",
"amount": 10,
"balance_after": 10,
"source": "gacha_sign",
"reason": "签到",
"created_at": "2026-06-20T12:00:00",
}
ranking_item = {
"rank": 1,
"user_id": "10001",
"points": 10,
"total_earned": 10,
"total_spent": 0,
}
responses.extend(
[
{"code": 200, "message": "", "data": {"items": [tx_item]}},
{"code": 200, "message": "", "data": {"items": [ranking_item]}},
]
)
points = make_points_api()
transactions = await points.get_transactions("10001", limit=5, offset=2)
ranking = await points.get_ranking(limit=3, order_by="unknown")
assert transactions == [tx_item]
assert ranking == [ranking_item]
assert calls[0]["params"] == {
"user": "robot",
"token": "secret",
"user_id": "10001",
"limit": 5,
"offset": 2,
}
assert calls[1]["params"] == {
"user": "robot",
"token": "secret",
"limit": 3,
"order_by": "points",
}
@pytest.mark.asyncio
async def test_network_error_keeps_old_failure_returns(monkeypatch):
class FailingSession:
async def __aenter__(self):
raise aiohttp.ClientError("offline")
async def __aexit__(self, exc_type, exc, tb):
return None
monkeypatch.setattr(api_module.aiohttp, "ClientSession", lambda: FailingSession())
points = make_points_api()
assert await points.get_balance("10001") == 0
assert await points.add_points("10001", 1, "gacha_sign") == (False, 0)
assert await points.spend_points("10001", 1, "horse_race") == (False, 0)
assert await points.set_points("10001", 1, "admin") == (False, 0)
assert await points.get_transactions("10001") == []
assert await points.get_ranking() == []
@pytest.mark.asyncio
async def test_invalid_change_request_does_not_call_http(fake_aiohttp):
_responses, calls = fake_aiohttp
points = make_points_api()
assert await points.add_points("10001", 0, "gacha_sign") == (False, 0)
assert await points.spend_points("", 1, "horse_race") == (False, 0)
assert await points.set_points("10001", -1, "admin") == (False, 0)
assert calls == []

View File

@@ -0,0 +1,474 @@
"""group_horse_racing 运行时 API 改造测试。"""
from __future__ import annotations
import importlib.util
import sys
import types
from datetime import datetime
from pathlib import Path
import aiohttp
import aiosqlite
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGIN_DIR = PROJECT_ROOT / "danding_bot" / "plugins" / "group_horse_racing"
def _load_module(full_name: str, path: Path):
spec = importlib.util.spec_from_file_location(full_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[full_name] = module
assert spec and spec.loader
spec.loader.exec_module(module)
return module
def _install_nonebot_stubs() -> None:
"""安装 shared.py 导入所需的最小 NoneBot 类型桩。"""
nonebot = sys.modules.setdefault("nonebot", types.ModuleType("nonebot"))
adapters = sys.modules.setdefault("nonebot.adapters", types.ModuleType("nonebot.adapters"))
onebot = sys.modules.setdefault("nonebot.adapters.onebot", types.ModuleType("nonebot.adapters.onebot"))
v11 = types.ModuleType("nonebot.adapters.onebot.v11")
class Bot:
async def get_group_member_info(self, **kwargs):
return {}
class Event:
def get_user_id(self):
return str(getattr(self, "user_id", ""))
class GroupMessageEvent(Event):
pass
class PrivateMessageEvent(Event):
pass
class Message(list):
pass
class MessageSegment:
@staticmethod
def image(data):
return {"type": "image", "data": data}
v11.Bot = Bot
v11.Event = Event
v11.GroupMessageEvent = GroupMessageEvent
v11.PrivateMessageEvent = PrivateMessageEvent
v11.Message = Message
v11.MessageSegment = MessageSegment
sys.modules["nonebot.adapters.onebot.v11"] = v11
nonebot.adapters = adapters
adapters.onebot = onebot
onebot.v11 = v11
def _install_qqpush_stubs() -> None:
qqpush_config = types.ModuleType("danding_bot.plugins.danding_qqpush.config")
qqpush_image = types.ModuleType("danding_bot.plugins.danding_qqpush.image_render")
class QqPushConfig:
FontPaths = []
class ImageRenderer:
def __init__(self, **kwargs):
self.kwargs = kwargs
def render_to_base64(self, body, title=""):
return f"base64://{title}:{body}"
qqpush_config.Config = QqPushConfig
qqpush_image.ImageRenderer = ImageRenderer
sys.modules["danding_bot.plugins.danding_qqpush.config"] = qqpush_config
sys.modules["danding_bot.plugins.danding_qqpush.image_render"] = qqpush_image
def _install_points_stub() -> None:
points_package = types.ModuleType("danding_bot.plugins.danding_points")
class PointsApi:
async def add_points(self, user_id, amount, source, reason=None):
return True, 0
async def spend_points(self, user_id, amount, source, reason=None):
return True, 0
async def set_points(self, user_id, amount, source, reason=None):
return True, amount
async def get_balance(self, user_id):
return 0
points_package.points_api = PointsApi()
sys.modules["danding_bot.plugins.danding_points"] = points_package
def load_room_store_modules():
"""直接加载 room_store 相关子模块,避免执行插件入口。"""
package_name = "_horse_racing_room_store_under_test"
package = types.ModuleType(package_name)
package.__path__ = [str(PLUGIN_DIR)]
sys.modules[package_name] = package
config_module = _load_module(f"{package_name}.config", PLUGIN_DIR / "config.py")
models_module = _load_module(f"{package_name}.models", PLUGIN_DIR / "models.py")
room_store_module = _load_module(f"{package_name}.room_store", PLUGIN_DIR / "room_store.py")
return config_module, models_module, room_store_module
def load_shared_modules():
"""加载 shared.py 及其依赖,同时隔离 NoneBot 和外部插件依赖。"""
_install_nonebot_stubs()
_install_qqpush_stubs()
_install_points_stub()
package_name = "_horse_racing_shared_under_test"
package = types.ModuleType(package_name)
package.__path__ = [str(PLUGIN_DIR)]
sys.modules[package_name] = package
config_module = _load_module(f"{package_name}.config", PLUGIN_DIR / "config.py")
models_module = _load_module(f"{package_name}.models", PLUGIN_DIR / "models.py")
package.plugin_config = config_module.Config(RACE_RENDER_AS_IMAGE=False, RACE_TICK_INTERVAL=0, RACE_DISTANCE=1)
_load_module(f"{package_name}.room_store", PLUGIN_DIR / "room_store.py")
_load_module(f"{package_name}.points_service", PLUGIN_DIR / "points_service.py")
_load_module(f"{package_name}.race_engine", PLUGIN_DIR / "race_engine.py")
_load_module(f"{package_name}.message_service", PLUGIN_DIR / "message_service.py")
commands_package_name = f"{package_name}.commands"
commands_package = types.ModuleType(commands_package_name)
commands_package.__path__ = [str(PLUGIN_DIR / "commands")]
sys.modules[commands_package_name] = commands_package
access_module = types.ModuleType(f"{commands_package_name}.access")
access_module.get_event_id = lambda event: str(getattr(event, "user_id", ""))
access_module.get_scope = lambda event: f"group_{getattr(event, 'group_id', '')}"
async def _check_access(bot, event):
return True
access_module.check_access = _check_access
sys.modules[f"{commands_package_name}.access"] = access_module
shared_module = _load_module(f"{commands_package_name}.shared", PLUGIN_DIR / "commands" / "shared.py")
return config_module, models_module, shared_module
config_module, models_module, room_store_module = load_room_store_modules()
Config = config_module.Config
RoomStore = room_store_module.RoomStore
RaceResult = models_module.RaceResult
class FakeResponse:
"""模拟 aiohttp 响应上下文。"""
def __init__(self, payload, status=200):
self.payload = payload
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def json(self):
return self.payload
class FakeSession:
"""记录赛马 HTTP 请求参数的 aiohttp ClientSession 替身。"""
def __init__(self, responses, calls):
self.responses = responses
self.calls = calls
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
def get(self, url, params=None, timeout=None):
self.calls.append({"method": "GET", "url": url, "params": params, "timeout": timeout})
return FakeResponse(self.responses.pop(0))
def post(self, url, json=None, timeout=None):
self.calls.append({"method": "POST", "url": url, "json": json, "timeout": timeout})
return FakeResponse(self.responses.pop(0))
def put(self, url, json=None, timeout=None):
self.calls.append({"method": "PUT", "url": url, "json": json, "timeout": timeout})
return FakeResponse(self.responses.pop(0))
def success_data(data):
return {"code": 200, "message": "", "data": data}
@pytest.fixture
def fake_aiohttp(monkeypatch):
calls = []
responses = []
monkeypatch.setattr(room_store_module.aiohttp, "ClientSession", lambda: FakeSession(responses, calls))
return responses, calls
def make_store(tmp_path: Path) -> RoomStore:
return RoomStore(
Config(
RACE_DB_FILE=str(tmp_path / "race.db"),
RACE_API_HOST="http://xapi.test/bot/race/",
BOT_USER="robot",
BOT_TOKEN="secret",
)
)
@pytest.mark.asyncio
async def test_room_store_race_history_and_horse_names_use_xapi(fake_aiohttp, tmp_path):
responses, calls = fake_aiohttp
responses.extend(
[
success_data({"user_id": "10001", "horse_name": "赤焰"}),
success_data({"success": True}),
success_data({"race_id": "race-001"}),
]
)
store = make_store(tmp_path)
result = RaceResult(
race_id="race-001",
scope="group_1000",
champion_name="赤焰",
champion_owner="10001",
participants=["赤焰", "青岚"],
bet_distribution={"赤焰": 30, "青岚": 10},
duration_ticks=8,
completed_at=datetime(2026, 6, 20, 10, 0, 0),
point_changes={"10001": 170},
point_change_summaries={"10001": "大赚特赚"},
odds_snapshot={"赤焰": 1.5},
)
horse_name = await store.get_last_horse_name("10001")
await store.set_last_horse_name("10001", "青岚")
await store.save_race_result(result)
assert horse_name == "赤焰"
assert calls[0]["method"] == "GET"
assert calls[0]["url"] == "http://xapi.test/bot/race/horse-name"
assert calls[0]["params"] == {"user": "robot", "token": "secret", "user_id": "10001"}
assert calls[1]["method"] == "PUT"
assert calls[1]["json"] == {"user": "robot", "token": "secret", "user_id": "10001", "horse_name": "青岚"}
assert calls[2]["method"] == "POST"
assert calls[2]["url"] == "http://xapi.test/bot/race/history"
assert calls[2]["json"]["race_id"] == "race-001"
assert calls[2]["json"]["participants"] == ["赤焰", "青岚"]
assert calls[2]["json"]["odds_snapshot"] == {"赤焰": 1.5}
@pytest.mark.asyncio
async def test_room_snapshots_still_use_local_sqlite_without_race_http(fake_aiohttp, tmp_path):
_responses, calls = fake_aiohttp
store = make_store(tmp_path)
room = await store.create_room("group_1000")
loaded = store.get_room("group_1000")
assert loaded is room
assert calls == []
async with aiosqlite.connect(store.db_path) as db:
cursor = await db.execute("SELECT scope, state FROM room_snapshots")
rows = await cursor.fetchall()
await store.close()
assert rows == [("group_1000", "waiting")]
@pytest.mark.asyncio
async def test_race_api_network_error_keeps_old_failure_shapes(monkeypatch, tmp_path):
class FailingSession:
async def __aenter__(self):
raise aiohttp.ClientError("offline")
async def __aexit__(self, exc_type, exc, tb):
return None
monkeypatch.setattr(room_store_module.aiohttp, "ClientSession", lambda: FailingSession())
store = make_store(tmp_path)
assert await store.get_last_horse_name("10001") is None
await store.set_last_horse_name("10001", "赤焰")
@pytest.mark.asyncio
async def test_save_race_result_raises_when_xapi_rejects(fake_aiohttp, tmp_path):
responses, _calls = fake_aiohttp
responses.append({"code": 500, "message": "写入失败", "data": None})
store = make_store(tmp_path)
result = RaceResult(
race_id="race-rejected",
scope="group_1000",
champion_name="赤焰",
champion_owner="10001",
participants=["赤焰"],
bet_distribution={"赤焰": 0},
duration_ticks=1,
completed_at=datetime(2026, 6, 20, 10, 0, 0),
point_changes={"10001": 170},
point_change_summaries={"10001": "大赚特赚"},
odds_snapshot={"赤焰": 1.2},
)
with pytest.raises(RuntimeError, match="赛马赛果写入 xapi 失败"):
await store.save_race_result(result)
@pytest.mark.asyncio
async def test_settle_race_awaits_balances_and_builds_complete_result():
_config_module, shared_models, shared = load_shared_modules()
Room = shared_models.Room
Horse = shared_models.Horse
HorseState = shared_models.HorseState
Bet = shared_models.Bet
class FakePointsService:
def __init__(self):
self.balances = {"10001": 100, "10002": 50, "20001": 200}
async def get_balance(self, user_id: str) -> int:
return self.balances[user_id]
async def reward_participant(self, user_id: str):
self.balances[user_id] += shared.config.PARTICIPANT_REWARD
return True, self.balances[user_id]
async def reward_champion(self, user_id: str):
self.balances[user_id] += shared.config.CHAMPION_REWARD
return True, self.balances[user_id]
async def payout_winnings(self, user_id: str, amount: int, odds: float):
self.balances[user_id] += max(1, round(amount * odds))
return True, self.balances[user_id]
shared.points_service = FakePointsService()
room = Room(scope="group_1000")
room.horses = {
"赤焰": Horse(owner_id="10001", name="赤焰", index=1, state=HorseState.RACING),
"青岚": Horse(owner_id="10002", name="青岚", index=2, state=HorseState.RACING),
}
room.bets = [Bet(user_id="20001", horse_name="赤焰", amount=30)]
room.champion_name = "赤焰"
room.tick_count = 7
settlement = await shared.settle_race(room)
assert settlement is not None
result, odds = settlement
assert result.race_id
assert result.scope == "group_1000"
assert result.participants == ["赤焰", "青岚"]
assert result.bet_distribution == {"赤焰": 30, "青岚": 0}
assert result.duration_ticks == 7
assert result.odds_snapshot == odds
assert result.point_changes == {"10001": 170, "10002": 20, "20001": 36}
assert all(isinstance(value, int) for value in result.point_changes.values())
@pytest.mark.asyncio
async def test_run_race_with_settlement_saves_result_before_deleting_room():
_config_module, shared_models, shared = load_shared_modules()
Room = shared_models.Room
Horse = shared_models.Horse
HorseState = shared_models.HorseState
Bet = shared_models.Bet
class FakeRaceEngine:
def __init__(self):
self.stopped = []
def tick(self, room):
room.tick_count = 3
return [room.horses["赤焰"]]
def format_progress(self, room):
return "progress"
def determine_champion(self, horses):
return horses[0]
def stop_race(self, scope):
self.stopped.append(scope)
class FakePointsService:
def __init__(self):
self.balances = {"10001": 100, "10002": 50, "20001": 200}
async def get_balance(self, user_id: str) -> int:
return self.balances[user_id]
async def reward_participant(self, user_id: str):
self.balances[user_id] += shared.config.PARTICIPANT_REWARD
return True, self.balances[user_id]
async def reward_champion(self, user_id: str):
self.balances[user_id] += shared.config.CHAMPION_REWARD
return True, self.balances[user_id]
async def payout_winnings(self, user_id: str, amount: int, odds: float):
self.balances[user_id] += max(1, round(amount * odds))
return True, self.balances[user_id]
class FakeMessageService:
def __init__(self):
self.sent = []
self.cleared = []
async def send_with_recall(self, bot, scope, message_type, message):
self.sent.append((scope, message_type, str(message)))
return "msg"
async def recall_previous_of_type(self, bot, scope, message_type):
return None
def clear_pending_recalls(self, scope):
self.cleared.append(scope)
class FakeRoomStore:
def __init__(self):
self.saved = []
self.deleted = []
async def save_race_result(self, result):
self.saved.append(result)
def delete_room(self, scope):
self.deleted.append(scope)
shared.race_engine = FakeRaceEngine()
shared.points_service = FakePointsService()
shared.message_service = FakeMessageService()
shared.room_store = FakeRoomStore()
shared.config.RACE_TICK_INTERVAL = 0
room = Room(scope="group_1000")
room.horses = {
"赤焰": Horse(owner_id="10001", name="赤焰", index=1, state=HorseState.RACING),
"青岚": Horse(owner_id="10002", name="青岚", index=2, state=HorseState.RACING),
}
room.bets = [Bet(user_id="20001", horse_name="赤焰", amount=30)]
await shared.run_race_with_settlement(object(), room, "group_1000")
assert len(shared.room_store.saved) == 1
saved = shared.room_store.saved[0]
assert saved.scope == "group_1000"
assert saved.duration_ticks == 3
assert saved.odds_snapshot == {"赤焰": 1.2, "青岚": 1.2}
assert shared.room_store.deleted == ["group_1000"]
assert shared.race_engine.stopped == ["group_1000"]
assert shared.message_service.cleared == ["group_1000"]

View File

@@ -0,0 +1,247 @@
"""onmyoji_gacha HTTP DataManager 测试。"""
from __future__ import annotations
import aiohttp
import importlib.util
import inspect
import pytest
import sys
import types
from pathlib import Path
def load_gacha_modules():
"""直接加载 onmyoji_gacha 子模块,避免测试环境执行 nonebot 插件入口。"""
plugin_dir = Path(__file__).resolve().parents[1] / "danding_bot" / "plugins" / "onmyoji_gacha"
package_name = "_onmyoji_gacha_under_test"
package = types.ModuleType(package_name)
package.__path__ = [str(plugin_dir)]
sys.modules[package_name] = package
for module_name in ("config", "data_manager", "gacha"):
full_name = f"{package_name}.{module_name}"
spec = importlib.util.spec_from_file_location(full_name, plugin_dir / f"{module_name}.py")
module = importlib.util.module_from_spec(spec)
sys.modules[full_name] = module
assert spec and spec.loader
spec.loader.exec_module(module)
return (
sys.modules[f"{package_name}.config"],
sys.modules[f"{package_name}.data_manager"],
sys.modules[f"{package_name}.gacha"],
)
config_module, data_manager_module, gacha_module = load_gacha_modules()
Config = config_module.Config
DataManager = data_manager_module.DataManager
GachaSystem = gacha_module.GachaSystem
class FakeResponse:
"""模拟 aiohttp 响应上下文。"""
def __init__(self, payload, status=200):
self.payload = payload
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def json(self):
return self.payload
class FakeSession:
"""记录请求参数的 aiohttp ClientSession 替身。"""
def __init__(self, responses, calls):
self.responses = responses
self.calls = calls
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
def get(self, url, params=None, timeout=None):
self.calls.append({"method": "GET", "url": url, "params": params, "timeout": timeout})
return FakeResponse(self.responses.pop(0))
def post(self, url, json=None, timeout=None):
self.calls.append({"method": "POST", "url": url, "json": json, "timeout": timeout})
return FakeResponse(self.responses.pop(0))
@pytest.fixture
def fake_aiohttp(monkeypatch):
calls = []
responses = []
monkeypatch.setattr(data_manager_module.aiohttp, "ClientSession", lambda: FakeSession(responses, calls))
test_config = Config(
GACHA_API_HOST="http://xapi.test/bot/gacha/",
BOT_USER_ID="robot",
BOT_TOKEN="secret",
)
monkeypatch.setattr(data_manager_module, "config", test_config)
monkeypatch.setattr(gacha_module, "config", test_config)
return responses, calls
def success_data(data):
return {"code": 200, "message": "", "data": data}
@pytest.mark.asyncio
async def test_shikigami_cache_and_draw_send_auth_to_xapi(fake_aiohttp):
responses, calls = fake_aiohttp
responses.extend(
[
success_data(
{
"items": [
{"id": 1, "name": "灯笼鬼", "rarity": "R", "image_path": "/r/灯笼鬼.png"},
{"id": 3, "name": "茨木童子", "rarity": "SSR", "image_path": "/ssr/茨木童子.png"},
]
}
),
success_data(
{
"success": True,
"rarity": "SSR",
"name": "茨木童子",
"image_path": "/ssr/茨木童子.png",
"image_url": "/ssr/茨木童子.png",
"draws_left": 2,
"unlocked_achievements": ["no_ssr_60"],
}
),
]
)
manager = DataManager()
grouped = await manager.refresh_shikigami_data()
draw_result = await manager.record_draw_result("10001", "SSR", grouped["SSR"][0])
assert grouped["SSR"][0]["image_url"] == "/ssr/茨木童子.png"
assert draw_result["unlocked_achievements"] == ["no_ssr_60"]
assert calls[0]["method"] == "GET"
assert calls[0]["url"] == "http://xapi.test/bot/gacha/shikigami"
assert calls[0]["params"] == {"user": "robot", "token": "secret"}
assert calls[1]["method"] == "POST"
assert calls[1]["json"] == {
"user": "robot",
"token": "secret",
"user_id": "10001",
"shikigami_id": 3,
"rarity": "SSR",
"name": "茨木童子",
}
@pytest.mark.asyncio
async def test_gacha_system_draw_keeps_probability_local_and_uses_xapi(fake_aiohttp, monkeypatch):
responses, calls = fake_aiohttp
responses.extend(
[
success_data({"draws_left": 3, "daily_limit": 3}),
success_data({"items": [{"id": 3, "name": "茨木童子", "rarity": "SSR", "image_path": "/ssr/茨木童子.png"}]}),
success_data(
{
"success": True,
"rarity": "SSR",
"name": "茨木童子",
"image_path": "/ssr/茨木童子.png",
"image_url": "/ssr/茨木童子.png",
"draws_left": 2,
"unlocked_achievements": [],
}
),
]
)
manager = DataManager()
monkeypatch.setattr(gacha_module, "data_manager", manager)
system = GachaSystem()
monkeypatch.setattr(system, "_draw_rarity", lambda user_id=None: "SSR")
result = await system.draw("10001")
assert result == {
"success": True,
"rarity": "SSR",
"name": "茨木童子",
"image_url": "/ssr/茨木童子.png",
"draws_left": 2,
"unlocked_achievements": [],
}
assert [call["url"].rsplit("/", 1)[-1] for call in calls] == ["draws-left", "shikigami", "draw"]
@pytest.mark.asyncio
async def test_triple_sign_in_claim_and_query_shapes(fake_aiohttp):
responses, calls = fake_aiohttp
responses.extend(
[
success_data({"success": True, "results": [], "draws_left": 0, "unlocked_achievements": ["no_ssr_60"]}),
success_data({"success": True, "signed_already": True}),
success_data({"success": True, "reward_type": "天卡"}),
success_data({"success": True, "total_draws": 1, "R_count": 1, "SR_count": 0, "SSR_count": 0, "SP_count": 0, "recent_draws": []}),
success_data({"success": True, "date": "2026-06-20", "stats": {"total_users": 1}}),
success_data({"items": [{"user_id": "10001", "total_draws": 1, "R_count": 1, "SR_count": 0, "SSR_count": 0, "SP_count": 0, "ssr_sp_total": 0}]}),
success_data({"achievements": {"no_ssr_60": {"unlocked_date": "2026-06-20", "reward_claimed": False}}, "progress": {"no_ssr_streak": 60}}),
success_data({"success": True, "date": "2026-06-20", "records": [], "total_count": 0}),
]
)
manager = DataManager()
draws = [
{"id": 1, "name": "灯笼鬼", "rarity": "R", "image_url": "/r/灯笼鬼.png"},
{"id": 2, "name": "雪女", "rarity": "SR", "image_url": "/sr/雪女.png"},
{"id": 3, "name": "茨木童子", "rarity": "SSR", "image_url": "/ssr/茨木童子.png"},
]
triple = await manager.record_triple_draw_result("10001", draws)
signed = await manager.record_sign_in("10001", 20)
claimed = await manager.claim_achievement_reward("10001", "no_ssr_60")
stats = await manager.get_user_stats("10001")
daily = await manager.get_daily_stats("2026-06-20")
rank = await manager.get_rank()
achievements = await manager.get_user_achievements("10001")
records = await manager.get_daily_records("2026-06-20")
assert triple["unlocked_achievements"] == ["no_ssr_60"]
assert signed is False
assert claimed is True
assert stats["success"] is True
assert daily["date"] == "2026-06-20"
assert rank[0]["user_id"] == "10001"
assert "unlocked" in achievements and "progress" in achievements
assert records["total_count"] == 0
assert calls[0]["json"]["draws"][0] == {"shikigami_id": 1, "rarity": "R", "name": "灯笼鬼"}
assert calls[1]["json"]["points_awarded"] == 20
assert calls[2]["json"]["achievement_id"] == "no_ssr_60"
assert calls[3]["params"]["user_id"] == "10001"
@pytest.mark.asyncio
async def test_network_error_keeps_failure_shapes_and_no_sqlite(monkeypatch):
class FailingSession:
async def __aenter__(self):
raise aiohttp.ClientError("offline")
async def __aexit__(self, exc_type, exc, tb):
return None
monkeypatch.setattr(data_manager_module.aiohttp, "ClientSession", lambda: FailingSession())
manager = DataManager()
assert await manager.get_draws_left("10001") == 0
assert await manager.record_sign_in("10001", 20) is False
assert await manager.get_rank() == []
assert "sqlite3" not in inspect.getsource(data_manager_module)