dictation-service/tests/test_vllm_integration.py
Kade Heyborne 73a15d03cd
Fix dictation service: state detection, async processing, and performance optimizations
- Fix state detection priority: dictation now takes precedence over conversation
- Fix critical bug: event loop was created but never started, preventing async coroutines from executing
- Optimize audio processing: reorder AcceptWaveform/PartialResult checks
- Switch to faster Vosk model: vosk-model-en-us-0.22-lgraph for 2-3x speed improvement
- Reduce block size from 8000 to 4000 for lower latency
- Add filtering to remove spurious 'the', 'a', 'an' words from start/end of transcriptions
- Update toggle-dictation.sh to properly clean up conversation lock file
- Improve batch audio processing for better responsiveness
2025-12-04 11:49:07 -07:00

464 lines
17 KiB
Python
Executable File

#!/usr/bin/env python3
"""
VLLM Integration Test Suite
Comprehensive testing of VLLM endpoint connectivity and functionality
"""
import os
import sys
import json
import time
import asyncio
import requests
import subprocess
import unittest
from unittest.mock import Mock, patch, AsyncMock
# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
class TestVLLMIntegration(unittest.TestCase):
"""Test VLLM endpoint integration"""
def setUp(self):
"""Setup test environment"""
self.vllm_endpoint = "http://127.0.0.1:8000/v1"
self.api_key = "vllm-api-key"
self.test_model = "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4"
def test_vllm_endpoint_connectivity(self):
"""Test basic VLLM endpoint connectivity"""
print("\n🔗 Testing VLLM Endpoint Connectivity...")
try:
response = requests.get(
f"{self.vllm_endpoint}/models",
headers={"Authorization": f"Bearer {self.api_key}"},
timeout=5
)
if response.status_code == 200:
models_data = response.json()
print("✅ VLLM endpoint is accessible")
self.assertIn("data", models_data)
if models_data["data"]:
print(f"📝 Available models: {len(models_data['data'])}")
for model in models_data["data"]:
print(f" - {model.get('id', 'unknown')}")
else:
print("⚠️ No models available")
else:
print(f"❌ VLLM endpoint returned status {response.status_code}")
print(f"Response: {response.text}")
except requests.exceptions.ConnectionError:
print("❌ Cannot connect to VLLM endpoint - is VLLM running?")
self.skipTest("VLLM endpoint not accessible")
except requests.exceptions.Timeout:
print("❌ VLLM endpoint timeout")
self.skipTest("VLLM endpoint timeout")
except Exception as e:
print(f"❌ VLLM connectivity test failed: {e}")
self.skipTest(f"VLLM test error: {e}")
def test_vllm_chat_completion(self):
"""Test VLLM chat completion API"""
print("\n💬 Testing VLLM Chat Completion...")
test_messages = [
{"role": "system", "content": "You are a helpful assistant. Be concise."},
{"role": "user", "content": "Say 'Hello from VLLM!' and nothing else."}
]
try:
response = requests.post(
f"{self.vllm_endpoint}/chat/completions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"model": self.test_model,
"messages": test_messages,
"max_tokens": 50,
"temperature": 0.7
},
timeout=10
)
if response.status_code == 200:
result = response.json()
self.assertIn("choices", result)
self.assertTrue(len(result["choices"]) > 0)
message = result["choices"][0]["message"]["content"]
print(f"✅ VLLM Response: '{message}'")
# Basic response validation
self.assertIsInstance(message, str)
self.assertTrue(len(message) > 0)
# Check if response contains expected content
self.assertIn("Hello", message, "Response should contain greeting")
print("✅ Chat completion test passed")
else:
print(f"❌ Chat completion failed: {response.status_code}")
print(f"Response: {response.text}")
self.fail("VLLM chat completion failed")
except requests.exceptions.RequestException as e:
print(f"❌ Chat completion request failed: {e}")
self.skipTest("VLLM request failed")
def test_vllm_conversation_context(self):
"""Test VLLM maintains conversation context"""
print("\n🧠 Testing VLLM Conversation Context...")
conversation = [
{"role": "system", "content": "You are a helpful assistant who remembers previous messages."},
{"role": "user", "content": "My name is Alex."},
{"role": "assistant", "content": "Hello Alex! Nice to meet you."},
{"role": "user", "content": "What is my name?"}
]
try:
response = requests.post(
f"{self.vllm_endpoint}/chat/completions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"model": self.test_model,
"messages": conversation,
"max_tokens": 50,
"temperature": 0.7
},
timeout=10
)
if response.status_code == 200:
result = response.json()
message = result["choices"][0]["message"]["content"]
print(f"✅ Context-aware response: '{message}'")
# Check if AI remembers the name
self.assertIn("Alex", message, "AI should remember the name 'Alex'")
print("✅ Conversation context test passed")
else:
print(f"❌ Context test failed: {response.status_code}")
self.fail("VLLM context test failed")
except requests.exceptions.RequestException as e:
print(f"❌ Context test request failed: {e}")
self.skipTest("VLLM context test failed")
def test_vllm_performance(self):
"""Test VLLM response performance"""
print("\n⚡ Testing VLLM Performance...")
test_message = [
{"role": "user", "content": "Respond with just 'Performance test successful'."}
]
times = []
num_tests = 3
for i in range(num_tests):
try:
start_time = time.time()
response = requests.post(
f"{self.vllm_endpoint}/chat/completions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"model": self.test_model,
"messages": test_message,
"max_tokens": 20,
"temperature": 0.1
},
timeout=15
)
end_time = time.time()
if response.status_code == 200:
response_time = end_time - start_time
times.append(response_time)
print(f" Test {i+1}: {response_time:.2f}s")
else:
print(f" Test {i+1}: Failed ({response.status_code})")
except requests.exceptions.RequestException as e:
print(f" Test {i+1}: Error - {e}")
if times:
avg_time = sum(times) / len(times)
print(f"✅ Average response time: {avg_time:.2f}s")
# Performance assertions
self.assertLess(avg_time, 10.0, "Average response time should be under 10 seconds")
print("✅ Performance test passed")
else:
print("❌ No successful performance tests")
self.fail("All performance tests failed")
def test_vllm_error_handling(self):
"""Test VLLM error handling"""
print("\n🚨 Testing VLLM Error Handling...")
# Test invalid model
try:
response = requests.post(
f"{self.vllm_endpoint}/chat/completions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"model": "nonexistent-model",
"messages": [{"role": "user", "content": "test"}],
"max_tokens": 10
},
timeout=5
)
# Should handle error gracefully
if response.status_code != 200:
print(f"✅ Invalid model error handled: {response.status_code}")
else:
print("⚠️ Invalid model did not return error")
except requests.exceptions.RequestException as e:
print(f"✅ Error handling test: {e}")
# Test invalid API key
try:
response = requests.post(
f"{self.vllm_endpoint}/chat/completions",
headers={
"Authorization": "Bearer invalid-key",
"Content-Type": "application/json"
},
json={
"model": self.test_model,
"messages": [{"role": "user", "content": "test"}],
"max_tokens": 10
},
timeout=5
)
if response.status_code == 401:
print("✅ Invalid API key properly rejected")
else:
print(f"⚠️ Invalid API key response: {response.status_code}")
except requests.exceptions.RequestException as e:
print(f"✅ API key error handling: {e}")
def test_vllm_streaming(self):
"""Test VLLM streaming capabilities (if supported)"""
print("\n🌊 Testing VLLM Streaming...")
try:
response = requests.post(
f"{self.vllm_endpoint}/chat/completions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"model": self.test_model,
"messages": [{"role": "user", "content": "Count from 1 to 5"}],
"max_tokens": 50,
"stream": True
},
timeout=10,
stream=True
)
if response.status_code == 200:
chunks_received = 0
for line in response.iter_lines():
if line:
chunks_received += 1
if chunks_received >= 5: # Test a few chunks
break
if chunks_received > 0:
print(f"✅ Streaming working: {chunks_received} chunks received")
else:
print("⚠️ Streaming enabled but no chunks received")
else:
print(f"⚠️ Streaming not supported or failed: {response.status_code}")
except requests.exceptions.RequestException as e:
print(f"⚠️ Streaming test failed: {e}")
class TestVLLMClientIntegration(unittest.TestCase):
"""Test VLLM client integration with AI dictation service"""
def setUp(self):
"""Setup test environment"""
try:
from src.dictation_service.ai_dictation_simple import VLLMClient
self.client = VLLMClient()
except ImportError as e:
self.skipTest(f"Cannot import VLLMClient: {e}")
def test_client_initialization(self):
"""Test VLLM client initialization"""
self.assertIsNotNone(self.client)
self.assertIsNotNone(self.client.client)
self.assertEqual(self.client.endpoint, "http://127.0.0.1:8000/v1")
def test_client_message_formatting(self):
"""Test client message formatting for API calls"""
# This would test the message formatting logic
# Implementation depends on the actual VLLMClient structure
pass
class TestConversationIntegration(unittest.TestCase):
"""Test conversation integration with VLLM"""
def setUp(self):
"""Setup test environment"""
self.temp_dir = os.path.join(os.getcwd(), "test_temp")
os.makedirs(self.temp_dir, exist_ok=True)
self.history_file = os.path.join(self.temp_dir, "test_history.json")
def tearDown(self):
"""Clean up test environment"""
if os.path.exists(self.history_file):
os.remove(self.history_file)
if os.path.exists(self.temp_dir):
os.rmdir(self.temp_dir)
def test_conversation_flow_simulation(self):
"""Simulate complete conversation flow with VLLM"""
print("\n🔄 Testing Conversation Flow Simulation...")
try:
# Test actual VLLM call if endpoint is available
response = requests.post(
"http://127.0.0.1:8000/v1/chat/completions",
headers={
"Authorization": "Bearer vllm-api-key",
"Content-Type": "application/json"
},
json={
"model": "default",
"messages": [
{"role": "system", "content": "You are a helpful AI assistant for dictation service testing."},
{"role": "user", "content": "Say 'Hello! I'm ready to help with your dictation.'"}
],
"max_tokens": 100,
"temperature": 0.7
},
timeout=10
)
if response.status_code == 200:
result = response.json()
ai_response = result["choices"][0]["message"]["content"]
print(f"✅ Conversation test response: '{ai_response}'")
# Basic validation
self.assertIsInstance(ai_response, str)
self.assertTrue(len(ai_response) > 0)
print("✅ Conversation flow simulation passed")
else:
print(f"⚠️ Conversation simulation failed: {response.status_code}")
except requests.exceptions.RequestException as e:
print(f"⚠️ Conversation simulation failed: {e}")
def test_vllm_service_status():
"""Test VLLM service status and configuration"""
print("\n🔍 VLLM Service Status Check...")
# Check if VLLM process is running
try:
result = subprocess.run(
["ps", "aux"],
capture_output=True,
text=True
)
if "vllm" in result.stdout.lower():
print("✅ VLLM process appears to be running")
# Extract some info
lines = result.stdout.split('\n')
for line in lines:
if 'vllm' in line.lower():
print(f" Process: {line[:80]}...")
else:
print("⚠️ VLLM process not detected")
except Exception as e:
print(f"⚠️ Could not check VLLM process status: {e}")
# Check common VLLM ports
common_ports = [8000, 8001, 8002]
for port in common_ports:
try:
response = requests.get(f"http://127.0.0.1:{port}/health", timeout=2)
if response.status_code == 200:
print(f"✅ VLLM health check passed on port {port}")
except:
pass
def test_vllm_configuration():
"""Test VLLM configuration recommendations"""
print("\n⚙️ VLLM Configuration Check...")
config_checks = [
("Environment variable VLLM_ENDPOINT", os.getenv("VLLM_ENDPOINT")),
("Environment variable VLLM_API_KEY", "vllm-api-key" in str(os.getenv("VLLM_API_KEY", ""))),
("Network connectivity to localhost", "127.0.0.1"),
]
for check_name, check_result in config_checks:
if check_result:
print(f"{check_name}: Available")
else:
print(f"⚠️ {check_name}: Not configured")
def main():
"""Main VLLM test runner"""
print("🤖 VLLM Integration Test Suite")
print("=" * 50)
# Service status checks
test_vllm_service_status()
test_vllm_configuration()
# Run unit tests
print("\n📋 Running VLLM Integration Tests...")
unittest.main(argv=[''], exit=False, verbosity=2)
print("\n" + "=" * 50)
print("✅ VLLM Integration Tests Complete!")
print("\n📊 Summary:")
print("- VLLM endpoint connectivity tested")
print("- Chat completion functionality verified")
print("- Conversation context management tested")
print("- Performance benchmarks conducted")
print("- Error handling validated")
print("\n🔧 VLLM Setup Status:")
print("- Endpoint: http://127.0.0.1:8000/v1")
print("- API Key: vllm-api-key")
print("- Model: default")
print("\n💡 Next Steps:")
print("1. Ensure VLLM service is running for full functionality")
print("2. Monitor response times for optimal user experience")
print("3. Consider model selection based on accuracy vs speed requirements")
if __name__ == "__main__":
main()