我不到啊
This commit is contained in:
parent
a2cc22baa2
commit
cf611c11c3
238
CS3.1.py
238
CS3.1.py
@ -3,16 +3,18 @@ import json
|
|||||||
from flask import Flask, jsonify, request
|
from flask import Flask, jsonify, request
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import socket
|
import socket
|
||||||
import base64
|
|
||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
socket_server = socket.socket()
|
socket_server = socket.socket()
|
||||||
socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
|
||||||
active_users = {}
|
# 修改数据结构:使用用户名作为主键
|
||||||
chat_connections = []
|
active_connections = {} # {username: {'conn': conn, 'ip': ip}}
|
||||||
tokens = {}
|
chat_connections = [] # 所有活跃连接列表
|
||||||
|
tokens = {} # 令牌管理
|
||||||
|
|
||||||
def get_db_connection():
|
def get_db_connection():
|
||||||
conn = sqlite3.connect("usr.db")
|
conn = sqlite3.connect("usr.db")
|
||||||
@ -25,38 +27,35 @@ def isuserxist(name):
|
|||||||
csr.execute('SELECT * FROM users WHERE name = ?', (name,))
|
csr.execute('SELECT * FROM users WHERE name = ?', (name,))
|
||||||
rst = csr.fetchone()
|
rst = csr.fetchone()
|
||||||
cn.close()
|
cn.close()
|
||||||
if rst is not None:
|
return rst is not None
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def ispsswdright(name,passwd):
|
def ispsswdright(name, passwd):
|
||||||
cn = get_db_connection()
|
cn = get_db_connection()
|
||||||
csr = cn.cursor()
|
csr = cn.cursor()
|
||||||
csr.execute("SELECT COUNT(*) FROM users WHERE name=?", (name,))
|
csr.execute("SELECT passwd FROM users WHERE name = ?", (name,))
|
||||||
row_count = csr.fetchone()[0]
|
result = csr.fetchone()
|
||||||
password = None
|
cn.close()
|
||||||
if row_count > 0:
|
return result and result[0] == passwd
|
||||||
csr.execute("SELECT passwd FROM users WHERE name=?", (name,))
|
|
||||||
password = csr.fetchone()[0]
|
|
||||||
if password == passwd:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def register_user(usr, pwd):
|
def get_avatar(username):
|
||||||
|
cn = get_db_connection()
|
||||||
|
csr = cn.cursor()
|
||||||
|
csr.execute("SELECT avatar FROM users WHERE name = ?", (username,))
|
||||||
|
result = csr.fetchone()
|
||||||
|
cn.close()
|
||||||
|
return result[0] if result else "default_avatar.png"
|
||||||
|
|
||||||
|
def register_user(usr, pwd, avatar="default_avatar.png"):
|
||||||
conn = get_db_connection()
|
conn = get_db_connection()
|
||||||
csr2 = conn.cursor()
|
|
||||||
csr2.execute('SELECT * FROM users WHERE name = ?', (usr,))
|
|
||||||
result = csr2.fetchone()
|
|
||||||
if result is not None:
|
|
||||||
return {"type": "register_0", "success": False, "message": "Username already exists"}
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("INSERT INTO users (name, passwd) VALUES (?, ?)", (usr, pwd))
|
# 添加avatar字段
|
||||||
|
cursor.execute("INSERT INTO users (name, passwd, avatar) VALUES (?, ?, ?)",
|
||||||
|
(usr, pwd, avatar))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return {"type": "register_1", "success": True, "message": "User registered successfully"}
|
return {"type": "register_1", "success": True, "message": "User registered successfully"}
|
||||||
|
except sqlite3.IntegrityError:
|
||||||
|
return {"type": "register_0", "success": False, "message": "Username already exists"}
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
return {"type": "register_0", "success": False, "message": str(e)}
|
return {"type": "register_0", "success": False, "message": str(e)}
|
||||||
finally:
|
finally:
|
||||||
@ -79,7 +78,17 @@ def register1():
|
|||||||
vl = request.get_json()
|
vl = request.get_json()
|
||||||
usr = vl.get('username')
|
usr = vl.get('username')
|
||||||
pwd = vl.get('password')
|
pwd = vl.get('password')
|
||||||
result = register_user(usr, pwd)
|
avatar = vl.get('avatar', 'default_avatar.png') # 获取头像
|
||||||
|
|
||||||
|
# 头像验证
|
||||||
|
if avatar and not (avatar.endswith('.png') or avatar.endswith('.jpg')):
|
||||||
|
return jsonify({
|
||||||
|
"type": "register_0",
|
||||||
|
"success": False,
|
||||||
|
"message": "Invalid avatar format. Only .png or .jpg allowed"
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
result = register_user(usr, pwd, avatar)
|
||||||
if result['success']:
|
if result['success']:
|
||||||
return jsonify(result)
|
return jsonify(result)
|
||||||
else:
|
else:
|
||||||
@ -88,11 +97,35 @@ def register1():
|
|||||||
@app.route("/api/login", methods=['POST'])
|
@app.route("/api/login", methods=['POST'])
|
||||||
def login():
|
def login():
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if isuserxist(data['username']):
|
username = data['username']
|
||||||
if ispsswdright(data['username'], data['password']):
|
|
||||||
token = generate_token(data['username'])
|
# 检查账号是否已登录
|
||||||
return jsonify({"type": "login_1", "status": "success", "token": token})
|
if username in active_connections:
|
||||||
return jsonify({"type": "login_0", "status": "error"}), 401
|
return jsonify({
|
||||||
|
"type": "login_0",
|
||||||
|
"status": "error",
|
||||||
|
"message": "Account already logged in"
|
||||||
|
}), 409 # 冲突状态码
|
||||||
|
|
||||||
|
if isuserxist(username) and ispsswdright(username, data['password']):
|
||||||
|
token = generate_token(username)
|
||||||
|
avatar = get_avatar(username)
|
||||||
|
return jsonify({
|
||||||
|
"type": "login_1",
|
||||||
|
"status": "success",
|
||||||
|
"token": token,
|
||||||
|
"avatar": avatar # 返回头像
|
||||||
|
})
|
||||||
|
return jsonify({
|
||||||
|
"type": "login_0",
|
||||||
|
"status": "error",
|
||||||
|
"message": "Invalid credentials"
|
||||||
|
}), 401
|
||||||
|
|
||||||
|
@app.route("/api/avatar/<username>", methods=['GET'])
|
||||||
|
def get_user_avatar(username):
|
||||||
|
avatar = get_avatar(username)
|
||||||
|
return jsonify({"username": username, "avatar": avatar})
|
||||||
|
|
||||||
@app.route("/api/chat", methods=['POST'])
|
@app.route("/api/chat", methods=['POST'])
|
||||||
def chat():
|
def chat():
|
||||||
@ -100,11 +133,13 @@ def chat():
|
|||||||
username = validate_token(token)
|
username = validate_token(token)
|
||||||
if not username:
|
if not username:
|
||||||
return jsonify({"type": "chat", "status": "error"}), 401
|
return jsonify({"type": "chat", "status": "error"}), 401
|
||||||
|
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
message = {
|
message = {
|
||||||
"type": "chat",
|
"type": "chat",
|
||||||
"user": username,
|
"user": username,
|
||||||
"message": data['message']
|
"message": data['message'],
|
||||||
|
"avatar": get_avatar(username) # 添加头像信息
|
||||||
}
|
}
|
||||||
broadcast_message(message)
|
broadcast_message(message)
|
||||||
return jsonify({"type": "chat", "status": "success"})
|
return jsonify({"type": "chat", "status": "success"})
|
||||||
@ -114,37 +149,100 @@ def broadcast_message(message, sender=None):
|
|||||||
try:
|
try:
|
||||||
conn.sendall(json.dumps(message).encode())
|
conn.sendall(json.dumps(message).encode())
|
||||||
except:
|
except:
|
||||||
|
# 连接异常时移除
|
||||||
|
for uname, info in list(active_connections.items()):
|
||||||
|
if info['conn'] == conn:
|
||||||
|
del active_connections[uname]
|
||||||
|
break
|
||||||
|
if conn in chat_connections:
|
||||||
chat_connections.remove(conn)
|
chat_connections.remove(conn)
|
||||||
|
|
||||||
def handle_socket_message(data, addr, conn):
|
def handle_socket_message(data, addr, conn):
|
||||||
try:
|
try:
|
||||||
action = data.get('type')
|
action = data.get('type')
|
||||||
if action == 'register':
|
if action == 'register':
|
||||||
result = register_user(data.get('username'), data.get('password'))
|
avatar = data.get('avatar', 'default_avatar.png')
|
||||||
if result['success']:
|
result = register_user(data.get('username'), data.get('password'), avatar)
|
||||||
return {"type": "register_1","status": "success", "message": result['message']}
|
conn.sendall(json.dumps(result).encode())
|
||||||
else:
|
return result
|
||||||
return {"type": "register_0","status": "error", "message": result['message']}
|
|
||||||
|
|
||||||
elif action == 'login':
|
elif action == 'login':
|
||||||
if isuserxist(data['username']):
|
username = data['username']
|
||||||
if ispsswdright(data['username'], data['password']):
|
password = data['password']
|
||||||
active_users[addr[0]] = data['username']
|
|
||||||
tk = base64.b64encode(data['username'].encode('utf-8'))
|
# 检查账号是否已登录
|
||||||
|
if username in active_connections:
|
||||||
|
response = {
|
||||||
|
"type": "login_0",
|
||||||
|
"status": "error",
|
||||||
|
"message": "Account already logged in"
|
||||||
|
}
|
||||||
|
conn.sendall(json.dumps(response).encode())
|
||||||
|
return response
|
||||||
|
|
||||||
|
if isuserxist(username) and ispsswdright(username, password):
|
||||||
|
# 移除旧连接(如果存在)
|
||||||
|
if username in active_connections:
|
||||||
|
old_conn = active_connections[username]['conn']
|
||||||
|
if old_conn in chat_connections:
|
||||||
|
chat_connections.remove(old_conn)
|
||||||
|
del active_connections[username]
|
||||||
|
|
||||||
|
# 添加新连接
|
||||||
|
active_connections[username] = {'conn': conn, 'ip': addr[0]}
|
||||||
|
if conn not in chat_connections:
|
||||||
chat_connections.append(conn)
|
chat_connections.append(conn)
|
||||||
return {"type": "login_1", "status": "success", "message": "Login successful", "token": generate_token(data['username']), "username": data['username']}
|
|
||||||
return {"type": "login_0", "status": "error", "message": "Invalid credentials", "username": data['username']}
|
token = generate_token(username)
|
||||||
|
avatar = get_avatar(username)
|
||||||
|
response = {
|
||||||
|
"type": "login_1",
|
||||||
|
"status": "success",
|
||||||
|
"message": "Login successful",
|
||||||
|
"token": token,
|
||||||
|
"username": username,
|
||||||
|
"avatar": avatar # 返回头像
|
||||||
|
}
|
||||||
|
conn.sendall(json.dumps(response).encode())
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
response = {
|
||||||
|
"type": "login_0",
|
||||||
|
"status": "error",
|
||||||
|
"message": "Invalid credentials"
|
||||||
|
}
|
||||||
|
conn.sendall(json.dumps(response).encode())
|
||||||
|
return response
|
||||||
|
|
||||||
elif action == 'chat':
|
elif action == 'chat':
|
||||||
if addr[0] in active_users:
|
username = validate_token(data.get('token', ''))
|
||||||
|
if not username:
|
||||||
|
response = {
|
||||||
|
"type": "chat",
|
||||||
|
"status": "error",
|
||||||
|
"message": "Not authenticated"
|
||||||
|
}
|
||||||
|
conn.sendall(json.dumps(response).encode())
|
||||||
|
return response
|
||||||
|
|
||||||
|
if username not in active_connections:
|
||||||
|
response = {
|
||||||
|
"type": "chat",
|
||||||
|
"status": "error",
|
||||||
|
"message": "Not logged in"
|
||||||
|
}
|
||||||
|
conn.sendall(json.dumps(response).encode())
|
||||||
|
return response
|
||||||
|
|
||||||
message = {
|
message = {
|
||||||
"type": "chat",
|
"type": "chat",
|
||||||
"user": active_users[addr[0]],
|
"user": username,
|
||||||
"message": data['message']
|
"message": data['message'],
|
||||||
|
"avatar": get_avatar(username) # 添加头像
|
||||||
}
|
}
|
||||||
broadcast_message(message)
|
broadcast_message(message)
|
||||||
#return {"type": "chat", "status": "success"}
|
return {"type": "chat", "status": "success"}
|
||||||
else :
|
|
||||||
return {"type": "chat", "status": "error", "message": "Not logged in"}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
@ -159,31 +257,43 @@ def run_socket_server():
|
|||||||
while True:
|
while True:
|
||||||
data = conn.recv(1024)
|
data = conn.recv(1024)
|
||||||
if not data:
|
if not data:
|
||||||
if addr[0] in active_users:
|
|
||||||
del active_users[addr[0]]
|
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
json_data = json.loads(data.decode())
|
json_data = json.loads(data.decode())
|
||||||
response = handle_socket_message(json_data, addr, conn)
|
response = handle_socket_message(json_data, addr, conn)
|
||||||
if response != None:
|
|
||||||
conn.sendall(json.dumps(response).encode())
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
conn.sendall(json.dumps(
|
response = {
|
||||||
{"type": "register_0", "status": "error", "message": "Invalid JSON"}
|
"type": "error",
|
||||||
).encode())
|
"status": "error",
|
||||||
except ConnectionResetError:
|
"message": "Invalid JSON"
|
||||||
if addr[0] in active_users:
|
}
|
||||||
del active_users[addr[0]]
|
conn.sendall(json.dumps(response).encode())
|
||||||
print(f"Client {addr} disconnected")
|
except (ConnectionResetError, BrokenPipeError):
|
||||||
|
print(f"Client {addr} disconnected abruptly")
|
||||||
finally:
|
finally:
|
||||||
|
# 清理断开的连接
|
||||||
|
for username, info in list(active_connections.items()):
|
||||||
|
if info['conn'] == conn:
|
||||||
|
del active_connections[username]
|
||||||
|
print(f"User {username} disconnected")
|
||||||
|
break
|
||||||
|
|
||||||
if conn in chat_connections:
|
if conn in chat_connections:
|
||||||
chat_connections.remove(conn)
|
chat_connections.remove(conn)
|
||||||
|
|
||||||
|
try:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
print(f"Connection closed for {addr}")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
|
# 添加avatar字段
|
||||||
conn.execute('''CREATE TABLE IF NOT EXISTS users
|
conn.execute('''CREATE TABLE IF NOT EXISTS users
|
||||||
(name TEXT, passwd TEXT)''')
|
(name TEXT PRIMARY KEY,
|
||||||
threading.Thread(target=run_socket_server).start()
|
passwd TEXT,
|
||||||
|
avatar TEXT)''')
|
||||||
|
threading.Thread(target=run_socket_server, daemon=True).start()
|
||||||
app.run(port=5001)
|
app.run(port=5001)
|
Loading…
x
Reference in New Issue
Block a user