Move to fasthttp for improved performance

This commit is contained in:
Morgan Bazalgette 2017-02-02 13:40:28 +01:00
parent ace2fded7e
commit 85e6dc7e5e
26 changed files with 448 additions and 380 deletions

View File

@ -1,17 +1,11 @@
// Package internals has methods that suit none of the API packages. // Package internals has methods that suit none of the API packages.
package internals package internals
import ( import "github.com/valyala/fasthttp"
"github.com/gin-gonic/gin"
)
type statusResponse struct { var statusResp = []byte(`{ "status": 1 }`)
Status int `json:"status"`
}
// Status is used for checking the API is up by the ripple website, on the status page. // Status is used for checking the API is up by the ripple website, on the status page.
func Status(c *gin.Context) { func Status(c *fasthttp.RequestCtx) {
c.JSON(200, statusResponse{ c.Write(statusResp)
Status: 1,
})
} }

View File

@ -1,52 +1,41 @@
package app package app
import ( import (
"crypto/md5"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"net"
"regexp" "regexp"
"strings" "unsafe"
"github.com/valyala/fasthttp"
"zxq.co/ripple/rippleapi/common" "zxq.co/ripple/rippleapi/common"
"github.com/gin-gonic/gin"
) )
// Method wraps an API method to a HandlerFunc. // Method wraps an API method to a HandlerFunc.
func Method(f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) gin.HandlerFunc { func Method(f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) fasthttp.RequestHandler {
return func(c *gin.Context) { return func(c *fasthttp.RequestCtx) {
initialCaretaker(c, f, privilegesNeeded...) initialCaretaker(c, f, privilegesNeeded...)
} }
} }
func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) { func initialCaretaker(c *fasthttp.RequestCtx, f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) {
rateLimiter()
var doggoTags []string var doggoTags []string
data, err := ioutil.ReadAll(c.Request.Body) qa := c.Request.URI().QueryArgs()
if err != nil { var token string
c.Error(err)
}
token := ""
switch { switch {
case c.Request.Header.Get("X-Ripple-Token") != "": case len(c.Request.Header.Peek("X-Ripple-Token")) > 0:
token = c.Request.Header.Get("X-Ripple-Token") token = string(c.Request.Header.Peek("X-Ripple-Token"))
case c.Query("token") != "": case len(qa.Peek("token")) > 0:
token = c.Query("token") token = string(qa.Peek("token"))
case c.Query("k") != "": case len(qa.Peek("k")) > 0:
token = c.Query("k") token = string(qa.Peek("k"))
default: default:
token, _ = c.Cookie("rt") token = string(c.Request.Header.Cookie("rt"))
} }
c.Set("token", fmt.Sprintf("%x", md5.Sum([]byte(token))))
md := common.MethodData{ md := common.MethodData{
DB: db, DB: db,
RequestData: data, Ctx: c,
C: c,
Doggo: doggo, Doggo: doggo,
R: red, R: red,
} }
@ -58,25 +47,8 @@ func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMe
} }
} }
var ip string // log into datadog that this is an hanayo request
if requestIP, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)); err != nil { if b2s(c.Request.Header.Peek("H-Key")) == cf.HanayoKey && b2s(c.UserAgent()) == "hanayo" {
panic(err)
} else {
// if requestIP is not 127.0.0.1, means no reverse proxy is being used => direct request.
if requestIP != "127.0.0.1" {
ip = requestIP
}
}
// means we're using reverse-proxy, so X-Real-IP
if ip == "" {
ip = c.ClientIP()
}
// requests from hanayo should not be rate limited.
if !(c.Request.Header.Get("H-Key") == cf.HanayoKey && c.Request.UserAgent() == "hanayo") {
perUserRequestLimiter(md.ID(), c.ClientIP())
} else {
doggoTags = append(doggoTags, "hanayo") doggoTags = append(doggoTags, "hanayo")
} }
@ -89,21 +61,22 @@ func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMe
} }
} }
if missingPrivileges != 0 { if missingPrivileges != 0 {
c.IndentedJSON(401, common.SimpleResponse(401, "You don't have the privilege(s): "+common.Privileges(missingPrivileges).String()+".")) c.SetStatusCode(401)
mkjson(c, common.SimpleResponse(401, "You don't have the privilege(s): "+common.Privileges(missingPrivileges).String()+"."))
return return
} }
resp := f(md) resp := f(md)
if md.HasQuery("pls200") { if md.HasQuery("pls200") {
c.Writer.WriteHeader(200) c.SetStatusCode(200)
} else { } else {
c.Writer.WriteHeader(resp.GetCode()) c.SetStatusCode(resp.GetCode())
} }
if md.HasQuery("callback") { if md.HasQuery("callback") {
c.Header("Content-Type", "application/javascript; charset=utf-8") c.Response.Header.Add("Content-Type", "application/javascript; charset=utf-8")
} else { } else {
c.Header("Content-Type", "application/json; charset=utf-8") c.Response.Header.Add("Content-Type", "application/json; charset=utf-8")
} }
mkjson(c, resp) mkjson(c, resp)
@ -113,22 +86,31 @@ func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMe
var callbackJSONP = regexp.MustCompile(`^[a-zA-Z_\$][a-zA-Z0-9_\$]*$`) var callbackJSONP = regexp.MustCompile(`^[a-zA-Z_\$][a-zA-Z0-9_\$]*$`)
// mkjson auto indents json, and wraps json into a jsonp callback if specified by the request. // mkjson auto indents json, and wraps json into a jsonp callback if specified by the request.
// then writes to the gin.Context the data. // then writes to the RequestCtx the data.
func mkjson(c *gin.Context, data interface{}) { func mkjson(c *fasthttp.RequestCtx, data interface{}) {
exported, err := json.MarshalIndent(data, "", "\t") exported, err := json.MarshalIndent(data, "", "\t")
if err != nil { if err != nil {
c.Error(err) fmt.Println(err)
exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong." }`) exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong." }`)
} }
cb := c.Query("callback") cb := string(c.URI().QueryArgs().Peek("callback"))
willcb := cb != "" && willcb := cb != "" &&
len(cb) < 100 && len(cb) < 100 &&
callbackJSONP.MatchString(cb) callbackJSONP.MatchString(cb)
if willcb { if willcb {
c.Writer.Write([]byte("/**/ typeof " + cb + " === 'function' && " + cb + "(")) c.Write([]byte("/**/ typeof " + cb + " === 'function' && " + cb + "("))
} }
c.Writer.Write(exported) c.Write(exported)
if willcb { if willcb {
c.Writer.Write([]byte(");")) c.Write([]byte(");"))
} }
} }
// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
//
// Note it may break if string and/or slice header will change
// in the future go versions.
func b2s(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}

View File

@ -4,32 +4,32 @@ import (
"strconv" "strconv"
"strings" "strings"
"zxq.co/ripple/rippleapi/common"
"github.com/gin-gonic/gin"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/thehowl/go-osuapi" "github.com/thehowl/go-osuapi"
"github.com/valyala/fasthttp"
"zxq.co/ripple/rippleapi/common"
) )
// GetBeatmap retrieves general beatmap information. // GetBeatmap retrieves general beatmap information.
func GetBeatmap(c *gin.Context, db *sqlx.DB) { func GetBeatmap(c *fasthttp.RequestCtx, db *sqlx.DB) {
var whereClauses []string var whereClauses []string
var params []interface{} var params []interface{}
limit := strconv.Itoa(common.InString(1, c.Query("limit"), 500, 500)) limit := strconv.Itoa(common.InString(1, query(c, "limit"), 500, 500))
// since value is not stored, silently ignore // since value is not stored, silently ignore
if c.Query("s") != "" { if query(c, "s") != "" {
whereClauses = append(whereClauses, "beatmaps.beatmapset_id = ?") whereClauses = append(whereClauses, "beatmaps.beatmapset_id = ?")
params = append(params, c.Query("s")) params = append(params, query(c, "s"))
} }
if c.Query("b") != "" { if query(c, "b") != "" {
whereClauses = append(whereClauses, "beatmaps.beatmap_id = ?") whereClauses = append(whereClauses, "beatmaps.beatmap_id = ?")
params = append(params, c.Query("b")) params = append(params, query(c, "b"))
// b is unique, so change limit to 1 // b is unique, so change limit to 1
limit = "1" limit = "1"
} }
// creator is not stored, silently ignore u and type // creator is not stored, silently ignore u and type
if c.Query("m") != "" { if query(c, "m") != "" {
m := genmode(c.Query("m")) m := genmode(query(c, "m"))
if m == "std" { if m == "std" {
// Since STD beatmaps are converted, all of the diffs must be != 0 // Since STD beatmaps are converted, all of the diffs must be != 0
for _, i := range modes { for _, i := range modes {
@ -37,14 +37,14 @@ func GetBeatmap(c *gin.Context, db *sqlx.DB) {
} }
} else { } else {
whereClauses = append(whereClauses, "beatmaps.difficulty_"+m+" != 0") whereClauses = append(whereClauses, "beatmaps.difficulty_"+m+" != 0")
if c.Query("a") == "1" { if query(c, "a") == "1" {
whereClauses = append(whereClauses, "beatmaps.difficulty_std = 0") whereClauses = append(whereClauses, "beatmaps.difficulty_std = 0")
} }
} }
} }
if c.Query("h") != "" { if query(c, "h") != "" {
whereClauses = append(whereClauses, "beatmaps.beatmap_md5 = ?") whereClauses = append(whereClauses, "beatmaps.beatmap_md5 = ?")
params = append(params, c.Query("h")) params = append(params, query(c, "h"))
} }
where := strings.Join(whereClauses, " AND ") where := strings.Join(whereClauses, " AND ")
@ -61,8 +61,8 @@ func GetBeatmap(c *gin.Context, db *sqlx.DB) {
FROM beatmaps `+where+" ORDER BY id DESC LIMIT "+limit, FROM beatmaps `+where+" ORDER BY id DESC LIMIT "+limit,
params...) params...)
if err != nil { if err != nil {
c.Error(err) common.Err(c, err)
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
return return
} }
@ -82,7 +82,7 @@ FROM beatmaps `+where+" ORDER BY id DESC LIMIT "+limit,
&rawLastUpdate, &rawLastUpdate,
) )
if err != nil { if err != nil {
c.Error(err) common.Err(c, err)
continue continue
} }
bm.TotalLength = bm.HitLength bm.TotalLength = bm.HitLength
@ -103,7 +103,7 @@ FROM beatmaps `+where+" ORDER BY id DESC LIMIT "+limit,
bms = append(bms, bm) bms = append(bms, bm)
} }
c.JSON(200, bms) json(c, 200, bms)
} }
var rippleToOsuRankedStatus = map[int]osuapi.ApprovedStatus{ var rippleToOsuRankedStatus = map[int]osuapi.ApprovedStatus{

View File

@ -2,12 +2,12 @@ package peppy
import ( import (
"database/sql" "database/sql"
_json "encoding/json"
"strconv" "strconv"
"zxq.co/ripple/rippleapi/common"
"github.com/gin-gonic/gin"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/valyala/fasthttp"
"zxq.co/ripple/rippleapi/common"
) )
var modes = []string{"std", "taiko", "ctb", "mania"} var modes = []string{"std", "taiko", "ctb", "mania"}
@ -30,25 +30,25 @@ func rankable(m string) bool {
return x == 0 || x == 3 return x == 0 || x == 3
} }
func genUser(c *gin.Context, db *sqlx.DB) (string, string) { func genUser(c *fasthttp.RequestCtx, db *sqlx.DB) (string, string) {
var whereClause string var whereClause string
var p string var p string
// used in second case of switch // used in second case of switch
s, err := strconv.Atoi(c.Query("u")) s, err := strconv.Atoi(query(c, "u"))
switch { switch {
// We know for sure that it's an username. // We know for sure that it's an username.
case c.Query("type") == "string": case query(c, "type") == "string":
whereClause = "users.username_safe = ?" whereClause = "users.username_safe = ?"
p = common.SafeUsername(c.Query("u")) p = common.SafeUsername(query(c, "u"))
// It could be an user ID, so we look for an user with that username first. // It could be an user ID, so we look for an user with that username first.
case err == nil: case err == nil:
err = db.QueryRow("SELECT id FROM users WHERE id = ? LIMIT 1", s).Scan(&p) err = db.QueryRow("SELECT id FROM users WHERE id = ? LIMIT 1", s).Scan(&p)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// If no user with that userID were found, assume username. // If no user with that userID were found, assume username.
whereClause = "users.username_safe = ?" whereClause = "users.username_safe = ?"
p = common.SafeUsername(c.Query("u")) p = common.SafeUsername(query(c, "u"))
} else { } else {
// An user with that userID was found. Thus it's an userID. // An user with that userID was found. Thus it's an userID.
whereClause = "users.id = ?" whereClause = "users.id = ?"
@ -56,7 +56,20 @@ func genUser(c *gin.Context, db *sqlx.DB) (string, string) {
// u contains letters, so it's an username. // u contains letters, so it's an username.
default: default:
whereClause = "users.username_safe = ?" whereClause = "users.username_safe = ?"
p = common.SafeUsername(c.Query("u")) p = common.SafeUsername(query(c, "u"))
} }
return whereClause, p return whereClause, p
} }
func query(c *fasthttp.RequestCtx, s string) string {
return string(c.QueryArgs().Peek(s))
}
func json(c *fasthttp.RequestCtx, code int, data interface{}) {
c.SetStatusCode(code)
d, err := _json.Marshal(data)
if err != nil {
panic(err)
}
c.Write(d)
}

View File

@ -2,11 +2,11 @@
package peppy package peppy
import ( import (
"github.com/gin-gonic/gin"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/valyala/fasthttp"
) )
// GetMatch retrieves general match information. // GetMatch retrieves general match information.
func GetMatch(c *gin.Context, db *sqlx.DB) { func GetMatch(c *fasthttp.RequestCtx, db *sqlx.DB) {
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
} }

View File

@ -7,43 +7,43 @@ import (
"zxq.co/ripple/rippleapi/common" "zxq.co/ripple/rippleapi/common"
"zxq.co/x/getrank"
"github.com/gin-gonic/gin"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/valyala/fasthttp"
"gopkg.in/thehowl/go-osuapi.v1" "gopkg.in/thehowl/go-osuapi.v1"
"zxq.co/x/getrank"
) )
// GetScores retrieve information about the top 100 scores of a specified beatmap. // GetScores retrieve information about the top 100 scores of a specified beatmap.
func GetScores(c *gin.Context, db *sqlx.DB) { func GetScores(c *fasthttp.RequestCtx, db *sqlx.DB) {
if c.Query("b") == "" { if query(c, "b") == "" {
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
return return
} }
var beatmapMD5 string var beatmapMD5 string
err := db.Get(&beatmapMD5, "SELECT beatmap_md5 FROM beatmaps WHERE beatmap_id = ? LIMIT 1", c.Query("b")) err := db.Get(&beatmapMD5, "SELECT beatmap_md5 FROM beatmaps WHERE beatmap_id = ? LIMIT 1", query(c, "b"))
switch { switch {
case err == sql.ErrNoRows: case err == sql.ErrNoRows:
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
return return
case err != nil: case err != nil:
c.Error(err) common.Err(c, err)
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
return return
} }
var sb = "scores.score" var sb = "scores.score"
if rankable(c.Query("m")) { if rankable(query(c, "m")) {
sb = "scores.pp" sb = "scores.pp"
} }
var ( var (
extraWhere string extraWhere string
extraParams []interface{} extraParams []interface{}
) )
if c.Query("u") != "" { if query(c, "u") != "" {
w, p := genUser(c, db) w, p := genUser(c, db)
extraWhere = "AND " + w extraWhere = "AND " + w
extraParams = append(extraParams, p) extraParams = append(extraParams, p)
} }
mods := common.Int(c.Query("mods")) mods := common.Int(query(c, "mods"))
rows, err := db.Query(` rows, err := db.Query(`
SELECT SELECT
scores.id, scores.score, users.username, scores.300_count, scores.100_count, scores.id, scores.score, users.username, scores.300_count, scores.100_count,
@ -58,11 +58,11 @@ WHERE scores.completed = '3'
AND scores.play_mode = ? AND scores.play_mode = ?
AND scores.mods & ? = ? AND scores.mods & ? = ?
`+extraWhere+` `+extraWhere+`
ORDER BY `+sb+` DESC LIMIT `+strconv.Itoa(common.InString(1, c.Query("limit"), 100, 50)), ORDER BY `+sb+` DESC LIMIT `+strconv.Itoa(common.InString(1, query(c, "limit"), 100, 50)),
append([]interface{}{beatmapMD5, genmodei(c.Query("m")), mods, mods}, extraParams...)...) append([]interface{}{beatmapMD5, genmodei(query(c, "m")), mods, mods}, extraParams...)...)
if err != nil { if err != nil {
c.Error(err) common.Err(c, err)
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
return return
} }
var results []osuapi.GSScore var results []osuapi.GSScore
@ -82,17 +82,17 @@ ORDER BY `+sb+` DESC LIMIT `+strconv.Itoa(common.InString(1, c.Query("limit"), 1
) )
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
c.Error(err) common.Err(c, err)
} }
continue continue
} }
s.FullCombo = osuapi.OsuBool(fullcombo) s.FullCombo = osuapi.OsuBool(fullcombo)
s.Mods = osuapi.Mods(mods) s.Mods = osuapi.Mods(mods)
s.Date = osuapi.MySQLDate(date) s.Date = osuapi.MySQLDate(date)
s.Rank = strings.ToUpper(getrank.GetRank(osuapi.Mode(genmodei(c.Query("m"))), s.Mods, s.Rank = strings.ToUpper(getrank.GetRank(osuapi.Mode(genmodei(query(c, "m"))), s.Mods,
accuracy, s.Count300, s.Count100, s.Count50, s.CountMiss)) accuracy, s.Count300, s.Count100, s.Count50, s.CountMiss))
results = append(results, s) results = append(results, s)
} }
c.JSON(200, results) json(c, 200, results)
return return
} }

View File

@ -5,23 +5,24 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"zxq.co/ripple/ocl"
"github.com/gin-gonic/gin"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/thehowl/go-osuapi" "github.com/thehowl/go-osuapi"
"github.com/valyala/fasthttp"
"zxq.co/ripple/ocl"
"zxq.co/ripple/rippleapi/common"
) )
// GetUser retrieves general user information. // GetUser retrieves general user information.
func GetUser(c *gin.Context, db *sqlx.DB) { func GetUser(c *fasthttp.RequestCtx, db *sqlx.DB) {
if c.Query("u") == "" { if query(c, "u") == "" {
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
return return
} }
var user osuapi.User var user osuapi.User
whereClause, p := genUser(c, db) whereClause, p := genUser(c, db)
whereClause = "WHERE " + whereClause whereClause = "WHERE " + whereClause
mode := genmode(c.Query("m")) mode := genmode(query(c, "m"))
var lbpos *int var lbpos *int
err := db.QueryRow(fmt.Sprintf( err := db.QueryRow(fmt.Sprintf(
@ -43,9 +44,9 @@ func GetUser(c *gin.Context, db *sqlx.DB) {
&user.Country, &user.Country,
) )
if err != nil { if err != nil {
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
c.Error(err) common.Err(c, err)
} }
return return
} }
@ -54,5 +55,5 @@ func GetUser(c *gin.Context, db *sqlx.DB) {
} }
user.Level = ocl.GetLevelPrecise(user.TotalScore) user.Level = ocl.GetLevelPrecise(user.TotalScore)
c.JSON(200, []osuapi.User{user}) json(c, 200, []osuapi.User{user})
} }

View File

@ -4,32 +4,32 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/jmoiron/sqlx"
"github.com/valyala/fasthttp"
"gopkg.in/thehowl/go-osuapi.v1"
"zxq.co/ripple/rippleapi/common" "zxq.co/ripple/rippleapi/common"
"zxq.co/x/getrank" "zxq.co/x/getrank"
"github.com/gin-gonic/gin"
"github.com/jmoiron/sqlx"
"gopkg.in/thehowl/go-osuapi.v1"
) )
// GetUserRecent retrieves an user's recent scores. // GetUserRecent retrieves an user's recent scores.
func GetUserRecent(c *gin.Context, db *sqlx.DB) { func GetUserRecent(c *fasthttp.RequestCtx, db *sqlx.DB) {
getUserX(c, db, "ORDER BY scores.time DESC", common.InString(1, c.Query("limit"), 50, 10)) getUserX(c, db, "ORDER BY scores.time DESC", common.InString(1, query(c, "limit"), 50, 10))
} }
// GetUserBest retrieves an user's best scores. // GetUserBest retrieves an user's best scores.
func GetUserBest(c *gin.Context, db *sqlx.DB) { func GetUserBest(c *fasthttp.RequestCtx, db *sqlx.DB) {
var sb string var sb string
if rankable(c.Query("m")) { if rankable(query(c, "m")) {
sb = "scores.pp" sb = "scores.pp"
} else { } else {
sb = "scores.score" sb = "scores.score"
} }
getUserX(c, db, "AND completed = '3' ORDER BY "+sb+" DESC", common.InString(1, c.Query("limit"), 100, 10)) getUserX(c, db, "AND completed = '3' ORDER BY "+sb+" DESC", common.InString(1, query(c, "limit"), 100, 10))
} }
func getUserX(c *gin.Context, db *sqlx.DB, orderBy string, limit int) { func getUserX(c *fasthttp.RequestCtx, db *sqlx.DB, orderBy string, limit int) {
whereClause, p := genUser(c, db) whereClause, p := genUser(c, db)
query := fmt.Sprintf( sqlQuery := fmt.Sprintf(
`SELECT `SELECT
beatmaps.beatmap_id, scores.score, scores.max_combo, beatmaps.beatmap_id, scores.score, scores.max_combo,
scores.300_count, scores.100_count, scores.50_count, scores.300_count, scores.100_count, scores.50_count,
@ -44,11 +44,11 @@ func getUserX(c *gin.Context, db *sqlx.DB, orderBy string, limit int) {
LIMIT %d`, whereClause, orderBy, limit, LIMIT %d`, whereClause, orderBy, limit,
) )
scores := make([]osuapi.GUSScore, 0, limit) scores := make([]osuapi.GUSScore, 0, limit)
m := genmodei(c.Query("m")) m := genmodei(query(c, "m"))
rows, err := db.Query(query, p, m) rows, err := db.Query(sqlQuery, p, m)
if err != nil { if err != nil {
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
c.Error(err) common.Err(c, err)
return return
} }
for rows.Next() { for rows.Next() {
@ -68,8 +68,8 @@ func getUserX(c *gin.Context, db *sqlx.DB, orderBy string, limit int) {
&curscore.PP, &acc, &curscore.PP, &acc,
) )
if err != nil { if err != nil {
c.JSON(200, defaultResponse) json(c, 200, defaultResponse)
c.Error(err) common.Err(c, err)
return return
} }
if bid == nil { if bid == nil {
@ -91,5 +91,5 @@ func getUserX(c *gin.Context, db *sqlx.DB, orderBy string, limit int) {
)) ))
scores = append(scores, curscore) scores = append(scores, curscore)
} }
c.JSON(200, scores) json(c, 200, scores)
} }

View File

@ -1,16 +1,13 @@
package app package app
import ( import (
"github.com/gin-gonic/gin"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/valyala/fasthttp"
) )
// PeppyMethod generates a method for the peppyapi // PeppyMethod generates a method for the peppyapi
func PeppyMethod(a func(c *gin.Context, db *sqlx.DB)) gin.HandlerFunc { func PeppyMethod(a func(c *fasthttp.RequestCtx, db *sqlx.DB)) fasthttp.RequestHandler {
return func(c *gin.Context) { return func(c *fasthttp.RequestCtx) {
rateLimiter()
perUserRequestLimiter(0, c.ClientIP())
doggo.Incr("requests.peppy", nil, 1) doggo.Incr("requests.peppy", nil, 1)
// I have no idea how, but I manged to accidentally string the first 4 // I have no idea how, but I manged to accidentally string the first 4

View File

@ -1,36 +0,0 @@
package app
import (
"strconv"
"time"
"zxq.co/ripple/rippleapi/limit"
)
const reqsPerSecond = 5000
const sleepTime = time.Second / reqsPerSecond
var limiter = make(chan struct{}, reqsPerSecond)
func setUpLimiter() {
for i := 0; i < reqsPerSecond; i++ {
limiter <- struct{}{}
}
go func() {
for {
limiter <- struct{}{}
time.Sleep(sleepTime)
}
}()
}
func rateLimiter() {
<-limiter
}
func perUserRequestLimiter(uid int, ip string) {
if uid == 0 {
limit.Request("ip:"+ip, 200)
} else {
limit.Request("user:"+strconv.Itoa(uid), 3000)
}
}

View File

@ -4,9 +4,8 @@ import (
"fmt" "fmt"
"github.com/DataDog/datadog-go/statsd" "github.com/DataDog/datadog-go/statsd"
fhr "github.com/buaazp/fasthttprouter"
"github.com/getsentry/raven-go" "github.com/getsentry/raven-go"
"github.com/gin-gonic/contrib/gzip"
"github.com/gin-gonic/gin"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/serenize/snaker" "github.com/serenize/snaker"
"gopkg.in/redis.v5" "gopkg.in/redis.v5"
@ -29,7 +28,7 @@ var commonClusterfucks = map[string]string{
} }
// Start begins taking HTTP connections. // Start begins taking HTTP connections.
func Start(conf common.Conf, dbO *sqlx.DB) *gin.Engine { func Start(conf common.Conf, dbO *sqlx.DB) *fhr.Router {
db = dbO db = dbO
cf = conf cf = conf
@ -40,10 +39,10 @@ func Start(conf common.Conf, dbO *sqlx.DB) *gin.Engine {
return snaker.CamelToSnake(s) return snaker.CamelToSnake(s)
}) })
setUpLimiter() r := fhr.New()
// TODO: add back gzip
r := gin.Default() // TODO: add logging
r.Use(gzip.Gzip(gzip.DefaultCompression)) // TODO: add sentry panic recovering
// sentry // sentry
if conf.SentryDSN != "" { if conf.SentryDSN != "" {
@ -52,7 +51,8 @@ func Start(conf common.Conf, dbO *sqlx.DB) *gin.Engine {
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
r.Use(Recovery(ravenClient, false)) // r.Use(Recovery(ravenClient, false))
common.RavenClient = ravenClient
} }
} }
@ -63,9 +63,9 @@ func Start(conf common.Conf, dbO *sqlx.DB) *gin.Engine {
fmt.Println(err) fmt.Println(err)
} }
doggo.Namespace = "api." doggo.Namespace = "api."
r.Use(func(c *gin.Context) { // r.Use(func(c *gin.Context) {
doggo.Incr("requests", nil, 1) // doggo.Incr("requests", nil, 1)
}) // })
// redis // redis
red = redis.NewClient(&redis.Options{ red = redis.NewClient(&redis.Options{
@ -77,94 +77,93 @@ func Start(conf common.Conf, dbO *sqlx.DB) *gin.Engine {
// token updater // token updater
go tokenUpdater(db) go tokenUpdater(db)
api := r.Group("/api") // peppyapi
{ {
p := api.Group("/") r.GET("/api/get_user", PeppyMethod(peppy.GetUser))
{ r.GET("/api/get_match", PeppyMethod(peppy.GetMatch))
p.GET("/get_user", PeppyMethod(peppy.GetUser)) r.GET("/api/get_user_recent", PeppyMethod(peppy.GetUserRecent))
p.GET("/get_match", PeppyMethod(peppy.GetMatch)) r.GET("/api/get_user_best", PeppyMethod(peppy.GetUserBest))
p.GET("/get_user_recent", PeppyMethod(peppy.GetUserRecent)) r.GET("/api/get_scores", PeppyMethod(peppy.GetScores))
p.GET("/get_user_best", PeppyMethod(peppy.GetUserBest)) r.GET("/api/get_beatmaps", PeppyMethod(peppy.GetBeatmap))
p.GET("/get_scores", PeppyMethod(peppy.GetScores))
p.GET("/get_beatmaps", PeppyMethod(peppy.GetBeatmap))
} }
gv1 := api.Group("/v1") // v1 API
{ {
gv1.POST("/tokens", Method(v1.TokenNewPOST)) r.POST("/api/v1/tokens", Method(v1.TokenNewPOST))
gv1.POST("/tokens/new", Method(v1.TokenNewPOST)) r.POST("/api/v1/tokens/new", Method(v1.TokenNewPOST))
gv1.POST("/tokens/self/delete", Method(v1.TokenSelfDeletePOST)) r.POST("/api/v1/tokens/self/delete", Method(v1.TokenSelfDeletePOST))
// Auth-free API endpoints (public data) // Auth-free API endpoints (public data)
gv1.GET("/ping", Method(v1.PingGET)) r.GET("/api/v1/ping", Method(v1.PingGET))
gv1.GET("/surprise_me", Method(v1.SurpriseMeGET)) r.GET("/api/v1/surprise_me", Method(v1.SurpriseMeGET))
gv1.GET("/doc", Method(v1.DocGET)) r.GET("/api/v1/doc", Method(v1.DocGET))
gv1.GET("/doc/content", Method(v1.DocContentGET)) r.GET("/api/v1/doc/content", Method(v1.DocContentGET))
gv1.GET("/doc/rules", Method(v1.DocRulesGET)) r.GET("/api/v1/doc/rules", Method(v1.DocRulesGET))
gv1.GET("/users", Method(v1.UsersGET)) r.GET("/api/v1/users", Method(v1.UsersGET))
gv1.GET("/users/whatid", Method(v1.UserWhatsTheIDGET)) r.GET("/api/v1/users/whatid", Method(v1.UserWhatsTheIDGET))
gv1.GET("/users/full", Method(v1.UserFullGET)) r.GET("/api/v1/users/full", Method(v1.UserFullGET))
gv1.GET("/users/userpage", Method(v1.UserUserpageGET)) r.GET("/api/v1/users/userpage", Method(v1.UserUserpageGET))
gv1.GET("/users/lookup", Method(v1.UserLookupGET)) r.GET("/api/v1/users/lookup", Method(v1.UserLookupGET))
gv1.GET("/users/scores/best", Method(v1.UserScoresBestGET)) r.GET("/api/v1/users/scores/best", Method(v1.UserScoresBestGET))
gv1.GET("/users/scores/recent", Method(v1.UserScoresRecentGET)) r.GET("/api/v1/users/scores/recent", Method(v1.UserScoresRecentGET))
gv1.GET("/badges", Method(v1.BadgesGET)) r.GET("/api/v1/badges", Method(v1.BadgesGET))
gv1.GET("/beatmaps", Method(v1.BeatmapGET)) r.GET("/api/v1/beatmaps", Method(v1.BeatmapGET))
gv1.GET("/leaderboard", Method(v1.LeaderboardGET)) r.GET("/api/v1/leaderboard", Method(v1.LeaderboardGET))
gv1.GET("/tokens", Method(v1.TokenGET)) r.GET("/api/v1/tokens", Method(v1.TokenGET))
gv1.GET("/users/self", Method(v1.UserSelfGET)) r.GET("/api/v1/users/self", Method(v1.UserSelfGET))
gv1.GET("/tokens/self", Method(v1.TokenSelfGET)) r.GET("/api/v1/tokens/self", Method(v1.TokenSelfGET))
gv1.GET("/blog/posts", Method(v1.BlogPostsGET)) r.GET("/api/v1/blog/posts", Method(v1.BlogPostsGET))
gv1.GET("/scores", Method(v1.ScoresGET)) r.GET("/api/v1/scores", Method(v1.ScoresGET))
gv1.GET("/beatmaps/rank_requests/status", Method(v1.BeatmapRankRequestsStatusGET)) r.GET("/api/v1/beatmaps/rank_requests/status", Method(v1.BeatmapRankRequestsStatusGET))
// ReadConfidential privilege required // ReadConfidential privilege required
gv1.GET("/friends", Method(v1.FriendsGET, common.PrivilegeReadConfidential)) r.GET("/api/v1/friends", Method(v1.FriendsGET, common.PrivilegeReadConfidential))
gv1.GET("/friends/with", Method(v1.FriendsWithGET, common.PrivilegeReadConfidential)) r.GET("/api/v1/friends/with", Method(v1.FriendsWithGET, common.PrivilegeReadConfidential))
gv1.GET("/users/self/donor_info", Method(v1.UsersSelfDonorInfoGET, common.PrivilegeReadConfidential)) r.GET("/api/v1/users/self/donor_info", Method(v1.UsersSelfDonorInfoGET, common.PrivilegeReadConfidential))
gv1.GET("/users/self/favourite_mode", Method(v1.UsersSelfFavouriteModeGET, common.PrivilegeReadConfidential)) r.GET("/api/v1/users/self/favourite_mode", Method(v1.UsersSelfFavouriteModeGET, common.PrivilegeReadConfidential))
gv1.GET("/users/self/settings", Method(v1.UsersSelfSettingsGET, common.PrivilegeReadConfidential)) r.GET("/api/v1/users/self/settings", Method(v1.UsersSelfSettingsGET, common.PrivilegeReadConfidential))
// Write privilege required // Write privilege required
gv1.POST("/friends/add", Method(v1.FriendsAddPOST, common.PrivilegeWrite)) r.POST("/api/v1/friends/add", Method(v1.FriendsAddPOST, common.PrivilegeWrite))
gv1.POST("/friends/del", Method(v1.FriendsDelPOST, common.PrivilegeWrite)) r.POST("/api/v1/friends/del", Method(v1.FriendsDelPOST, common.PrivilegeWrite))
gv1.POST("/users/self/settings", Method(v1.UsersSelfSettingsPOST, common.PrivilegeWrite)) r.POST("/api/v1/users/self/settings", Method(v1.UsersSelfSettingsPOST, common.PrivilegeWrite))
gv1.POST("/users/self/userpage", Method(v1.UserSelfUserpagePOST, common.PrivilegeWrite)) r.POST("/api/v1/users/self/userpage", Method(v1.UserSelfUserpagePOST, common.PrivilegeWrite))
gv1.POST("/beatmaps/rank_requests", Method(v1.BeatmapRankRequestsSubmitPOST, common.PrivilegeWrite)) r.POST("/api/v1/beatmaps/rank_requests", Method(v1.BeatmapRankRequestsSubmitPOST, common.PrivilegeWrite))
// Admin: beatmap // Admin: beatmap
gv1.POST("/beatmaps/set_status", Method(v1.BeatmapSetStatusPOST, common.PrivilegeBeatmap)) r.POST("/api/v1/beatmaps/set_status", Method(v1.BeatmapSetStatusPOST, common.PrivilegeBeatmap))
gv1.GET("/beatmaps/ranked_frozen_full", Method(v1.BeatmapRankedFrozenFullGET, common.PrivilegeBeatmap)) r.GET("/api/v1/beatmaps/ranked_frozen_full", Method(v1.BeatmapRankedFrozenFullGET, common.PrivilegeBeatmap))
// Admin: user managing // Admin: user managing
gv1.POST("/users/manage/set_allowed", Method(v1.UserManageSetAllowedPOST, common.PrivilegeManageUser)) r.POST("/api/v1/users/manage/set_allowed", Method(v1.UserManageSetAllowedPOST, common.PrivilegeManageUser))
// M E T A // M E T A
// E T "wow thats so meta" // E T "wow thats so meta"
// T E -- the one who said "wow thats so meta" // T E -- the one who said "wow thats so meta"
// A T E M // A T E M
gv1.GET("/meta/restart", Method(v1.MetaRestartGET, common.PrivilegeAPIMeta)) r.GET("/api/v1/meta/restart", Method(v1.MetaRestartGET, common.PrivilegeAPIMeta))
gv1.GET("/meta/kill", Method(v1.MetaKillGET, common.PrivilegeAPIMeta)) r.GET("/api/v1/meta/kill", Method(v1.MetaKillGET, common.PrivilegeAPIMeta))
gv1.GET("/meta/up_since", Method(v1.MetaUpSinceGET, common.PrivilegeAPIMeta)) r.GET("/api/v1/meta/up_since", Method(v1.MetaUpSinceGET, common.PrivilegeAPIMeta))
gv1.GET("/meta/update", Method(v1.MetaUpdateGET, common.PrivilegeAPIMeta)) r.GET("/api/v1/meta/update", Method(v1.MetaUpdateGET, common.PrivilegeAPIMeta))
// User Managing + meta // User Managing + meta
gv1.POST("/tokens/fix_privileges", Method(v1.TokenFixPrivilegesPOST, r.POST("/api/v1/tokens/fix_privileges", Method(v1.TokenFixPrivilegesPOST,
common.PrivilegeManageUser, common.PrivilegeAPIMeta)) common.PrivilegeManageUser, common.PrivilegeAPIMeta))
}
// in the new osu-web, the old endpoints are also in /v1 it seems. So /shrug // in the new osu-web, the old endpoints are also in /v1 it seems. So /shrug
gv1.GET("/get_user", PeppyMethod(peppy.GetUser)) {
gv1.GET("/get_match", PeppyMethod(peppy.GetMatch)) r.GET("/api/v1/get_user", PeppyMethod(peppy.GetUser))
gv1.GET("/get_user_recent", PeppyMethod(peppy.GetUserRecent)) r.GET("/api/v1/get_match", PeppyMethod(peppy.GetMatch))
gv1.GET("/get_user_best", PeppyMethod(peppy.GetUserBest)) r.GET("/api/v1/get_user_recent", PeppyMethod(peppy.GetUserRecent))
gv1.GET("/get_scores", PeppyMethod(peppy.GetScores)) r.GET("/api/v1/get_user_best", PeppyMethod(peppy.GetUserBest))
gv1.GET("/get_beatmaps", PeppyMethod(peppy.GetBeatmap)) r.GET("/api/v1/get_scores", PeppyMethod(peppy.GetScores))
r.GET("/api/v1/get_beatmaps", PeppyMethod(peppy.GetBeatmap))
} }
api.GET("/status", internals.Status) r.GET("/api/status", internals.Status)
}
r.NoRoute(v1.Handle404) r.NotFound = v1.Handle404
return r return r
} }

View File

@ -1,8 +1,10 @@
package v1 package v1
import ( import (
"encoding/json"
"github.com/valyala/fasthttp"
"zxq.co/ripple/rippleapi/common" "zxq.co/ripple/rippleapi/common"
"github.com/gin-gonic/gin"
) )
type response404 struct { type response404 struct {
@ -11,12 +13,17 @@ type response404 struct {
} }
// Handle404 handles requests with no implemented handlers. // Handle404 handles requests with no implemented handlers.
func Handle404(c *gin.Context) { func Handle404(c *fasthttp.RequestCtx) {
c.Header("X-Real-404", "yes") c.Response.Header.Add("X-Real-404", "yes")
c.IndentedJSON(404, response404{ data, err := json.MarshalIndent(response404{
ResponseBase: common.ResponseBase{ ResponseBase: common.ResponseBase{
Code: 404, Code: 404,
}, },
Cats: surpriseMe(), Cats: surpriseMe(),
}) }, "", "\t")
if err != nil {
panic(err)
}
c.SetStatusCode(404)
c.Write(data)
} }

View File

@ -2,8 +2,6 @@ package v1
import ( import (
"database/sql" "database/sql"
"fmt"
"net/url"
"zxq.co/ripple/rippleapi/common" "zxq.co/ripple/rippleapi/common"
) )
@ -51,10 +49,10 @@ type beatmapSetStatusData struct {
// the beatmap ranked status is frozen. Or freezed. Freezed best meme 2k16 // the beatmap ranked status is frozen. Or freezed. Freezed best meme 2k16
func BeatmapSetStatusPOST(md common.MethodData) common.CodeMessager { func BeatmapSetStatusPOST(md common.MethodData) common.CodeMessager {
var req beatmapSetStatusData var req beatmapSetStatusData
md.RequestData.Unmarshal(&req) md.Unmarshal(&req)
var miss []string var miss []string
if req.BeatmapsetID == 0 && req.BeatmapID == 0 { if req.BeatmapsetID <= 0 && req.BeatmapID <= 0 {
miss = append(miss, "beatmapset_id or beatmap_id") miss = append(miss, "beatmapset_id or beatmap_id")
} }
if len(miss) != 0 { if len(miss) != 0 {
@ -84,32 +82,14 @@ func BeatmapSetStatusPOST(md common.MethodData) common.CodeMessager {
SET ranked = ?, ranked_status_freezed = ? SET ranked = ?, ranked_status_freezed = ?
WHERE beatmapset_id = ?`, req.RankedStatus, req.Frozen, param) WHERE beatmapset_id = ?`, req.RankedStatus, req.Frozen, param)
var x = make(map[string]interface{}, 1) if req.BeatmapID > 0 {
if req.BeatmapID != 0 { md.Ctx.Request.URI().QueryArgs().SetUint("bb", req.BeatmapID)
x["bb"] = req.BeatmapID
} else { } else {
x["s"] = req.BeatmapsetID md.Ctx.Request.URI().QueryArgs().SetUint("s", req.BeatmapsetID)
} }
md.C.Request.URL = genURL(x)
return getMultipleBeatmaps(md) return getMultipleBeatmaps(md)
} }
func genURL(d map[string]interface{}) *url.URL {
var s string
for k, v := range d {
if s != "" {
s += "&"
}
s += k + "=" + url.QueryEscape(fmt.Sprintf("%v", v))
}
u := new(url.URL)
if len(d) == 0 {
return u
}
u.RawQuery = s
return u
}
// BeatmapGET retrieves a beatmap. // BeatmapGET retrieves a beatmap.
func BeatmapGET(md common.MethodData) common.CodeMessager { func BeatmapGET(md common.MethodData) common.CodeMessager {
beatmapID := common.Int(md.Query("b")) beatmapID := common.Int(md.Query("b"))

View File

@ -78,7 +78,7 @@ type submitRequestData struct {
// BeatmapRankRequestsSubmitPOST submits a new beatmap for ranking approval. // BeatmapRankRequestsSubmitPOST submits a new beatmap for ranking approval.
func BeatmapRankRequestsSubmitPOST(md common.MethodData) common.CodeMessager { func BeatmapRankRequestsSubmitPOST(md common.MethodData) common.CodeMessager {
var d submitRequestData var d submitRequestData
err := md.RequestData.Unmarshal(&d) err := md.Unmarshal(&d)
if err != nil { if err != nil {
return ErrBadJSON return ErrBadJSON
} }
@ -91,9 +91,6 @@ func BeatmapRankRequestsSubmitPOST(md common.MethodData) common.CodeMessager {
if !limit.NonBlockingRequest("rankrequest:u:"+strconv.Itoa(md.ID()), 5) { if !limit.NonBlockingRequest("rankrequest:u:"+strconv.Itoa(md.ID()), 5) {
return common.SimpleResponse(429, "You may only try to request 5 beatmaps per minute.") return common.SimpleResponse(429, "You may only try to request 5 beatmaps per minute.")
} }
if !limit.NonBlockingRequest("rankrequest:ip:"+md.C.ClientIP(), 8) {
return common.SimpleResponse(429, "You may only try to request 8 beatmaps per minute from the same IP.")
}
// find out from BeatmapRankRequestsStatusGET if we can submit beatmaps. // find out from BeatmapRankRequestsStatusGET if we can submit beatmaps.
statusRaw := BeatmapRankRequestsStatusGET(md) statusRaw := BeatmapRankRequestsStatusGET(md)

View File

@ -134,7 +134,7 @@ func FriendsAddPOST(md common.MethodData) common.CodeMessager {
var u struct { var u struct {
User int `json:"user"` User int `json:"user"`
} }
md.RequestData.Unmarshal(&u) md.Unmarshal(&u)
return addFriend(md, u.User) return addFriend(md, u.User)
} }
@ -183,7 +183,7 @@ func FriendsDelPOST(md common.MethodData) common.CodeMessager {
var u struct { var u struct {
User int `json:"user"` User int `json:"user"`
} }
md.RequestData.Unmarshal(&u) md.Unmarshal(&u)
return delFriend(md, u.User) return delFriend(md, u.User)
} }

View File

@ -14,7 +14,7 @@ type setAllowedData struct {
// UserManageSetAllowedPOST allows to set the allowed status of an user. // UserManageSetAllowedPOST allows to set the allowed status of an user.
func UserManageSetAllowedPOST(md common.MethodData) common.CodeMessager { func UserManageSetAllowedPOST(md common.MethodData) common.CodeMessager {
data := setAllowedData{} data := setAllowedData{}
if err := md.RequestData.Unmarshal(&data); err != nil { if err := md.Unmarshal(&data); err != nil {
return ErrBadJSON return ErrBadJSON
} }
if data.Allowed < 0 || data.Allowed > 2 { if data.Allowed < 0 || data.Allowed > 2 {

View File

@ -62,7 +62,7 @@ type userSettingsData struct {
// UsersSelfSettingsPOST allows to modify information about the current user. // UsersSelfSettingsPOST allows to modify information about the current user.
func UsersSelfSettingsPOST(md common.MethodData) common.CodeMessager { func UsersSelfSettingsPOST(md common.MethodData) common.CodeMessager {
var d userSettingsData var d userSettingsData
md.RequestData.Unmarshal(&d) md.Unmarshal(&d)
// input sanitisation // input sanitisation
*d.UsernameAKA = common.SanitiseString(*d.UsernameAKA) *d.UsernameAKA = common.SanitiseString(*d.UsernameAKA)

View File

@ -8,10 +8,10 @@ import (
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"golang.org/x/crypto/bcrypt"
"zxq.co/ripple/rippleapi/common" "zxq.co/ripple/rippleapi/common"
"zxq.co/ripple/rippleapi/limit" "zxq.co/ripple/rippleapi/limit"
"zxq.co/ripple/schiavolib" "zxq.co/ripple/schiavolib"
"golang.org/x/crypto/bcrypt"
) )
type tokenNewInData struct { type tokenNewInData struct {
@ -37,7 +37,7 @@ type tokenNewResponse struct {
func TokenNewPOST(md common.MethodData) common.CodeMessager { func TokenNewPOST(md common.MethodData) common.CodeMessager {
var r tokenNewResponse var r tokenNewResponse
data := tokenNewInData{} data := tokenNewInData{}
err := md.RequestData.Unmarshal(&data) err := md.Unmarshal(&data)
if err != nil { if err != nil {
return ErrBadJSON return ErrBadJSON
} }
@ -80,7 +80,7 @@ func TokenNewPOST(md common.MethodData) common.CodeMessager {
} }
privileges := common.UserPrivileges(privilegesRaw) privileges := common.UserPrivileges(privilegesRaw)
if !limit.NonBlockingRequest(fmt.Sprintf("loginattempt:%d:%s", r.ID, md.C.ClientIP()), 5) { if !limit.NonBlockingRequest(fmt.Sprintf("loginattempt:%d:%s", r.ID, md.ClientIP()), 5) {
return common.SimpleResponse(429, "You've made too many login attempts. Try again later.") return common.SimpleResponse(429, "You've made too many login attempts. Try again later.")
} }

View File

@ -5,9 +5,9 @@ import (
"database/sql" "database/sql"
"strconv" "strconv"
"strings" "strings"
"unicode"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"zxq.co/ripple/ocl" "zxq.co/ripple/ocl"
"zxq.co/ripple/rippleapi/common" "zxq.co/ripple/rippleapi/common"
) )
@ -71,8 +71,7 @@ type userPutsMultiUserData struct {
} }
func userPutsMulti(md common.MethodData) common.CodeMessager { func userPutsMulti(md common.MethodData) common.CodeMessager {
q := md.C.Request.URL.Query() pm := md.Ctx.Request.URI().QueryArgs().PeekMulti
// query composition // query composition
wh := common. wh := common.
Where("users.username_safe = ?", common.SafeUsername(md.Query("nname"))). Where("users.username_safe = ?", common.SafeUsername(md.Query("nname"))).
@ -83,10 +82,10 @@ func userPutsMulti(md common.MethodData) common.CodeMessager {
Where("users_stats.country = ?", md.Query("country")). Where("users_stats.country = ?", md.Query("country")).
Where("users_stats.username_aka = ?", md.Query("name_aka")). Where("users_stats.username_aka = ?", md.Query("name_aka")).
Where("privileges_groups.name = ?", md.Query("privilege_group")). Where("privileges_groups.name = ?", md.Query("privilege_group")).
In("users.id", q["ids"]...). In("users.id", pm("ids")...).
In("users.username_safe", safeUsernameBulk(q["names"])...). In("users.username_safe", safeUsernameBulk(pm("names"))...).
In("users_stats.username_aka", q["names_aka"]...). In("users_stats.username_aka", pm("names_aka")...).
In("users_stats.country", q["countries"]...) In("users_stats.country", pm("countries")...)
var extraJoin string var extraJoin string
if md.Query("privilege_group") != "" { if md.Query("privilege_group") != "" {
@ -130,13 +129,19 @@ func userPutsMulti(md common.MethodData) common.CodeMessager {
// UserSelfGET is a shortcut for /users/id/self. (/users/self) // UserSelfGET is a shortcut for /users/id/self. (/users/self)
func UserSelfGET(md common.MethodData) common.CodeMessager { func UserSelfGET(md common.MethodData) common.CodeMessager {
md.C.Request.URL.RawQuery = "id=self&" + md.C.Request.URL.RawQuery md.Ctx.Request.URI().SetQueryString("id=self")
return UsersGET(md) return UsersGET(md)
} }
func safeUsernameBulk(us []string) []string { func safeUsernameBulk(us [][]byte) [][]byte {
for i, u := range us { for _, u := range us {
us[i] = common.SafeUsername(u) for idx, v := range u {
if v == ' ' {
u[idx] = '_'
continue
}
u[idx] = byte(unicode.ToLower(rune(v)))
}
} }
return us return us
} }
@ -341,7 +346,7 @@ func UserSelfUserpagePOST(md common.MethodData) common.CodeMessager {
var d struct { var d struct {
Data *string `json:"data"` Data *string `json:"data"`
} }
md.RequestData.Unmarshal(&d) md.Unmarshal(&d)
if d.Data == nil { if d.Data == nil {
return ErrMissingField("data") return ErrMissingField("data")
} }
@ -350,7 +355,7 @@ func UserSelfUserpagePOST(md common.MethodData) common.CodeMessager {
if err != nil { if err != nil {
md.Err(err) md.Err(err)
} }
md.C.Request.URL.RawQuery += "&id=" + strconv.Itoa(md.ID()) md.Ctx.URI().SetQueryString("id=self")
return UserUserpageGET(md) return UserUserpageGET(md)
} }

29
common/conversions.go Normal file
View File

@ -0,0 +1,29 @@
package common
import (
"reflect"
"unsafe"
)
// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
//
// Note it may break if string and/or slice header will change
// in the future go versions.
func b2s(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// s2b converts string to a byte slice without memory allocation.
//
// Note it may break if string and/or slice header will change
// in the future go versions.
func s2b(s string) []byte {
sh := (*reflect.StringHeader)(unsafe.Pointer(&s))
bh := reflect.SliceHeader{
Data: sh.Data,
Len: sh.Len,
Cap: sh.Len,
}
return *(*[]byte)(unsafe.Pointer(&bh))
}

View File

@ -2,26 +2,132 @@ package common
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strconv"
"strings"
"github.com/DataDog/datadog-go/statsd" "github.com/DataDog/datadog-go/statsd"
"github.com/gin-gonic/gin" "github.com/getsentry/raven-go"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/valyala/fasthttp"
"gopkg.in/redis.v5" "gopkg.in/redis.v5"
) )
// RavenClient is the raven client to which report errors happening.
// If nil, errors will just be fmt.Println'd
var RavenClient *raven.Client
// MethodData is a struct containing the data passed over to an API method. // MethodData is a struct containing the data passed over to an API method.
type MethodData struct { type MethodData struct {
User Token User Token
DB *sqlx.DB DB *sqlx.DB
RequestData RequestData
C *gin.Context
Doggo *statsd.Client Doggo *statsd.Client
R *redis.Client R *redis.Client
Ctx *fasthttp.RequestCtx
} }
// Err logs an error into gin. // ClientIP implements a best effort algorithm to return the real client IP, it parses
// X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
func (md MethodData) ClientIP() string {
clientIP := strings.TrimSpace(string(md.Ctx.Request.Header.Peek("X-Real-Ip")))
if len(clientIP) > 0 {
return clientIP
}
clientIP = string(md.Ctx.Request.Header.Peek("X-Forwarded-For"))
if index := strings.IndexByte(clientIP, ','); index >= 0 {
clientIP = clientIP[0:index]
}
clientIP = strings.TrimSpace(clientIP)
if len(clientIP) > 0 {
return clientIP
}
return md.Ctx.RemoteIP().String()
}
// Err logs an error. If RavenClient is set, it will use the client to report
// the error to sentry, otherwise it will just write the error to stdout.
func (md MethodData) Err(err error) { func (md MethodData) Err(err error) {
md.C.Error(err) if RavenClient == nil {
fmt.Println("ERROR!!!!")
fmt.Println(err)
return
}
// Create stacktrace
st := raven.NewStacktrace(0, 3, []string{"zxq.co/ripple", "git.zxq.co/ripple"})
// Generate tags for error
tags := map[string]string{
"endpoint": b2s(md.Ctx.RequestURI()),
"token": md.User.Value,
}
RavenClient.CaptureError(
err,
tags,
st,
generateRavenHTTP(md.Ctx),
&raven.User{
ID: strconv.Itoa(md.User.UserID),
Username: md.User.Value,
IP: md.Ctx.RemoteAddr().String(),
},
)
}
// Err for peppy API calls
func Err(c *fasthttp.RequestCtx, err error) {
if RavenClient == nil {
fmt.Println("ERROR!!!!")
fmt.Println(err)
return
}
// Create stacktrace
st := raven.NewStacktrace(0, 3, []string{"zxq.co/ripple", "git.zxq.co/ripple"})
// Generate tags for error
tags := map[string]string{
"endpoint": b2s(c.RequestURI()),
}
RavenClient.CaptureError(
err,
tags,
st,
generateRavenHTTP(c),
)
}
func generateRavenHTTP(ctx *fasthttp.RequestCtx) *raven.Http {
// build uri
uri := ctx.URI()
// safe to use b2s because a new string gets allocated eventually for
// concatenation
sURI := b2s(uri.Scheme()) + "://" + b2s(uri.Host()) + b2s(uri.Path())
// build header map
// using ctx.Request.Header.Len would mean calling .VisitAll two times
// which can be quite expensive since it means iterating over all the
// headers, so we give a rough estimate of the number of headers we expect
// to have
m := make(map[string]string, 16)
ctx.Request.Header.VisitAll(func(k, v []byte) {
// not using b2s because we mustn't keep references to the underlying
// k and v
m[string(k)] = string(v)
})
return &raven.Http{
URL: sURI,
// Not using b2s because raven sending is concurrent and may happen
// AFTER the request, meaning that values could potentially be replaced
// by new ones.
Method: string(ctx.Method()),
Query: string(uri.QueryString()),
Cookies: string(ctx.Request.Header.Peek("Cookie")),
Headers: m,
}
} }
// ID retrieves the Token's owner user ID. // ID retrieves the Token's owner user ID.
@ -31,23 +137,16 @@ func (md MethodData) ID() int {
// Query is shorthand for md.C.Query. // Query is shorthand for md.C.Query.
func (md MethodData) Query(q string) string { func (md MethodData) Query(q string) string {
return md.C.Query(q) return b2s(md.Ctx.QueryArgs().Peek(q))
} }
// HasQuery returns true if the parameter is encountered in the querystring. // HasQuery returns true if the parameter is encountered in the querystring.
// It returns true even if the parameter is "" (the case of ?param&etc=etc) // It returns true even if the parameter is "" (the case of ?param&etc=etc)
func (md MethodData) HasQuery(q string) bool { func (md MethodData) HasQuery(q string) bool {
_, has := md.C.GetQuery(q) return md.Ctx.QueryArgs().Has(q)
return has
} }
// RequestData is the body of a request. It is wrapped into this type // Unmarshal unmarshals a request's JSON body into an interface.
// to implement the Unmarshal function, which is just a shorthand to func (md MethodData) Unmarshal(into interface{}) error {
// json.Unmarshal. return json.Unmarshal(md.Ctx.PostBody(), into)
type RequestData []byte
// Unmarshal json-decodes Requestdata into a value. Basically a
// shorthand to json.Unmarshal.
func (r RequestData) Unmarshal(into interface{}) error {
return json.Unmarshal([]byte(r), into)
} }

View File

@ -19,8 +19,8 @@ func Sort(md MethodData, config SortConfiguration) string {
config.Table += "." config.Table += "."
} }
var sortBy string var sortBy string
for _, s := range md.C.Request.URL.Query()["sort"] { for _, s := range md.Ctx.Request.URI().QueryArgs().PeekMulti("sort") {
sortParts := strings.Split(strings.ToLower(s), ",") sortParts := strings.Split(strings.ToLower(b2s(s)), ",")
if contains(config.Allowed, sortParts[0]) { if contains(config.Allowed, sortParts[0]) {
if sortBy != "" { if sortBy != "" {
sortBy += ", " sortBy += ", "

View File

@ -51,15 +51,15 @@ func (w *WhereClause) And() *WhereClause {
// initial is the initial part, e.g. "users.id". // initial is the initial part, e.g. "users.id".
// Fields are the possible values. // Fields are the possible values.
// Sample output: users.id IN ('1', '2', '3') // Sample output: users.id IN ('1', '2', '3')
func (w *WhereClause) In(initial string, fields ...string) *WhereClause { func (w *WhereClause) In(initial string, fields ...[]byte) *WhereClause {
if len(fields) == 0 { if len(fields) == 0 {
return w return w
} }
w.addWhere() w.addWhere()
w.Clause += initial + " IN (" + generateQuestionMarks(len(fields)) + ")" w.Clause += initial + " IN (" + generateQuestionMarks(len(fields)) + ")"
fieldsInterfaced := make([]interface{}, 0, len(fields)) fieldsInterfaced := make([]interface{}, len(fields))
for _, i := range fields { for k, f := range fields {
fieldsInterfaced = append(fieldsInterfaced, interface{}(i)) fieldsInterfaced[k] = string(f)
} }
w.Params = append(w.Params, fieldsInterfaced...) w.Params = append(w.Params, fieldsInterfaced...)
return w return w

View File

@ -62,5 +62,5 @@ func main() {
engine := app.Start(conf, db) engine := app.Start(conf, db)
startuato(engine) startuato(engine.Handler)
} }

View File

@ -3,19 +3,18 @@
package main package main
import ( import (
"fmt"
"log" "log"
"net" "net"
"net/http"
"fmt"
"time" "time"
"zxq.co/ripple/schiavolib"
"zxq.co/ripple/rippleapi/common"
"github.com/gin-gonic/gin"
"github.com/rcrowley/goagain" "github.com/rcrowley/goagain"
"github.com/valyala/fasthttp"
"zxq.co/ripple/rippleapi/common"
"zxq.co/ripple/schiavolib"
) )
func startuato(engine *gin.Engine) { func startuato(hn fasthttp.RequestHandler) {
conf, _ := common.Load() conf, _ := common.Load()
// Inherit a net.Listener from our parent process or listen anew. // Inherit a net.Listener from our parent process or listen anew.
l, err := goagain.Listener() l, err := goagain.Listener()
@ -35,13 +34,12 @@ func startuato(engine *gin.Engine) {
schiavo.Bunker.Send(fmt.Sprint("LISTENINGU STARTUATO ON ", l.Addr())) schiavo.Bunker.Send(fmt.Sprint("LISTENINGU STARTUATO ON ", l.Addr()))
// Accept connections in a new goroutine. // Accept connections in a new goroutine.
go http.Serve(l, engine) go fasthttp.Serve(l, hn)
} else { } else {
// Resume accepting connections in a new goroutine. // Resume accepting connections in a new goroutine.
schiavo.Bunker.Send(fmt.Sprint("LISTENINGU RESUMINGU ON ", l.Addr())) schiavo.Bunker.Send(fmt.Sprint("LISTENINGU RESUMINGU ON ", l.Addr()))
go http.Serve(l, engine) go fasthttp.Serve(l, hn)
// Kill the parent, now that the child has started successfully. // Kill the parent, now that the child has started successfully.
if err := goagain.Kill(); nil != err { if err := goagain.Kill(); nil != err {

View File

@ -1,23 +1,26 @@
// +build windows // +build windows
// The Ripple API on Windows is not officially supported and you're probably
// gonna swear a lot if you intend to use it on Windows. Caveat emptor.
package main package main
import ( import (
"net"
"log" "log"
"net/http" "net"
"github.com/gin-gonic/gin" "github.com/valyala/fasthttp"
"zxq.co/ripple/rippleapi/common" "zxq.co/ripple/rippleapi/common"
) )
func startuato(engine *gin.Engine) { func startuato(hn fasthttp.RequestHandler) {
conf, _ := common.Load() conf, _ := common.Load()
var ( var (
l net.Listener l net.Listener
err error err error
) )
// Listen on a TCP or a UNIX domain socket (TCP here).
// Listen on a TCP or a UNIX domain socket.
if conf.Unix { if conf.Unix {
l, err = net.Listen("unix", conf.ListenTo) l, err = net.Listen("unix", conf.ListenTo)
} else { } else {
@ -27,5 +30,5 @@ func startuato(engine *gin.Engine) {
log.Fatalln(err) log.Fatalln(err)
} }
http.Serve(l, engine) fasthttp.Serve(l, hn)
} }