import threading import json from flask import Flask, jsonify, request import sqlite3 import socket import secrets import time import os app = Flask(__name__) socket_server = socket.socket() socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # 修改数据结构:使用用户名作为主键 active_connections = {} # {username: {'conn': conn, 'ip': ip}} chat_connections = [] # 所有活跃连接列表 tokens = {} # 令牌管理 def get_db_connection(): conn = sqlite3.connect("usr.db") conn.row_factory = sqlite3.Row return conn def isuserxist(name): cn = get_db_connection() csr = cn.cursor() csr.execute('SELECT * FROM users WHERE name = ?', (name,)) rst = csr.fetchone() cn.close() return rst is not None def ispsswdright(name, passwd): cn = get_db_connection() csr = cn.cursor() csr.execute("SELECT passwd FROM users WHERE name = ?", (name,)) result = csr.fetchone() cn.close() return result and result[0] == passwd def get_avatar(username): cn = get_db_connection() csr = cn.cursor() csr.execute("SELECT avatar FROM users WHERE name = ?", (username,)) result = csr.fetchone() cn.close() return result[0] if result else "default_avatar.png" def register_user(usr, pwd, avatar="default_avatar.png"): conn = get_db_connection() try: cursor = conn.cursor() # 添加avatar字段 cursor.execute("INSERT INTO users (name, passwd, avatar) VALUES (?, ?, ?)", (usr, pwd, avatar)) conn.commit() return {"type": "register_1", "success": True, "message": "User registered successfully"} except sqlite3.IntegrityError: return {"type": "register_0", "success": False, "message": "Username already exists"} except sqlite3.Error as e: return {"type": "register_0", "success": False, "message": str(e)} finally: conn.close() def generate_token(username): token = secrets.token_hex(16) tokens[token] = {'username': username, 'timestamp': time.time()} return token def validate_token(token): if token in tokens: if time.time() - tokens[token]['timestamp'] < 3600: tokens[token]['timestamp'] = time.time() return tokens[token]['username'] return None @app.route("/api/register", methods=['POST']) def register1(): vl = request.get_json() usr = vl.get('username') pwd = vl.get('password') avatar = vl.get('avatar', 'default_avatar.png') # 获取头像 # 头像验证 if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')): return jsonify({ "type": "register_0", "success": False, "message": "Invalid avatar format. Only .png or .jpg allowed" }), 400 result = register_user(usr, pwd, avatar) if result['success']: return jsonify(result) else: return jsonify(result), 403 if result['message'] == "Username already exists" else 500 @app.route("/api/login", methods=['POST']) def login(): data = request.get_json() username = data['username'] # 检查账号是否已登录 if username in active_connections: return jsonify({ "type": "login_0", "status": "error", "message": "Account already logged in" }), 409 # 冲突状态码 if isuserxist(username) and ispsswdright(username, data['password']): token = generate_token(username) avatar = get_avatar(username) return jsonify({ "type": "login_1", "status": "success", "token": token, "avatar": avatar # 返回头像 }) return jsonify({ "type": "login_0", "status": "error", "message": "Invalid credentials" }), 401 @app.route("/api/avatar/", methods=['GET']) def get_user_avatar(username): avatar = get_avatar(username) return jsonify({"username": username, "avatar": avatar}) @app.route("/api/chat", methods=['POST']) def chat(): token = request.headers.get('Authorization') username = validate_token(token) if not username: return jsonify({"type": "chat", "status": "error"}), 401 data = request.get_json() message = { "type": "chat", "user": username, "message": data['message'], "avatar": get_avatar(username) # 添加头像信息 } broadcast_message(message) return jsonify({"type": "chat", "status": "success"}) def broadcast_message(message, sender=None): for conn in chat_connections: try: conn.sendall(json.dumps(message).encode()) except: # 连接异常时移除 for uname, info in list(active_connections.items()): if info['conn'] == conn: del active_connections[uname] break if conn in chat_connections: chat_connections.remove(conn) def handle_socket_message(data, addr, conn): try: action = data.get('type') if action == 'register': avatar = data.get('avatar', 'default_avatar.png') result = register_user(data.get('username'), data.get('password'), avatar) conn.sendall(json.dumps(result).encode()) return result elif action == 'login': username = data['username'] password = data['password'] # 检查账号是否已登录 if username in active_connections: response = { "type": "login_0", "status": "error", "message": "Account already logged in" } conn.sendall(json.dumps(response).encode()) return response if isuserxist(username) and ispsswdright(username, password): # 移除旧连接(如果存在) if username in active_connections: old_conn = active_connections[username]['conn'] if old_conn in chat_connections: chat_connections.remove(old_conn) del active_connections[username] # 添加新连接 active_connections[username] = {'conn': conn, 'ip': addr[0]} if conn not in chat_connections: chat_connections.append(conn) token = generate_token(username) avatar = get_avatar(username) response = { "type": "login_1", "status": "success", "message": "Login successful", "token": token, "username": username, "avatar": avatar # 返回头像 } conn.sendall(json.dumps(response).encode()) return response else: response = { "type": "login_0", "status": "error", "message": "Invalid credentials" } conn.sendall(json.dumps(response).encode()) return response elif action == 'chat': username = validate_token(data.get('token', '')) if not username: response = { "type": "chat", "status": "error", "message": "Not authenticated" } conn.sendall(json.dumps(response).encode()) return response if username not in active_connections: response = { "type": "chat", "status": "error", "message": "Not logged in" } conn.sendall(json.dumps(response).encode()) return response message = { "type": "chat", "user": username, "message": data['message'], "avatar": get_avatar(username) # 添加头像 } broadcast_message(message) return {"type": "chat", "status": "success"} except Exception as e: return {"status": "error", "message": str(e)} def run_socket_server(): socket_server.bind(("localhost", 8889)) socket_server.listen() print("Socket server running on port 8889") while True: conn, addr = socket_server.accept() print(f"Socket client connected: {addr}") try: while True: data = conn.recv(1024) if not data: break try: json_data = json.loads(data.decode()) response = handle_socket_message(json_data, addr, conn) except json.JSONDecodeError: response = { "type": "error", "status": "error", "message": "Invalid JSON" } conn.sendall(json.dumps(response).encode()) except (ConnectionResetError, BrokenPipeError): print(f"Client {addr} disconnected abruptly") finally: # 清理断开的连接 for username, info in list(active_connections.items()): if info['conn'] == conn: del active_connections[username] print(f"User {username} disconnected") break if conn in chat_connections: chat_connections.remove(conn) try: conn.close() except: pass print(f"Connection closed for {addr}") if __name__ == '__main__': with get_db_connection() as conn: # 添加avatar字段 conn.execute('''CREATE TABLE IF NOT EXISTS users (name TEXT PRIMARY KEY, passwd TEXT, avatar TEXT)''') threading.Thread(target=run_socket_server, daemon=True).start() app.run(port=5001)