From 62865e0993cd4d4def23dc2a60b1cb7d132f0013 Mon Sep 17 00:00:00 2001 From: DZY Date: Sat, 31 May 2025 19:42:19 +0800 Subject: [PATCH] =?UTF-8?q?=E7=99=BB=E5=BD=95=E6=B3=A8=E5=86=8C=E4=B8=8D?= =?UTF-8?q?=E4=BA=86=EF=BC=8C=E6=9C=89bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chatserver3.0.py | 278 ++++++++++++++++++++--------------------------- 1 file changed, 116 insertions(+), 162 deletions(-) diff --git a/chatserver3.0.py b/chatserver3.0.py index 0874238..c803df3 100644 --- a/chatserver3.0.py +++ b/chatserver3.0.py @@ -4,6 +4,7 @@ import json import sqlite3 from datetime import datetime import os +import traceback class ChatServer: def __init__(self, host='0.0.0.0', port=5555): @@ -12,264 +13,217 @@ class ChatServer: self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server.bind((host, port)) - self.server.listen() + self.server.listen(5) self.clients = {} self.groups = {'General': set()} self.setup_database() def setup_database(self): - # 确保数据库文件存在并初始化 - db_exists = os.path.exists('chat_server.db') self.db = sqlite3.connect('chat_server.db', check_same_thread=False) + self.db.execute("PRAGMA foreign_keys = ON") cursor = self.db.cursor() - # 创建用户表 cursor.execute(''' CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT UNIQUE NOT NULL, - password TEXT NOT NULL + username TEXT UNIQUE NOT NULL CHECK(length(username) >= 3), + password TEXT NOT NULL CHECK(length(password) >= 3), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') - # 创建消息表 cursor.execute(''' CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, sender TEXT NOT NULL, receiver TEXT NOT NULL, message TEXT NOT NULL, - timestamp DATETIME NOT NULL, - is_group INTEGER NOT NULL + timestamp DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + is_group INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(sender) REFERENCES users(username) ON DELETE CASCADE ) ''') - # 如果是首次创建数据库,添加一个默认管理员用户 - if not db_exists: + cursor.execute("SELECT COUNT(*) FROM users") + if cursor.fetchone()[0] == 0: try: cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)', ('admin', 'admin123')) - print("Created default admin user: admin/admin123") + print("[DB] 创建默认管理员账号: admin/admin123") except sqlite3.IntegrityError: pass self.db.commit() + print("[DB] 数据库初始化完成") def authenticate_user(self, username, password): - cursor = self.db.cursor() - cursor.execute('SELECT * FROM users WHERE username=? AND password=?', (username, password)) - return cursor.fetchone() is not None + try: + cursor = self.db.cursor() + cursor.execute(''' + SELECT 1 FROM users + WHERE username=? AND password=? + LIMIT 1 + ''', (username, password)) + return cursor.fetchone() is not None + except sqlite3.Error as e: + print(f"[DB] 认证错误: {e}") + return False def register_user(self, username, password): - if not username or not password: - return False + if len(username) < 3 or len(password) < 3: + return False, "用户名和密码至少需要3个字符" try: cursor = self.db.cursor() - cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)', (username, password)) + cursor.execute(''' + INSERT INTO users (username, password) + VALUES (?, ?) + ''', (username, password)) self.db.commit() - return True + return True, "注册成功" except sqlite3.IntegrityError: - return False - - def save_message(self, sender, receiver, message, is_group): - cursor = self.db.cursor() - cursor.execute(''' - INSERT INTO messages (sender, receiver, message, timestamp, is_group) - VALUES (?, ?, ?, ?, ?) - ''', (sender, receiver, message, datetime.now(), 1 if is_group else 0)) - self.db.commit() + return False, "用户名已存在" + except sqlite3.Error as e: + print(f"[DB] 注册错误: {e}") + return False, "数据库错误" def send_user_list(self): + """发送当前在线用户列表给所有客户端""" user_list = list(self.clients.keys()) - for client in self.clients.values(): + for username, client in list(self.clients.items()): try: client.send(json.dumps({ 'type': 'user_list', 'users': user_list }).encode('utf-8')) except: - pass + self.remove_client(username) def send_group_list(self, username): - group_list = [group for group, members in self.groups.items() if username in members] + """发送用户所在的群组列表""" if username in self.clients: + group_list = [group for group, members in self.groups.items() if username in members] try: self.clients[username].send(json.dumps({ 'type': 'group_list', 'groups': group_list }).encode('utf-8')) except: - pass - - def send_group_list_to_all(self): - for username in self.clients: - self.send_group_list(username) - - def send_initial_messages(self, username): - cursor = self.db.cursor() - cursor.execute(''' - SELECT sender, receiver, message, timestamp, is_group FROM messages - WHERE receiver=? OR sender=? OR (is_group=1 AND receiver IN ( - SELECT name FROM sqlite_master WHERE type='table' AND name='groups' - )) - ORDER BY timestamp - LIMIT 100 - ''', (username, username)) - - messages = [] - for sender, receiver, message, timestamp, is_group in cursor.fetchall(): - messages.append({ - 'sender': sender, - 'receiver': receiver, - 'message': message, - 'timestamp': timestamp, - 'is_group': is_group - }) + self.remove_client(username) + def remove_client(self, username): + """安全移除客户端""" if username in self.clients: try: - self.clients[username].send(json.dumps({ - 'type': 'initial_messages', - 'messages': messages - }).encode('utf-8')) + self.clients[username].close() except: pass - - def user_exists(self, username): - cursor = self.db.cursor() - cursor.execute('SELECT 1 FROM users WHERE username=?', (username,)) - return cursor.fetchone() is not None + del self.clients[username] + for group in self.groups.values(): + group.discard(username) + self.send_user_list() def handle_client(self, client, address): + print(f"[连接] 新客户端: {address}") username = None - print(f"New connection from {address}") try: while True: try: - message = client.recv(1024).decode('utf-8') - if not message: + raw_data = client.recv(4096) + if not raw_data: + print(f"[连接] 客户端 {address} 断开连接") break - data = json.loads(message) - print(f"Received from {address}: {data}") + try: + data = json.loads(raw_data.decode('utf-8')) + except json.JSONDecodeError: + print(f"[错误] 无效JSON数据来自 {address}") + client.send(json.dumps({ + 'type': 'error', + 'message': 'Invalid JSON format' + }).encode('utf-8')) + continue + + print(f"[请求] {address}: {data.get('type')}") + + if data['type'] == 'register': + username = data.get('username', '').strip() + password = data.get('password', '').strip() + + success, message = self.register_user(username, password) + response = { + 'type': 'register_response', + 'success': success, + 'message': message + } + client.send(json.dumps(response).encode('utf-8')) + + if success: + print(f"[注册] 新用户: {username}") + else: + print(f"[注册] 失败: {username} - {message}") + + elif data['type'] == 'login': + username = data.get('username', '').strip() + password = data.get('password', '').strip() - if data['type'] == 'login': - username = data['username'] - password = data['password'] if self.authenticate_user(username, password): self.clients[username] = client self.groups['General'].add(username) - client.send(json.dumps({'type': 'login_success'}).encode('utf-8')) + + client.send(json.dumps({ + 'type': 'login_response', + 'success': True, + 'message': '登录成功' + }).encode('utf-8')) + self.send_user_list() self.send_group_list(username) - self.send_initial_messages(username) - print(f"User {username} logged in successfully") + print(f"[登录] 成功: {username}") else: - client.send(json.dumps({'type': 'login_fail', 'reason': 'Invalid username or password'}).encode('utf-8')) - print(f"Login failed for {username}") + client.send(json.dumps({ + 'type': 'login_response', + 'success': False, + 'message': '用户名或密码错误' + }).encode('utf-8')) + print(f"[登录] 失败: {username}") - elif data['type'] == 'register': - username = data['username'] - password = data['password'] - if len(username) < 3 or len(password) < 3: - client.send(json.dumps({'type': 'register_fail', 'reason': 'Username and password must be at least 3 characters'}).encode('utf-8')) - elif self.register_user(username, password): - client.send(json.dumps({'type': 'register_success'}).encode('utf-8')) - print(f"New user registered: {username}") - else: - client.send(json.dumps({'type': 'register_fail', 'reason': 'Username already taken'}).encode('utf-8')) - - elif data['type'] == 'message': - sender = data['sender'] - receiver = data['receiver'] - msg = data['message'] - is_group = data['is_group'] - self.save_message(sender, receiver, msg, is_group) - if is_group: - for member in self.groups.get(receiver, set()): - if member in self.clients and member != sender: - try: - self.clients[member].send(json.dumps({ - 'type': 'message', - 'sender': sender, - 'receiver': receiver, - 'message': msg, - 'is_group': True - }).encode('utf-8')) - except: - pass - else: - if receiver in self.clients: - try: - self.clients[receiver].send(json.dumps({ - 'type': 'message', - 'sender': sender, - 'receiver': receiver, - 'message': msg, - 'is_group': False - }).encode('utf-8')) - except: - pass - - elif data['type'] == 'create_group': - group_name = data['group_name'] - if group_name not in self.groups: - self.groups[group_name] = set() - for user in data['members']: - if user in self.clients or self.user_exists(user): - self.groups[group_name].add(user) - self.send_group_list_to_all() - client.send(json.dumps({'type': 'group_created', 'group': group_name}).encode('utf-8')) - else: - client.send(json.dumps({'type': 'group_exists', 'group': group_name}).encode('utf-8')) - - except json.JSONDecodeError: - print(f"Invalid JSON received from {address}") - break except ConnectionResetError: - print(f"Connection reset by {address}") + print(f"[连接] 客户端 {address} 强制断开") break except Exception as e: - print(f"Error handling client {address}: {str(e)}") + print(f"[错误] 处理客户端 {address}: {str(e)}") + traceback.print_exc() break except Exception as e: - print(f"Client handler error for {address}: {str(e)}") + print(f"[错误] 客户端处理异常: {str(e)}") + traceback.print_exc() finally: - if username and username in self.clients: - print(f"User {username} disconnected") - del self.clients[username] - for group in self.groups.values(): - if username in group: - group.remove(username) - self.send_user_list() - try: - client.close() - except: - pass + if username: + self.remove_client(username) + client.close() def start(self): - print(f"Server started on {self.host}:{self.port}") - print("Database initialized at chat_server.db") + print(f"[启动] 服务器运行在 {self.host}:{self.port}") try: while True: client, address = self.server.accept() - thread = threading.Thread(target=self.handle_client, args=(client, address)) - thread.daemon = True + thread = threading.Thread( + target=self.handle_client, + args=(client, address), + daemon=True + ) thread.start() except KeyboardInterrupt: - print("\nServer is shutting down...") + print("\n[关闭] 服务器正在停止...") finally: - for client in self.clients.values(): - try: - client.close() - except: - pass + for username in list(self.clients.keys()): + self.remove_client(username) self.server.close() self.db.close() - print("Server stopped") + print("[关闭] 服务器已停止") if __name__ == "__main__": server = ChatServer()