468 lines
17 KiB
Python
468 lines
17 KiB
Python
import threading
|
||
import json
|
||
from flask import Flask, jsonify, request, send_from_directory
|
||
import sqlite3
|
||
import socket
|
||
import secrets
|
||
import time
|
||
import os
|
||
import logging
|
||
import re # 添加正则表达式模块
|
||
|
||
# 在文件顶部添加常量定义
|
||
AVATAR_BASE_DIR = "avatar" # 头像存储基础目录
|
||
DEFAULT_AVATAR = "default_avatar.png" # 默认头像文件名
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
app = Flask(__name__)
|
||
socket_server = socket.socket()
|
||
socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||
|
||
# 修改数据结构:使用用户名作为主键
|
||
active_connections = {} # {username: {'conn': conn, 'ip': ip, 'last_active': timestamp}}
|
||
chat_connections = [] # 所有活跃连接列表
|
||
tokens = {} # 令牌管理
|
||
|
||
def get_db_connection():
|
||
conn = sqlite3.connect("usr.db")
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
def isuserxist(name):
|
||
cn = get_db_connection()
|
||
csr = cn.cursor()
|
||
csr.execute('SELECT * FROM users WHERE name = ?', (name,))
|
||
rst = csr.fetchone()
|
||
cn.close()
|
||
return rst is not None
|
||
|
||
def ispsswdright(name, passwd):
|
||
cn = get_db_connection()
|
||
csr = cn.cursor()
|
||
csr.execute("SELECT passwd FROM users WHERE name = ?", (name,))
|
||
result = csr.fetchone()
|
||
cn.close()
|
||
return result and result[0] == passwd
|
||
|
||
# 用户名验证函数
|
||
def validate_username(username):
|
||
"""验证用户名是否符合规则"""
|
||
# 长度在2-20个字符
|
||
if len(username) < 2 or len(username) > 20:
|
||
return False
|
||
|
||
# 只能包含特定字符:字母、数字、_!@#$%^&+-~?
|
||
if not re.match(r'^[a-zA-Z0-9_!@#$%^&+\-~?]+$', username):
|
||
return False
|
||
|
||
# 不能以数字开头
|
||
if username[0].isdigit():
|
||
return False
|
||
|
||
return True
|
||
|
||
# 修改获取头像函数
|
||
def get_avatar(username):
|
||
cn = get_db_connection()
|
||
csr = cn.cursor()
|
||
csr.execute("SELECT avatar FROM users WHERE name = ?", (username,))
|
||
result = csr.fetchone()
|
||
cn.close()
|
||
if result:
|
||
avatar_file = result[0]
|
||
# 返回HTTP URL格式的头像路径
|
||
return f"/avatar/{username}/{avatar_file}"
|
||
|
||
# 返回默认头像URL
|
||
return f"/avatar/default/{DEFAULT_AVATAR}"
|
||
|
||
def register_user(usr, pwd, avatar="default_avatar.png"):
|
||
conn = get_db_connection()
|
||
try:
|
||
cursor = conn.cursor()
|
||
# 创建用户专属头像目录
|
||
user_avatar_dir = os.path.join(AVATAR_BASE_DIR, usr)
|
||
os.makedirs(user_avatar_dir, exist_ok=True)
|
||
cursor.execute("INSERT INTO users (name, passwd, avatar) VALUES (?, ?, ?)",
|
||
(usr, pwd, avatar))
|
||
conn.commit()
|
||
return {"type": "register_1", "error": False, "message": "User registered successfully"}
|
||
except sqlite3.IntegrityError:
|
||
return {"type": "register_0", "success": False, "message": "Username already exists"}
|
||
except sqlite3.Error as e:
|
||
return {"type": "register_-1", "error": True, "message": str(e)}
|
||
finally:
|
||
conn.close()
|
||
|
||
def generate_token(username):
|
||
token = secrets.token_hex(16)
|
||
tokens[token] = {'username': username, 'timestamp': time.time()}
|
||
return token
|
||
|
||
def validate_token(token):
|
||
if token in tokens:
|
||
if time.time() - tokens[token]['timestamp'] < 3600:
|
||
tokens[token]['timestamp'] = time.time()
|
||
return tokens[token]['username']
|
||
return None
|
||
|
||
@app.route("/api/register", methods=['POST'])
|
||
def register1():
|
||
vl = request.get_json()
|
||
usr = vl.get('username')
|
||
pwd = vl.get('password')
|
||
avatar = vl.get('avatar', 'default_avatar.png')
|
||
|
||
# 验证用户名
|
||
if not validate_username(usr):
|
||
return jsonify({
|
||
"type": "register_-2",
|
||
"success": False,
|
||
"message": "Invalid username. Must be 2-20 characters, start with a letter or symbol, and contain only letters, numbers, or symbols: _!@#$%^&+-~?"
|
||
}), 400
|
||
|
||
if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')):
|
||
return jsonify({
|
||
"type": "register_0",
|
||
"success": False,
|
||
"message": "Invalid avatar format. Only .png or .jpg allowed"
|
||
}), 400
|
||
|
||
result = register_user(usr, pwd, avatar)
|
||
if result.get('success', True): # 成功注册时返回的字典有'success'键
|
||
return jsonify(result)
|
||
else:
|
||
return jsonify(result), 403 if result['message'] == "Username already exists" else 500
|
||
|
||
@app.route("/api/login", methods=['POST'])
|
||
def login():
|
||
data = request.get_json()
|
||
username = data['username']
|
||
|
||
# 检查账号是否已登录
|
||
if username in active_connections:
|
||
# 检查连接是否仍然活跃
|
||
conn_info = active_connections[username]
|
||
try:
|
||
# 发送测试消息检查连接是否有效
|
||
conn_info['conn'].sendall(json.dumps({"type": "ping"}).encode('utf-8'))
|
||
logger.info(f"用户 {username} 的连接仍然活跃")
|
||
return jsonify({
|
||
"type": "login",
|
||
"status": "error_0",
|
||
"message": "Account already logged in"
|
||
}), 409
|
||
except:
|
||
# 连接已失效,清理旧连接
|
||
logger.warning(f"清理无效连接: {username}")
|
||
if username in active_connections:
|
||
del active_connections[username]
|
||
if conn_info['conn'] in chat_connections:
|
||
chat_connections.remove(conn_info['conn'])
|
||
|
||
if isuserxist(username) and ispsswdright(username, data['password']):
|
||
token = generate_token(username)
|
||
avatar = get_avatar(username)
|
||
return jsonify({
|
||
"type": "login_1",
|
||
"status": "success",
|
||
"token": token,
|
||
"avatar": avatar
|
||
})
|
||
return jsonify({
|
||
"type": "login_0",
|
||
"status": "error",
|
||
"message": "Invalid credentials"
|
||
}), 401
|
||
|
||
# 添加头像静态路由
|
||
@app.route("/avatar/<username>/<filename>", methods=['GET'])
|
||
def serve_avatar(username, filename):
|
||
try:
|
||
avatar_dir = os.path.join(AVATAR_BASE_DIR, username)
|
||
return send_from_directory(avatar_dir, filename)
|
||
except FileNotFoundError:
|
||
# 如果找不到头像,返回默认头像
|
||
return send_from_directory(AVATAR_BASE_DIR, DEFAULT_AVATAR)
|
||
|
||
@app.route("/api/chat", methods=['POST'])
|
||
def chat():
|
||
token = request.headers.get('Authorization')
|
||
username = validate_token(token)
|
||
if not username:
|
||
return jsonify({"type": "chat", "status": "error"}), 401
|
||
|
||
data = request.get_json()
|
||
message = {
|
||
"type": "chat",
|
||
"user": username,
|
||
"message": data['message'],
|
||
"avatar": get_avatar(username)
|
||
}
|
||
broadcast_message(message)
|
||
return jsonify({"type": "chat", "status": "success"})
|
||
|
||
def broadcast_message(message, sender=None):
|
||
for conn in chat_connections[:]: # 使用副本迭代
|
||
try:
|
||
conn.sendall(json.dumps(message).encode('utf-8'))
|
||
except:
|
||
# 连接异常时移除
|
||
for uname, info in list(active_connections.items()):
|
||
if info['conn'] == conn:
|
||
logger.warning(f"广播时移除无效连接: {uname}")
|
||
del active_connections[uname]
|
||
break
|
||
if conn in chat_connections:
|
||
chat_connections.remove(conn)
|
||
|
||
def handle_socket_message(data, addr, conn):
|
||
try:
|
||
action = data.get('type')
|
||
if action == 'register':
|
||
username = data.get('username')
|
||
password = data.get('password')
|
||
avatar = data.get('avatar', 'default_avatar.png')
|
||
|
||
# 验证用户名
|
||
if not validate_username(username):
|
||
response = {
|
||
"type": "register_-2",
|
||
"success": False,
|
||
"message": "Invalid username format"
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
return response
|
||
|
||
if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')):
|
||
response = {
|
||
"type": "register_-3",
|
||
"success": False,
|
||
"message": "Invalid avatar format. Only .png or .jpg allowed"
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
return response
|
||
|
||
result = register_user(username, password, avatar)
|
||
conn.sendall(json.dumps(result).encode('utf-8'))
|
||
return result
|
||
|
||
elif action == 'login':
|
||
username = data['username']
|
||
password = data['password']
|
||
|
||
# 检查账号是否已登录且连接有效
|
||
if username in active_connections:
|
||
conn_info = active_connections[username]
|
||
try:
|
||
# 测试连接是否仍然有效
|
||
conn_info['conn'].sendall(json.dumps({"type": "ping"}).encode('utf-8'))
|
||
logger.info(f"用户 {username} 尝试登录但已有活跃连接")
|
||
response = {
|
||
"type": "login",
|
||
"status": "error_-1",
|
||
"message": "Account already logged in"
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
return response
|
||
except:
|
||
# 连接已失效,清理旧连接
|
||
logger.warning(f"清理无效连接后允许登录: {username}")
|
||
if username in active_connections:
|
||
del active_connections[username]
|
||
if conn_info['conn'] in chat_connections:
|
||
chat_connections.remove(conn_info['conn'])
|
||
|
||
if isuserxist(username) and ispsswdright(username, password):
|
||
# 添加新连接
|
||
active_connections[username] = {'conn': conn, 'ip': addr[0], 'last_active': time.time()}
|
||
if conn not in chat_connections:
|
||
chat_connections.append(conn)
|
||
|
||
token = generate_token(username)
|
||
avatar = get_avatar(username)
|
||
response = {
|
||
"type": "login",
|
||
"status": "success",
|
||
"message": "Login successful",
|
||
"token": token,
|
||
"username": username,
|
||
"avatar": avatar
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
logger.info(f"用户 {username} 登录成功")
|
||
return response
|
||
else:
|
||
response = {
|
||
"type": "login",
|
||
"status": "error_0",
|
||
"message": "Invalid credentials"
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
return response
|
||
|
||
elif action == 'chat':
|
||
token = data.get('token')
|
||
if not token:
|
||
response = {
|
||
"type": "chat",
|
||
"status": "error_Mt",
|
||
"message": "Missing token"
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
return response
|
||
|
||
username = validate_token(token)
|
||
if not username:
|
||
response = {
|
||
"type": "chat",
|
||
"status": "error_It",
|
||
"message": "Invalid token"
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
return response
|
||
|
||
if username not in active_connections:
|
||
response = {
|
||
"type": "chat",
|
||
"status": "error_Nli",
|
||
"message": "Not logged in"
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
return response
|
||
|
||
# 更新最后活跃时间
|
||
active_connections[username]['last_active'] = time.time()
|
||
|
||
message = {
|
||
"type": "chat",
|
||
"user": username,
|
||
"message": data['message'],
|
||
"avatar": get_avatar(username)
|
||
}
|
||
broadcast_message(message)
|
||
response = {"type": "chat", "status": "success"}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
return response
|
||
|
||
elif action == 'heartbeat':
|
||
# 心跳检测
|
||
token = data.get('token')
|
||
if token:
|
||
username = validate_token(token)
|
||
if username and username in active_connections:
|
||
# 更新最后活跃时间
|
||
active_connections[username]['last_active'] = time.time()
|
||
response = {"type": "heartbeat", "status": "success"}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
return response
|
||
return {"type": "heartbeat", "status": "error"}
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理消息时出错: {str(e)}")
|
||
response = {
|
||
"status": "error",
|
||
"message": str(e)
|
||
}
|
||
try:
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
except:
|
||
pass
|
||
return response
|
||
|
||
def check_inactive_connections():
|
||
"""定期检查不活跃的连接并清理"""
|
||
while True:
|
||
time.sleep(60) # 每分钟检查一次
|
||
current_time = time.time()
|
||
inactive_users = []
|
||
|
||
for username, info in list(active_connections.items()):
|
||
# 5分钟无活动视为不活跃
|
||
if current_time - info['last_active'] > 300:
|
||
logger.warning(f"检测到不活跃用户: {username}, 最后活跃: {current_time - info['last_active']}秒前")
|
||
inactive_users.append(username)
|
||
|
||
for username in inactive_users:
|
||
info = active_connections[username]
|
||
try:
|
||
info['conn'].close()
|
||
except:
|
||
pass
|
||
if username in active_connections:
|
||
del active_connections[username]
|
||
if info['conn'] in chat_connections:
|
||
chat_connections.remove(info['conn'])
|
||
logger.info(f"已清理不活跃用户: {username}")
|
||
|
||
def run_socket_server():
|
||
socket_server.bind(("0.0.0.0", 8889))
|
||
socket_server.listen()
|
||
logger.info("Socket server running on port 8889")
|
||
|
||
# 启动连接检查线程
|
||
threading.Thread(target=check_inactive_connections, daemon=True).start()
|
||
|
||
while True:
|
||
conn, addr = socket_server.accept()
|
||
logger.info(f"Socket client connected: {addr}")
|
||
try:
|
||
while True:
|
||
data = conn.recv(1024)
|
||
if not data:
|
||
break
|
||
|
||
try:
|
||
decoded_data = data.decode('utf-8', errors='ignore')
|
||
json_data = json.loads(decoded_data)
|
||
response = handle_socket_message(json_data, addr, conn)
|
||
except json.JSONDecodeError:
|
||
response = {
|
||
"type": "error",
|
||
"status": "error",
|
||
"message": "Invalid JSON"
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
except Exception as e:
|
||
logger.error(f"处理数据时出错: {str(e)}")
|
||
response = {
|
||
"type": "error",
|
||
"status": "error",
|
||
"message": str(e)
|
||
}
|
||
conn.sendall(json.dumps(response).encode('utf-8'))
|
||
except (ConnectionResetError, BrokenPipeError):
|
||
logger.warning(f"Client {addr} disconnected abruptly")
|
||
finally:
|
||
# 清理断开的连接
|
||
for username, info in list(active_connections.items()):
|
||
if info['conn'] == conn:
|
||
del active_connections[username]
|
||
logger.info(f"用户 {username} 断开连接")
|
||
break
|
||
|
||
if conn in chat_connections:
|
||
chat_connections.remove(conn)
|
||
|
||
try:
|
||
conn.close()
|
||
except:
|
||
pass
|
||
logger.info(f"Connection closed for {addr}")
|
||
|
||
if __name__ == '__main__':
|
||
with get_db_connection() as conn:
|
||
conn.execute('''CREATE TABLE IF NOT EXISTS users
|
||
(name TEXT PRIMARY KEY,
|
||
passwd TEXT,
|
||
avatar TEXT DEFAULT 'default_avatar.png')''')
|
||
# 确保头像目录存在
|
||
os.makedirs(AVATAR_BASE_DIR, exist_ok=True)
|
||
threading.Thread(target=run_socket_server, daemon=True).start()
|
||
app.run(port=5001, host='0.0.0.0') |