feat(bot): use runtime api for bot data
This commit is contained in:
@@ -10,7 +10,6 @@
|
||||
danding_points/
|
||||
├── __init__.py # 插件元数据 & points_api 单例导出
|
||||
├── config.py # 配置类
|
||||
├── database.py # SQLite 数据库操作
|
||||
└── api.py # PointsAPI 核心
|
||||
```
|
||||
|
||||
@@ -20,10 +19,11 @@ danding_points/
|
||||
|
||||
| 配置项 | 类型 | 默认值 | 说明 |
|
||||
|--------|------|--------|------|
|
||||
| `DANDING_POINTS_DB_FILE` | str | `data/danding_points/points.db` | 数据库文件路径 |
|
||||
| `DANDING_POINTS_MAX_BALANCE` | int | `0` | 用户积分余额上限,`0` = 无限制 |
|
||||
| `DANDING_POINTS_MAX_PER_OPERATION` | int | `0` | 单次操作积分上限,`0` = 无限制 |
|
||||
| `DANDING_POINTS_LOG_RETENTION_DAYS` | int | `365` | 流水日志保留天数 |
|
||||
| `DANDING_POINTS_API_HOST` | str | `https://api.danding.vip/bot/points` | xapi 积分 API 地址 |
|
||||
| `DANDING_BOT_USER` | str | `1424473282` | xapi Bot 鉴权用户 |
|
||||
| `DANDING_BOT_TOKEN` | str | 空 | xapi Bot 鉴权 Token,未设置时读取 `DANDING_API_TOKEN` / `BOT_TOKEN` |
|
||||
|
||||
积分余额上限与单次操作上限由 xapi `BOT_POINTS_MAX_BALANCE` / `BOT_POINTS_MAX_PER_OPERATION` 控制,nonebot 本地不再持有阈值。
|
||||
|
||||
## API 接口
|
||||
|
||||
@@ -170,36 +170,10 @@ async def admin_set(user_id: str, amount: int):
|
||||
|
||||
| 插件 | source 值 |
|
||||
|------|-----------|
|
||||
| onmyoji_gacha | `"onmyoji_gacha"` |
|
||||
| danding_api | `"danding_api"` |
|
||||
| shop | `"shop"` |
|
||||
| sign_in | `"sign_in"` |
|
||||
| onmyoji_gacha 签到 | `"gacha_sign"` |
|
||||
| group_horse_racing | `"horse_race"` |
|
||||
| 管理调整 | `"admin"` |
|
||||
|
||||
## 数据库
|
||||
|
||||
使用 SQLite,数据文件位于 `data/danding_points/points.db`,无需额外配置。
|
||||
|
||||
### 表结构
|
||||
|
||||
**user_points** — 用户积分账户
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| user_id | TEXT PK | 用户 ID |
|
||||
| points | INTEGER | 当前余额,>= 0 |
|
||||
| total_earned | INTEGER | 累计获得 |
|
||||
| total_spent | INTEGER | 累计消费 |
|
||||
| created_at | TEXT | 创建时间 |
|
||||
| updated_at | TEXT | 最后更新时间 |
|
||||
|
||||
**point_transactions** — 积分变动流水
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| id | INTEGER PK | 自增 ID |
|
||||
| user_id | TEXT | 用户 ID |
|
||||
| amount | INTEGER | 变动数额(消费为负) |
|
||||
| balance_after | INTEGER | 变动后余额 |
|
||||
| source | TEXT | 来源标识 |
|
||||
| reason | TEXT | 变动原因 |
|
||||
| created_at | TEXT | 创建时间 |
|
||||
本插件不再写入本地 SQLite。积分账户与流水由 xapi MySQL `bot_user_points` / `bot_point_transactions` 承载,nonebot 只通过 `/bot/points/*` HTTP API 读写。
|
||||
|
||||
@@ -1,312 +1,181 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from .config import Config
|
||||
from .database import PointsDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import logging
|
||||
from typing import Tuple, List, Dict, Any, Optional
|
||||
from .config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PointsAPI:
|
||||
"""Points system API for managing user points."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.db = PointsDatabase(config)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
async def get_balance(self, user_id: str) -> int:
|
||||
"""Get user's current points balance."""
|
||||
return await asyncio.to_thread(self.db.get_user_balance, user_id)
|
||||
"""Points system API for managing user points."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
"""拼接 /bot/points 端点地址。"""
|
||||
|
||||
return f"{self.config.POINTS_API_HOST}/{path.lstrip('/')}"
|
||||
|
||||
def _auth(self) -> Dict[str, str]:
|
||||
"""生成 xapi Bot 鉴权参数。"""
|
||||
|
||||
return {
|
||||
"user": self.config.BOT_USER,
|
||||
"token": self.config.BOT_TOKEN,
|
||||
}
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""调用 xapi /bot/points,并只向上层暴露 data。"""
|
||||
|
||||
request_url = self._url(path)
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if method == "GET":
|
||||
request_params = {**self._auth(), **(params or {})}
|
||||
async with session.get(request_url, params=request_params, timeout=timeout) as resp:
|
||||
return await self._parse_response(resp, path)
|
||||
request_payload = {**self._auth(), **(payload or {})}
|
||||
async with session.post(request_url, json=request_payload, timeout=timeout) as resp:
|
||||
return await self._parse_response(resp, path)
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("points api request failed path=%s error=%s", path, exc)
|
||||
return None
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.error("points api request timeout path=%s error=%s", path, exc)
|
||||
return None
|
||||
|
||||
async def _parse_response(self, resp: aiohttp.ClientResponse, path: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析 xapi 统一响应,失败时返回 None 维持旧 API 失败语义。"""
|
||||
|
||||
if resp.status != 200:
|
||||
logger.error("points api bad status path=%s status=%s", path, resp.status)
|
||||
return None
|
||||
body = await resp.json()
|
||||
if body.get("code") != 200:
|
||||
logger.error("points api fail path=%s code=%s message=%s", path, body.get("code"), body.get("message"))
|
||||
return None
|
||||
data = body.get("data")
|
||||
return data if isinstance(data, dict) else None
|
||||
|
||||
async def get_balance(self, user_id: str) -> int:
|
||||
"""Get user's current points balance."""
|
||||
data = await self._request("GET", "balance", params={"user_id": user_id})
|
||||
if data is None:
|
||||
return 0
|
||||
return int(data.get("balance", 0) or 0)
|
||||
|
||||
async def add_points(
|
||||
self, user_id: str, amount: int, source: str, reason: str = None
|
||||
) -> Tuple[bool, int]:
|
||||
"""Add points to user account.
|
||||
|
||||
Returns: (success, new_balance)
|
||||
"""
|
||||
# Parameter validation
|
||||
if not isinstance(amount, int) or amount <= 0:
|
||||
return False, 0
|
||||
if not user_id or not source:
|
||||
return False, 0
|
||||
|
||||
# Operation limit validation
|
||||
if self.config.POINTS_MAX_PER_OPERATION > 0:
|
||||
if amount > self.config.POINTS_MAX_PER_OPERATION:
|
||||
return False, 0
|
||||
|
||||
def _add():
|
||||
with self._lock:
|
||||
conn = self.db.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# Ensure user exists
|
||||
self.db.ensure_user_exists(user_id, conn)
|
||||
|
||||
# Get current balance
|
||||
cursor.execute(
|
||||
"SELECT points FROM user_points WHERE user_id = ?",
|
||||
(user_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
current_balance = row["points"] if row else 0
|
||||
|
||||
# Check balance limit
|
||||
new_balance = current_balance + amount
|
||||
if self.config.POINTS_MAX_BALANCE > 0:
|
||||
if new_balance > self.config.POINTS_MAX_BALANCE:
|
||||
conn.rollback()
|
||||
return False, current_balance
|
||||
|
||||
# Update balance and total_earned
|
||||
now = datetime.now().isoformat()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE user_points
|
||||
SET points = ?, total_earned = total_earned + ?, updated_at = ?
|
||||
WHERE user_id = ?
|
||||
""",
|
||||
(new_balance, amount, now, user_id),
|
||||
)
|
||||
|
||||
# Write transaction log
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO point_transactions
|
||||
(user_id, amount, balance_after, source, reason, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, amount, new_balance, source, reason, now),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return True, new_balance
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(f"add_points failed for {user_id}: {e}")
|
||||
return False, 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return await asyncio.to_thread(_add)
|
||||
"""Add points to user account.
|
||||
|
||||
Returns: (success, new_balance)
|
||||
"""
|
||||
# 保留原 PointsAPI 的入参失败语义;限额校验由 xapi 承担。
|
||||
if not isinstance(amount, int) or amount <= 0:
|
||||
return False, 0
|
||||
if not user_id or not source:
|
||||
return False, 0
|
||||
|
||||
data = await self._request(
|
||||
"POST",
|
||||
"add",
|
||||
payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason},
|
||||
)
|
||||
return self._change_result(data)
|
||||
|
||||
async def spend_points(
|
||||
self, user_id: str, amount: int, source: str, reason: str = None
|
||||
) -> Tuple[bool, int]:
|
||||
"""Spend points from user account.
|
||||
|
||||
Returns: (success, new_balance)
|
||||
"""
|
||||
# Parameter validation
|
||||
if not isinstance(amount, int) or amount <= 0:
|
||||
return False, 0
|
||||
if not user_id or not source:
|
||||
return False, 0
|
||||
|
||||
# Operation limit validation
|
||||
if self.config.POINTS_MAX_PER_OPERATION > 0:
|
||||
if amount > self.config.POINTS_MAX_PER_OPERATION:
|
||||
return False, 0
|
||||
|
||||
def _spend():
|
||||
with self._lock:
|
||||
conn = self.db.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# Ensure user exists
|
||||
self.db.ensure_user_exists(user_id, conn)
|
||||
|
||||
# Get current balance
|
||||
cursor.execute(
|
||||
"SELECT points FROM user_points WHERE user_id = ?",
|
||||
(user_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
current_balance = row["points"] if row else 0
|
||||
|
||||
# Check sufficient balance
|
||||
if current_balance < amount:
|
||||
conn.rollback()
|
||||
return False, current_balance
|
||||
|
||||
# Update balance and total_spent
|
||||
new_balance = current_balance - amount
|
||||
now = datetime.now().isoformat()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE user_points
|
||||
SET points = ?, total_spent = total_spent + ?, updated_at = ?
|
||||
WHERE user_id = ?
|
||||
""",
|
||||
(new_balance, amount, now, user_id),
|
||||
)
|
||||
|
||||
# Write transaction log (amount as negative)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO point_transactions
|
||||
(user_id, amount, balance_after, source, reason, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, -amount, new_balance, source, reason, now),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return True, new_balance
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(f"spend_points failed for {user_id}: {e}")
|
||||
return False, 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return await asyncio.to_thread(_spend)
|
||||
"""Spend points from user account.
|
||||
|
||||
Returns: (success, new_balance)
|
||||
"""
|
||||
# 保留原 PointsAPI 的入参失败语义;余额与限额校验由 xapi 承担。
|
||||
if not isinstance(amount, int) or amount <= 0:
|
||||
return False, 0
|
||||
if not user_id or not source:
|
||||
return False, 0
|
||||
|
||||
data = await self._request(
|
||||
"POST",
|
||||
"spend",
|
||||
payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason},
|
||||
)
|
||||
return self._change_result(data)
|
||||
|
||||
async def set_points(
|
||||
self, user_id: str, amount: int, source: str, reason: str = None
|
||||
) -> Tuple[bool, int]:
|
||||
"""Set user's points to exact amount.
|
||||
|
||||
Returns: (success, new_balance)
|
||||
"""
|
||||
# Parameter validation
|
||||
if not isinstance(amount, int) or amount < 0:
|
||||
return False, 0
|
||||
if not user_id or not source:
|
||||
return False, 0
|
||||
|
||||
def _set():
|
||||
with self._lock:
|
||||
conn = self.db.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# Ensure user exists
|
||||
self.db.ensure_user_exists(user_id, conn)
|
||||
|
||||
# Get current balance
|
||||
cursor.execute(
|
||||
"SELECT points, total_earned FROM user_points WHERE user_id = ?",
|
||||
(user_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
current_balance = row["points"] if row else 0
|
||||
current_earned = row["total_earned"] if row else 0
|
||||
|
||||
# If new value equals old value, return without writing
|
||||
if current_balance == amount:
|
||||
conn.rollback()
|
||||
return True, amount
|
||||
|
||||
# Calculate difference for total_earned (only positive diff)
|
||||
diff = amount - current_balance
|
||||
earned_diff = max(0, diff)
|
||||
|
||||
# Update balance and total_earned
|
||||
now = datetime.now().isoformat()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE user_points
|
||||
SET points = ?, total_earned = total_earned + ?, updated_at = ?
|
||||
WHERE user_id = ?
|
||||
""",
|
||||
(amount, earned_diff, now, user_id),
|
||||
)
|
||||
|
||||
# Write transaction log
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO point_transactions
|
||||
(user_id, amount, balance_after, source, reason, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, diff, amount, source, reason, now),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return True, amount
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(f"set_points failed for {user_id}: {e}")
|
||||
return False, 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return await asyncio.to_thread(_set)
|
||||
"""Set user's points to exact amount.
|
||||
|
||||
Returns: (success, new_balance)
|
||||
"""
|
||||
# set 仍保持原 PointsAPI 行为:只校验非负,不做余额上限判断。
|
||||
if not isinstance(amount, int) or amount < 0:
|
||||
return False, 0
|
||||
if not user_id or not source:
|
||||
return False, 0
|
||||
|
||||
data = await self._request(
|
||||
"POST",
|
||||
"set",
|
||||
payload={"user_id": user_id, "amount": amount, "source": source, "reason": reason},
|
||||
)
|
||||
return self._change_result(data)
|
||||
|
||||
async def get_transactions(
|
||||
self, user_id: str, limit: int = 20, offset: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get transaction history for a user.
|
||||
|
||||
Returns: List of transaction dicts
|
||||
"""
|
||||
# Normalize parameters
|
||||
limit = max(1, min(100, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
def _get():
|
||||
conn = self.db.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, user_id, amount, balance_after, source, reason, created_at
|
||||
FROM point_transactions
|
||||
WHERE user_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(user_id, limit, offset),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"get_transactions failed for {user_id}: {e}")
|
||||
return []
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return await asyncio.to_thread(_get)
|
||||
"""Get transaction history for a user.
|
||||
|
||||
Returns: List of transaction dicts
|
||||
"""
|
||||
limit = max(1, min(100, limit))
|
||||
offset = max(0, offset)
|
||||
data = await self._request(
|
||||
"GET",
|
||||
"transactions",
|
||||
params={"user_id": user_id, "limit": limit, "offset": offset},
|
||||
)
|
||||
if data is None:
|
||||
return []
|
||||
items = data.get("items", [])
|
||||
return items if isinstance(items, list) else []
|
||||
|
||||
async def get_ranking(
|
||||
self, limit: int = 10, order_by: str = "points"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get points ranking.
|
||||
|
||||
Returns: List of ranking dicts with rank field
|
||||
"""
|
||||
# Normalize parameters
|
||||
limit = max(1, min(100, limit))
|
||||
if order_by not in ("points", "total_earned"):
|
||||
order_by = "points"
|
||||
|
||||
def _get():
|
||||
conn = self.db.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
order_column = "points" if order_by == "points" else "total_earned"
|
||||
query = f"""
|
||||
SELECT
|
||||
RANK() OVER (ORDER BY {order_column} DESC) as rank,
|
||||
user_id,
|
||||
points,
|
||||
total_earned,
|
||||
total_spent
|
||||
FROM user_points
|
||||
ORDER BY {order_column} DESC, user_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
cursor.execute(query, (limit,))
|
||||
rows = cursor.fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"get_ranking failed: {e}")
|
||||
return []
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return await asyncio.to_thread(_get)
|
||||
"""Get points ranking.
|
||||
|
||||
Returns: List of ranking dicts with rank field
|
||||
"""
|
||||
limit = max(1, min(100, limit))
|
||||
if order_by not in ("points", "total_earned"):
|
||||
order_by = "points"
|
||||
data = await self._request(
|
||||
"GET",
|
||||
"ranking",
|
||||
params={"limit": limit, "order_by": order_by},
|
||||
)
|
||||
if data is None:
|
||||
return []
|
||||
items = data.get("items", [])
|
||||
return items if isinstance(items, list) else []
|
||||
|
||||
def _change_result(self, data: Optional[Dict[str, Any]]) -> Tuple[bool, int]:
|
||||
"""解析 add/spend/set 响应并维持旧失败返回值。"""
|
||||
|
||||
if data is None:
|
||||
return False, 0
|
||||
return bool(data.get("success")), int(data.get("balance", 0) or 0)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
|
||||
@@ -11,22 +10,24 @@ class Config(BaseSettings):
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# 数据库配置
|
||||
POINTS_DB_FILE: str = os.getenv("DANDING_POINTS_DB_FILE", "data/danding_points/points.db")
|
||||
POINTS_MAX_BALANCE: int = int(os.getenv("DANDING_POINTS_MAX_BALANCE", "0"))
|
||||
POINTS_MAX_PER_OPERATION: int = int(os.getenv("DANDING_POINTS_MAX_PER_OPERATION", "0"))
|
||||
POINTS_LOG_RETENTION_DAYS: int = int(os.getenv("DANDING_POINTS_LOG_RETENTION_DAYS", "365"))
|
||||
# xapi /bot/points 运行时 API 配置
|
||||
POINTS_API_HOST: str = os.getenv("DANDING_POINTS_API_HOST", "https://api.danding.vip/bot/points")
|
||||
BOT_USER: str = os.getenv("DANDING_BOT_USER", "1424473282")
|
||||
BOT_TOKEN: str = os.getenv(
|
||||
"DANDING_BOT_TOKEN",
|
||||
os.getenv("DANDING_API_TOKEN", os.getenv("BOT_TOKEN", "")),
|
||||
)
|
||||
|
||||
@field_validator("POINTS_MAX_BALANCE", "POINTS_MAX_PER_OPERATION", "POINTS_LOG_RETENTION_DAYS")
|
||||
@field_validator("POINTS_API_HOST")
|
||||
@classmethod
|
||||
def validate_non_negative(cls, v):
|
||||
if v < 0:
|
||||
raise ValueError("Value must be non-negative")
|
||||
return v
|
||||
def validate_api_host(cls, value):
|
||||
if not value:
|
||||
raise ValueError("POINTS_API_HOST cannot be empty")
|
||||
return value.rstrip("/")
|
||||
|
||||
@field_validator("POINTS_DB_FILE")
|
||||
@field_validator("BOT_USER")
|
||||
@classmethod
|
||||
def validate_db_path(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Database file path cannot be empty")
|
||||
return v
|
||||
def validate_bot_user(cls, value):
|
||||
if not value:
|
||||
raise ValueError("BOT_USER cannot be empty")
|
||||
return value
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
import sqlite3
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from .config import Config
|
||||
|
||||
|
||||
class PointsDatabase:
|
||||
"""SQLite database handler for points system."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.db_path = config.POINTS_DB_FILE
|
||||
self._ensure_db_dir()
|
||||
self._init_db()
|
||||
|
||||
def _ensure_db_dir(self):
|
||||
"""Create database directory if it doesn't exist."""
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
if db_dir:
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database tables."""
|
||||
conn = sqlite3.connect(self.db_path, timeout=5.0)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create user_points table
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS user_points (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
points INTEGER NOT NULL DEFAULT 0 CHECK(points >= 0),
|
||||
total_earned INTEGER NOT NULL DEFAULT 0,
|
||||
total_spent INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create point_transactions table
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS point_transactions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
amount INTEGER NOT NULL,
|
||||
balance_after INTEGER NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
reason TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_transactions_user_id ON point_transactions(user_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_transactions_source ON point_transactions(source)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_transactions_created_at ON point_transactions(created_at)"
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def get_connection(self) -> sqlite3.Connection:
|
||||
"""Get a database connection."""
|
||||
conn = sqlite3.connect(self.db_path, timeout=5.0)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def get_user_balance(self, user_id: str) -> int:
|
||||
"""Get user's current points balance."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT points FROM user_points WHERE user_id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
return row["points"] if row else 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def ensure_user_exists(self, user_id: str, conn=None) -> None:
|
||||
"""Create user account if it doesn't exist. Reuses provided conn if given."""
|
||||
should_close = False
|
||||
if conn is None:
|
||||
conn = self.get_connection()
|
||||
should_close = True
|
||||
cursor = conn.cursor()
|
||||
now = datetime.now().isoformat()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO user_points
|
||||
(user_id, points, total_earned, total_spent, created_at, updated_at)
|
||||
VALUES (?, 0, 0, 0, ?, ?)
|
||||
""",
|
||||
(user_id, now, now),
|
||||
)
|
||||
if should_close:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
Reference in New Issue
Block a user