473 lines
17 KiB
Python
473 lines
17 KiB
Python
import sqlite3
|
||
from datetime import datetime
|
||
import hashlib
|
||
import uuid
|
||
import json
|
||
import socket
|
||
import threading
|
||
from flask import Flask, request, jsonify
|
||
from werkzeug.serving import make_server
|
||
import time
|
||
print("man !!! what can I say ! Manba out!!!!")
|
||
class IntegratedChatServer:
|
||
def __init__(self, http_host='0.0.0.0', http_port=5000, socket_host='0.0.0.0', socket_port=12345):
|
||
self.http_host = http_host
|
||
self.http_port = http_port
|
||
self.socket_host = socket_host
|
||
self.socket_port = socket_port
|
||
|
||
self.init_db()
|
||
|
||
|
||
self.app = Flask(__name__)
|
||
self.setup_http_routes()
|
||
self.http_server = make_server(self.http_host, self.http_port, self.app)
|
||
|
||
|
||
self.socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
self.socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||
self.clients = {}
|
||
self.lock = threading.Lock()
|
||
|
||
|
||
self.http_thread = None
|
||
self.socket_thread = None
|
||
|
||
def init_db(self):
|
||
"""初始化数据库"""
|
||
conn = sqlite3.connect('chat_server.db')
|
||
cursor = conn.cursor()
|
||
|
||
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS users (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
username TEXT UNIQUE NOT NULL,
|
||
password TEXT NOT NULL,
|
||
created_at TEXT NOT NULL,
|
||
last_login TEXT,
|
||
is_online INTEGER DEFAULT 0
|
||
)
|
||
''')
|
||
|
||
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS messages (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
sender_id INTEGER NOT NULL,
|
||
receiver_id INTEGER NOT NULL,
|
||
content TEXT NOT NULL,
|
||
timestamp TEXT NOT NULL,
|
||
is_recalled INTEGER DEFAULT 0,
|
||
FOREIGN KEY (sender_id) REFERENCES users(id),
|
||
FOREIGN KEY (receiver_id) REFERENCES users(id)
|
||
)
|
||
''')
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def get_db_connection(self):
|
||
"""获取数据库连接"""
|
||
conn = sqlite3.connect('chat_server.db')
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
|
||
def setup_http_routes(self):
|
||
"""设置HTTP路由"""
|
||
|
||
@self.app.route('/api/register', methods=['POST'])
|
||
def register():
|
||
global conn
|
||
data = request.get_json()
|
||
username = data.get('username')
|
||
password = data.get('password')
|
||
|
||
if not username or not password:
|
||
return jsonify({'success': False, 'message': 'Username and password are required'}), 400
|
||
|
||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||
created_at = datetime.now().isoformat()
|
||
|
||
try:
|
||
conn = self.get_db_connection()
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
INSERT INTO users (username, password, created_at)
|
||
VALUES (?, ?, ?)
|
||
''', (username, hashed_password, created_at))
|
||
conn.commit()
|
||
return jsonify({'success': True, 'message': 'User registered successfully'}), 201
|
||
except sqlite3.IntegrityError:
|
||
return jsonify({'success': False, 'message': 'Username already exists'}), 400
|
||
finally:
|
||
conn.close()
|
||
|
||
@self.app.route('/api/login', methods=['POST'])
|
||
def login():
|
||
data = request.get_json()
|
||
username = data.get('username')
|
||
password = data.get('password')
|
||
|
||
if not username or not password:
|
||
return jsonify({'success': False, 'message': 'Username and password are required'}), 400
|
||
|
||
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
||
|
||
conn = None
|
||
try:
|
||
conn = self.get_db_connection()
|
||
cursor = conn.cursor()
|
||
cursor.execute('SELECT * FROM users WHERE username = ? AND password = ?', (username, hashed_password))
|
||
user = cursor.fetchone()
|
||
|
||
if user:
|
||
last_login = datetime.now().isoformat()
|
||
cursor.execute('UPDATE users SET last_login = ?, is_online = 1 WHERE id = ?', (last_login, user['id']))
|
||
conn.commit()
|
||
|
||
token = str(uuid.uuid4())
|
||
return jsonify({
|
||
'success': True,
|
||
'message': 'Login successful',
|
||
'token': token,
|
||
'user_id': user['id'],
|
||
'username': user['username']
|
||
}), 200
|
||
else:
|
||
return jsonify({'success': False, 'message': 'Invalid username or password'}), 401
|
||
except Exception as e:
|
||
return jsonify({'success': False, 'message': str(e)}), 500
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
|
||
@self.app.route('/api/users', methods=['GET'])
|
||
def get_users():
|
||
conn = self.get_db_connection()
|
||
cursor = conn.cursor()
|
||
cursor.execute('SELECT id, username, is_online, last_login FROM users')
|
||
users = cursor.fetchall()
|
||
conn.close()
|
||
|
||
users_list = [dict(user) for user in users]
|
||
return jsonify({'success': True, 'users': users_list}), 200
|
||
|
||
@self.app.route('/api/messages', methods=['GET'])
|
||
def get_messages():
|
||
user_id = request.args.get('user_id')
|
||
other_id = request.args.get('other_id')
|
||
limit = request.args.get('limit', 100)
|
||
|
||
if not user_id or not other_id:
|
||
return jsonify({'success': False, 'message': 'user_id and other_id are required'}), 400
|
||
|
||
conn = self.get_db_connection()
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
SELECT m.id, m.sender_id, m.receiver_id, m.content, m.timestamp, m.is_recalled,
|
||
u1.username as sender_name, u2.username as receiver_name
|
||
FROM messages m
|
||
JOIN users u1 ON m.sender_id = u1.id
|
||
JOIN users u2 ON m.receiver_id = u2.id
|
||
WHERE (m.sender_id = ? AND m.receiver_id = ?) OR (m.sender_id = ? AND m.receiver_id = ?)
|
||
ORDER BY m.timestamp DESC
|
||
LIMIT ?
|
||
''', (user_id, other_id, other_id, user_id, limit))
|
||
|
||
messages = cursor.fetchall()
|
||
conn.close()
|
||
|
||
messages_list = []
|
||
for msg in messages:
|
||
msg_dict = dict(msg)
|
||
if msg_dict['is_recalled']:
|
||
msg_dict['content'] = '[消息已撤回]'
|
||
messages_list.append(msg_dict)
|
||
|
||
return jsonify({'success': True, 'messages': messages_list}), 200
|
||
|
||
@self.app.route('/api/recall_message', methods=['POST'])
|
||
def recall_message():
|
||
data = request.get_json()
|
||
message_id = data.get('message_id')
|
||
user_id = data.get('user_id')
|
||
|
||
if not message_id or not user_id:
|
||
return jsonify({'success': False, 'message': 'message_id and user_id are required'}), 400
|
||
|
||
conn = self.get_db_connection()
|
||
cursor = conn.cursor()
|
||
|
||
#紧 急 撤 离 !
|
||
cursor.execute('SELECT sender_id, receiver_id FROM messages WHERE id = ?', (message_id,))
|
||
message = cursor.fetchone()
|
||
|
||
if not message:
|
||
return jsonify({'success': False, 'message': 'Message not found'}), 404
|
||
|
||
if message['sender_id'] != int(user_id):
|
||
return jsonify({'success': False, 'message': 'You can only recall your own messages'}), 403
|
||
|
||
|
||
cursor.execute('UPDATE messages SET is_recalled = 1 WHERE id = ?', (message_id,))
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
recall_notification = {
|
||
'type': 'message_recalled',
|
||
'message_id': message_id,
|
||
'sender_id': user_id,
|
||
'receiver_id': message['receiver_id'],
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
self.notify_clients(recall_notification, message['receiver_id'], user_id)
|
||
|
||
return jsonify({'success': True, 'message': 'Message recalled successfully'}), 200
|
||
|
||
#socket来喽
|
||
def broadcast(self, message, exclude_user_id=None):
|
||
"""广播消息给所有客户端"""
|
||
with self.lock:
|
||
for user_id, (client_socket, _) in self.clients.items():
|
||
if user_id != exclude_user_id:
|
||
try:
|
||
client_socket.send((json.dumps(message) + '\n').encode('utf-8'))
|
||
except:
|
||
self.remove_client(user_id)
|
||
|
||
def remove_client(self, user_id):
|
||
"""移除客户端并更新在线状态"""
|
||
with self.lock:
|
||
if user_id in self.clients:
|
||
try:
|
||
self.clients[user_id][0].close()
|
||
except:
|
||
pass
|
||
del self.clients[user_id]
|
||
|
||
|
||
conn = self.get_db_connection()
|
||
cursor = conn.cursor()
|
||
cursor.execute('UPDATE users SET is_online = 0 WHERE id = ?', (user_id,))
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
|
||
self.broadcast({
|
||
'type': 'user_offline',
|
||
'user_id': user_id,
|
||
'timestamp': datetime.now().isoformat()
|
||
})
|
||
|
||
def notify_clients(self, message, *user_ids):
|
||
"""通知特定客户端"""
|
||
with self.lock:
|
||
for user_id in user_ids:
|
||
if str(user_id) in self.clients:
|
||
try:
|
||
self.clients[str(user_id)][0].send((json.dumps(message) + '\n').encode('utf-8'))
|
||
except:
|
||
self.remove_client(str(user_id))
|
||
|
||
def handle_client(self, client_socket, address):
|
||
"""处理客户端连接"""
|
||
print(f"New connection from {address}")
|
||
|
||
try:
|
||
while True:
|
||
data = client_socket.recv(4096)
|
||
if not data:
|
||
break
|
||
|
||
|
||
messages = data.decode('utf-8').split('\n')
|
||
for msg in messages:
|
||
if not msg.strip():
|
||
continue
|
||
|
||
try:
|
||
message = json.loads(msg)
|
||
self.process_message(message, client_socket)
|
||
except json.JSONDecodeError:
|
||
print(f"Invalid JSON from {address}: {msg}")
|
||
except ConnectionResetError:
|
||
print(f"Client {address} disconnected abruptly")
|
||
finally:
|
||
|
||
user_id_to_remove = None
|
||
with self.lock:
|
||
for user_id, (sock, _) in self.clients.items():
|
||
if sock == client_socket:
|
||
user_id_to_remove = user_id
|
||
break
|
||
|
||
if user_id_to_remove:
|
||
self.remove_client(user_id_to_remove)
|
||
|
||
client_socket.close()
|
||
print(f"Connection from {address} closed")
|
||
|
||
def process_message(self, message, client_socket):
|
||
"""处理不同类型的消息"""
|
||
msg_type = message.get('type')
|
||
|
||
if msg_type == 'login':
|
||
self.handle_login(message, client_socket)
|
||
elif msg_type == 'message':
|
||
self.handle_chat_message(message)
|
||
elif msg_type == 'recall':
|
||
self.handle_recall_message(message)
|
||
|
||
def handle_login(self, message, client_socket):
|
||
"""处理用户登录到socket服务器"""
|
||
user_id = message.get('user_id')
|
||
username = message.get('username')
|
||
token = message.get('token')
|
||
|
||
if not user_id or not username:
|
||
return
|
||
|
||
|
||
with self.lock:
|
||
self.clients[user_id] = (client_socket, username)
|
||
|
||
conn = self.get_db_connection()
|
||
cursor = conn.cursor()
|
||
cursor.execute('UPDATE users SET is_online = 1 WHERE id = ?', (user_id,))
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
self.broadcast({
|
||
'type': 'user_online',
|
||
'user_id': user_id,
|
||
'username': username,
|
||
'timestamp': datetime.now().isoformat()
|
||
})
|
||
|
||
online_users = []
|
||
with self.lock:
|
||
for uid, (_, uname) in self.clients.items():
|
||
online_users.append({'user_id': uid, 'username': uname})
|
||
|
||
client_socket.send(json.dumps({
|
||
'type': 'online_users',
|
||
'users': online_users,
|
||
'timestamp': datetime.now().isoformat()
|
||
}).encode('utf-8'))
|
||
|
||
def handle_chat_message(self, message):
|
||
"""处理聊天消息"""
|
||
sender_id = message.get('sender_id')
|
||
receiver_id = message.get('receiver_id')
|
||
content = message.get('content')
|
||
|
||
if not sender_id or not receiver_id or not content:
|
||
return
|
||
|
||
|
||
conn = self.get_db_connection()
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
INSERT INTO messages (sender_id, receiver_id, content, timestamp)
|
||
VALUES (?, ?, ?, ?)
|
||
''', (sender_id, receiver_id, content, datetime.now().isoformat()))
|
||
conn.commit()
|
||
message_id = cursor.lastrowid
|
||
conn.close()
|
||
|
||
|
||
message_to_send = {
|
||
'type': 'message',
|
||
'message_id': message_id,
|
||
'sender_id': sender_id,
|
||
'receiver_id': receiver_id,
|
||
'content': content,
|
||
'timestamp': datetime.now().isoformat(),
|
||
'is_recalled': False
|
||
}
|
||
|
||
|
||
self.notify_clients(message_to_send, receiver_id, sender_id)
|
||
|
||
def handle_recall_message(self, message):
|
||
"""处理撤回消息请求"""
|
||
message_id = message.get('message_id')
|
||
user_id = message.get('user_id')
|
||
|
||
if not message_id or not user_id:
|
||
return
|
||
|
||
|
||
conn = self.get_db_connection()
|
||
cursor = conn.cursor()
|
||
|
||
|
||
cursor.execute('SELECT sender_id, receiver_id FROM messages WHERE id = ?', (message_id,))
|
||
msg = cursor.fetchone()
|
||
|
||
if not msg or msg['sender_id'] != int(user_id):
|
||
conn.close()
|
||
return
|
||
|
||
|
||
cursor.execute('UPDATE messages SET is_recalled = 1 WHERE id = ?', (message_id,))
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
|
||
recall_notification = {
|
||
'type': 'message_recalled',
|
||
'message_id': message_id,
|
||
'sender_id': user_id,
|
||
'receiver_id': msg['receiver_id'],
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
self.notify_clients(recall_notification, msg['receiver_id'], user_id)
|
||
|
||
def start_http_server(self):
|
||
"""启动HTTP服务器"""
|
||
print(f"HTTP server running on http://{self.http_host}:{self.http_port}")
|
||
self.http_server.serve_forever()
|
||
|
||
def start_socket_server(self):
|
||
"""启动Socket服务器"""
|
||
self.socket_server.bind((self.socket_host, self.socket_port))
|
||
self.socket_server.listen(5)
|
||
print(f"Socket server listening on {self.socket_host}:{self.socket_port}")
|
||
|
||
try:
|
||
while True:
|
||
client_socket, address = self.socket_server.accept()
|
||
client_thread = threading.Thread(
|
||
target=self.handle_client,
|
||
args=(client_socket, address),
|
||
daemon=True
|
||
)
|
||
client_thread.start()
|
||
except KeyboardInterrupt:
|
||
print("Shutting down socket server...")
|
||
finally:
|
||
self.socket_server.close()
|
||
|
||
def run(self):
|
||
"""启动服务器"""
|
||
self.http_thread = threading.Thread(target=self.start_http_server, daemon=True)
|
||
self.http_thread.start()
|
||
self.socket_thread = threading.Thread(target=self.start_socket_server, daemon=True)
|
||
self.socket_thread.start()
|
||
|
||
try:
|
||
while True:
|
||
time.sleep(1)
|
||
except KeyboardInterrupt:
|
||
print("Shutting down servers...")
|
||
self.http_server.shutdown()
|
||
self.socket_server.close()
|
||
|
||
if __name__ == '__main__':
|
||
server = IntegratedChatServer()
|
||
server.run() |