重写了

This commit is contained in:
DZY 2025-06-02 11:52:02 +08:00
parent 4c13e8bc83
commit e900da360d

View File

@ -1,3 +1,4 @@
import threading import threading
import json import json
from flask import Flask, jsonify, request from flask import Flask, jsonify, request
@ -7,10 +8,14 @@ app = Flask(__name__)
socket_server = socket.socket() socket_server = socket.socket()
socket_server.bind(("localhost", 8888)) socket_server.bind(("localhost", 8888))
active_users = {}
chat_connections = []
def get_db_connection(): def get_db_connection():
conn = sqlite3.connect("usr.db") conn = sqlite3.connect("usr.db")
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
return conn return conn
def isuserxist(name): def isuserxist(name):
cn = get_db_connection() cn = get_db_connection()
csr = cn.cursor() csr = cn.cursor()
@ -22,7 +27,6 @@ def isuserxist(name):
else: else:
return False return False
def ispsswdright(name,passwd): def ispsswdright(name,passwd):
cn = get_db_connection() cn = get_db_connection()
csr = cn.cursor() csr = cn.cursor()
@ -36,6 +40,7 @@ def ispsswdright(name,passwd):
return True return True
else: else:
return False return False
@app.route("/api/register", methods=['POST']) @app.route("/api/register", methods=['POST'])
def register(usr = None,pwd = None): def register(usr = None,pwd = None):
conn = get_db_connection() conn = get_db_connection()
@ -56,7 +61,6 @@ def register(usr = None,pwd = None):
try: try:
conn = get_db_connection() conn = get_db_connection()
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("INSERT INTO users (name, passwd) VALUES (?, ?)", (usr, pwd)) cursor.execute("INSERT INTO users (name, passwd) VALUES (?, ?)", (usr, pwd))
conn.commit() conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
@ -72,7 +76,15 @@ def register(usr = None,pwd = None):
}) })
finally: finally:
conn.close() conn.close()
def handle_socket_message(data):
def broadcast_message(message, sender=None):
for conn in chat_connections:
try:
conn.sendall(json.dumps(message).encode())
except:
chat_connections.remove(conn)
def handle_socket_message(data, addr, conn):
try: try:
action = data.get('type') action = data.get('type')
if action == 'register': if action == 'register':
@ -80,10 +92,23 @@ def handle_socket_message(data):
elif action == 'login': elif action == 'login':
if isuserxist(data['username']): if isuserxist(data['username']):
if ispsswdright(data['username'], data['password']): if ispsswdright(data['username'], data['password']):
active_users[addr[0]] = data['username']
chat_connections.append(conn)
return {"status": "success", "message": "Login successful"} return {"status": "success", "message": "Login successful"}
return {"status": "error", "message": "Invalid credentials"} return {"success": "error", "message": "Invalid credentials"}
elif action == 'chat':
if addr[0] in active_users:
message = {
"type": "chat",
"user": active_users[addr[0]],
"message": data['message']
}
broadcast_message(message)
return {"status": "success"}
return {"status": "error", "message": "Not logged in"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
def run_socket_server(): def run_socket_server():
socket_server.listen() socket_server.listen()
print("Socket server running on port 8888") print("Socket server running on port 8888")
@ -93,20 +118,28 @@ def run_socket_server():
try: try:
while True: while True:
data = conn.recv(1024) data = conn.recv(1024)
if not data: break if not data:
if addr[0] in active_users:
del active_users[addr[0]]
break
try: try:
json_data = json.loads(data.decode()) json_data = json.loads(data.decode())
response = handle_socket_message(json_data) response = handle_socket_message(json_data, addr, conn)
conn.sendall(json.dumps(response).encode()) conn.sendall(json.dumps(response).encode())
except json.JSONDecodeError: except json.JSONDecodeError:
conn.sendall(json.dumps( conn.sendall(json.dumps(
{"success": "error", "message": "Invalid JSON"} {"success": "error", "message": "Invalid JSON"}
).encode()) ).encode())
except ConnectionResetError: except ConnectionResetError:
if addr[0] in active_users:
del active_users[addr[0]]
print(f"Client {addr} disconnected") print(f"Client {addr} disconnected")
finally: finally:
if conn in chat_connections:
chat_connections.remove(conn)
conn.close() conn.close()
if __name__ == '__main__': if __name__ == '__main__':
threading.Thread(target=run_socket_server, daemon=True).start() threading.Thread(target=run_socket_server, daemon=True).start()
with get_db_connection() as conn: with get_db_connection() as conn: