import socket import threading import json import sqlite3 from datetime import datetime import os import traceback class ChatServer: def __init__(self, host='0.0.0.0', port=5555): self.host = host self.port = port self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server.bind((host, port)) self.server.listen(5) self.clients = {} self.groups = {'General': set()} self.setup_database() def setup_database(self): self.db = sqlite3.connect('chat_server.db', check_same_thread=False) self.db.execute("PRAGMA foreign_keys = ON") cursor = self.db.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL CHECK(length(username) >= 3), password TEXT NOT NULL CHECK(length(password) >= 3), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, sender TEXT NOT NULL, receiver TEXT NOT NULL, message TEXT NOT NULL, timestamp DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, is_group INTEGER NOT NULL DEFAULT 0, FOREIGN KEY(sender) REFERENCES users(username) ON DELETE CASCADE ) ''') cursor.execute("SELECT COUNT(*) FROM users") if cursor.fetchone()[0] == 0: try: cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)', ('admin', 'admin123')) print("[DB] 创建默认管理员账号: admin/admin123") except sqlite3.IntegrityError: pass self.db.commit() print("[DB] 数据库初始化完成") def authenticate_user(self, username, password): try: cursor = self.db.cursor() cursor.execute(''' SELECT 1 FROM users WHERE username=? AND password=? LIMIT 1 ''', (username, password)) return cursor.fetchone() is not None except sqlite3.Error as e: print(f"[DB] 认证错误: {e}") return False def register_user(self, username, password): if len(username) < 3 or len(password) < 3: return False, "用户名和密码至少需要3个字符" try: cursor = self.db.cursor() cursor.execute(''' INSERT INTO users (username, password) VALUES (?, ?) ''', (username, password)) self.db.commit() return True, "注册成功" except sqlite3.IntegrityError: return False, "用户名已存在" except sqlite3.Error as e: print(f"[DB] 注册错误: {e}") return False, "数据库错误" def send_user_list(self): """发送当前在线用户列表给所有客户端""" user_list = list(self.clients.keys()) for username, client in list(self.clients.items()): try: client.send(json.dumps({ 'type': 'user_list', 'users': user_list }).encode('utf-8')) except: self.remove_client(username) def send_group_list(self, username): """发送用户所在的群组列表""" if username in self.clients: group_list = [group for group, members in self.groups.items() if username in members] try: self.clients[username].send(json.dumps({ 'type': 'group_list', 'groups': group_list }).encode('utf-8')) except: self.remove_client(username) def remove_client(self, username): """安全移除客户端""" if username in self.clients: try: self.clients[username].close() except: pass del self.clients[username] for group in self.groups.values(): group.discard(username) self.send_user_list() def handle_client(self, client, address): print(f"[连接] 新客户端: {address}") username = None try: while True: try: raw_data = client.recv(4096) if not raw_data: print(f"[连接] 客户端 {address} 断开连接") break try: data = json.loads(raw_data.decode('utf-8')) except json.JSONDecodeError: print(f"[错误] 无效JSON数据来自 {address}") client.send(json.dumps({ 'type': 'error', 'message': 'Invalid JSON format' }).encode('utf-8')) continue print(f"[请求] {address}: {data.get('type')}") if data['type'] == 'register': username = data.get('username', '').strip() password = data.get('password', '').strip() success, message = self.register_user(username, password) response = { 'type': 'register_response', 'success': success, 'message': message } client.send(json.dumps(response).encode('utf-8')) if success: print(f"[注册] 新用户: {username}") else: print(f"[注册] 失败: {username} - {message}") elif data['type'] == 'login': username = data.get('username', '').strip() password = data.get('password', '').strip() if self.authenticate_user(username, password): self.clients[username] = client self.groups['General'].add(username) client.send(json.dumps({ 'type': 'login_response', 'success': True, 'message': '登录成功' }).encode('utf-8')) self.send_user_list() self.send_group_list(username) print(f"[登录] 成功: {username}") else: client.send(json.dumps({ 'type': 'login_response', 'success': False, 'message': '用户名或密码错误' }).encode('utf-8')) print(f"[登录] 失败: {username}") except ConnectionResetError: print(f"[连接] 客户端 {address} 强制断开") break except Exception as e: print(f"[错误] 处理客户端 {address}: {str(e)}") traceback.print_exc() break except Exception as e: print(f"[错误] 客户端处理异常: {str(e)}") traceback.print_exc() finally: if username: self.remove_client(username) client.close() def start(self): print(f"[启动] 服务器运行在 {self.host}:{self.port}") try: while True: client, address = self.server.accept() thread = threading.Thread( target=self.handle_client, args=(client, address), daemon=True ) thread.start() except KeyboardInterrupt: print("\n[关闭] 服务器正在停止...") finally: for username in list(self.clients.keys()): self.remove_client(username) self.server.close() self.db.close() print("[关闭] 服务器已停止") if __name__ == "__main__": server = ChatServer() server.start()