chatserver/CS3.1.3.py
2025-06-14 03:25:05 +08:00

468 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import threading
import json
from flask import Flask, jsonify, request, send_from_directory
import sqlite3
import socket
import secrets
import time
import os
import logging
import re # 添加正则表达式模块
# 在文件顶部添加常量定义
AVATAR_BASE_DIR = "avatar" # 头像存储基础目录
DEFAULT_AVATAR = "default_avatar.png" # 默认头像文件名
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = Flask(__name__)
socket_server = socket.socket()
socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# 修改数据结构:使用用户名作为主键
active_connections = {} # {username: {'conn': conn, 'ip': ip, 'last_active': timestamp}}
chat_connections = [] # 所有活跃连接列表
tokens = {} # 令牌管理
def get_db_connection():
conn = sqlite3.connect("usr.db")
conn.row_factory = sqlite3.Row
return conn
def isuserxist(name):
cn = get_db_connection()
csr = cn.cursor()
csr.execute('SELECT * FROM users WHERE name = ?', (name,))
rst = csr.fetchone()
cn.close()
return rst is not None
def ispsswdright(name, passwd):
cn = get_db_connection()
csr = cn.cursor()
csr.execute("SELECT passwd FROM users WHERE name = ?", (name,))
result = csr.fetchone()
cn.close()
return result and result[0] == passwd
# 用户名验证函数
def validate_username(username):
"""验证用户名是否符合规则"""
# 长度在2-20个字符
if len(username) < 2 or len(username) > 20:
return False
# 只能包含特定字符字母、数字、_!@#$%^&+-~?
if not re.match(r'^[a-zA-Z0-9_!@#$%^&+\-~?]+$', username):
return False
# 不能以数字开头
if username[0].isdigit():
return False
return True
# 修改获取头像函数
def get_avatar(username):
cn = get_db_connection()
csr = cn.cursor()
csr.execute("SELECT avatar FROM users WHERE name = ?", (username,))
result = csr.fetchone()
cn.close()
if result:
avatar_file = result[0]
# 返回HTTP URL格式的头像路径
return f"/avatar/{username}/{avatar_file}"
# 返回默认头像URL
return f"/avatar/default/{DEFAULT_AVATAR}"
def register_user(usr, pwd, avatar="default_avatar.png"):
conn = get_db_connection()
try:
cursor = conn.cursor()
# 创建用户专属头像目录
user_avatar_dir = os.path.join(AVATAR_BASE_DIR, usr)
os.makedirs(user_avatar_dir, exist_ok=True)
cursor.execute("INSERT INTO users (name, passwd, avatar) VALUES (?, ?, ?)",
(usr, pwd, avatar))
conn.commit()
return {"type": "register_1", "error": False, "message": "User registered successfully"}
except sqlite3.IntegrityError:
return {"type": "register_0", "success": False, "message": "Username already exists"}
except sqlite3.Error as e:
return {"type": "register_-1", "error": True, "message": str(e)}
finally:
conn.close()
def generate_token(username):
token = secrets.token_hex(16)
tokens[token] = {'username': username, 'timestamp': time.time()}
return token
def validate_token(token):
if token in tokens:
if time.time() - tokens[token]['timestamp'] < 3600:
tokens[token]['timestamp'] = time.time()
return tokens[token]['username']
return None
@app.route("/api/register", methods=['POST'])
def register1():
vl = request.get_json()
usr = vl.get('username')
pwd = vl.get('password')
avatar = vl.get('avatar', 'default_avatar.png')
# 验证用户名
if not validate_username(usr):
return jsonify({
"type": "register_-2",
"success": False,
"message": "Invalid username. Must be 2-20 characters, start with a letter or symbol, and contain only letters, numbers, or symbols: _!@#$%^&+-~?"
}), 400
if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')):
return jsonify({
"type": "register_0",
"success": False,
"message": "Invalid avatar format. Only .png or .jpg allowed"
}), 400
result = register_user(usr, pwd, avatar)
if result.get('success', True): # 成功注册时返回的字典有'success'键
return jsonify(result)
else:
return jsonify(result), 403 if result['message'] == "Username already exists" else 500
@app.route("/api/login", methods=['POST'])
def login():
data = request.get_json()
username = data['username']
# 检查账号是否已登录
if username in active_connections:
# 检查连接是否仍然活跃
conn_info = active_connections[username]
try:
# 发送测试消息检查连接是否有效
conn_info['conn'].sendall(json.dumps({"type": "ping"}).encode('utf-8'))
logger.info(f"用户 {username} 的连接仍然活跃")
return jsonify({
"type": "login",
"status": "error_0",
"message": "Account already logged in"
}), 409
except:
# 连接已失效,清理旧连接
logger.warning(f"清理无效连接: {username}")
if username in active_connections:
del active_connections[username]
if conn_info['conn'] in chat_connections:
chat_connections.remove(conn_info['conn'])
if isuserxist(username) and ispsswdright(username, data['password']):
token = generate_token(username)
avatar = get_avatar(username)
return jsonify({
"type": "login_1",
"status": "success",
"token": token,
"avatar": avatar
})
return jsonify({
"type": "login_0",
"status": "error",
"message": "Invalid credentials"
}), 401
# 添加头像静态路由
@app.route("/avatar/<username>/<filename>", methods=['GET'])
def serve_avatar(username, filename):
try:
avatar_dir = os.path.join(AVATAR_BASE_DIR, username)
return send_from_directory(avatar_dir, filename)
except FileNotFoundError:
# 如果找不到头像,返回默认头像
return send_from_directory(AVATAR_BASE_DIR, DEFAULT_AVATAR)
@app.route("/api/chat", methods=['POST'])
def chat():
token = request.headers.get('Authorization')
username = validate_token(token)
if not username:
return jsonify({"type": "chat", "status": "error"}), 401
data = request.get_json()
message = {
"type": "chat",
"user": username,
"message": data['message'],
"avatar": get_avatar(username)
}
broadcast_message(message)
return jsonify({"type": "chat", "status": "success"})
def broadcast_message(message, sender=None):
for conn in chat_connections[:]: # 使用副本迭代
try:
conn.sendall(json.dumps(message).encode('utf-8'))
except:
# 连接异常时移除
for uname, info in list(active_connections.items()):
if info['conn'] == conn:
logger.warning(f"广播时移除无效连接: {uname}")
del active_connections[uname]
break
if conn in chat_connections:
chat_connections.remove(conn)
def handle_socket_message(data, addr, conn):
try:
action = data.get('type')
if action == 'register':
username = data.get('username')
password = data.get('password')
avatar = data.get('avatar', 'default_avatar.png')
# 验证用户名
if not validate_username(username):
response = {
"type": "register_-2",
"success": False,
"message": "Invalid username format"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')):
response = {
"type": "register_-3",
"success": False,
"message": "Invalid avatar format. Only .png or .jpg allowed"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
result = register_user(username, password, avatar)
conn.sendall(json.dumps(result).encode('utf-8'))
return result
elif action == 'login':
username = data['username']
password = data['password']
# 检查账号是否已登录且连接有效
if username in active_connections:
conn_info = active_connections[username]
try:
# 测试连接是否仍然有效
conn_info['conn'].sendall(json.dumps({"type": "ping"}).encode('utf-8'))
logger.info(f"用户 {username} 尝试登录但已有活跃连接")
response = {
"type": "login",
"status": "error_-1",
"message": "Account already logged in"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
except:
# 连接已失效,清理旧连接
logger.warning(f"清理无效连接后允许登录: {username}")
if username in active_connections:
del active_connections[username]
if conn_info['conn'] in chat_connections:
chat_connections.remove(conn_info['conn'])
if isuserxist(username) and ispsswdright(username, password):
# 添加新连接
active_connections[username] = {'conn': conn, 'ip': addr[0], 'last_active': time.time()}
if conn not in chat_connections:
chat_connections.append(conn)
token = generate_token(username)
avatar = get_avatar(username)
response = {
"type": "login",
"status": "success",
"message": "Login successful",
"token": token,
"username": username,
"avatar": avatar
}
conn.sendall(json.dumps(response).encode('utf-8'))
logger.info(f"用户 {username} 登录成功")
return response
else:
response = {
"type": "login",
"status": "error_0",
"message": "Invalid credentials"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
elif action == 'chat':
token = data.get('token')
if not token:
response = {
"type": "chat",
"status": "error_Mt",
"message": "Missing token"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
username = validate_token(token)
if not username:
response = {
"type": "chat",
"status": "error_It",
"message": "Invalid token"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
if username not in active_connections:
response = {
"type": "chat",
"status": "error_Nli",
"message": "Not logged in"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
# 更新最后活跃时间
active_connections[username]['last_active'] = time.time()
message = {
"type": "chat",
"user": username,
"message": data['message'],
"avatar": get_avatar(username)
}
broadcast_message(message)
response = {"type": "chat", "status": "success"}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
elif action == 'heartbeat':
# 心跳检测
token = data.get('token')
if token:
username = validate_token(token)
if username and username in active_connections:
# 更新最后活跃时间
active_connections[username]['last_active'] = time.time()
response = {"type": "heartbeat", "status": "success"}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
return {"type": "heartbeat", "status": "error"}
except Exception as e:
logger.error(f"处理消息时出错: {str(e)}")
response = {
"status": "error",
"message": str(e)
}
try:
conn.sendall(json.dumps(response).encode('utf-8'))
except:
pass
return response
def check_inactive_connections():
"""定期检查不活跃的连接并清理"""
while True:
time.sleep(60) # 每分钟检查一次
current_time = time.time()
inactive_users = []
for username, info in list(active_connections.items()):
# 5分钟无活动视为不活跃
if current_time - info['last_active'] > 300:
logger.warning(f"检测到不活跃用户: {username}, 最后活跃: {current_time - info['last_active']}秒前")
inactive_users.append(username)
for username in inactive_users:
info = active_connections[username]
try:
info['conn'].close()
except:
pass
if username in active_connections:
del active_connections[username]
if info['conn'] in chat_connections:
chat_connections.remove(info['conn'])
logger.info(f"已清理不活跃用户: {username}")
def run_socket_server():
socket_server.bind(("0.0.0.0", 8889))
socket_server.listen()
logger.info("Socket server running on port 8889")
# 启动连接检查线程
threading.Thread(target=check_inactive_connections, daemon=True).start()
while True:
conn, addr = socket_server.accept()
logger.info(f"Socket client connected: {addr}")
try:
while True:
data = conn.recv(1024)
if not data:
break
try:
decoded_data = data.decode('utf-8', errors='ignore')
json_data = json.loads(decoded_data)
response = handle_socket_message(json_data, addr, conn)
except json.JSONDecodeError:
response = {
"type": "error",
"status": "error",
"message": "Invalid JSON"
}
conn.sendall(json.dumps(response).encode('utf-8'))
except Exception as e:
logger.error(f"处理数据时出错: {str(e)}")
response = {
"type": "error",
"status": "error",
"message": str(e)
}
conn.sendall(json.dumps(response).encode('utf-8'))
except (ConnectionResetError, BrokenPipeError):
logger.warning(f"Client {addr} disconnected abruptly")
finally:
# 清理断开的连接
for username, info in list(active_connections.items()):
if info['conn'] == conn:
del active_connections[username]
logger.info(f"用户 {username} 断开连接")
break
if conn in chat_connections:
chat_connections.remove(conn)
try:
conn.close()
except:
pass
logger.info(f"Connection closed for {addr}")
if __name__ == '__main__':
with get_db_connection() as conn:
conn.execute('''CREATE TABLE IF NOT EXISTS users
(name TEXT PRIMARY KEY,
passwd TEXT,
avatar TEXT DEFAULT 'default_avatar.png')''')
# 确保头像目录存在
os.makedirs(AVATAR_BASE_DIR, exist_ok=True)
threading.Thread(target=run_socket_server, daemon=True).start()
app.run(port=5001, host='0.0.0.0')