"""Test RoomStore._get_lock behavior - the setdefault fix. Per verify_sop: tests must run, with adversarial probing. This tests the get_lock concurrency fix (setdefault vs if-check). """ import asyncio import pytest class FakeRoomStore: """Minimal reproduction of RoomStore._get_lock for testing the lock pattern.""" def __init__(self): self._locks: dict[str, asyncio.Lock] = {} def get_lock(self, scope: str) -> asyncio.Lock: """The fixed pattern: setdefault (atomic)""" return self._locks.setdefault(scope, asyncio.Lock()) def get_lock_old_buggy(self, scope: str) -> asyncio.Lock: """The old buggy pattern: check-then-set (race condition)""" if scope not in self._locks: self._locks[scope] = asyncio.Lock() # NOT thread-safe return self._locks[scope] class TestGetLock: def test_same_scope_returns_same_lock(self): store = FakeRoomStore() lock1 = store.get_lock("scope1") lock2 = store.get_lock("scope1") assert lock1 is lock2 def test_different_scopes_different_locks(self): store = FakeRoomStore() lock1 = store.get_lock("scope1") lock2 = store.get_lock("scope2") assert lock1 is not lock2 def test_lock_initially_not_locked(self): store = FakeRoomStore() lock = store.get_lock("scope1") assert not lock.locked() @pytest.mark.asyncio async def test_concurrent_get_lock_same_scope(self): """Adversarial: concurrent calls to get_lock for same scope must return same lock.""" store = FakeRoomStore() results = [] async def grab_lock(scope): lock = store.get_lock(scope) results.append(id(lock)) tasks = [grab_lock("shared") for _ in range(100)] await asyncio.gather(*tasks) # ALL must be the same lock object assert len(set(results)) == 1, f"Got {len(set(results))} different lock objects!" @pytest.mark.asyncio async def test_lock_prevents_concurrent_execution(self): """Verify the lock actually serializes access.""" store = FakeRoomStore() execution_log = [] async def protected_operation(scope, op_id): lock = store.get_lock(scope) async with lock: execution_log.append(f"start_{op_id}") await asyncio.sleep(0.01) execution_log.append(f"end_{op_id}") await asyncio.gather( protected_operation("scope", "A"), protected_operation("scope", "B"), ) # Operations must be serialized: start_A, end_A, start_B, end_B assert execution_log == ["start_A", "end_A", "start_B", "end_B"]