- Fix state detection priority: dictation now takes precedence over conversation - Fix critical bug: event loop was created but never started, preventing async coroutines from executing - Optimize audio processing: reorder AcceptWaveform/PartialResult checks - Switch to faster Vosk model: vosk-model-en-us-0.22-lgraph for 2-3x speed improvement - Reduce block size from 8000 to 4000 for lower latency - Add filtering to remove spurious 'the', 'a', 'an' words from start/end of transcriptions - Update toggle-dictation.sh to properly clean up conversation lock file - Improve batch audio processing for better responsiveness
464 lines
17 KiB
Python
Executable File
464 lines
17 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
VLLM Integration Test Suite
|
|
Comprehensive testing of VLLM endpoint connectivity and functionality
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
import asyncio
|
|
import requests
|
|
import subprocess
|
|
import unittest
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
|
|
# Add src to path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
|
|
|
class TestVLLMIntegration(unittest.TestCase):
|
|
"""Test VLLM endpoint integration"""
|
|
|
|
def setUp(self):
|
|
"""Setup test environment"""
|
|
self.vllm_endpoint = "http://127.0.0.1:8000/v1"
|
|
self.api_key = "vllm-api-key"
|
|
self.test_model = "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4"
|
|
|
|
def test_vllm_endpoint_connectivity(self):
|
|
"""Test basic VLLM endpoint connectivity"""
|
|
print("\n🔗 Testing VLLM Endpoint Connectivity...")
|
|
|
|
try:
|
|
response = requests.get(
|
|
f"{self.vllm_endpoint}/models",
|
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
|
timeout=5
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
models_data = response.json()
|
|
print("✅ VLLM endpoint is accessible")
|
|
self.assertIn("data", models_data)
|
|
|
|
if models_data["data"]:
|
|
print(f"📝 Available models: {len(models_data['data'])}")
|
|
for model in models_data["data"]:
|
|
print(f" - {model.get('id', 'unknown')}")
|
|
else:
|
|
print("⚠️ No models available")
|
|
else:
|
|
print(f"❌ VLLM endpoint returned status {response.status_code}")
|
|
print(f"Response: {response.text}")
|
|
|
|
except requests.exceptions.ConnectionError:
|
|
print("❌ Cannot connect to VLLM endpoint - is VLLM running?")
|
|
self.skipTest("VLLM endpoint not accessible")
|
|
except requests.exceptions.Timeout:
|
|
print("❌ VLLM endpoint timeout")
|
|
self.skipTest("VLLM endpoint timeout")
|
|
except Exception as e:
|
|
print(f"❌ VLLM connectivity test failed: {e}")
|
|
self.skipTest(f"VLLM test error: {e}")
|
|
|
|
def test_vllm_chat_completion(self):
|
|
"""Test VLLM chat completion API"""
|
|
print("\n💬 Testing VLLM Chat Completion...")
|
|
|
|
test_messages = [
|
|
{"role": "system", "content": "You are a helpful assistant. Be concise."},
|
|
{"role": "user", "content": "Say 'Hello from VLLM!' and nothing else."}
|
|
]
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self.vllm_endpoint}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": self.test_model,
|
|
"messages": test_messages,
|
|
"max_tokens": 50,
|
|
"temperature": 0.7
|
|
},
|
|
timeout=10
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
self.assertIn("choices", result)
|
|
self.assertTrue(len(result["choices"]) > 0)
|
|
|
|
message = result["choices"][0]["message"]["content"]
|
|
print(f"✅ VLLM Response: '{message}'")
|
|
|
|
# Basic response validation
|
|
self.assertIsInstance(message, str)
|
|
self.assertTrue(len(message) > 0)
|
|
|
|
# Check if response contains expected content
|
|
self.assertIn("Hello", message, "Response should contain greeting")
|
|
print("✅ Chat completion test passed")
|
|
else:
|
|
print(f"❌ Chat completion failed: {response.status_code}")
|
|
print(f"Response: {response.text}")
|
|
self.fail("VLLM chat completion failed")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"❌ Chat completion request failed: {e}")
|
|
self.skipTest("VLLM request failed")
|
|
|
|
def test_vllm_conversation_context(self):
|
|
"""Test VLLM maintains conversation context"""
|
|
print("\n🧠 Testing VLLM Conversation Context...")
|
|
|
|
conversation = [
|
|
{"role": "system", "content": "You are a helpful assistant who remembers previous messages."},
|
|
{"role": "user", "content": "My name is Alex."},
|
|
{"role": "assistant", "content": "Hello Alex! Nice to meet you."},
|
|
{"role": "user", "content": "What is my name?"}
|
|
]
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self.vllm_endpoint}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": self.test_model,
|
|
"messages": conversation,
|
|
"max_tokens": 50,
|
|
"temperature": 0.7
|
|
},
|
|
timeout=10
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
message = result["choices"][0]["message"]["content"]
|
|
print(f"✅ Context-aware response: '{message}'")
|
|
|
|
# Check if AI remembers the name
|
|
self.assertIn("Alex", message, "AI should remember the name 'Alex'")
|
|
print("✅ Conversation context test passed")
|
|
else:
|
|
print(f"❌ Context test failed: {response.status_code}")
|
|
self.fail("VLLM context test failed")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"❌ Context test request failed: {e}")
|
|
self.skipTest("VLLM context test failed")
|
|
|
|
def test_vllm_performance(self):
|
|
"""Test VLLM response performance"""
|
|
print("\n⚡ Testing VLLM Performance...")
|
|
|
|
test_message = [
|
|
{"role": "user", "content": "Respond with just 'Performance test successful'."}
|
|
]
|
|
|
|
times = []
|
|
num_tests = 3
|
|
|
|
for i in range(num_tests):
|
|
try:
|
|
start_time = time.time()
|
|
response = requests.post(
|
|
f"{self.vllm_endpoint}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": self.test_model,
|
|
"messages": test_message,
|
|
"max_tokens": 20,
|
|
"temperature": 0.1
|
|
},
|
|
timeout=15
|
|
)
|
|
end_time = time.time()
|
|
|
|
if response.status_code == 200:
|
|
response_time = end_time - start_time
|
|
times.append(response_time)
|
|
print(f" Test {i+1}: {response_time:.2f}s")
|
|
else:
|
|
print(f" Test {i+1}: Failed ({response.status_code})")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f" Test {i+1}: Error - {e}")
|
|
|
|
if times:
|
|
avg_time = sum(times) / len(times)
|
|
print(f"✅ Average response time: {avg_time:.2f}s")
|
|
|
|
# Performance assertions
|
|
self.assertLess(avg_time, 10.0, "Average response time should be under 10 seconds")
|
|
print("✅ Performance test passed")
|
|
else:
|
|
print("❌ No successful performance tests")
|
|
self.fail("All performance tests failed")
|
|
|
|
def test_vllm_error_handling(self):
|
|
"""Test VLLM error handling"""
|
|
print("\n🚨 Testing VLLM Error Handling...")
|
|
|
|
# Test invalid model
|
|
try:
|
|
response = requests.post(
|
|
f"{self.vllm_endpoint}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": "nonexistent-model",
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"max_tokens": 10
|
|
},
|
|
timeout=5
|
|
)
|
|
|
|
# Should handle error gracefully
|
|
if response.status_code != 200:
|
|
print(f"✅ Invalid model error handled: {response.status_code}")
|
|
else:
|
|
print("⚠️ Invalid model did not return error")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"✅ Error handling test: {e}")
|
|
|
|
# Test invalid API key
|
|
try:
|
|
response = requests.post(
|
|
f"{self.vllm_endpoint}/chat/completions",
|
|
headers={
|
|
"Authorization": "Bearer invalid-key",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": self.test_model,
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"max_tokens": 10
|
|
},
|
|
timeout=5
|
|
)
|
|
|
|
if response.status_code == 401:
|
|
print("✅ Invalid API key properly rejected")
|
|
else:
|
|
print(f"⚠️ Invalid API key response: {response.status_code}")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"✅ API key error handling: {e}")
|
|
|
|
def test_vllm_streaming(self):
|
|
"""Test VLLM streaming capabilities (if supported)"""
|
|
print("\n🌊 Testing VLLM Streaming...")
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self.vllm_endpoint}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": self.test_model,
|
|
"messages": [{"role": "user", "content": "Count from 1 to 5"}],
|
|
"max_tokens": 50,
|
|
"stream": True
|
|
},
|
|
timeout=10,
|
|
stream=True
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
chunks_received = 0
|
|
for line in response.iter_lines():
|
|
if line:
|
|
chunks_received += 1
|
|
if chunks_received >= 5: # Test a few chunks
|
|
break
|
|
|
|
if chunks_received > 0:
|
|
print(f"✅ Streaming working: {chunks_received} chunks received")
|
|
else:
|
|
print("⚠️ Streaming enabled but no chunks received")
|
|
else:
|
|
print(f"⚠️ Streaming not supported or failed: {response.status_code}")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"⚠️ Streaming test failed: {e}")
|
|
|
|
class TestVLLMClientIntegration(unittest.TestCase):
|
|
"""Test VLLM client integration with AI dictation service"""
|
|
|
|
def setUp(self):
|
|
"""Setup test environment"""
|
|
try:
|
|
from src.dictation_service.ai_dictation_simple import VLLMClient
|
|
self.client = VLLMClient()
|
|
except ImportError as e:
|
|
self.skipTest(f"Cannot import VLLMClient: {e}")
|
|
|
|
def test_client_initialization(self):
|
|
"""Test VLLM client initialization"""
|
|
self.assertIsNotNone(self.client)
|
|
self.assertIsNotNone(self.client.client)
|
|
self.assertEqual(self.client.endpoint, "http://127.0.0.1:8000/v1")
|
|
|
|
def test_client_message_formatting(self):
|
|
"""Test client message formatting for API calls"""
|
|
# This would test the message formatting logic
|
|
# Implementation depends on the actual VLLMClient structure
|
|
pass
|
|
|
|
class TestConversationIntegration(unittest.TestCase):
|
|
"""Test conversation integration with VLLM"""
|
|
|
|
def setUp(self):
|
|
"""Setup test environment"""
|
|
self.temp_dir = os.path.join(os.getcwd(), "test_temp")
|
|
os.makedirs(self.temp_dir, exist_ok=True)
|
|
self.history_file = os.path.join(self.temp_dir, "test_history.json")
|
|
|
|
def tearDown(self):
|
|
"""Clean up test environment"""
|
|
if os.path.exists(self.history_file):
|
|
os.remove(self.history_file)
|
|
if os.path.exists(self.temp_dir):
|
|
os.rmdir(self.temp_dir)
|
|
|
|
def test_conversation_flow_simulation(self):
|
|
"""Simulate complete conversation flow with VLLM"""
|
|
print("\n🔄 Testing Conversation Flow Simulation...")
|
|
|
|
try:
|
|
# Test actual VLLM call if endpoint is available
|
|
response = requests.post(
|
|
"http://127.0.0.1:8000/v1/chat/completions",
|
|
headers={
|
|
"Authorization": "Bearer vllm-api-key",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": "default",
|
|
"messages": [
|
|
{"role": "system", "content": "You are a helpful AI assistant for dictation service testing."},
|
|
{"role": "user", "content": "Say 'Hello! I'm ready to help with your dictation.'"}
|
|
],
|
|
"max_tokens": 100,
|
|
"temperature": 0.7
|
|
},
|
|
timeout=10
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
ai_response = result["choices"][0]["message"]["content"]
|
|
print(f"✅ Conversation test response: '{ai_response}'")
|
|
|
|
# Basic validation
|
|
self.assertIsInstance(ai_response, str)
|
|
self.assertTrue(len(ai_response) > 0)
|
|
print("✅ Conversation flow simulation passed")
|
|
else:
|
|
print(f"⚠️ Conversation simulation failed: {response.status_code}")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"⚠️ Conversation simulation failed: {e}")
|
|
|
|
def test_vllm_service_status():
|
|
"""Test VLLM service status and configuration"""
|
|
print("\n🔍 VLLM Service Status Check...")
|
|
|
|
# Check if VLLM process is running
|
|
try:
|
|
result = subprocess.run(
|
|
["ps", "aux"],
|
|
capture_output=True,
|
|
text=True
|
|
)
|
|
|
|
if "vllm" in result.stdout.lower():
|
|
print("✅ VLLM process appears to be running")
|
|
|
|
# Extract some info
|
|
lines = result.stdout.split('\n')
|
|
for line in lines:
|
|
if 'vllm' in line.lower():
|
|
print(f" Process: {line[:80]}...")
|
|
else:
|
|
print("⚠️ VLLM process not detected")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Could not check VLLM process status: {e}")
|
|
|
|
# Check common VLLM ports
|
|
common_ports = [8000, 8001, 8002]
|
|
for port in common_ports:
|
|
try:
|
|
response = requests.get(f"http://127.0.0.1:{port}/health", timeout=2)
|
|
if response.status_code == 200:
|
|
print(f"✅ VLLM health check passed on port {port}")
|
|
except:
|
|
pass
|
|
|
|
def test_vllm_configuration():
|
|
"""Test VLLM configuration recommendations"""
|
|
print("\n⚙️ VLLM Configuration Check...")
|
|
|
|
config_checks = [
|
|
("Environment variable VLLM_ENDPOINT", os.getenv("VLLM_ENDPOINT")),
|
|
("Environment variable VLLM_API_KEY", "vllm-api-key" in str(os.getenv("VLLM_API_KEY", ""))),
|
|
("Network connectivity to localhost", "127.0.0.1"),
|
|
]
|
|
|
|
for check_name, check_result in config_checks:
|
|
if check_result:
|
|
print(f"✅ {check_name}: Available")
|
|
else:
|
|
print(f"⚠️ {check_name}: Not configured")
|
|
|
|
def main():
|
|
"""Main VLLM test runner"""
|
|
print("🤖 VLLM Integration Test Suite")
|
|
print("=" * 50)
|
|
|
|
# Service status checks
|
|
test_vllm_service_status()
|
|
test_vllm_configuration()
|
|
|
|
# Run unit tests
|
|
print("\n📋 Running VLLM Integration Tests...")
|
|
unittest.main(argv=[''], exit=False, verbosity=2)
|
|
|
|
print("\n" + "=" * 50)
|
|
print("✅ VLLM Integration Tests Complete!")
|
|
|
|
print("\n📊 Summary:")
|
|
print("- VLLM endpoint connectivity tested")
|
|
print("- Chat completion functionality verified")
|
|
print("- Conversation context management tested")
|
|
print("- Performance benchmarks conducted")
|
|
print("- Error handling validated")
|
|
|
|
print("\n🔧 VLLM Setup Status:")
|
|
print("- Endpoint: http://127.0.0.1:8000/v1")
|
|
print("- API Key: vllm-api-key")
|
|
print("- Model: default")
|
|
|
|
print("\n💡 Next Steps:")
|
|
print("1. Ensure VLLM service is running for full functionality")
|
|
print("2. Monitor response times for optimal user experience")
|
|
print("3. Consider model selection based on accuracy vs speed requirements")
|
|
|
|
if __name__ == "__main__":
|
|
main() |