登录注册不了,有bug

This commit is contained in:
DZY 2025-05-31 19:42:19 +08:00
parent b9b2af664a
commit 62865e0993

View File

@ -4,6 +4,7 @@ import json
import sqlite3 import sqlite3
from datetime import datetime from datetime import datetime
import os import os
import traceback
class ChatServer: class ChatServer:
def __init__(self, host='0.0.0.0', port=5555): def __init__(self, host='0.0.0.0', port=5555):
@ -12,264 +13,217 @@ class ChatServer:
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server.bind((host, port)) self.server.bind((host, port))
self.server.listen() self.server.listen(5)
self.clients = {} self.clients = {}
self.groups = {'General': set()} self.groups = {'General': set()}
self.setup_database() self.setup_database()
def setup_database(self): def setup_database(self):
# 确保数据库文件存在并初始化
db_exists = os.path.exists('chat_server.db')
self.db = sqlite3.connect('chat_server.db', check_same_thread=False) self.db = sqlite3.connect('chat_server.db', check_same_thread=False)
self.db.execute("PRAGMA foreign_keys = ON")
cursor = self.db.cursor() cursor = self.db.cursor()
# 创建用户表
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL, username TEXT UNIQUE NOT NULL CHECK(length(username) >= 3),
password TEXT NOT NULL password TEXT NOT NULL CHECK(length(password) >= 3),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) )
''') ''')
# 创建消息表
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS messages ( CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
sender TEXT NOT NULL, sender TEXT NOT NULL,
receiver TEXT NOT NULL, receiver TEXT NOT NULL,
message TEXT NOT NULL, message TEXT NOT NULL,
timestamp DATETIME NOT NULL, timestamp DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
is_group INTEGER NOT NULL is_group INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY(sender) REFERENCES users(username) ON DELETE CASCADE
) )
''') ''')
# 如果是首次创建数据库,添加一个默认管理员用户 cursor.execute("SELECT COUNT(*) FROM users")
if not db_exists: if cursor.fetchone()[0] == 0:
try: try:
cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)', cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)',
('admin', 'admin123')) ('admin', 'admin123'))
print("Created default admin user: admin/admin123") print("[DB] 创建默认管理员账号: admin/admin123")
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
pass pass
self.db.commit() self.db.commit()
print("[DB] 数据库初始化完成")
def authenticate_user(self, username, password): def authenticate_user(self, username, password):
cursor = self.db.cursor() try:
cursor.execute('SELECT * FROM users WHERE username=? AND password=?', (username, password)) cursor = self.db.cursor()
return cursor.fetchone() is not None cursor.execute('''
SELECT 1 FROM users
WHERE username=? AND password=?
LIMIT 1
''', (username, password))
return cursor.fetchone() is not None
except sqlite3.Error as e:
print(f"[DB] 认证错误: {e}")
return False
def register_user(self, username, password): def register_user(self, username, password):
if not username or not password: if len(username) < 3 or len(password) < 3:
return False return False, "用户名和密码至少需要3个字符"
try: try:
cursor = self.db.cursor() cursor = self.db.cursor()
cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)', (username, password)) cursor.execute('''
INSERT INTO users (username, password)
VALUES (?, ?)
''', (username, password))
self.db.commit() self.db.commit()
return True return True, "注册成功"
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
return False return False, "用户名已存在"
except sqlite3.Error as e:
def save_message(self, sender, receiver, message, is_group): print(f"[DB] 注册错误: {e}")
cursor = self.db.cursor() return False, "数据库错误"
cursor.execute('''
INSERT INTO messages (sender, receiver, message, timestamp, is_group)
VALUES (?, ?, ?, ?, ?)
''', (sender, receiver, message, datetime.now(), 1 if is_group else 0))
self.db.commit()
def send_user_list(self): def send_user_list(self):
"""发送当前在线用户列表给所有客户端"""
user_list = list(self.clients.keys()) user_list = list(self.clients.keys())
for client in self.clients.values(): for username, client in list(self.clients.items()):
try: try:
client.send(json.dumps({ client.send(json.dumps({
'type': 'user_list', 'type': 'user_list',
'users': user_list 'users': user_list
}).encode('utf-8')) }).encode('utf-8'))
except: except:
pass self.remove_client(username)
def send_group_list(self, username): def send_group_list(self, username):
group_list = [group for group, members in self.groups.items() if username in members] """发送用户所在的群组列表"""
if username in self.clients: if username in self.clients:
group_list = [group for group, members in self.groups.items() if username in members]
try: try:
self.clients[username].send(json.dumps({ self.clients[username].send(json.dumps({
'type': 'group_list', 'type': 'group_list',
'groups': group_list 'groups': group_list
}).encode('utf-8')) }).encode('utf-8'))
except: except:
pass self.remove_client(username)
def send_group_list_to_all(self):
for username in self.clients:
self.send_group_list(username)
def send_initial_messages(self, username):
cursor = self.db.cursor()
cursor.execute('''
SELECT sender, receiver, message, timestamp, is_group FROM messages
WHERE receiver=? OR sender=? OR (is_group=1 AND receiver IN (
SELECT name FROM sqlite_master WHERE type='table' AND name='groups'
))
ORDER BY timestamp
LIMIT 100
''', (username, username))
messages = []
for sender, receiver, message, timestamp, is_group in cursor.fetchall():
messages.append({
'sender': sender,
'receiver': receiver,
'message': message,
'timestamp': timestamp,
'is_group': is_group
})
def remove_client(self, username):
"""安全移除客户端"""
if username in self.clients: if username in self.clients:
try: try:
self.clients[username].send(json.dumps({ self.clients[username].close()
'type': 'initial_messages',
'messages': messages
}).encode('utf-8'))
except: except:
pass pass
del self.clients[username]
def user_exists(self, username): for group in self.groups.values():
cursor = self.db.cursor() group.discard(username)
cursor.execute('SELECT 1 FROM users WHERE username=?', (username,)) self.send_user_list()
return cursor.fetchone() is not None
def handle_client(self, client, address): def handle_client(self, client, address):
print(f"[连接] 新客户端: {address}")
username = None username = None
print(f"New connection from {address}")
try: try:
while True: while True:
try: try:
message = client.recv(1024).decode('utf-8') raw_data = client.recv(4096)
if not message: if not raw_data:
print(f"[连接] 客户端 {address} 断开连接")
break break
data = json.loads(message) try:
print(f"Received from {address}: {data}") data = json.loads(raw_data.decode('utf-8'))
except json.JSONDecodeError:
print(f"[错误] 无效JSON数据来自 {address}")
client.send(json.dumps({
'type': 'error',
'message': 'Invalid JSON format'
}).encode('utf-8'))
continue
print(f"[请求] {address}: {data.get('type')}")
if data['type'] == 'register':
username = data.get('username', '').strip()
password = data.get('password', '').strip()
success, message = self.register_user(username, password)
response = {
'type': 'register_response',
'success': success,
'message': message
}
client.send(json.dumps(response).encode('utf-8'))
if success:
print(f"[注册] 新用户: {username}")
else:
print(f"[注册] 失败: {username} - {message}")
elif data['type'] == 'login':
username = data.get('username', '').strip()
password = data.get('password', '').strip()
if data['type'] == 'login':
username = data['username']
password = data['password']
if self.authenticate_user(username, password): if self.authenticate_user(username, password):
self.clients[username] = client self.clients[username] = client
self.groups['General'].add(username) self.groups['General'].add(username)
client.send(json.dumps({'type': 'login_success'}).encode('utf-8'))
client.send(json.dumps({
'type': 'login_response',
'success': True,
'message': '登录成功'
}).encode('utf-8'))
self.send_user_list() self.send_user_list()
self.send_group_list(username) self.send_group_list(username)
self.send_initial_messages(username) print(f"[登录] 成功: {username}")
print(f"User {username} logged in successfully")
else: else:
client.send(json.dumps({'type': 'login_fail', 'reason': 'Invalid username or password'}).encode('utf-8')) client.send(json.dumps({
print(f"Login failed for {username}") 'type': 'login_response',
'success': False,
'message': '用户名或密码错误'
}).encode('utf-8'))
print(f"[登录] 失败: {username}")
elif data['type'] == 'register':
username = data['username']
password = data['password']
if len(username) < 3 or len(password) < 3:
client.send(json.dumps({'type': 'register_fail', 'reason': 'Username and password must be at least 3 characters'}).encode('utf-8'))
elif self.register_user(username, password):
client.send(json.dumps({'type': 'register_success'}).encode('utf-8'))
print(f"New user registered: {username}")
else:
client.send(json.dumps({'type': 'register_fail', 'reason': 'Username already taken'}).encode('utf-8'))
elif data['type'] == 'message':
sender = data['sender']
receiver = data['receiver']
msg = data['message']
is_group = data['is_group']
self.save_message(sender, receiver, msg, is_group)
if is_group:
for member in self.groups.get(receiver, set()):
if member in self.clients and member != sender:
try:
self.clients[member].send(json.dumps({
'type': 'message',
'sender': sender,
'receiver': receiver,
'message': msg,
'is_group': True
}).encode('utf-8'))
except:
pass
else:
if receiver in self.clients:
try:
self.clients[receiver].send(json.dumps({
'type': 'message',
'sender': sender,
'receiver': receiver,
'message': msg,
'is_group': False
}).encode('utf-8'))
except:
pass
elif data['type'] == 'create_group':
group_name = data['group_name']
if group_name not in self.groups:
self.groups[group_name] = set()
for user in data['members']:
if user in self.clients or self.user_exists(user):
self.groups[group_name].add(user)
self.send_group_list_to_all()
client.send(json.dumps({'type': 'group_created', 'group': group_name}).encode('utf-8'))
else:
client.send(json.dumps({'type': 'group_exists', 'group': group_name}).encode('utf-8'))
except json.JSONDecodeError:
print(f"Invalid JSON received from {address}")
break
except ConnectionResetError: except ConnectionResetError:
print(f"Connection reset by {address}") print(f"[连接] 客户端 {address} 强制断开")
break break
except Exception as e: except Exception as e:
print(f"Error handling client {address}: {str(e)}") print(f"[错误] 处理客户端 {address}: {str(e)}")
traceback.print_exc()
break break
except Exception as e: except Exception as e:
print(f"Client handler error for {address}: {str(e)}") print(f"[错误] 客户端处理异常: {str(e)}")
traceback.print_exc()
finally: finally:
if username and username in self.clients: if username:
print(f"User {username} disconnected") self.remove_client(username)
del self.clients[username] client.close()
for group in self.groups.values():
if username in group:
group.remove(username)
self.send_user_list()
try:
client.close()
except:
pass
def start(self): def start(self):
print(f"Server started on {self.host}:{self.port}") print(f"[启动] 服务器运行在 {self.host}:{self.port}")
print("Database initialized at chat_server.db")
try: try:
while True: while True:
client, address = self.server.accept() client, address = self.server.accept()
thread = threading.Thread(target=self.handle_client, args=(client, address)) thread = threading.Thread(
thread.daemon = True target=self.handle_client,
args=(client, address),
daemon=True
)
thread.start() thread.start()
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nServer is shutting down...") print("\n[关闭] 服务器正在停止...")
finally: finally:
for client in self.clients.values(): for username in list(self.clients.keys()):
try: self.remove_client(username)
client.close()
except:
pass
self.server.close() self.server.close()
self.db.close() self.db.close()
print("Server stopped") print("[关闭] 服务器已停止")
if __name__ == "__main__": if __name__ == "__main__":
server = ChatServer() server = ChatServer()