import asyncio
import websockets
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import sys
# Define available models
models = {
"0.5B": {"model_name": "Qwen/Qwen2.5-Coder-0.5B-Instruct"},
"1.5B": {"model_name": "Qwen/Qwen2.5-Coder-1.5B-Instruct"},
"3B": {"model_name": "Qwen/Qwen2.5-Coder-3B-Instruct"}
}
# Default model size
default_model_size = "1.5B"
# Set cache directory
cache_dir = "d:/.cache/huggingface/hub"
# Load model function
def load_model(size):
model_info = models[size]
model = AutoModelForCausalLM.from_pretrained(model_info["model_name"], cache_dir=cache_dir).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_info["model_name"], cache_dir=cache_dir)
return model, tokenizer
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")
model, tokenizer = load_model(default_model_size)
# WebSocket handler
async def handler(websocket):
global model, tokenizer, default_model_size
try:
async for message in websocket:
if message.startswith("qwen use"):
size = message.split()[-1]
if size in models:
default_model_size = size
model, tokenizer = load_model(size)
response = f"Model switched to {size} parameters. Using {device}"
else:
response = "Unknown model size."
await websocket.send(response)
elif message == "qwen shutdown":
response = "Server is shutting down..."
await websocket.send(response)
await websocket.close()
await asyncio.sleep(1) # Allow time for messages to be sent
sys.exit(0) # Terminate the Python process
else:
text = tokenizer.apply_chat_template(
[{"role": "user", "content": message}],
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = model.generate(**model_inputs, max_new_tokens=1024)
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
response = generated_texts[0]
await websocket.send(response)
except websockets.ConnectionClosed as e:
print(f"Connection closed: {e}")
# Main function to start the server
async def main():
async with websockets.serve(handler, "localhost", 8765):
await asyncio.Future() # Run forever
if __name__ == "__main__":
asyncio.run(main())