import socket import threading import json import sqlite3 from datetime import datetime import os class ChatServer: def __init__(self, host='0.0.0.0', port=5555): self.host = host self.port = port self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server.bind((host, port)) self.server.listen() self.clients = {} self.groups = {'General': set()} self.setup_database() def setup_database(self): # 确保数据库文件存在并初始化 db_exists = os.path.exists('chat_server.db') self.db = sqlite3.connect('chat_server.db', check_same_thread=False) cursor = self.db.cursor() # 创建用户表 cursor.execute(''' CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password TEXT NOT NULL ) ''') # 创建消息表 cursor.execute(''' CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, sender TEXT NOT NULL, receiver TEXT NOT NULL, message TEXT NOT NULL, timestamp DATETIME NOT NULL, is_group INTEGER NOT NULL ) ''') # 如果是首次创建数据库,添加一个默认管理员用户 if not db_exists: try: cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)', ('admin', 'admin123')) print("Created default admin user: admin/admin123") except sqlite3.IntegrityError: pass self.db.commit() def authenticate_user(self, username, password): cursor = self.db.cursor() cursor.execute('SELECT * FROM users WHERE username=? AND password=?', (username, password)) return cursor.fetchone() is not None def register_user(self, username, password): if not username or not password: return False try: cursor = self.db.cursor() cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)', (username, password)) self.db.commit() return True except sqlite3.IntegrityError: return False def save_message(self, sender, receiver, message, is_group): cursor = self.db.cursor() cursor.execute(''' INSERT INTO messages (sender, receiver, message, timestamp, is_group) VALUES (?, ?, ?, ?, ?) ''', (sender, receiver, message, datetime.now(), 1 if is_group else 0)) self.db.commit() def send_user_list(self): user_list = list(self.clients.keys()) for client in self.clients.values(): try: client.send(json.dumps({ 'type': 'user_list', 'users': user_list }).encode('utf-8')) except: pass def send_group_list(self, username): group_list = [group for group, members in self.groups.items() if username in members] if username in self.clients: try: self.clients[username].send(json.dumps({ 'type': 'group_list', 'groups': group_list }).encode('utf-8')) except: pass def send_group_list_to_all(self): for username in self.clients: self.send_group_list(username) def send_initial_messages(self, username): cursor = self.db.cursor() cursor.execute(''' SELECT sender, receiver, message, timestamp, is_group FROM messages WHERE receiver=? OR sender=? OR (is_group=1 AND receiver IN ( SELECT name FROM sqlite_master WHERE type='table' AND name='groups' )) ORDER BY timestamp LIMIT 100 ''', (username, username)) messages = [] for sender, receiver, message, timestamp, is_group in cursor.fetchall(): messages.append({ 'sender': sender, 'receiver': receiver, 'message': message, 'timestamp': timestamp, 'is_group': is_group }) if username in self.clients: try: self.clients[username].send(json.dumps({ 'type': 'initial_messages', 'messages': messages }).encode('utf-8')) except: pass def user_exists(self, username): cursor = self.db.cursor() cursor.execute('SELECT 1 FROM users WHERE username=?', (username,)) return cursor.fetchone() is not None def handle_client(self, client, address): username = None print(f"New connection from {address}") try: while True: try: message = client.recv(1024).decode('utf-8') if not message: break data = json.loads(message) print(f"Received from {address}: {data}") if data['type'] == 'login': username = data['username'] password = data['password'] if self.authenticate_user(username, password): self.clients[username] = client self.groups['General'].add(username) client.send(json.dumps({'type': 'login_success'}).encode('utf-8')) self.send_user_list() self.send_group_list(username) self.send_initial_messages(username) print(f"User {username} logged in successfully") else: client.send(json.dumps({'type': 'login_fail', 'reason': 'Invalid username or password'}).encode('utf-8')) print(f"Login failed for {username}") elif data['type'] == 'register': username = data['username'] password = data['password'] if len(username) < 3 or len(password) < 3: client.send(json.dumps({'type': 'register_fail', 'reason': 'Username and password must be at least 3 characters'}).encode('utf-8')) elif self.register_user(username, password): client.send(json.dumps({'type': 'register_success'}).encode('utf-8')) print(f"New user registered: {username}") else: client.send(json.dumps({'type': 'register_fail', 'reason': 'Username already taken'}).encode('utf-8')) elif data['type'] == 'message': sender = data['sender'] receiver = data['receiver'] msg = data['message'] is_group = data['is_group'] self.save_message(sender, receiver, msg, is_group) if is_group: for member in self.groups.get(receiver, set()): if member in self.clients and member != sender: try: self.clients[member].send(json.dumps({ 'type': 'message', 'sender': sender, 'receiver': receiver, 'message': msg, 'is_group': True }).encode('utf-8')) except: pass else: if receiver in self.clients: try: self.clients[receiver].send(json.dumps({ 'type': 'message', 'sender': sender, 'receiver': receiver, 'message': msg, 'is_group': False }).encode('utf-8')) except: pass elif data['type'] == 'create_group': group_name = data['group_name'] if group_name not in self.groups: self.groups[group_name] = set() for user in data['members']: if user in self.clients or self.user_exists(user): self.groups[group_name].add(user) self.send_group_list_to_all() client.send(json.dumps({'type': 'group_created', 'group': group_name}).encode('utf-8')) else: client.send(json.dumps({'type': 'group_exists', 'group': group_name}).encode('utf-8')) except json.JSONDecodeError: print(f"Invalid JSON received from {address}") break except ConnectionResetError: print(f"Connection reset by {address}") break except Exception as e: print(f"Error handling client {address}: {str(e)}") break except Exception as e: print(f"Client handler error for {address}: {str(e)}") finally: if username and username in self.clients: print(f"User {username} disconnected") del self.clients[username] for group in self.groups.values(): if username in group: group.remove(username) self.send_user_list() try: client.close() except: pass def start(self): print(f"Server started on {self.host}:{self.port}") print("Database initialized at chat_server.db") try: while True: client, address = self.server.accept() thread = threading.Thread(target=self.handle_client, args=(client, address)) thread.daemon = True thread.start() except KeyboardInterrupt: print("\nServer is shutting down...") finally: for client in self.clients.values(): try: client.close() except: pass self.server.close() self.db.close() print("Server stopped") if __name__ == "__main__": server = ChatServer() server.start()