.BANCHO. .FIX. Add user stats cache and user stats request packet

This commit is contained in:
Nyo 2016-06-16 13:38:17 +02:00
parent b5806bdfbf
commit 7035743362
15 changed files with 203 additions and 63 deletions

View File

@ -4,7 +4,7 @@ from helpers import packetHelper
from constants import slotStatuses from constants import slotStatuses
""" General packets """ """ Users listing packets """
def userActionChange(stream): def userActionChange(stream):
return packetHelper.readPacketData(stream, return packetHelper.readPacketData(stream,
[ [
@ -15,6 +15,9 @@ def userActionChange(stream):
["gameMode", dataTypes.byte] ["gameMode", dataTypes.byte]
]) ])
def userStatsRequest(stream):
return packetHelper.readPacketData(stream, [["users", dataTypes.intList]])
""" Client chat packets """ """ Client chat packets """

View File

@ -10,3 +10,4 @@ sInt64 = 6
string = 7 string = 7
ffloat = 8 # because float is a keyword ffloat = 8 # because float is a keyword
bbytes = 9 bbytes = 9
intList = 10 # TODO: Maybe there are some packets that still use uInt16 + uInt32 thing somewhere.

View File

@ -90,16 +90,17 @@ def userLogout(userID):
def userPanel(userID): def userPanel(userID):
# Get user data # Get user data
userToken = glob.tokens.getTokenFromUserID(userID) userToken = glob.tokens.getTokenFromUserID(userID)
username = userHelper.getUsername(userID) username = userToken.username
timezone = 24 # TODO: Timezone timezone = 24 # TODO: Timezone
country = userToken.getCountry() country = userToken.country
gameRank = userHelper.getGameRank(userID, userToken.gameMode) gameRank = userToken.gameRank
latitude = userToken.getLatitude() latitude = userToken.getLatitude()
longitude = userToken.getLongitude() longitude = userToken.getLongitude()
# Get username color according to rank # Get username color according to rank
# Only admins and normal users are currently supported # Only admins and normal users are currently supported
rank = userHelper.getRankPrivileges(userID) #rank = userHelper.getRankPrivileges(userID)
rank = userToken.rank
if username == "FokaBot": if username == "FokaBot":
userRank = userRanks.MOD userRank = userRanks.MOD
elif rank == 4: elif rank == 4:
@ -111,7 +112,6 @@ def userPanel(userID):
else: else:
userRank = userRanks.NORMAL userRank = userRanks.NORMAL
return packetHelper.buildPacket(packetIDs.server_userPanel, return packetHelper.buildPacket(packetIDs.server_userPanel,
[ [
[userID, dataTypes.sInt32], [userID, dataTypes.sInt32],
@ -128,16 +128,15 @@ def userPanel(userID):
def userStats(userID): def userStats(userID):
# Get userID's token from tokens list # Get userID's token from tokens list
userToken = glob.tokens.getTokenFromUserID(userID) userToken = glob.tokens.getTokenFromUserID(userID)
if userToken == None:
# Get stats from DB return bytes() # NOTE: ???
# TODO: Caching system # Stats are cached in token object
rankedScore = userHelper.getRankedScore(userID, userToken.gameMode) #rankedScore = userHelper.getRankedScore(userID, userToken.gameMode)
accuracy = userHelper.getAccuracy(userID, userToken.gameMode)/100 #accuracy = userHelper.getAccuracy(userID, userToken.gameMode)/100
playcount = userHelper.getPlaycount(userID, userToken.gameMode) #playcount = userHelper.getPlaycount(userID, userToken.gameMode)
totalScore = userHelper.getTotalScore(userID, userToken.gameMode) #totalScore = userHelper.getTotalScore(userID, userToken.gameMode)
gameRank = userHelper.getGameRank(userID, userToken.gameMode) #gameRank = userHelper.getGameRank(userID, userToken.gameMode)
pp = int(userHelper.getPP(userID, userToken.gameMode)) #pp = int(userHelper.getPP(userID, userToken.gameMode))
return packetHelper.buildPacket(packetIDs.server_userStats, return packetHelper.buildPacket(packetIDs.server_userStats,
[ [
[userID, dataTypes.uInt32], [userID, dataTypes.uInt32],
@ -147,12 +146,12 @@ def userStats(userID):
[userToken.actionMods, dataTypes.sInt32], [userToken.actionMods, dataTypes.sInt32],
[userToken.gameMode, dataTypes.byte], [userToken.gameMode, dataTypes.byte],
[0, dataTypes.sInt32], [0, dataTypes.sInt32],
[rankedScore, dataTypes.uInt64], [userToken.rankedScore, dataTypes.uInt64],
[accuracy, dataTypes.ffloat], [userToken.accuracy, dataTypes.ffloat],
[playcount, dataTypes.uInt32], [userToken.playcount, dataTypes.uInt32],
[totalScore, dataTypes.uInt64], [userToken.totalScore, dataTypes.uInt64],
[gameRank, dataTypes.uInt32], [userToken.gameRank, dataTypes.uInt32],
[pp, dataTypes.uInt16] [userToken.pp, dataTypes.uInt16]
]) ])

View File

@ -18,12 +18,38 @@ def handle(userToken, packetData):
# Change action packet # Change action packet
packetData = clientPackets.userActionChange(packetData) packetData = clientPackets.userActionChange(packetData)
# Update our action id, text and md5 # Update cached stats if our pp changedm if we've just submitted a score or we've changed gameMode
if (userToken.actionID == actions.playing or userToken.actionID == actions.multiplaying) or (userToken.pp != userHelper.getPP(userID, userToken.gameMode)) or (userToken.gameMode != packetData["gameMode"]):
log.debug("!!!! UPDATING CACHED STATS !!!!")
# Always update game mode, or we'll cache stats from the wrong game mode if we've changed it
userToken.gameMode = packetData["gameMode"]
userToken.updateCachedStats()
# Always update action id, text and md5
userToken.actionID = packetData["actionID"] userToken.actionID = packetData["actionID"]
userToken.actionText = packetData["actionText"] userToken.actionText = packetData["actionText"]
userToken.actionMd5 = packetData["actionMd5"] userToken.actionMd5 = packetData["actionMd5"]
userToken.actionMods = packetData["actionMods"] userToken.actionMods = packetData["actionMods"]
userToken.gameMode = packetData["gameMode"]
# Enqueue our new user panel and stats to us and our spectators
recipients = [userID]
if len(userToken.spectators) > 0:
recipients += userToken.spectators
for i in recipients:
if i == userID:
# Save some loops
token = userToken
else:
token = glob.tokens.getTokenFromUserID(i)
if token != None:
token.enqueue(serverPackets.userPanel(userID))
token.enqueue(serverPackets.userStats(userID))
# TODO: Enqueue all if we've changed game mode, (maybe not needed because it's cached)
#glob.tokens.enqueueAll(serverPackets.userPanel(userID))
#glob.tokens.enqueueAll(serverPackets.userStats(userID))
# Send osu!direct alert if needed # Send osu!direct alert if needed
# NOTE: Remove this when osu!direct will be fixed # NOTE: Remove this when osu!direct will be fixed
@ -31,9 +57,6 @@ def handle(userToken, packetData):
userToken.osuDirectAlert = True userToken.osuDirectAlert = True
userToken.enqueue(serverPackets.sendMessage("FokaBot", userToken.username, "Sup! osu!direct works, kinda. To download a beatmap, you have to click the \"View listing\" button (the last one) instead of \"Download\". However, if you are on the stable (fallback) branch, it should work also with the \"Download\" button. We'll fix that bug as soon as possibleTM.")) userToken.enqueue(serverPackets.sendMessage("FokaBot", userToken.username, "Sup! osu!direct works, kinda. To download a beatmap, you have to click the \"View listing\" button (the last one) instead of \"Download\". However, if you are on the stable (fallback) branch, it should work also with the \"Download\" button. We'll fix that bug as soon as possibleTM."))
# Enqueue our new user panel and stats to everyone
glob.tokens.enqueueAll(serverPackets.userPanel(userID))
glob.tokens.enqueueAll(serverPackets.userStats(userID))
# Console output # Console output
log.info("{} changed action: {} [{}][{}]".format(username, str(userToken.actionID), userToken.actionText, userToken.actionMd5)) log.info("{} changed action: {} [{}][{}]".format(username, str(userToken.actionID), userToken.actionText, userToken.actionMd5))

View File

@ -101,7 +101,6 @@ def handle(tornadoRequest):
# Channel info end (before starting!?! wtf bancho?) # Channel info end (before starting!?! wtf bancho?)
responseToken.enqueue(serverPackets.channelInfoEnd()) responseToken.enqueue(serverPackets.channelInfoEnd())
# Default opened channels # Default opened channels
# TODO: Configurable default channels # TODO: Configurable default channels
channelJoinEvent.joinChannel(responseToken, "#osu") channelJoinEvent.joinChannel(responseToken, "#osu")
@ -125,9 +124,9 @@ def handle(tornadoRequest):
# Get everyone else userpanel # Get everyone else userpanel
# TODO: Better online users handling # TODO: Better online users handling
for key, value in glob.tokens.tokens.items(): #for key, value in glob.tokens.tokens.items():
responseToken.enqueue(serverPackets.userPanel(value.userID)) # responseToken.enqueue(serverPackets.userPanel(value.userID))
responseToken.enqueue(serverPackets.userStats(value.userID)) # responseToken.enqueue(serverPackets.userStats(value.userID))
# Send online users IDs array # Send online users IDs array
responseToken.enqueue(serverPackets.onlineUsers()) responseToken.enqueue(serverPackets.onlineUsers())

View File

@ -0,0 +1,11 @@
from constants import clientPackets
from constants import serverPackets
from helpers import userHelper
from helpers import logHelper as log
def handle(userToken, packetData):
log.debug("Requested status update")
# Update cache and send new stats
userToken.updateCachedStats()
userToken.enqueue(serverPackets.userStats(userToken.userID))

View File

@ -0,0 +1,22 @@
from constants import clientPackets
from constants import serverPackets
from helpers import logHelper as log
def handle(userToken, packetData):
# Read userIDs list
packetData = clientPackets.userStatsRequest(packetData)
# Process lists with length <= 32
if len(packetData) > 32:
log.warning("Received userStatsRequest with length > 32")
return
for i in packetData["users"]:
log.debug("Sending stats for user {}".format(i))
# Skip our stats
if i == userToken.userID:
continue
# Enqueue stats packets relative to this user
userToken.enqueue(serverPackets.userStats(i))

View File

@ -44,6 +44,8 @@ from events import matchTransferHostEvent
from events import matchFailedEvent from events import matchFailedEvent
from events import matchInviteEvent from events import matchInviteEvent
from events import matchChangeTeamEvent from events import matchChangeTeamEvent
from events import userStatsRequestEvent
from events import requestStatusUpdateEvent
# Exception tracking # Exception tracking
import tornado.web import tornado.web
@ -147,7 +149,9 @@ class handler(SentryMixin, requestHelper.asyncRequestHandler):
packetIDs.client_matchTransferHost: handleEvent(matchTransferHostEvent), packetIDs.client_matchTransferHost: handleEvent(matchTransferHostEvent),
packetIDs.client_matchFailed: handleEvent(matchFailedEvent), packetIDs.client_matchFailed: handleEvent(matchFailedEvent),
packetIDs.client_invite: handleEvent(matchInviteEvent), packetIDs.client_invite: handleEvent(matchInviteEvent),
packetIDs.client_matchChangeTeam: handleEvent(matchChangeTeamEvent) packetIDs.client_matchChangeTeam: handleEvent(matchChangeTeamEvent),
packetIDs.client_userStatsRequest: handleEvent(userStatsRequestEvent),
packetIDs.client_requestStatusUpdate: handleEvent(requestStatusUpdateEvent),
} }
if packetID != 4: if packetID != 4:
@ -205,8 +209,9 @@ class handler(SentryMixin, requestHelper.asyncRequestHandler):
log.error("Unknown error!\n```\n{}\n{}```".format(sys.exc_info(), traceback.format_exc())) log.error("Unknown error!\n```\n{}\n{}```".format(sys.exc_info(), traceback.format_exc()))
if glob.sentry: if glob.sentry:
yield tornado.gen.Task(self.captureException, exc_info=True) yield tornado.gen.Task(self.captureException, exc_info=True)
finally: #finally:
self.finish() # self.finish()
@tornado.web.asynchronous @tornado.web.asynchronous
@tornado.gen.engine @tornado.gen.engine

View File

@ -52,7 +52,7 @@ class db:
__params -- array with params. Optional __params -- array with params. Optional
""" """
log.debug(query)
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
try: try:
# Bind params if needed # Bind params if needed
@ -77,7 +77,7 @@ class db:
return -- dictionary with result data or False if failed return -- dictionary with result data or False if failed
""" """
log.debug(query)
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
try: try:
# Bind params if needed # Bind params if needed

View File

@ -1,5 +1,6 @@
import MySQLdb import MySQLdb
import threading import threading
from helpers import logHelper as log
class mysqlWorker: class mysqlWorker:
""" """
@ -66,6 +67,7 @@ class db:
query -- Query to execute. You can bind parameters with %s query -- Query to execute. You can bind parameters with %s
params -- Parameters list. First element replaces first %s and so on. Optional. params -- Parameters list. First element replaces first %s and so on. Optional.
""" """
log.debug(query)
# Get a worker and acquire its lock # Get a worker and acquire its lock
worker = self.getWorker() worker = self.getWorker()
worker.lock.acquire() worker.lock.acquire()
@ -89,6 +91,7 @@ class db:
params -- Parameters list. First element replaces first %s and so on. Optional. params -- Parameters list. First element replaces first %s and so on. Optional.
all -- Fetch one or all values. Used internally. Use fetchAll if you want to fetch all values. all -- Fetch one or all values. Used internally. Use fetchAll if you want to fetch all values.
""" """
log.debug(query)
# Get a worker and acquire its lock # Get a worker and acquire its lock
worker = self.getWorker() worker = self.getWorker()
worker.lock.acquire() worker.lock.acquire()

View File

@ -167,36 +167,36 @@ def buildPacket(__packet, __packetData = []):
return packetBytes return packetBytes
def readPacketID(__stream): def readPacketID(stream):
""" """
Read packetID from __stream (0-1 bytes) Read packetID from stream (0-1 bytes)
__stream -- data stream stream -- data stream
return -- packet ID (int) return -- packet ID (int)
""" """
return unpackData(__stream[0:2], dataTypes.uInt16) return unpackData(stream[0:2], dataTypes.uInt16)
def readPacketLength(__stream): def readPacketLength(stream):
""" """
Read packet length from __stream (3-4-5-6 bytes) Read packet length from stream (3-4-5-6 bytes)
__stream -- data stream stream -- data stream
return -- packet length (int) return -- packet length (int)
""" """
return unpackData(__stream[3:7], dataTypes.uInt32) return unpackData(stream[3:7], dataTypes.uInt32)
def readPacketData(__stream, __structure = [], __hasFirstBytes = True): def readPacketData(stream, structure = [], hasFirstBytes = True):
""" """
Read packet data from __stream according to __structure Read packet data from stream according to structure
__stream -- data stream stream -- data stream
__structure -- [[name, dataType], [name, dataType], ...] structure -- [[name, dataType], [name, dataType], ...]
__hasFirstBytes -- if True, __stream has packetID and length bytes. hasFirstBytes -- if True, stream has packetID and length bytes.
if False, __stream has only packetData. if False, stream has only packetData.
Optional. Default: True Optional. Default: True
return -- dictionary. key: name, value: read data return -- dictionary. key: name, value: read data
""" """
@ -205,7 +205,7 @@ def readPacketData(__stream, __structure = [], __hasFirstBytes = True):
data = {} data = {}
# Skip packet ID and packet length if needed # Skip packet ID and packet length if needed
if __hasFirstBytes == True: if hasFirstBytes == True:
end = 7 end = 7
start = 7 start = 7
else: else:
@ -213,26 +213,41 @@ def readPacketData(__stream, __structure = [], __hasFirstBytes = True):
start = 0 start = 0
# Read packet # Read packet
for i in __structure: for i in structure:
start = end start = end
unpack = True unpack = True
if i[1] == dataTypes.string: if i[1] == dataTypes.intList:
# sInt32 list.
# Unpack manually with for loop
unpack = False
# Read length (uInt16)
length = unpackData(stream[start:start+2], dataTypes.uInt16)
# Read all int inside list
data[i[0]] = []
for j in range(0,length):
data[i[0]].append(unpackData(stream[start+2+(4*j):start+2+(4*(j+1))], dataTypes.sInt32))
# Update end
end = start+2+(4*length)
elif i[1] == dataTypes.string:
# String, don't unpack # String, don't unpack
unpack = False unpack = False
# Check empty string # Check empty string
if __stream[start] == 0: if stream[start] == 0:
# Empty string # Empty string
data[i[0]] = "" data[i[0]] = ""
end = start+1 end = start+1
else: else:
# Non empty string # Non empty string
# Read length and calculate end # Read length and calculate end
length = uleb128Decode(__stream[start+1:]) length = uleb128Decode(stream[start+1:])
end = start+length[0]+length[1]+1 end = start+length[0]+length[1]+1
# Read bytes # Read bytes
data[i[0]] = ''.join(chr(j) for j in __stream[start+1+length[1]:end]) data[i[0]] = ''.join(chr(j) for j in stream[start+1+length[1]:end])
elif i[1] == dataTypes.byte: elif i[1] == dataTypes.byte:
end = start+1 end = start+1
elif i[1] == dataTypes.uInt16 or i[1] == dataTypes.sInt16: elif i[1] == dataTypes.uInt16 or i[1] == dataTypes.sInt16:
@ -244,6 +259,6 @@ def readPacketData(__stream, __structure = [], __hasFirstBytes = True):
# Unpack if needed # Unpack if needed
if unpack == True: if unpack == True:
data[i[0]] = unpackData(__stream[start:end], i[1]) data[i[0]] = unpackData(stream[start:end], i[1])
return data return data

View File

@ -20,6 +20,8 @@ class asyncRequestHandler(tornado.web.RequestHandler):
yield tornado.gen.Task(runBackground, (self.asyncGet, tuple(args), dict(kwargs))) yield tornado.gen.Task(runBackground, (self.asyncGet, tuple(args), dict(kwargs)))
except Exception as e: except Exception as e:
yield tornado.gen.Task(self.captureException, exc_info=True) yield tornado.gen.Task(self.captureException, exc_info=True)
finally:
self.finish()
@tornado.web.asynchronous @tornado.web.asynchronous
@tornado.gen.engine @tornado.gen.engine
@ -28,6 +30,8 @@ class asyncRequestHandler(tornado.web.RequestHandler):
yield tornado.gen.Task(runBackground, (self.asyncPost, tuple(args), dict(kwargs))) yield tornado.gen.Task(runBackground, (self.asyncPost, tuple(args), dict(kwargs)))
except Exception as e: except Exception as e:
yield tornado.gen.Task(self.captureException, exc_info=True) yield tornado.gen.Task(self.captureException, exc_info=True)
finally:
self.finish()
def asyncGet(self, *args, **kwargs): def asyncGet(self, *args, **kwargs):
self.send_error(405) self.send_error(405)

View File

@ -346,3 +346,32 @@ def check2FA(userID, ip):
result = glob.db.fetch("SELECT id FROM ip_user WHERE userid = %s AND ip = %s", [userID, ip]) result = glob.db.fetch("SELECT id FROM ip_user WHERE userid = %s AND ip = %s", [userID, ip])
return True if result is None else False return True if result is None else False
def getUserStats(userID, gameMode):
"""
Get all user stats relative to gameMode with only two queries
userID --
gameMode -- gameMode number
return -- dictionary with results
"""
modeForDB = gameModes.getGameModeForDB(gameMode)
# Get stats
stats = glob.db.fetch("""SELECT
ranked_score_{gm} AS rankedScore,
avg_accuracy_{gm} AS accuracy,
playcount_{gm} AS playcount,
total_score_{gm} AS totalScore,
pp_{gm} AS pp
FROM users_stats WHERE id = %s LIMIT 1""".format(gm=modeForDB), [userID])
# Get game rank
result = glob.db.fetch("SELECT position FROM leaderboard_{} WHERE user = %s LIMIT 1".format(modeForDB), [userID])
if result == None:
stats["gameRank"] = 0
else:
stats["gameRank"] = result["position"]
# Return stats + game rank
return stats

View File

@ -57,11 +57,6 @@ class token:
self.spectating = 0 self.spectating = 0
self.location = [0,0] self.location = [0,0]
self.joinedChannels = [] self.joinedChannels = []
self.actionID = actions.idle
self.actionText = ""
self.actionMd5 = ""
self.actionMods = 0
self.gameMode = gameModes.std
self.ip = ip self.ip = ip
self.country = 0 self.country = 0
self.location = [0,0] self.location = [0,0]
@ -77,12 +72,29 @@ class token:
self.spamRate = 0 self.spamRate = 0
#self.lastMessagetime = 0 #self.lastMessagetime = 0
# Stats cache
self.actionID = actions.idle
self.actionText = ""
self.actionMd5 = ""
self.actionMods = 0
self.gameMode = gameModes.std
self.rankedScore = 0
self.accuracy = 0.0
self.playcount = 0
self.totalScore = 0
self.gameRank = 0
self.pp = 0
# Generate/set token # Generate/set token
if token != None: if token != None:
self.token = token self.token = token
else: else:
self.token = str(uuid.uuid4()) self.token = str(uuid.uuid4())
# Set stats
self.updateCachedStats()
# If we have a valid ip, save bancho session in DB so we can cache LETS logins # If we have a valid ip, save bancho session in DB so we can cache LETS logins
if ip != "": if ip != "":
userHelper.saveBanchoSession(self.userID, self.ip) userHelper.saveBanchoSession(self.userID, self.ip)
@ -270,3 +282,17 @@ class token:
return -- silence seconds left return -- silence seconds left
""" """
return max(0, self.silenceEndTime-int(time.time())) return max(0, self.silenceEndTime-int(time.time()))
def updateCachedStats(self):
"""Update all cached stats for this token"""
stats = userHelper.getUserStats(self.userID, self.gameMode)
log.debug(str(stats))
if stats == None:
log.warning("Stats query returned None")
return
self.rankedScore = stats["rankedScore"]
self.accuracy = stats["accuracy"]/100
self.playcount = stats["playcount"]
self.totalScore = stats["totalScore"]
self.gameRank = stats["gameRank"]
self.pp = stats["pp"]

View File

@ -180,7 +180,7 @@ class tokenList:
Reset spam rate every 10 seconds. Reset spam rate every 10 seconds.
CALL THIS FUNCTION ONLY ONCE! CALL THIS FUNCTION ONLY ONCE!
""" """
log.debug("Resetting spam protection...") #log.debug("Resetting spam protection...")
# Reset spamRate for every token # Reset spamRate for every token
for _, value in self.tokens.items(): for _, value in self.tokens.items():