Files
DanDingNoneBot/tests/test_room_store_lock.py
Mr.Xia e94161e802 test: add unit tests for models, payout logic, and room lock
- test_models.py: 10 tests for Room/Horse/Bet/RaceResult dataclasses
- test_payout_logic.py: 12 tests for payout formula (max+round)
- test_room_store_lock.py: 5 tests for get_lock() setdefault pattern
- All 34 tests pass in 0.27s
2026-05-09 23:31:54 +08:00

80 lines
2.8 KiB
Python

"""Test RoomStore._get_lock behavior - the setdefault fix.
Per verify_sop: tests must run, with adversarial probing.
This tests the get_lock concurrency fix (setdefault vs if-check).
"""
import asyncio
import pytest
class FakeRoomStore:
"""Minimal reproduction of RoomStore._get_lock for testing the lock pattern."""
def __init__(self):
self._locks: dict[str, asyncio.Lock] = {}
def get_lock(self, scope: str) -> asyncio.Lock:
"""The fixed pattern: setdefault (atomic)"""
return self._locks.setdefault(scope, asyncio.Lock())
def get_lock_old_buggy(self, scope: str) -> asyncio.Lock:
"""The old buggy pattern: check-then-set (race condition)"""
if scope not in self._locks:
self._locks[scope] = asyncio.Lock() # NOT thread-safe
return self._locks[scope]
class TestGetLock:
def test_same_scope_returns_same_lock(self):
store = FakeRoomStore()
lock1 = store.get_lock("scope1")
lock2 = store.get_lock("scope1")
assert lock1 is lock2
def test_different_scopes_different_locks(self):
store = FakeRoomStore()
lock1 = store.get_lock("scope1")
lock2 = store.get_lock("scope2")
assert lock1 is not lock2
def test_lock_initially_not_locked(self):
store = FakeRoomStore()
lock = store.get_lock("scope1")
assert not lock.locked()
@pytest.mark.asyncio
async def test_concurrent_get_lock_same_scope(self):
"""Adversarial: concurrent calls to get_lock for same scope must return same lock."""
store = FakeRoomStore()
results = []
async def grab_lock(scope):
lock = store.get_lock(scope)
results.append(id(lock))
tasks = [grab_lock("shared") for _ in range(100)]
await asyncio.gather(*tasks)
# ALL must be the same lock object
assert len(set(results)) == 1, f"Got {len(set(results))} different lock objects!"
@pytest.mark.asyncio
async def test_lock_prevents_concurrent_execution(self):
"""Verify the lock actually serializes access."""
store = FakeRoomStore()
execution_log = []
async def protected_operation(scope, op_id):
lock = store.get_lock(scope)
async with lock:
execution_log.append(f"start_{op_id}")
await asyncio.sleep(0.01)
execution_log.append(f"end_{op_id}")
await asyncio.gather(
protected_operation("scope", "A"),
protected_operation("scope", "B"),
)
# Operations must be serialized: start_A, end_A, start_B, end_B
assert execution_log == ["start_A", "end_A", "start_B", "end_B"]