diff --git a/app/v1/beatmap.go b/app/v1/beatmap.go index 0187829..60e101f 100644 --- a/app/v1/beatmap.go +++ b/app/v1/beatmap.go @@ -82,23 +82,16 @@ func BeatmapSetStatusPOST(md common.MethodData) common.CodeMessager { SET ranked = ?, ranked_status_freezed = ? WHERE beatmapset_id = ?`, req.RankedStatus, req.Frozen, param) - return getSet(md, param) + return getMultipleBeatmaps(md) } // BeatmapGET retrieves a beatmap. func BeatmapGET(md common.MethodData) common.CodeMessager { - if md.Query("s") == "" && md.Query("b") == "" { - return common.SimpleResponse(400, "Must pass either querystring param 'b' or 's'") - } - setID := common.Int(md.Query("s")) - if setID != 0 { - return getSet(md, setID) - } beatmapID := common.Int(md.Query("b")) if beatmapID != 0 { - return getBeatmap(md, beatmapID) + return getBeatmapSingle(md, beatmapID) } - return common.SimpleResponse(400, "Please pass either a valid beatmapset ID or a valid beatmap ID") + return getMultipleBeatmaps(md) } const baseBeatmapSelect = ` @@ -111,8 +104,33 @@ SELECT FROM beatmaps ` -func getSet(md common.MethodData, setID int) common.CodeMessager { - rows, err := md.DB.Query(baseBeatmapSelect+"WHERE beatmapset_id = ?", setID) +func getMultipleBeatmaps(md common.MethodData) common.CodeMessager { + sort := common.Sort(md, common.SortConfiguration{ + Allowed: []string{ + "beatmapset_id", + "beatmap_id", + "id", + "ar", + "od", + "difficulty_std", + "difficulty_taiko", + "difficulty_ctb", + "difficulty_mania", + "max_combo", + "latest_update", + "playcount", + "passcount", + }, + Default: "id DESC", + Table: "beatmaps", + }) + where := common.Where("beatmapsetid = ?", md.Query("s")). + Where("song_name = ?", md.Query("song_name")). + Where("ranked_status_freezed = ?", md.Query("ranked_status_frozen"), "0", "1") + + rows, err := md.DB.Query(baseBeatmapSelect+ + where.Clause+" "+sort+" "+ + common.Paginate(md.Query("p"), md.Query("l"), 50), where.Params...) if err != nil { md.Err(err) return Err500 @@ -138,7 +156,7 @@ func getSet(md common.MethodData, setID int) common.CodeMessager { return r } -func getBeatmap(md common.MethodData, beatmapID int) common.CodeMessager { +func getBeatmapSingle(md common.MethodData, beatmapID int) common.CodeMessager { var b beatmap err := md.DB.QueryRow(baseBeatmapSelect+"WHERE beatmap_id = ? LIMIT 1", beatmapID).Scan( &b.BeatmapID, &b.BeatmapsetID, &b.BeatmapMD5, diff --git a/common/where.go b/common/where.go new file mode 100644 index 0000000..96e2bc2 --- /dev/null +++ b/common/where.go @@ -0,0 +1,33 @@ +package common + +// WhereClause is a struct representing a where clause. +// This is made to easily create WHERE clauses from parameters passed from a request. +type WhereClause struct { + Clause string + Params []interface{} +} + +// Where adds a new WHERE clause to the WhereClause. +func (w *WhereClause) Where(clause, passedParam string, allowedValues ...string) *WhereClause { + if passedParam == "" { + return w + } + if len(allowedValues) != 0 && !contains(allowedValues, passedParam) { + return w + } + // checks passed, if string is empty add "WHERE" + if w.Clause == "" { + w.Clause += "WHERE " + } else { + w.Clause += " AND " + } + w.Clause += clause + w.Params = append(w.Params, passedParam) + return w +} + +// Where is the same as WhereClause.Where, but creates a new WhereClause. +func Where(clause, passedParam string, allowedValues ...string) *WhereClause { + w := new(WhereClause) + return w.Where(clause, passedParam, allowedValues...) +}