182 lines
6.3 KiB
Python
182 lines
6.3 KiB
Python
import asyncio
|
||
import aiohttp
|
||
import logging
|
||
from typing import Tuple, List, Dict, Any, Optional
|
||
from .config import Config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class PointsAPI:
|
||
"""Points system API for managing user points."""
|
||
|
||
def __init__(self, config: Config):
|
||
self.config = config
|
||
|
||
def _url(self, path: str) -> str:
|
||
"""拼接 /bot/points 端点地址。"""
|
||
|
||
return f"{self.config.POINTS_API_HOST}/{path.lstrip('/')}"
|
||
|
||
def _auth(self) -> Dict[str, str]:
|
||
"""生成 xapi Bot 鉴权参数。"""
|
||
|
||
return {
|
||
"user": self.config.BOT_USER,
|
||
"token": self.config.BOT_TOKEN,
|
||
}
|
||
|
||
async def _request(
|
||
self,
|
||
method: str,
|
||
path: str,
|
||
*,
|
||
payload: Optional[Dict[str, Any]] = None,
|
||
params: Optional[Dict[str, Any]] = None,
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""调用 xapi /bot/points,并只向上层暴露 data。"""
|
||
|
||
request_url = self._url(path)
|
||
timeout = aiohttp.ClientTimeout(total=10)
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
if method == "GET":
|
||
request_params = {**self._auth(), **(params or {})}
|
||
async with session.get(request_url, params=request_params, timeout=timeout) as resp:
|
||
return await self._parse_response(resp, path)
|
||
request_payload = {**self._auth(), **(payload or {})}
|
||
async with session.post(request_url, json=request_payload, timeout=timeout) as resp:
|
||
return await self._parse_response(resp, path)
|
||
except aiohttp.ClientError as exc:
|
||
logger.error("points api request failed path=%s error=%s", path, exc)
|
||
return None
|
||
except asyncio.TimeoutError as exc:
|
||
logger.error("points api request timeout path=%s error=%s", path, exc)
|
||
return None
|
||
|
||
async def _parse_response(self, resp: aiohttp.ClientResponse, path: str) -> Optional[Dict[str, Any]]:
|
||
"""解析 xapi 统一响应,失败时返回 None 维持旧 API 失败语义。"""
|
||
|
||
if resp.status != 200:
|
||
logger.error("points api bad status path=%s status=%s", path, resp.status)
|
||
return None
|
||
body = await resp.json()
|
||
if body.get("code") != 200:
|
||
logger.error("points api fail path=%s code=%s message=%s", path, body.get("code"), body.get("message"))
|
||
return None
|
||
data = body.get("data")
|
||
return data if isinstance(data, dict) else None
|
||
|
||
async def get_balance(self, user_id: str) -> int:
|
||
"""Get user's current points balance."""
|
||
data = await self._request("GET", "balance", params={"user_id": user_id})
|
||
if data is None:
|
||
return 0
|
||
return int(data.get("balance", 0) or 0)
|
||
|
||
async def add_points(
|
||
self, user_id: str, amount: int, source: str, reason: str = None
|
||
) -> Tuple[bool, int]:
|
||
"""Add points to user account.
|
||
|
||
Returns: (success, new_balance)
|
||
"""
|
||
# 保留原 PointsAPI 的入参失败语义;限额校验由 xapi 承担。
|
||
if not isinstance(amount, int) or amount <= 0:
|
||
return False, 0
|
||
if not user_id or not source:
|
||
return False, 0
|
||
|
||
data = await self._request(
|
||
"POST",
|
||
"add",
|
||
payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason},
|
||
)
|
||
return self._change_result(data)
|
||
|
||
async def spend_points(
|
||
self, user_id: str, amount: int, source: str, reason: str = None
|
||
) -> Tuple[bool, int]:
|
||
"""Spend points from user account.
|
||
|
||
Returns: (success, new_balance)
|
||
"""
|
||
# 保留原 PointsAPI 的入参失败语义;余额与限额校验由 xapi 承担。
|
||
if not isinstance(amount, int) or amount <= 0:
|
||
return False, 0
|
||
if not user_id or not source:
|
||
return False, 0
|
||
|
||
data = await self._request(
|
||
"POST",
|
||
"spend",
|
||
payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason},
|
||
)
|
||
return self._change_result(data)
|
||
|
||
async def set_points(
|
||
self, user_id: str, amount: int, source: str, reason: str = None
|
||
) -> Tuple[bool, int]:
|
||
"""Set user's points to exact amount.
|
||
|
||
Returns: (success, new_balance)
|
||
"""
|
||
# set 仍保持原 PointsAPI 行为:只校验非负,不做余额上限判断。
|
||
if not isinstance(amount, int) or amount < 0:
|
||
return False, 0
|
||
if not user_id or not source:
|
||
return False, 0
|
||
|
||
data = await self._request(
|
||
"POST",
|
||
"set",
|
||
payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason},
|
||
)
|
||
return self._change_result(data)
|
||
|
||
async def get_transactions(
|
||
self, user_id: str, limit: int = 20, offset: int = 0
|
||
) -> List[Dict[str, Any]]:
|
||
"""Get transaction history for a user.
|
||
|
||
Returns: List of transaction dicts
|
||
"""
|
||
limit = max(1, min(100, limit))
|
||
offset = max(0, offset)
|
||
data = await self._request(
|
||
"GET",
|
||
"transactions",
|
||
params={"user_id": user_id, "limit": limit, "offset": offset},
|
||
)
|
||
if data is None:
|
||
return []
|
||
items = data.get("items", [])
|
||
return items if isinstance(items, list) else []
|
||
|
||
async def get_ranking(
|
||
self, limit: int = 10, order_by: str = "points"
|
||
) -> List[Dict[str, Any]]:
|
||
"""Get points ranking.
|
||
|
||
Returns: List of ranking dicts with rank field
|
||
"""
|
||
limit = max(1, min(100, limit))
|
||
if order_by not in ("points", "total_earned"):
|
||
order_by = "points"
|
||
data = await self._request(
|
||
"GET",
|
||
"ranking",
|
||
params={"limit": limit, "order_by": order_by},
|
||
)
|
||
if data is None:
|
||
return []
|
||
items = data.get("items", [])
|
||
return items if isinstance(items, list) else []
|
||
|
||
def _change_result(self, data: Optional[Dict[str, Any]]) -> Tuple[bool, int]:
|
||
"""解析 add/spend/set 响应并维持旧失败返回值。"""
|
||
|
||
if data is None:
|
||
return False, 0
|
||
return bool(data.get("success")), int(data.get("balance", 0) or 0)
|