From cf611c11c3a6f909dc611e92dae093c6611c8783 Mon Sep 17 00:00:00 2001 From: DZY Date: Fri, 13 Jun 2025 20:19:55 +0800 Subject: [PATCH] =?UTF-8?q?=E6=88=91=E4=B8=8D=E5=88=B0=E5=95=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CS3.1.py | 266 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 188 insertions(+), 78 deletions(-) diff --git a/CS3.1.py b/CS3.1.py index 26d82d7..3849d7d 100644 --- a/CS3.1.py +++ b/CS3.1.py @@ -3,16 +3,18 @@ import json from flask import Flask, jsonify, request import sqlite3 import socket -import base64 import secrets import time +import os + app = Flask(__name__) socket_server = socket.socket() socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) -active_users = {} -chat_connections = [] -tokens = {} +# 修改数据结构:使用用户名作为主键 +active_connections = {} # {username: {'conn': conn, 'ip': ip}} +chat_connections = [] # 所有活跃连接列表 +tokens = {} # 令牌管理 def get_db_connection(): conn = sqlite3.connect("usr.db") @@ -25,42 +27,39 @@ def isuserxist(name): csr.execute('SELECT * FROM users WHERE name = ?', (name,)) rst = csr.fetchone() cn.close() - if rst is not None: - return True - else: - return False + return rst is not None -def ispsswdright(name,passwd): +def ispsswdright(name, passwd): cn = get_db_connection() csr = cn.cursor() - csr.execute("SELECT COUNT(*) FROM users WHERE name=?", (name,)) - row_count = csr.fetchone()[0] - password = None - if row_count > 0: - csr.execute("SELECT passwd FROM users WHERE name=?", (name,)) - password = csr.fetchone()[0] - if password == passwd: - return True - else: - return False + csr.execute("SELECT passwd FROM users WHERE name = ?", (name,)) + result = csr.fetchone() + cn.close() + return result and result[0] == passwd -def register_user(usr, pwd): +def get_avatar(username): + cn = get_db_connection() + csr = cn.cursor() + csr.execute("SELECT avatar FROM users WHERE name = ?", (username,)) + result = csr.fetchone() + cn.close() + return result[0] if result else "default_avatar.png" + +def register_user(usr, pwd, avatar="default_avatar.png"): conn = get_db_connection() - csr2 = conn.cursor() - csr2.execute('SELECT * FROM users WHERE name = ?', (usr,)) - result = csr2.fetchone() - if result is not None: + try: + cursor = conn.cursor() + # 添加avatar字段 + cursor.execute("INSERT INTO users (name, passwd, avatar) VALUES (?, ?, ?)", + (usr, pwd, avatar)) + conn.commit() + return {"type": "register_1", "success": True, "message": "User registered successfully"} + except sqlite3.IntegrityError: return {"type": "register_0", "success": False, "message": "Username already exists"} - else: - try: - cursor = conn.cursor() - cursor.execute("INSERT INTO users (name, passwd) VALUES (?, ?)", (usr, pwd)) - conn.commit() - return {"type": "register_1", "success": True, "message": "User registered successfully"} - except sqlite3.Error as e: - return {"type": "register_0", "success": False, "message": str(e)} - finally: - conn.close() + except sqlite3.Error as e: + return {"type": "register_0", "success": False, "message": str(e)} + finally: + conn.close() def generate_token(username): token = secrets.token_hex(16) @@ -79,7 +78,17 @@ def register1(): vl = request.get_json() usr = vl.get('username') pwd = vl.get('password') - result = register_user(usr, pwd) + avatar = vl.get('avatar', 'default_avatar.png') # 获取头像 + + # 头像验证 + if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')): + return jsonify({ + "type": "register_0", + "success": False, + "message": "Invalid avatar format. Only .png or .jpg allowed" + }), 400 + + result = register_user(usr, pwd, avatar) if result['success']: return jsonify(result) else: @@ -88,11 +97,35 @@ def register1(): @app.route("/api/login", methods=['POST']) def login(): data = request.get_json() - if isuserxist(data['username']): - if ispsswdright(data['username'], data['password']): - token = generate_token(data['username']) - return jsonify({"type": "login_1", "status": "success", "token": token}) - return jsonify({"type": "login_0", "status": "error"}), 401 + username = data['username'] + + # 检查账号是否已登录 + if username in active_connections: + return jsonify({ + "type": "login_0", + "status": "error", + "message": "Account already logged in" + }), 409 # 冲突状态码 + + if isuserxist(username) and ispsswdright(username, data['password']): + token = generate_token(username) + avatar = get_avatar(username) + return jsonify({ + "type": "login_1", + "status": "success", + "token": token, + "avatar": avatar # 返回头像 + }) + return jsonify({ + "type": "login_0", + "status": "error", + "message": "Invalid credentials" + }), 401 + +@app.route("/api/avatar/", methods=['GET']) +def get_user_avatar(username): + avatar = get_avatar(username) + return jsonify({"username": username, "avatar": avatar}) @app.route("/api/chat", methods=['POST']) def chat(): @@ -100,11 +133,13 @@ def chat(): username = validate_token(token) if not username: return jsonify({"type": "chat", "status": "error"}), 401 + data = request.get_json() message = { "type": "chat", "user": username, - "message": data['message'] + "message": data['message'], + "avatar": get_avatar(username) # 添加头像信息 } broadcast_message(message) return jsonify({"type": "chat", "status": "success"}) @@ -114,37 +149,100 @@ def broadcast_message(message, sender=None): try: conn.sendall(json.dumps(message).encode()) except: - chat_connections.remove(conn) + # 连接异常时移除 + for uname, info in list(active_connections.items()): + if info['conn'] == conn: + del active_connections[uname] + break + if conn in chat_connections: + chat_connections.remove(conn) def handle_socket_message(data, addr, conn): try: action = data.get('type') if action == 'register': - result = register_user(data.get('username'), data.get('password')) - if result['success']: - return {"type": "register_1","status": "success", "message": result['message']} - else: - return {"type": "register_0","status": "error", "message": result['message']} + avatar = data.get('avatar', 'default_avatar.png') + result = register_user(data.get('username'), data.get('password'), avatar) + conn.sendall(json.dumps(result).encode()) + return result elif action == 'login': - if isuserxist(data['username']): - if ispsswdright(data['username'], data['password']): - active_users[addr[0]] = data['username'] - tk = base64.b64encode(data['username'].encode('utf-8')) - chat_connections.append(conn) - return {"type": "login_1", "status": "success", "message": "Login successful", "token": generate_token(data['username']), "username": data['username']} - return {"type": "login_0", "status": "error", "message": "Invalid credentials", "username": data['username']} - elif action == 'chat': - if addr[0] in active_users: - message = { - "type": "chat", - "user": active_users[addr[0]], - "message": data['message'] + username = data['username'] + password = data['password'] + + # 检查账号是否已登录 + if username in active_connections: + response = { + "type": "login_0", + "status": "error", + "message": "Account already logged in" } - broadcast_message(message) - #return {"type": "chat", "status": "success"} - else : - return {"type": "chat", "status": "error", "message": "Not logged in"} + conn.sendall(json.dumps(response).encode()) + return response + + if isuserxist(username) and ispsswdright(username, password): + # 移除旧连接(如果存在) + if username in active_connections: + old_conn = active_connections[username]['conn'] + if old_conn in chat_connections: + chat_connections.remove(old_conn) + del active_connections[username] + + # 添加新连接 + active_connections[username] = {'conn': conn, 'ip': addr[0]} + if conn not in chat_connections: + chat_connections.append(conn) + + token = generate_token(username) + avatar = get_avatar(username) + response = { + "type": "login_1", + "status": "success", + "message": "Login successful", + "token": token, + "username": username, + "avatar": avatar # 返回头像 + } + conn.sendall(json.dumps(response).encode()) + return response + else: + response = { + "type": "login_0", + "status": "error", + "message": "Invalid credentials" + } + conn.sendall(json.dumps(response).encode()) + return response + + elif action == 'chat': + username = validate_token(data.get('token', '')) + if not username: + response = { + "type": "chat", + "status": "error", + "message": "Not authenticated" + } + conn.sendall(json.dumps(response).encode()) + return response + + if username not in active_connections: + response = { + "type": "chat", + "status": "error", + "message": "Not logged in" + } + conn.sendall(json.dumps(response).encode()) + return response + + message = { + "type": "chat", + "user": username, + "message": data['message'], + "avatar": get_avatar(username) # 添加头像 + } + broadcast_message(message) + return {"type": "chat", "status": "success"} + except Exception as e: return {"status": "error", "message": str(e)} @@ -159,31 +257,43 @@ def run_socket_server(): while True: data = conn.recv(1024) if not data: - if addr[0] in active_users: - del active_users[addr[0]] break try: json_data = json.loads(data.decode()) response = handle_socket_message(json_data, addr, conn) - if response != None: - conn.sendall(json.dumps(response).encode()) except json.JSONDecodeError: - conn.sendall(json.dumps( - {"type": "register_0", "status": "error", "message": "Invalid JSON"} - ).encode()) - except ConnectionResetError: - if addr[0] in active_users: - del active_users[addr[0]] - print(f"Client {addr} disconnected") + response = { + "type": "error", + "status": "error", + "message": "Invalid JSON" + } + conn.sendall(json.dumps(response).encode()) + except (ConnectionResetError, BrokenPipeError): + print(f"Client {addr} disconnected abruptly") finally: + # 清理断开的连接 + for username, info in list(active_connections.items()): + if info['conn'] == conn: + del active_connections[username] + print(f"User {username} disconnected") + break + if conn in chat_connections: chat_connections.remove(conn) - conn.close() + + try: + conn.close() + except: + pass + print(f"Connection closed for {addr}") if __name__ == '__main__': with get_db_connection() as conn: + # 添加avatar字段 conn.execute('''CREATE TABLE IF NOT EXISTS users - (name TEXT, passwd TEXT)''') - threading.Thread(target=run_socket_server).start() + (name TEXT PRIMARY KEY, + passwd TEXT, + avatar TEXT)''') + threading.Thread(target=run_socket_server, daemon=True).start() app.run(port=5001) \ No newline at end of file