API/CS3.1.py
2025-06-13 20:32:24 +08:00

326 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import threading
import json
from flask import Flask, jsonify, request
import sqlite3
import socket
import secrets
import time
import os
app = Flask(__name__)
socket_server = socket.socket()
socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# 修改数据结构:使用用户名作为主键
active_connections = {} # {username: {'conn': conn, 'ip': ip}}
chat_connections = [] # 所有活跃连接列表
tokens = {} # 令牌管理
def get_db_connection():
conn = sqlite3.connect("usr.db")
conn.row_factory = sqlite3.Row
return conn
def isuserxist(name):
cn = get_db_connection()
csr = cn.cursor()
csr.execute('SELECT * FROM users WHERE name = ?', (name,))
rst = csr.fetchone()
cn.close()
return rst is not None
def ispsswdright(name, passwd):
cn = get_db_connection()
csr = cn.cursor()
csr.execute("SELECT passwd FROM users WHERE name = ?", (name,))
result = csr.fetchone()
cn.close()
return result and result[0] == passwd
def get_avatar(username):
cn = get_db_connection()
csr = cn.cursor()
csr.execute("SELECT avatar FROM users WHERE name = ?", (username,))
result = csr.fetchone()
cn.close()
return result[0] if result else "default_avatar.png"
def register_user(usr, pwd, avatar="default_avatar.png"):
conn = get_db_connection()
try:
cursor = conn.cursor()
# 添加avatar字段
cursor.execute("INSERT INTO users (name, passwd, avatar) VALUES (?, ?, ?)",
(usr, pwd, avatar))
conn.commit()
return {"type": "register_1", "success": True, "message": "User registered successfully"}
except sqlite3.IntegrityError:
return {"type": "register_0", "success": False, "message": "Username already exists"}
except sqlite3.Error as e:
return {"type": "register_0", "success": False, "message": str(e)}
finally:
conn.close()
def generate_token(username):
token = secrets.token_hex(16)
tokens[token] = {'username': username, 'timestamp': time.time()}
return token
def validate_token(token):
if token in tokens:
if time.time() - tokens[token]['timestamp'] < 3600:
tokens[token]['timestamp'] = time.time()
return tokens[token]['username']
return None
@app.route("/api/register", methods=['POST'])
def register1():
vl = request.get_json()
usr = vl.get('username')
pwd = vl.get('password')
avatar = vl.get('avatar', 'default_avatar.png') # 获取头像
# 头像验证
if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')):
return jsonify({
"type": "register_0",
"success": False,
"message": "Invalid avatar format. Only .png or .jpg allowed"
}), 400
result = register_user(usr, pwd, avatar)
if result['success']:
return jsonify(result)
else:
return jsonify(result), 403 if result['message'] == "Username already exists" else 500
@app.route("/api/login", methods=['POST'])
def login():
data = request.get_json()
username = data['username']
# 检查账号是否已登录
if username in active_connections:
return jsonify({
"type": "login_0",
"status": "error",
"message": "Account already logged in"
}), 409 # 冲突状态码
if isuserxist(username) and ispsswdright(username, data['password']):
token = generate_token(username)
avatar = get_avatar(username)
return jsonify({
"type": "login_1",
"status": "success",
"token": token,
"avatar": avatar # 返回头像
})
return jsonify({
"type": "login_0",
"status": "error",
"message": "Invalid credentials"
}), 401
@app.route("/api/avatar/<username>", methods=['GET'])
def get_user_avatar(username):
avatar = get_avatar(username)
return jsonify({"username": username, "avatar": avatar})
@app.route("/api/chat", methods=['POST'])
def chat():
token = request.headers.get('Authorization')
username = validate_token(token)
if not username:
return jsonify({"type": "chat", "status": "error"}), 401
data = request.get_json()
message = {
"type": "chat",
"user": username,
"message": data['message'],
"avatar": get_avatar(username) # 添加头像信息
}
broadcast_message(message)
return jsonify({"type": "chat", "status": "success"})
def broadcast_message(message, sender=None):
for conn in chat_connections:
try:
conn.sendall(json.dumps(message).encode('utf-8'))
except:
# 连接异常时移除
for uname, info in list(active_connections.items()):
if info['conn'] == conn:
del active_connections[uname]
break
if conn in chat_connections:
chat_connections.remove(conn)
def handle_socket_message(data, addr, conn):
try:
action = data.get('type')
if action == 'register':
avatar = data.get('avatar', 'default_avatar.png')
result = register_user(data.get('username'), data.get('password'), avatar)
conn.sendall(json.dumps(result).encode('utf-8'))
return result
elif action == 'login':
username = data['username']
password = data['password']
# 检查账号是否已登录
if username in active_connections:
response = {
"type": "login_0",
"status": "error",
"message": "Account already logged in"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
if isuserxist(username) and ispsswdright(username, password):
# 移除旧连接(如果存在)
if username in active_connections:
old_conn = active_connections[username]['conn']
if old_conn in chat_connections:
chat_connections.remove(old_conn)
del active_connections[username]
# 添加新连接
active_connections[username] = {'conn': conn, 'ip': addr[0]}
if conn not in chat_connections:
chat_connections.append(conn)
token = generate_token(username)
avatar = get_avatar(username)
response = {
"type": "login_1",
"status": "success",
"message": "Login successful",
"token": token,
"username": username,
"avatar": avatar # 返回头像
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
else:
response = {
"type": "login_0",
"status": "error",
"message": "Invalid credentials"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
elif action == 'chat':
token = data.get('token')
if not token:
response = {
"type": "chat",
"status": "error",
"message": "Missing token"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
username = validate_token(token)
if not username:
response = {
"type": "chat",
"status": "error",
"message": "Invalid token"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
if username not in active_connections:
response = {
"type": "chat",
"status": "error",
"message": "Not logged in"
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
message = {
"type": "chat",
"user": username,
"message": data['message'],
"avatar": get_avatar(username) # 添加头像
}
broadcast_message(message)
# 返回成功响应
response = {"type": "chat", "status": "success"}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
except Exception as e:
response = {
"status": "error",
"message": str(e)
}
conn.sendall(json.dumps(response).encode('utf-8'))
return response
def run_socket_server():
socket_server.bind(("localhost", 8889))
socket_server.listen()
print("Socket server running on port 8889")
while True:
conn, addr = socket_server.accept()
print(f"Socket client connected: {addr}")
try:
while True:
data = conn.recv(1024)
if not data:
break
try:
# 尝试解码为UTF-8忽略错误字符
decoded_data = data.decode('utf-8', errors='ignore')
json_data = json.loads(decoded_data)
response = handle_socket_message(json_data, addr, conn)
except json.JSONDecodeError:
response = {
"type": "error",
"status": "error",
"message": "Invalid JSON"
}
conn.sendall(json.dumps(response).encode('utf-8'))
except Exception as e:
response = {
"type": "error",
"status": "error",
"message": str(e)
}
conn.sendall(json.dumps(response).encode('utf-8'))
except (ConnectionResetError, BrokenPipeError):
print(f"Client {addr} disconnected abruptly")
finally:
# 清理断开的连接
for username, info in list(active_connections.items()):
if info['conn'] == conn:
del active_connections[username]
print(f"User {username} disconnected")
break
if conn in chat_connections:
chat_connections.remove(conn)
try:
conn.close()
except:
pass
print(f"Connection closed for {addr}")
if __name__ == '__main__':
with get_db_connection() as conn:
# 添加avatar字段
conn.execute('''CREATE TABLE IF NOT EXISTS users
(name TEXT PRIMARY KEY,
passwd TEXT,
avatar TEXT DEFAULT 'default_avatar.png')''')
threading.Thread(target=run_socket_server, daemon=True).start()
app.run(port=5001)