登录注册不了,有bug
This commit is contained in:
parent
b9b2af664a
commit
62865e0993
278
chatserver3.0.py
278
chatserver3.0.py
@ -4,6 +4,7 @@ import json
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
|
|
||||||
class ChatServer:
|
class ChatServer:
|
||||||
def __init__(self, host='0.0.0.0', port=5555):
|
def __init__(self, host='0.0.0.0', port=5555):
|
||||||
@ -12,264 +13,217 @@ class ChatServer:
|
|||||||
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
self.server.bind((host, port))
|
self.server.bind((host, port))
|
||||||
self.server.listen()
|
self.server.listen(5)
|
||||||
self.clients = {}
|
self.clients = {}
|
||||||
self.groups = {'General': set()}
|
self.groups = {'General': set()}
|
||||||
self.setup_database()
|
self.setup_database()
|
||||||
|
|
||||||
def setup_database(self):
|
def setup_database(self):
|
||||||
# 确保数据库文件存在并初始化
|
|
||||||
db_exists = os.path.exists('chat_server.db')
|
|
||||||
self.db = sqlite3.connect('chat_server.db', check_same_thread=False)
|
self.db = sqlite3.connect('chat_server.db', check_same_thread=False)
|
||||||
|
self.db.execute("PRAGMA foreign_keys = ON")
|
||||||
cursor = self.db.cursor()
|
cursor = self.db.cursor()
|
||||||
|
|
||||||
# 创建用户表
|
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
username TEXT UNIQUE NOT NULL,
|
username TEXT UNIQUE NOT NULL CHECK(length(username) >= 3),
|
||||||
password TEXT NOT NULL
|
password TEXT NOT NULL CHECK(length(password) >= 3),
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
)
|
)
|
||||||
''')
|
''')
|
||||||
|
|
||||||
# 创建消息表
|
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS messages (
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
sender TEXT NOT NULL,
|
sender TEXT NOT NULL,
|
||||||
receiver TEXT NOT NULL,
|
receiver TEXT NOT NULL,
|
||||||
message TEXT NOT NULL,
|
message TEXT NOT NULL,
|
||||||
timestamp DATETIME NOT NULL,
|
timestamp DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
is_group INTEGER NOT NULL
|
is_group INTEGER NOT NULL DEFAULT 0,
|
||||||
|
FOREIGN KEY(sender) REFERENCES users(username) ON DELETE CASCADE
|
||||||
)
|
)
|
||||||
''')
|
''')
|
||||||
|
|
||||||
# 如果是首次创建数据库,添加一个默认管理员用户
|
cursor.execute("SELECT COUNT(*) FROM users")
|
||||||
if not db_exists:
|
if cursor.fetchone()[0] == 0:
|
||||||
try:
|
try:
|
||||||
cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)',
|
cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)',
|
||||||
('admin', 'admin123'))
|
('admin', 'admin123'))
|
||||||
print("Created default admin user: admin/admin123")
|
print("[DB] 创建默认管理员账号: admin/admin123")
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
print("[DB] 数据库初始化完成")
|
||||||
|
|
||||||
def authenticate_user(self, username, password):
|
def authenticate_user(self, username, password):
|
||||||
cursor = self.db.cursor()
|
try:
|
||||||
cursor.execute('SELECT * FROM users WHERE username=? AND password=?', (username, password))
|
cursor = self.db.cursor()
|
||||||
return cursor.fetchone() is not None
|
cursor.execute('''
|
||||||
|
SELECT 1 FROM users
|
||||||
|
WHERE username=? AND password=?
|
||||||
|
LIMIT 1
|
||||||
|
''', (username, password))
|
||||||
|
return cursor.fetchone() is not None
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
print(f"[DB] 认证错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def register_user(self, username, password):
|
def register_user(self, username, password):
|
||||||
if not username or not password:
|
if len(username) < 3 or len(password) < 3:
|
||||||
return False
|
return False, "用户名和密码至少需要3个字符"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cursor = self.db.cursor()
|
cursor = self.db.cursor()
|
||||||
cursor.execute('INSERT INTO users (username, password) VALUES (?, ?)', (username, password))
|
cursor.execute('''
|
||||||
|
INSERT INTO users (username, password)
|
||||||
|
VALUES (?, ?)
|
||||||
|
''', (username, password))
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return True
|
return True, "注册成功"
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
return False
|
return False, "用户名已存在"
|
||||||
|
except sqlite3.Error as e:
|
||||||
def save_message(self, sender, receiver, message, is_group):
|
print(f"[DB] 注册错误: {e}")
|
||||||
cursor = self.db.cursor()
|
return False, "数据库错误"
|
||||||
cursor.execute('''
|
|
||||||
INSERT INTO messages (sender, receiver, message, timestamp, is_group)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
''', (sender, receiver, message, datetime.now(), 1 if is_group else 0))
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
def send_user_list(self):
|
def send_user_list(self):
|
||||||
|
"""发送当前在线用户列表给所有客户端"""
|
||||||
user_list = list(self.clients.keys())
|
user_list = list(self.clients.keys())
|
||||||
for client in self.clients.values():
|
for username, client in list(self.clients.items()):
|
||||||
try:
|
try:
|
||||||
client.send(json.dumps({
|
client.send(json.dumps({
|
||||||
'type': 'user_list',
|
'type': 'user_list',
|
||||||
'users': user_list
|
'users': user_list
|
||||||
}).encode('utf-8'))
|
}).encode('utf-8'))
|
||||||
except:
|
except:
|
||||||
pass
|
self.remove_client(username)
|
||||||
|
|
||||||
def send_group_list(self, username):
|
def send_group_list(self, username):
|
||||||
group_list = [group for group, members in self.groups.items() if username in members]
|
"""发送用户所在的群组列表"""
|
||||||
if username in self.clients:
|
if username in self.clients:
|
||||||
|
group_list = [group for group, members in self.groups.items() if username in members]
|
||||||
try:
|
try:
|
||||||
self.clients[username].send(json.dumps({
|
self.clients[username].send(json.dumps({
|
||||||
'type': 'group_list',
|
'type': 'group_list',
|
||||||
'groups': group_list
|
'groups': group_list
|
||||||
}).encode('utf-8'))
|
}).encode('utf-8'))
|
||||||
except:
|
except:
|
||||||
pass
|
self.remove_client(username)
|
||||||
|
|
||||||
def send_group_list_to_all(self):
|
|
||||||
for username in self.clients:
|
|
||||||
self.send_group_list(username)
|
|
||||||
|
|
||||||
def send_initial_messages(self, username):
|
|
||||||
cursor = self.db.cursor()
|
|
||||||
cursor.execute('''
|
|
||||||
SELECT sender, receiver, message, timestamp, is_group FROM messages
|
|
||||||
WHERE receiver=? OR sender=? OR (is_group=1 AND receiver IN (
|
|
||||||
SELECT name FROM sqlite_master WHERE type='table' AND name='groups'
|
|
||||||
))
|
|
||||||
ORDER BY timestamp
|
|
||||||
LIMIT 100
|
|
||||||
''', (username, username))
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
for sender, receiver, message, timestamp, is_group in cursor.fetchall():
|
|
||||||
messages.append({
|
|
||||||
'sender': sender,
|
|
||||||
'receiver': receiver,
|
|
||||||
'message': message,
|
|
||||||
'timestamp': timestamp,
|
|
||||||
'is_group': is_group
|
|
||||||
})
|
|
||||||
|
|
||||||
|
def remove_client(self, username):
|
||||||
|
"""安全移除客户端"""
|
||||||
if username in self.clients:
|
if username in self.clients:
|
||||||
try:
|
try:
|
||||||
self.clients[username].send(json.dumps({
|
self.clients[username].close()
|
||||||
'type': 'initial_messages',
|
|
||||||
'messages': messages
|
|
||||||
}).encode('utf-8'))
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
del self.clients[username]
|
||||||
def user_exists(self, username):
|
for group in self.groups.values():
|
||||||
cursor = self.db.cursor()
|
group.discard(username)
|
||||||
cursor.execute('SELECT 1 FROM users WHERE username=?', (username,))
|
self.send_user_list()
|
||||||
return cursor.fetchone() is not None
|
|
||||||
|
|
||||||
def handle_client(self, client, address):
|
def handle_client(self, client, address):
|
||||||
|
print(f"[连接] 新客户端: {address}")
|
||||||
username = None
|
username = None
|
||||||
print(f"New connection from {address}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
message = client.recv(1024).decode('utf-8')
|
raw_data = client.recv(4096)
|
||||||
if not message:
|
if not raw_data:
|
||||||
|
print(f"[连接] 客户端 {address} 断开连接")
|
||||||
break
|
break
|
||||||
|
|
||||||
data = json.loads(message)
|
try:
|
||||||
print(f"Received from {address}: {data}")
|
data = json.loads(raw_data.decode('utf-8'))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"[错误] 无效JSON数据来自 {address}")
|
||||||
|
client.send(json.dumps({
|
||||||
|
'type': 'error',
|
||||||
|
'message': 'Invalid JSON format'
|
||||||
|
}).encode('utf-8'))
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"[请求] {address}: {data.get('type')}")
|
||||||
|
|
||||||
|
if data['type'] == 'register':
|
||||||
|
username = data.get('username', '').strip()
|
||||||
|
password = data.get('password', '').strip()
|
||||||
|
|
||||||
|
success, message = self.register_user(username, password)
|
||||||
|
response = {
|
||||||
|
'type': 'register_response',
|
||||||
|
'success': success,
|
||||||
|
'message': message
|
||||||
|
}
|
||||||
|
client.send(json.dumps(response).encode('utf-8'))
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print(f"[注册] 新用户: {username}")
|
||||||
|
else:
|
||||||
|
print(f"[注册] 失败: {username} - {message}")
|
||||||
|
|
||||||
|
elif data['type'] == 'login':
|
||||||
|
username = data.get('username', '').strip()
|
||||||
|
password = data.get('password', '').strip()
|
||||||
|
|
||||||
if data['type'] == 'login':
|
|
||||||
username = data['username']
|
|
||||||
password = data['password']
|
|
||||||
if self.authenticate_user(username, password):
|
if self.authenticate_user(username, password):
|
||||||
self.clients[username] = client
|
self.clients[username] = client
|
||||||
self.groups['General'].add(username)
|
self.groups['General'].add(username)
|
||||||
client.send(json.dumps({'type': 'login_success'}).encode('utf-8'))
|
|
||||||
|
client.send(json.dumps({
|
||||||
|
'type': 'login_response',
|
||||||
|
'success': True,
|
||||||
|
'message': '登录成功'
|
||||||
|
}).encode('utf-8'))
|
||||||
|
|
||||||
self.send_user_list()
|
self.send_user_list()
|
||||||
self.send_group_list(username)
|
self.send_group_list(username)
|
||||||
self.send_initial_messages(username)
|
print(f"[登录] 成功: {username}")
|
||||||
print(f"User {username} logged in successfully")
|
|
||||||
else:
|
else:
|
||||||
client.send(json.dumps({'type': 'login_fail', 'reason': 'Invalid username or password'}).encode('utf-8'))
|
client.send(json.dumps({
|
||||||
print(f"Login failed for {username}")
|
'type': 'login_response',
|
||||||
|
'success': False,
|
||||||
|
'message': '用户名或密码错误'
|
||||||
|
}).encode('utf-8'))
|
||||||
|
print(f"[登录] 失败: {username}")
|
||||||
|
|
||||||
elif data['type'] == 'register':
|
|
||||||
username = data['username']
|
|
||||||
password = data['password']
|
|
||||||
if len(username) < 3 or len(password) < 3:
|
|
||||||
client.send(json.dumps({'type': 'register_fail', 'reason': 'Username and password must be at least 3 characters'}).encode('utf-8'))
|
|
||||||
elif self.register_user(username, password):
|
|
||||||
client.send(json.dumps({'type': 'register_success'}).encode('utf-8'))
|
|
||||||
print(f"New user registered: {username}")
|
|
||||||
else:
|
|
||||||
client.send(json.dumps({'type': 'register_fail', 'reason': 'Username already taken'}).encode('utf-8'))
|
|
||||||
|
|
||||||
elif data['type'] == 'message':
|
|
||||||
sender = data['sender']
|
|
||||||
receiver = data['receiver']
|
|
||||||
msg = data['message']
|
|
||||||
is_group = data['is_group']
|
|
||||||
self.save_message(sender, receiver, msg, is_group)
|
|
||||||
if is_group:
|
|
||||||
for member in self.groups.get(receiver, set()):
|
|
||||||
if member in self.clients and member != sender:
|
|
||||||
try:
|
|
||||||
self.clients[member].send(json.dumps({
|
|
||||||
'type': 'message',
|
|
||||||
'sender': sender,
|
|
||||||
'receiver': receiver,
|
|
||||||
'message': msg,
|
|
||||||
'is_group': True
|
|
||||||
}).encode('utf-8'))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if receiver in self.clients:
|
|
||||||
try:
|
|
||||||
self.clients[receiver].send(json.dumps({
|
|
||||||
'type': 'message',
|
|
||||||
'sender': sender,
|
|
||||||
'receiver': receiver,
|
|
||||||
'message': msg,
|
|
||||||
'is_group': False
|
|
||||||
}).encode('utf-8'))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif data['type'] == 'create_group':
|
|
||||||
group_name = data['group_name']
|
|
||||||
if group_name not in self.groups:
|
|
||||||
self.groups[group_name] = set()
|
|
||||||
for user in data['members']:
|
|
||||||
if user in self.clients or self.user_exists(user):
|
|
||||||
self.groups[group_name].add(user)
|
|
||||||
self.send_group_list_to_all()
|
|
||||||
client.send(json.dumps({'type': 'group_created', 'group': group_name}).encode('utf-8'))
|
|
||||||
else:
|
|
||||||
client.send(json.dumps({'type': 'group_exists', 'group': group_name}).encode('utf-8'))
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Invalid JSON received from {address}")
|
|
||||||
break
|
|
||||||
except ConnectionResetError:
|
except ConnectionResetError:
|
||||||
print(f"Connection reset by {address}")
|
print(f"[连接] 客户端 {address} 强制断开")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error handling client {address}: {str(e)}")
|
print(f"[错误] 处理客户端 {address}: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Client handler error for {address}: {str(e)}")
|
print(f"[错误] 客户端处理异常: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
if username and username in self.clients:
|
if username:
|
||||||
print(f"User {username} disconnected")
|
self.remove_client(username)
|
||||||
del self.clients[username]
|
client.close()
|
||||||
for group in self.groups.values():
|
|
||||||
if username in group:
|
|
||||||
group.remove(username)
|
|
||||||
self.send_user_list()
|
|
||||||
try:
|
|
||||||
client.close()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
print(f"Server started on {self.host}:{self.port}")
|
print(f"[启动] 服务器运行在 {self.host}:{self.port}")
|
||||||
print("Database initialized at chat_server.db")
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
client, address = self.server.accept()
|
client, address = self.server.accept()
|
||||||
thread = threading.Thread(target=self.handle_client, args=(client, address))
|
thread = threading.Thread(
|
||||||
thread.daemon = True
|
target=self.handle_client,
|
||||||
|
args=(client, address),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nServer is shutting down...")
|
print("\n[关闭] 服务器正在停止...")
|
||||||
finally:
|
finally:
|
||||||
for client in self.clients.values():
|
for username in list(self.clients.keys()):
|
||||||
try:
|
self.remove_client(username)
|
||||||
client.close()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
self.server.close()
|
self.server.close()
|
||||||
self.db.close()
|
self.db.close()
|
||||||
print("Server stopped")
|
print("[关闭] 服务器已停止")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
server = ChatServer()
|
server = ChatServer()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user