我不到啊

This commit is contained in:
DZY 2025-06-13 20:19:55 +08:00
parent a2cc22baa2
commit cf611c11c3

238
CS3.1.py
View File

@ -3,16 +3,18 @@ import json
from flask import Flask, jsonify, request from flask import Flask, jsonify, request
import sqlite3 import sqlite3
import socket import socket
import base64
import secrets import secrets
import time import time
import os
app = Flask(__name__) app = Flask(__name__)
socket_server = socket.socket() socket_server = socket.socket()
socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
active_users = {} # 修改数据结构:使用用户名作为主键
chat_connections = [] active_connections = {} # {username: {'conn': conn, 'ip': ip}}
tokens = {} chat_connections = [] # 所有活跃连接列表
tokens = {} # 令牌管理
def get_db_connection(): def get_db_connection():
conn = sqlite3.connect("usr.db") conn = sqlite3.connect("usr.db")
@ -25,38 +27,35 @@ def isuserxist(name):
csr.execute('SELECT * FROM users WHERE name = ?', (name,)) csr.execute('SELECT * FROM users WHERE name = ?', (name,))
rst = csr.fetchone() rst = csr.fetchone()
cn.close() cn.close()
if rst is not None: return rst is not None
return True
else:
return False
def ispsswdright(name,passwd): def ispsswdright(name, passwd):
cn = get_db_connection() cn = get_db_connection()
csr = cn.cursor() csr = cn.cursor()
csr.execute("SELECT COUNT(*) FROM users WHERE name=?", (name,)) csr.execute("SELECT passwd FROM users WHERE name = ?", (name,))
row_count = csr.fetchone()[0] result = csr.fetchone()
password = None cn.close()
if row_count > 0: return result and result[0] == passwd
csr.execute("SELECT passwd FROM users WHERE name=?", (name,))
password = csr.fetchone()[0]
if password == passwd:
return True
else:
return False
def register_user(usr, pwd): def get_avatar(username):
cn = get_db_connection()
csr = cn.cursor()
csr.execute("SELECT avatar FROM users WHERE name = ?", (username,))
result = csr.fetchone()
cn.close()
return result[0] if result else "default_avatar.png"
def register_user(usr, pwd, avatar="default_avatar.png"):
conn = get_db_connection() conn = get_db_connection()
csr2 = conn.cursor()
csr2.execute('SELECT * FROM users WHERE name = ?', (usr,))
result = csr2.fetchone()
if result is not None:
return {"type": "register_0", "success": False, "message": "Username already exists"}
else:
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("INSERT INTO users (name, passwd) VALUES (?, ?)", (usr, pwd)) # 添加avatar字段
cursor.execute("INSERT INTO users (name, passwd, avatar) VALUES (?, ?, ?)",
(usr, pwd, avatar))
conn.commit() conn.commit()
return {"type": "register_1", "success": True, "message": "User registered successfully"} return {"type": "register_1", "success": True, "message": "User registered successfully"}
except sqlite3.IntegrityError:
return {"type": "register_0", "success": False, "message": "Username already exists"}
except sqlite3.Error as e: except sqlite3.Error as e:
return {"type": "register_0", "success": False, "message": str(e)} return {"type": "register_0", "success": False, "message": str(e)}
finally: finally:
@ -79,7 +78,17 @@ def register1():
vl = request.get_json() vl = request.get_json()
usr = vl.get('username') usr = vl.get('username')
pwd = vl.get('password') pwd = vl.get('password')
result = register_user(usr, pwd) avatar = vl.get('avatar', 'default_avatar.png') # 获取头像
# 头像验证
if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')):
return jsonify({
"type": "register_0",
"success": False,
"message": "Invalid avatar format. Only .png or .jpg allowed"
}), 400
result = register_user(usr, pwd, avatar)
if result['success']: if result['success']:
return jsonify(result) return jsonify(result)
else: else:
@ -88,11 +97,35 @@ def register1():
@app.route("/api/login", methods=['POST']) @app.route("/api/login", methods=['POST'])
def login(): def login():
data = request.get_json() data = request.get_json()
if isuserxist(data['username']): username = data['username']
if ispsswdright(data['username'], data['password']):
token = generate_token(data['username']) # 检查账号是否已登录
return jsonify({"type": "login_1", "status": "success", "token": token}) if username in active_connections:
return jsonify({"type": "login_0", "status": "error"}), 401 return jsonify({
"type": "login_0",
"status": "error",
"message": "Account already logged in"
}), 409 # 冲突状态码
if isuserxist(username) and ispsswdright(username, data['password']):
token = generate_token(username)
avatar = get_avatar(username)
return jsonify({
"type": "login_1",
"status": "success",
"token": token,
"avatar": avatar # 返回头像
})
return jsonify({
"type": "login_0",
"status": "error",
"message": "Invalid credentials"
}), 401
@app.route("/api/avatar/<username>", methods=['GET'])
def get_user_avatar(username):
avatar = get_avatar(username)
return jsonify({"username": username, "avatar": avatar})
@app.route("/api/chat", methods=['POST']) @app.route("/api/chat", methods=['POST'])
def chat(): def chat():
@ -100,11 +133,13 @@ def chat():
username = validate_token(token) username = validate_token(token)
if not username: if not username:
return jsonify({"type": "chat", "status": "error"}), 401 return jsonify({"type": "chat", "status": "error"}), 401
data = request.get_json() data = request.get_json()
message = { message = {
"type": "chat", "type": "chat",
"user": username, "user": username,
"message": data['message'] "message": data['message'],
"avatar": get_avatar(username) # 添加头像信息
} }
broadcast_message(message) broadcast_message(message)
return jsonify({"type": "chat", "status": "success"}) return jsonify({"type": "chat", "status": "success"})
@ -114,37 +149,100 @@ def broadcast_message(message, sender=None):
try: try:
conn.sendall(json.dumps(message).encode()) conn.sendall(json.dumps(message).encode())
except: except:
# 连接异常时移除
for uname, info in list(active_connections.items()):
if info['conn'] == conn:
del active_connections[uname]
break
if conn in chat_connections:
chat_connections.remove(conn) chat_connections.remove(conn)
def handle_socket_message(data, addr, conn): def handle_socket_message(data, addr, conn):
try: try:
action = data.get('type') action = data.get('type')
if action == 'register': if action == 'register':
result = register_user(data.get('username'), data.get('password')) avatar = data.get('avatar', 'default_avatar.png')
if result['success']: result = register_user(data.get('username'), data.get('password'), avatar)
return {"type": "register_1","status": "success", "message": result['message']} conn.sendall(json.dumps(result).encode())
else: return result
return {"type": "register_0","status": "error", "message": result['message']}
elif action == 'login': elif action == 'login':
if isuserxist(data['username']): username = data['username']
if ispsswdright(data['username'], data['password']): password = data['password']
active_users[addr[0]] = data['username']
tk = base64.b64encode(data['username'].encode('utf-8')) # 检查账号是否已登录
if username in active_connections:
response = {
"type": "login_0",
"status": "error",
"message": "Account already logged in"
}
conn.sendall(json.dumps(response).encode())
return response
if isuserxist(username) and ispsswdright(username, password):
# 移除旧连接(如果存在)
if username in active_connections:
old_conn = active_connections[username]['conn']
if old_conn in chat_connections:
chat_connections.remove(old_conn)
del active_connections[username]
# 添加新连接
active_connections[username] = {'conn': conn, 'ip': addr[0]}
if conn not in chat_connections:
chat_connections.append(conn) chat_connections.append(conn)
return {"type": "login_1", "status": "success", "message": "Login successful", "token": generate_token(data['username']), "username": data['username']}
return {"type": "login_0", "status": "error", "message": "Invalid credentials", "username": data['username']} token = generate_token(username)
avatar = get_avatar(username)
response = {
"type": "login_1",
"status": "success",
"message": "Login successful",
"token": token,
"username": username,
"avatar": avatar # 返回头像
}
conn.sendall(json.dumps(response).encode())
return response
else:
response = {
"type": "login_0",
"status": "error",
"message": "Invalid credentials"
}
conn.sendall(json.dumps(response).encode())
return response
elif action == 'chat': elif action == 'chat':
if addr[0] in active_users: username = validate_token(data.get('token', ''))
if not username:
response = {
"type": "chat",
"status": "error",
"message": "Not authenticated"
}
conn.sendall(json.dumps(response).encode())
return response
if username not in active_connections:
response = {
"type": "chat",
"status": "error",
"message": "Not logged in"
}
conn.sendall(json.dumps(response).encode())
return response
message = { message = {
"type": "chat", "type": "chat",
"user": active_users[addr[0]], "user": username,
"message": data['message'] "message": data['message'],
"avatar": get_avatar(username) # 添加头像
} }
broadcast_message(message) broadcast_message(message)
#return {"type": "chat", "status": "success"} return {"type": "chat", "status": "success"}
else :
return {"type": "chat", "status": "error", "message": "Not logged in"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@ -159,31 +257,43 @@ def run_socket_server():
while True: while True:
data = conn.recv(1024) data = conn.recv(1024)
if not data: if not data:
if addr[0] in active_users:
del active_users[addr[0]]
break break
try: try:
json_data = json.loads(data.decode()) json_data = json.loads(data.decode())
response = handle_socket_message(json_data, addr, conn) response = handle_socket_message(json_data, addr, conn)
if response != None:
conn.sendall(json.dumps(response).encode())
except json.JSONDecodeError: except json.JSONDecodeError:
conn.sendall(json.dumps( response = {
{"type": "register_0", "status": "error", "message": "Invalid JSON"} "type": "error",
).encode()) "status": "error",
except ConnectionResetError: "message": "Invalid JSON"
if addr[0] in active_users: }
del active_users[addr[0]] conn.sendall(json.dumps(response).encode())
print(f"Client {addr} disconnected") except (ConnectionResetError, BrokenPipeError):
print(f"Client {addr} disconnected abruptly")
finally: finally:
# 清理断开的连接
for username, info in list(active_connections.items()):
if info['conn'] == conn:
del active_connections[username]
print(f"User {username} disconnected")
break
if conn in chat_connections: if conn in chat_connections:
chat_connections.remove(conn) chat_connections.remove(conn)
try:
conn.close() conn.close()
except:
pass
print(f"Connection closed for {addr}")
if __name__ == '__main__': if __name__ == '__main__':
with get_db_connection() as conn: with get_db_connection() as conn:
# 添加avatar字段
conn.execute('''CREATE TABLE IF NOT EXISTS users conn.execute('''CREATE TABLE IF NOT EXISTS users
(name TEXT, passwd TEXT)''') (name TEXT PRIMARY KEY,
threading.Thread(target=run_socket_server).start() passwd TEXT,
avatar TEXT)''')
threading.Thread(target=run_socket_server, daemon=True).start()
app.run(port=5001) app.run(port=5001)