diff --git a/CS3.1.3.py b/CS3.1.3.py new file mode 100644 index 0000000..b7a8d4e --- /dev/null +++ b/CS3.1.3.py @@ -0,0 +1,468 @@ +import threading +import json +from flask import Flask, jsonify, request, send_from_directory +import sqlite3 +import socket +import secrets +import time +import os +import logging +import re # 添加正则表达式模块 + +# 在文件顶部添加常量定义 +AVATAR_BASE_DIR = "avatar" # 头像存储基础目录 +DEFAULT_AVATAR = "default_avatar.png" # 默认头像文件名 + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +app = Flask(__name__) +socket_server = socket.socket() +socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + +# 修改数据结构:使用用户名作为主键 +active_connections = {} # {username: {'conn': conn, 'ip': ip, 'last_active': timestamp}} +chat_connections = [] # 所有活跃连接列表 +tokens = {} # 令牌管理 + +def get_db_connection(): + conn = sqlite3.connect("usr.db") + conn.row_factory = sqlite3.Row + return conn + +def isuserxist(name): + cn = get_db_connection() + csr = cn.cursor() + csr.execute('SELECT * FROM users WHERE name = ?', (name,)) + rst = csr.fetchone() + cn.close() + return rst is not None + +def ispsswdright(name, passwd): + cn = get_db_connection() + csr = cn.cursor() + csr.execute("SELECT passwd FROM users WHERE name = ?", (name,)) + result = csr.fetchone() + cn.close() + return result and result[0] == passwd + +# 用户名验证函数 +def validate_username(username): + """验证用户名是否符合规则""" + # 长度在2-20个字符 + if len(username) < 2 or len(username) > 20: + return False + + # 只能包含特定字符:字母、数字、_!@#$%^&+-~? + if not re.match(r'^[a-zA-Z0-9_!@#$%^&+\-~?]+$', username): + return False + + # 不能以数字开头 + if username[0].isdigit(): + return False + + return True + +# 修改获取头像函数 +def get_avatar(username): + cn = get_db_connection() + csr = cn.cursor() + csr.execute("SELECT avatar FROM users WHERE name = ?", (username,)) + result = csr.fetchone() + cn.close() + if result: + avatar_file = result[0] + # 返回HTTP URL格式的头像路径 + return f"/avatar/{username}/{avatar_file}" + + # 返回默认头像URL + return f"/avatar/default/{DEFAULT_AVATAR}" + +def register_user(usr, pwd, avatar="default_avatar.png"): + conn = get_db_connection() + try: + cursor = conn.cursor() + # 创建用户专属头像目录 + user_avatar_dir = os.path.join(AVATAR_BASE_DIR, usr) + os.makedirs(user_avatar_dir, exist_ok=True) + cursor.execute("INSERT INTO users (name, passwd, avatar) VALUES (?, ?, ?)", + (usr, pwd, avatar)) + conn.commit() + return {"type": "register_1", "error": False, "message": "User registered successfully"} + except sqlite3.IntegrityError: + return {"type": "register_0", "success": False, "message": "Username already exists"} + except sqlite3.Error as e: + return {"type": "register_-1", "error": True, "message": str(e)} + finally: + conn.close() + +def generate_token(username): + token = secrets.token_hex(16) + tokens[token] = {'username': username, 'timestamp': time.time()} + return token + +def validate_token(token): + if token in tokens: + if time.time() - tokens[token]['timestamp'] < 3600: + tokens[token]['timestamp'] = time.time() + return tokens[token]['username'] + return None + +@app.route("/api/register", methods=['POST']) +def register1(): + vl = request.get_json() + usr = vl.get('username') + pwd = vl.get('password') + avatar = vl.get('avatar', 'default_avatar.png') + + # 验证用户名 + if not validate_username(usr): + return jsonify({ + "type": "register_-2", + "success": False, + "message": "Invalid username. Must be 2-20 characters, start with a letter or symbol, and contain only letters, numbers, or symbols: _!@#$%^&+-~?" + }), 400 + + if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')): + return jsonify({ + "type": "register_0", + "success": False, + "message": "Invalid avatar format. Only .png or .jpg allowed" + }), 400 + + result = register_user(usr, pwd, avatar) + if result.get('success', True): # 成功注册时返回的字典有'success'键 + return jsonify(result) + else: + return jsonify(result), 403 if result['message'] == "Username already exists" else 500 + +@app.route("/api/login", methods=['POST']) +def login(): + data = request.get_json() + username = data['username'] + + # 检查账号是否已登录 + if username in active_connections: + # 检查连接是否仍然活跃 + conn_info = active_connections[username] + try: + # 发送测试消息检查连接是否有效 + conn_info['conn'].sendall(json.dumps({"type": "ping"}).encode('utf-8')) + logger.info(f"用户 {username} 的连接仍然活跃") + return jsonify({ + "type": "login", + "status": "error_0", + "message": "Account already logged in" + }), 409 + except: + # 连接已失效,清理旧连接 + logger.warning(f"清理无效连接: {username}") + if username in active_connections: + del active_connections[username] + if conn_info['conn'] in chat_connections: + chat_connections.remove(conn_info['conn']) + + if isuserxist(username) and ispsswdright(username, data['password']): + token = generate_token(username) + avatar = get_avatar(username) + return jsonify({ + "type": "login_1", + "status": "success", + "token": token, + "avatar": avatar + }) + return jsonify({ + "type": "login_0", + "status": "error", + "message": "Invalid credentials" + }), 401 + +# 添加头像静态路由 +@app.route("/avatar//", methods=['GET']) +def serve_avatar(username, filename): + try: + avatar_dir = os.path.join(AVATAR_BASE_DIR, username) + return send_from_directory(avatar_dir, filename) + except FileNotFoundError: + # 如果找不到头像,返回默认头像 + return send_from_directory(AVATAR_BASE_DIR, DEFAULT_AVATAR) + +@app.route("/api/chat", methods=['POST']) +def chat(): + token = request.headers.get('Authorization') + username = validate_token(token) + if not username: + return jsonify({"type": "chat", "status": "error"}), 401 + + data = request.get_json() + message = { + "type": "chat", + "user": username, + "message": data['message'], + "avatar": get_avatar(username) + } + broadcast_message(message) + return jsonify({"type": "chat", "status": "success"}) + +def broadcast_message(message, sender=None): + for conn in chat_connections[:]: # 使用副本迭代 + try: + conn.sendall(json.dumps(message).encode('utf-8')) + except: + # 连接异常时移除 + for uname, info in list(active_connections.items()): + if info['conn'] == conn: + logger.warning(f"广播时移除无效连接: {uname}") + del active_connections[uname] + break + if conn in chat_connections: + chat_connections.remove(conn) + +def handle_socket_message(data, addr, conn): + try: + action = data.get('type') + if action == 'register': + username = data.get('username') + password = data.get('password') + avatar = data.get('avatar', 'default_avatar.png') + + # 验证用户名 + if not validate_username(username): + response = { + "type": "register_-2", + "success": False, + "message": "Invalid username format" + } + conn.sendall(json.dumps(response).encode('utf-8')) + return response + + if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')): + response = { + "type": "register_-3", + "success": False, + "message": "Invalid avatar format. Only .png or .jpg allowed" + } + conn.sendall(json.dumps(response).encode('utf-8')) + return response + + result = register_user(username, password, avatar) + conn.sendall(json.dumps(result).encode('utf-8')) + return result + + elif action == 'login': + username = data['username'] + password = data['password'] + + # 检查账号是否已登录且连接有效 + if username in active_connections: + conn_info = active_connections[username] + try: + # 测试连接是否仍然有效 + conn_info['conn'].sendall(json.dumps({"type": "ping"}).encode('utf-8')) + logger.info(f"用户 {username} 尝试登录但已有活跃连接") + response = { + "type": "login", + "status": "error_-1", + "message": "Account already logged in" + } + conn.sendall(json.dumps(response).encode('utf-8')) + return response + except: + # 连接已失效,清理旧连接 + logger.warning(f"清理无效连接后允许登录: {username}") + if username in active_connections: + del active_connections[username] + if conn_info['conn'] in chat_connections: + chat_connections.remove(conn_info['conn']) + + if isuserxist(username) and ispsswdright(username, password): + # 添加新连接 + active_connections[username] = {'conn': conn, 'ip': addr[0], 'last_active': time.time()} + if conn not in chat_connections: + chat_connections.append(conn) + + token = generate_token(username) + avatar = get_avatar(username) + response = { + "type": "login", + "status": "success", + "message": "Login successful", + "token": token, + "username": username, + "avatar": avatar + } + conn.sendall(json.dumps(response).encode('utf-8')) + logger.info(f"用户 {username} 登录成功") + return response + else: + response = { + "type": "login", + "status": "error_0", + "message": "Invalid credentials" + } + conn.sendall(json.dumps(response).encode('utf-8')) + return response + + elif action == 'chat': + token = data.get('token') + if not token: + response = { + "type": "chat", + "status": "error_Mt", + "message": "Missing token" + } + conn.sendall(json.dumps(response).encode('utf-8')) + return response + + username = validate_token(token) + if not username: + response = { + "type": "chat", + "status": "error_It", + "message": "Invalid token" + } + conn.sendall(json.dumps(response).encode('utf-8')) + return response + + if username not in active_connections: + response = { + "type": "chat", + "status": "error_Nli", + "message": "Not logged in" + } + conn.sendall(json.dumps(response).encode('utf-8')) + return response + + # 更新最后活跃时间 + active_connections[username]['last_active'] = time.time() + + message = { + "type": "chat", + "user": username, + "message": data['message'], + "avatar": get_avatar(username) + } + broadcast_message(message) + response = {"type": "chat", "status": "success"} + conn.sendall(json.dumps(response).encode('utf-8')) + return response + + elif action == 'heartbeat': + # 心跳检测 + token = data.get('token') + if token: + username = validate_token(token) + if username and username in active_connections: + # 更新最后活跃时间 + active_connections[username]['last_active'] = time.time() + response = {"type": "heartbeat", "status": "success"} + conn.sendall(json.dumps(response).encode('utf-8')) + return response + return {"type": "heartbeat", "status": "error"} + + except Exception as e: + logger.error(f"处理消息时出错: {str(e)}") + response = { + "status": "error", + "message": str(e) + } + try: + conn.sendall(json.dumps(response).encode('utf-8')) + except: + pass + return response + +def check_inactive_connections(): + """定期检查不活跃的连接并清理""" + while True: + time.sleep(60) # 每分钟检查一次 + current_time = time.time() + inactive_users = [] + + for username, info in list(active_connections.items()): + # 5分钟无活动视为不活跃 + if current_time - info['last_active'] > 300: + logger.warning(f"检测到不活跃用户: {username}, 最后活跃: {current_time - info['last_active']}秒前") + inactive_users.append(username) + + for username in inactive_users: + info = active_connections[username] + try: + info['conn'].close() + except: + pass + if username in active_connections: + del active_connections[username] + if info['conn'] in chat_connections: + chat_connections.remove(info['conn']) + logger.info(f"已清理不活跃用户: {username}") + +def run_socket_server(): + socket_server.bind(("0.0.0.0", 8889)) + socket_server.listen() + logger.info("Socket server running on port 8889") + + # 启动连接检查线程 + threading.Thread(target=check_inactive_connections, daemon=True).start() + + while True: + conn, addr = socket_server.accept() + logger.info(f"Socket client connected: {addr}") + try: + while True: + data = conn.recv(1024) + if not data: + break + + try: + decoded_data = data.decode('utf-8', errors='ignore') + json_data = json.loads(decoded_data) + response = handle_socket_message(json_data, addr, conn) + except json.JSONDecodeError: + response = { + "type": "error", + "status": "error", + "message": "Invalid JSON" + } + conn.sendall(json.dumps(response).encode('utf-8')) + except Exception as e: + logger.error(f"处理数据时出错: {str(e)}") + response = { + "type": "error", + "status": "error", + "message": str(e) + } + conn.sendall(json.dumps(response).encode('utf-8')) + except (ConnectionResetError, BrokenPipeError): + logger.warning(f"Client {addr} disconnected abruptly") + finally: + # 清理断开的连接 + for username, info in list(active_connections.items()): + if info['conn'] == conn: + del active_connections[username] + logger.info(f"用户 {username} 断开连接") + break + + if conn in chat_connections: + chat_connections.remove(conn) + + try: + conn.close() + except: + pass + logger.info(f"Connection closed for {addr}") + +if __name__ == '__main__': + with get_db_connection() as conn: + conn.execute('''CREATE TABLE IF NOT EXISTS users + (name TEXT PRIMARY KEY, + passwd TEXT, + avatar TEXT DEFAULT 'default_avatar.png')''') + # 确保头像目录存在 + os.makedirs(AVATAR_BASE_DIR, exist_ok=True) + threading.Thread(target=run_socket_server, daemon=True).start() + app.run(port=5001, host='0.0.0.0') \ No newline at end of file