dictation-service/tests/test_suite.py
Kade Heyborne 73a15d03cd
Fix dictation service: state detection, async processing, and performance optimizations
- Fix state detection priority: dictation now takes precedence over conversation
- Fix critical bug: event loop was created but never started, preventing async coroutines from executing
- Optimize audio processing: reorder AcceptWaveform/PartialResult checks
- Switch to faster Vosk model: vosk-model-en-us-0.22-lgraph for 2-3x speed improvement
- Reduce block size from 8000 to 4000 for lower latency
- Add filtering to remove spurious 'the', 'a', 'an' words from start/end of transcriptions
- Update toggle-dictation.sh to properly clean up conversation lock file
- Improve batch audio processing for better responsiveness
2025-12-04 11:49:07 -07:00

642 lines
24 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Comprehensive Test Suite for AI Dictation Service
Tests all features: basic dictation, AI conversation, TTS, state management, etc.
"""
import os
import sys
import json
import time
import tempfile
import unittest
import threading
import subprocess
import asyncio
import aiohttp
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
# Add src to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
# Test Configuration
TEST_CONFIG = {
"test_audio_file": "test_audio.wav",
"test_conversation_file": "test_conversation_history.json",
"test_lock_files": {
"dictation": "test_listening.lock",
"conversation": "test_conversation.lock"
}
}
class TestVLLMClient(unittest.TestCase):
"""Test VLLM API integration"""
def setUp(self):
"""Setup test environment"""
self.test_endpoint = "http://127.0.0.1:8000/v1"
# Import here to avoid import issues if dependencies missing
try:
from src.dictation_service.ai_dictation_simple import VLLMClient
self.client = VLLMClient(self.test_endpoint)
except ImportError as e:
self.skipTest(f"Cannot import VLLMClient: {e}")
def test_client_initialization(self):
"""Test VLLM client can be initialized"""
self.assertIsNotNone(self.client)
self.assertEqual(self.client.endpoint, self.test_endpoint)
self.assertIsNotNone(self.client.client)
def test_connection_test(self):
"""Test VLLM endpoint connectivity"""
# Mock requests to test connection logic
with patch('requests.get') as mock_get:
# Test successful connection
mock_response = Mock()
mock_response.status_code = 200
mock_get.return_value = mock_response
# This should not raise an exception
self.client._test_connection()
mock_get.assert_called_with(f"{self.test_endpoint}/models", timeout=2)
def test_api_response_formatting(self):
"""Test API response formatting"""
test_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"}
]
# Mock the OpenAI client response
with patch.object(self.client.client, 'chat') as mock_chat:
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Hello! How can I help you?"
mock_chat.completions.create.return_value = mock_response
# Test async call (simplified)
async def test_call():
result = await self.client.get_response(test_messages)
self.assertEqual(result, "Hello! How can I help you?")
mock_chat.completions.create.assert_called_once()
# Run the test
asyncio.run(test_call())
class TestTTSManager(unittest.TestCase):
"""Test Text-to-Speech functionality"""
def setUp(self):
"""Setup test environment"""
try:
from src.dictation_service.ai_dictation_simple import TTSManager
self.tts = TTSManager()
except ImportError as e:
self.skipTest(f"Cannot import TTSManager: {e}")
def test_tts_initialization(self):
"""Test TTS manager initialization"""
self.assertIsNotNone(self.tts)
# TTS might be disabled if engine fails to initialize
self.assertIsInstance(self.tts.enabled, bool)
def test_tts_speak_empty_text(self):
"""Test TTS with empty text"""
# Should not crash with empty text
try:
self.tts.speak("")
self.tts.speak(" ")
except Exception as e:
self.fail(f"TTS crashed with empty text: {e}")
def test_tts_speak_normal_text(self):
"""Test TTS with normal text"""
test_text = "Hello world, this is a test."
# Mock pyttsx3 to avoid actual speech during tests
with patch('pyttsx3.init') as mock_init:
mock_engine = Mock()
mock_init.return_value = mock_engine
# Re-initialize TTS with mock
from src.dictation_service.ai_dictation_simple import TTSManager
tts_mock = TTSManager()
tts_mock.speak(test_text)
mock_engine.say.assert_called_once_with(test_text)
mock_engine.runAndWait.assert_called_once()
class TestConversationManager(unittest.TestCase):
"""Test conversation management and context persistence"""
def setUp(self):
"""Setup test environment"""
self.temp_dir = tempfile.mkdtemp()
self.history_file = os.path.join(self.temp_dir, "test_history.json")
try:
from src.dictation_service.ai_dictation_simple import ConversationManager, ConversationMessage
# Patch the history file path
with patch('src.dictation_service.ai_dictation_simple.ConversationManager.persistent_history_file', self.history_file):
self.conv_manager = ConversationManager()
except ImportError as e:
self.skipTest(f"Cannot import ConversationManager: {e}")
def tearDown(self):
"""Clean up test environment"""
if os.path.exists(self.history_file):
os.remove(self.history_file)
os.rmdir(self.temp_dir)
def test_message_addition(self):
"""Test adding messages to conversation"""
initial_count = len(self.conv_manager.conversation_history)
self.conv_manager.add_message("user", "Hello AI")
self.conv_manager.add_message("assistant", "Hello human!")
self.assertEqual(len(self.conv_manager.conversation_history), initial_count + 2)
self.assertEqual(self.conv_manager.conversation_history[-1].content, "Hello human!")
self.assertEqual(self.conv_manager.conversation_history[-1].role, "assistant")
def test_conversation_persistence(self):
"""Test conversation history persistence"""
# Add some messages
self.conv_manager.add_message("user", "Test message 1")
self.conv_manager.add_message("assistant", "Test response 1")
# Force save
self.conv_manager.save_persistent_history()
# Verify file exists and contains data
self.assertTrue(os.path.exists(self.history_file))
with open(self.history_file, 'r') as f:
data = json.load(f)
self.assertEqual(len(data), 2)
self.assertEqual(data[0]['content'], "Test message 1")
self.assertEqual(data[1]['content'], "Test response 1")
def test_conversation_loading(self):
"""Test loading conversation from file"""
# Create test history file
test_data = [
{"role": "user", "content": "Loaded message 1", "timestamp": 1234567890},
{"role": "assistant", "content": "Loaded response 1", "timestamp": 1234567891}
]
with open(self.history_file, 'w') as f:
json.dump(test_data, f)
# Create new manager and load
with patch('src.dictation_service.ai_dictation_simple.ConversationManager.persistent_history_file', self.history_file):
new_manager = ConversationManager()
self.assertEqual(len(new_manager.conversation_history), 2)
self.assertEqual(new_manager.conversation_history[0].content, "Loaded message 1")
def test_api_message_formatting(self):
"""Test message formatting for API calls"""
self.conv_manager.add_message("user", "Test user message")
self.conv_manager.add_message("assistant", "Test assistant response")
api_messages = self.conv_manager.get_messages_for_api()
# Should have system prompt + conversation messages
self.assertEqual(len(api_messages), 3) # system + 2 messages
# Check system prompt
self.assertEqual(api_messages[0]['role'], 'system')
self.assertIn('helpful AI assistant', api_messages[0]['content'])
# Check user message
self.assertEqual(api_messages[1]['role'], 'user')
self.assertEqual(api_messages[1]['content'], 'Test user message')
def test_history_limit(self):
"""Test conversation history limit"""
# Mock max history to be small for testing
original_max = self.conv_manager.max_history
self.conv_manager.max_history = 3
# Add more messages than limit
for i in range(5):
self.conv_manager.add_message("user", f"Message {i}")
# Should only keep the last 3 messages
self.assertEqual(len(self.conv_manager.conversation_history), 3)
self.assertEqual(self.conv_manager.conversation_history[-1].content, "Message 4")
# Restore original limit
self.conv_manager.max_history = original_max
def test_clear_history(self):
"""Test clearing conversation history"""
# Add some messages
self.conv_manager.add_message("user", "Test message")
self.conv_manager.save_persistent_history()
# Verify file exists
self.assertTrue(os.path.exists(self.history_file))
# Clear history
self.conv_manager.clear_all_history()
# Verify cleared
self.assertEqual(len(self.conv_manager.conversation_history), 0)
self.assertFalse(os.path.exists(self.history_file))
class TestStateManager(unittest.TestCase):
"""Test application state management"""
def setUp(self):
"""Setup test environment"""
self.test_files = {
'dictation': TEST_CONFIG["test_lock_files"]["dictation"],
'conversation': TEST_CONFIG["test_lock_files"]["conversation"]
}
# Clean up any existing test files
for file_path in self.test_files.values():
if os.path.exists(file_path):
os.remove(file_path)
def tearDown(self):
"""Clean up test environment"""
for file_path in self.test_files.values():
if os.path.exists(file_path):
os.remove(file_path)
def test_lock_file_creation_removal(self):
"""Test lock file creation and removal"""
# Test dictation lock
self.assertFalse(os.path.exists(self.test_files['dictation']))
# Create lock file
Path(self.test_files['dictation']).touch()
self.assertTrue(os.path.exists(self.test_files['dictation']))
# Remove lock file
os.remove(self.test_files['dictation'])
self.assertFalse(os.path.exists(self.test_files['dictation']))
def test_state_transitions(self):
"""Test state transition logic"""
# Simulate state checking logic
def get_app_state():
dictation_active = os.path.exists(self.test_files['dictation'])
conversation_active = os.path.exists(self.test_files['conversation'])
if conversation_active:
return "conversation"
elif dictation_active:
return "dictation"
else:
return "idle"
# Test idle state
self.assertEqual(get_app_state(), "idle")
# Test dictation state
Path(self.test_files['dictation']).touch()
self.assertEqual(get_app_state(), "dictation")
# Test conversation state (takes precedence)
Path(self.test_files['conversation']).touch()
self.assertEqual(get_app_state(), "conversation")
# Test removing conversation state
os.remove(self.test_files['conversation'])
self.assertEqual(get_app_state(), "dictation")
# Test back to idle
os.remove(self.test_files['dictation'])
self.assertEqual(get_app_state(), "idle")
class TestAudioProcessing(unittest.TestCase):
"""Test audio processing functionality"""
def test_audio_callback_basic(self):
"""Test basic audio callback functionality"""
try:
import numpy as np
from src.dictation_service.ai_dictation_simple import audio_callback
# Create mock audio data
audio_data = np.random.randint(-32768, 32767, size=(8000, 1), dtype=np.int16)
# Test that callback doesn't crash
try:
audio_callback(audio_data, 8000, None, None)
except Exception as e:
self.fail(f"Audio callback crashed: {e}")
except ImportError:
self.skipTest("numpy not available for audio testing")
def test_text_filtering(self):
"""Test text filtering and processing"""
# Mock text processing function
def should_filter_text(text):
"""Simulate text filtering logic"""
formatted = text.strip()
# Filter spurious words
if len(formatted.split()) == 1 and formatted.lower() in ['the', 'a', 'an', 'uh', 'huh', 'um', 'hmm']:
return True
# Filter very short text
if len(formatted) < 2:
return True
return False
# Test filtering
self.assertTrue(should_filter_text("the"))
self.assertTrue(should_filter_text("uh"))
self.assertTrue(should_filter_text("a"))
self.assertTrue(should_filter_text("x"))
self.assertTrue(should_filter_text(" "))
# Test passing through
self.assertFalse(should_filter_text("hello world"))
self.assertFalse(should_filter_text("test message"))
self.assertFalse(should_filter_text("conversation"))
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete system"""
def setUp(self):
"""Setup integration test environment"""
self.temp_dir = tempfile.mkdtemp()
# Create temporary config files
self.history_file = os.path.join(self.temp_dir, "integration_history.json")
self.lock_files = {
'dictation': os.path.join(self.temp_dir, "dictation.lock"),
'conversation': os.path.join(self.temp_dir, "conversation.lock")
}
def tearDown(self):
"""Clean up integration test environment"""
# Clean up temp files
for file_path in [self.history_file] + list(self.lock_files.values()):
if os.path.exists(file_path):
os.remove(file_path)
os.rmdir(self.temp_dir)
def test_full_conversation_flow(self):
"""Test complete conversation flow without actual VLLM calls"""
try:
from src.dictation_service.ai_dictation_simple import ConversationManager
# Mock the VLLM client to avoid actual API calls
with patch('src.dictation_service.ai_dictation_simple.VLLMClient') as mock_client_class:
mock_client = Mock()
mock_client_class.return_value = mock_client
# Mock async response
async def mock_get_response(messages):
return "Mock AI response"
mock_client.get_response = mock_get_response
# Mock TTS to avoid actual speech
with patch('src.dictation_service.ai_dictation_simple.TTSManager') as mock_tts_class:
mock_tts = Mock()
mock_tts_class.return_value = mock_tts
# Patch history file
with patch('src.dictation_service.ai_dictation_simple.ConversationManager.persistent_history_file', self.history_file):
manager = ConversationManager()
# Test conversation flow
async def test_conversation():
# Start conversation
manager.start_conversation()
# Process user input
await manager.process_user_input("Hello AI")
# Verify user message was added
self.assertEqual(len(manager.conversation_history), 1)
self.assertEqual(manager.conversation_history[0].role, "user")
# Verify AI response was processed
mock_client.get_response.assert_called_once()
# End conversation
manager.end_conversation()
# Run async test
asyncio.run(test_conversation())
# Verify persistence
self.assertTrue(os.path.exists(self.history_file))
except ImportError as e:
self.skipTest(f"Cannot import required modules: {e}")
def test_vllm_endpoint_connectivity(self):
"""Test actual VLLM endpoint connectivity if available"""
try:
import requests
# Test VLLM endpoint
response = requests.get("http://127.0.0.1:8000/v1/models",
headers={"Authorization": "Bearer vllm-api-key"},
timeout=5)
# If VLLM is running, test basic functionality
if response.status_code == 200:
self.assertIn("data", response.json())
print("✅ VLLM endpoint is accessible")
else:
print(f"⚠️ VLLM endpoint returned status {response.status_code}")
except requests.exceptions.RequestException as e:
print(f"⚠️ VLLM endpoint not accessible: {e}")
# This is not a failure, just info
self.skipTest("VLLM endpoint not available")
class TestScriptFunctionality(unittest.TestCase):
"""Test shell scripts and external functionality"""
def setUp(self):
"""Setup script testing environment"""
self.script_dir = os.path.join(os.path.dirname(__file__), '..', 'scripts')
self.temp_dir = tempfile.mkdtemp()
# Create test lock files in temp directory
self.test_locks = {
'listening': os.path.join(self.temp_dir, 'listening.lock'),
'conversation': os.path.join(self.temp_dir, 'conversation.lock')
}
def tearDown(self):
"""Clean up script test environment"""
for lock_file in self.test_locks.values():
if os.path.exists(lock_file):
os.remove(lock_file)
os.rmdir(self.temp_dir)
def test_toggle_scripts_exist(self):
"""Test that toggle scripts exist and are executable"""
dictation_script = os.path.join(self.script_dir, 'toggle-dictation.sh')
conversation_script = os.path.join(self.script_dir, 'toggle-conversation.sh')
self.assertTrue(os.path.exists(dictation_script), "Dictation toggle script should exist")
self.assertTrue(os.path.exists(conversation_script), "Conversation toggle script should exist")
# Check they're executable (might not be if user hasn't run chmod)
# This is informational, not a failure
if not os.access(dictation_script, os.X_OK):
print("⚠️ Dictation script not executable - run 'chmod +x toggle-dictation.sh'")
if not os.access(conversation_script, os.X_OK):
print("⚠️ Conversation script not executable - run 'chmod +x toggle-conversation.sh'")
def test_notification_system(self):
"""Test system notification functionality"""
try:
result = subprocess.run(
["notify-send", "-t", "1000", "Test Title", "Test Message"],
capture_output=True,
timeout=5
)
# If notify-send works, it should return 0
if result.returncode == 0:
print("✅ System notifications working")
else:
print(f"⚠️ Notification system issue: {result.stderr.decode()}")
except subprocess.TimeoutExpired:
print("⚠️ Notification command timed out")
except FileNotFoundError:
print("⚠️ notify-send not available")
except Exception as e:
print(f"⚠️ Notification test error: {e}")
def run_audio_input_test():
"""Interactive test for audio input (requires user interaction)"""
print("\n🎤 Audio Input Test")
print("This test requires a microphone and will record 3 seconds of audio.")
print("Press Enter to start or skip with Ctrl+C...")
try:
input()
# Test audio recording
test_file = "test_audio_recording.wav"
try:
subprocess.run([
"arecord", "-d", "3", "-f", "cd", test_file
], check=True, capture_output=True)
if os.path.exists(test_file):
print("✅ Audio recording successful")
# Test playback
subprocess.run(["aplay", test_file], check=True, capture_output=True)
print("✅ Audio playback successful")
# Clean up
os.remove(test_file)
else:
print("❌ Audio recording failed - no file created")
except subprocess.CalledProcessError as e:
print(f"❌ Audio test failed: {e}")
except FileNotFoundError:
print("⚠️ arecord/aplay not available")
except KeyboardInterrupt:
print("\n⏭️ Audio test skipped")
def run_vllm_test():
"""Test VLLM functionality with actual API call"""
print("\n🤖 VLLM Integration Test")
print("Testing actual VLLM API call...")
try:
import requests
import time
# Test endpoint
response = requests.get(
"http://127.0.0.1:8000/v1/models",
headers={"Authorization": "Bearer vllm-api-key"},
timeout=5
)
if response.status_code == 200:
print("✅ VLLM endpoint accessible")
# Test chat completion
chat_response = requests.post(
"http://127.0.0.1:8000/v1/chat/completions",
headers={
"Authorization": "Bearer vllm-api-key",
"Content-Type": "application/json"
},
json={
"model": "default",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say 'Hello from VLLM!'"}
],
"max_tokens": 50,
"temperature": 0.7
},
timeout=10
)
if chat_response.status_code == 200:
result = chat_response.json()
message = result['choices'][0]['message']['content']
print(f"✅ VLLM chat successful: '{message}'")
else:
print(f"❌ VLLM chat failed: {chat_response.status_code} - {chat_response.text}")
else:
print(f"❌ VLLM endpoint error: {response.status_code} - {response.text}")
except requests.exceptions.RequestException as e:
print(f"❌ VLLM connection failed: {e}")
except Exception as e:
print(f"❌ VLLM test error: {e}")
def main():
"""Main test runner"""
print("🧪 AI Dictation Service - Comprehensive Test Suite")
print("=" * 50)
# Run unit tests
print("\n📋 Running Unit Tests...")
unittest.main(argv=[''], exit=False, verbosity=2)
print("\n" + "=" * 50)
print("🎯 Running Interactive Tests...")
# Audio input test (requires user interaction)
run_audio_input_test()
# VLLM integration test
run_vllm_test()
print("\n" + "=" * 50)
print("✅ Test Suite Complete!")
print("\n📊 Summary:")
print("- Unit tests cover all core components")
print("- Integration tests verify system interaction")
print("- Audio tests require microphone access")
print("- VLLM tests require running VLLM service")
print("\n🔧 Next Steps:")
print("1. Ensure VLLM is running for full functionality")
print("2. Set up keybindings manually if scripts failed")
print("3. Test with actual voice input for real-world validation")
if __name__ == "__main__":
main()