replace zxq.co/ripple/hanayo

This commit is contained in:
Alicia
2019-02-23 13:29:15 +00:00
commit c3d206c173
5871 changed files with 1353715 additions and 0 deletions

View File

@@ -0,0 +1,65 @@
package common
import (
"fmt"
"github.com/thehowl/conf"
)
// Version is the git hash of the application. Do not edit. This is
// automatically set using -ldflags during build time.
var Version string
// Conf is the configuration file data for the ripple API.
// Conf uses https://github.com/thehowl/conf
type Conf struct {
DatabaseType string `description:"At the moment, 'mysql' is the only supported database type."`
DSN string `description:"The Data Source Name for the database. More: https://github.com/go-sql-driver/mysql#dsn-data-source-name"`
ListenTo string `description:"The IP/Port combination from which to take connections, e.g. :8080"`
Unix bool `description:"Bool indicating whether ListenTo is a UNIX socket or an address."`
SentryDSN string `description:"thing for sentry whatever"`
HanayoKey string
BeatmapRequestsPerUser int
RankQueueSize int
OsuAPIKey string
RedisAddr string
RedisPassword string
RedisDB int
}
var cachedConf *Conf
// Load creates a new Conf, using the data in the file "api.conf".
func Load() (c Conf, halt bool) {
if cachedConf != nil {
c = *cachedConf
return
}
err := conf.Load(&c, "api.conf")
halt = err == conf.ErrNoFile
if halt {
conf.MustExport(Conf{
DatabaseType: "mysql",
DSN: "root@/ripple",
ListenTo: ":40001",
Unix: false,
HanayoKey: "Potato",
BeatmapRequestsPerUser: 2,
RankQueueSize: 25,
RedisAddr: "localhost:6379",
}, "api.conf")
fmt.Println("Please compile the configuration file (api.conf).")
}
cachedConf = &c
return
}
// GetConf returns the cachedConf.
func GetConf() *Conf {
if cachedConf == nil {
return nil
}
// so that the cachedConf cannot actually get modified
c := *cachedConf
return &c
}

View File

@@ -0,0 +1,29 @@
package common
import (
"reflect"
"unsafe"
)
// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
//
// Note it may break if string and/or slice header will change
// in the future go versions.
func b2s(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// s2b converts string to a byte slice without memory allocation.
//
// Note it may break if string and/or slice header will change
// in the future go versions.
func s2b(s string) []byte {
sh := (*reflect.StringHeader)(unsafe.Pointer(&s))
bh := reflect.SliceHeader{
Data: sh.Data,
Len: sh.Len,
Cap: sh.Len,
}
return *(*[]byte)(unsafe.Pointer(&bh))
}

View File

@@ -0,0 +1,8 @@
package common
// These are the flags an user can have. Mostly settings or things like whether
// the user has verified their email address.
const (
FlagEmailVerified = 1 << iota
FlagCountry2FA
)

View File

@@ -0,0 +1,25 @@
package common
import "strconv"
// In picks x if y < x, picks z if y > z, or if none of the previous
// conditions is satisfies, it simply picks y.
func In(x, y, z int) int {
switch {
case y < x:
return x
case y > z:
return z
}
return y
}
// InString takes y as a string, also allows for a default value should y be
// invalid as a number.
func InString(x int, y string, z, def int) int {
num, err := strconv.Atoi(y)
if err != nil {
return def
}
return In(x, num, z)
}

View File

@@ -0,0 +1,9 @@
package common
import "strconv"
// Int converts s to an int. If s in an invalid int, it defaults to 0.
func Int(s string) int {
r, _ := strconv.Atoi(s)
return r
}

View File

@@ -0,0 +1,166 @@
package common
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/DataDog/datadog-go/statsd"
"github.com/getsentry/raven-go"
"github.com/jmoiron/sqlx"
"github.com/valyala/fasthttp"
"gopkg.in/redis.v5"
)
// RavenClient is the raven client to which report errors happening.
// If nil, errors will just be fmt.Println'd
var RavenClient *raven.Client
// MethodData is a struct containing the data passed over to an API method.
type MethodData struct {
User Token
DB *sqlx.DB
Doggo *statsd.Client
R *redis.Client
Ctx *fasthttp.RequestCtx
}
// ClientIP implements a best effort algorithm to return the real client IP, it parses
// X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
func (md MethodData) ClientIP() string {
clientIP := strings.TrimSpace(string(md.Ctx.Request.Header.Peek("X-Real-Ip")))
if len(clientIP) > 0 {
return clientIP
}
clientIP = string(md.Ctx.Request.Header.Peek("X-Forwarded-For"))
if index := strings.IndexByte(clientIP, ','); index >= 0 {
clientIP = clientIP[0:index]
}
clientIP = strings.TrimSpace(clientIP)
if len(clientIP) > 0 {
return clientIP
}
return md.Ctx.RemoteIP().String()
}
// Err logs an error. If RavenClient is set, it will use the client to report
// the error to sentry, otherwise it will just write the error to stdout.
func (md MethodData) Err(err error) {
user := &raven.User{
ID: strconv.Itoa(md.User.UserID),
Username: md.User.Value,
IP: md.Ctx.RemoteAddr().String(),
}
// Generate tags for error
tags := map[string]string{
"endpoint": string(md.Ctx.RequestURI()),
"token": md.User.Value,
}
_err(err, tags, user, md.Ctx)
}
// Err for peppy API calls
func Err(c *fasthttp.RequestCtx, err error) {
// Generate tags for error
tags := map[string]string{
"endpoint": string(c.RequestURI()),
}
_err(err, tags, nil, c)
}
// WSErr is the error function for errors happening in the websockets.
func WSErr(err error) {
_err(err, map[string]string{
"endpoint": "/api/v1/ws",
}, nil, nil)
}
// GenericError is just an error. Can't make a good description.
func GenericError(err error) {
_err(err, nil, nil, nil)
}
func _err(err error, tags map[string]string, user *raven.User, c *fasthttp.RequestCtx) {
if RavenClient == nil {
fmt.Println("ERROR!!!!")
fmt.Println(err)
return
}
// Create stacktrace
st := raven.NewStacktrace(0, 3, []string{"zxq.co/ripple", "git.zxq.co/ripple"})
ifaces := []raven.Interface{st, generateRavenHTTP(c)}
if user != nil {
ifaces = append(ifaces, user)
}
RavenClient.CaptureError(
err,
tags,
ifaces...,
)
}
func generateRavenHTTP(ctx *fasthttp.RequestCtx) *raven.Http {
if ctx == nil {
return nil
}
// build uri
uri := ctx.URI()
// safe to use b2s because a new string gets allocated eventually for
// concatenation
sURI := b2s(uri.Scheme()) + "://" + b2s(uri.Host()) + b2s(uri.Path())
// build header map
// using ctx.Request.Header.Len would mean calling .VisitAll two times
// which can be quite expensive since it means iterating over all the
// headers, so we give a rough estimate of the number of headers we expect
// to have
m := make(map[string]string, 16)
ctx.Request.Header.VisitAll(func(k, v []byte) {
// not using b2s because we mustn't keep references to the underlying
// k and v
m[string(k)] = string(v)
})
return &raven.Http{
URL: sURI,
// Not using b2s because raven sending is concurrent and may happen
// AFTER the request, meaning that values could potentially be replaced
// by new ones.
Method: string(ctx.Method()),
Query: string(uri.QueryString()),
Cookies: string(ctx.Request.Header.Peek("Cookie")),
Headers: m,
}
}
// ID retrieves the Token's owner user ID.
func (md MethodData) ID() int {
return md.User.UserID
}
// Query is shorthand for md.C.Query.
func (md MethodData) Query(q string) string {
return b2s(md.Ctx.QueryArgs().Peek(q))
}
// HasQuery returns true if the parameter is encountered in the querystring.
// It returns true even if the parameter is "" (the case of ?param&etc=etc)
func (md MethodData) HasQuery(q string) bool {
return md.Ctx.QueryArgs().Has(q)
}
// Unmarshal unmarshals a request's JSON body into an interface.
func (md MethodData) Unmarshal(into interface{}) error {
return json.Unmarshal(md.Ctx.PostBody(), into)
}
// IsBearer tells whether the current token is a Bearer (oauth) token.
func (md MethodData) IsBearer() bool {
return md.User.ID == -1
}

View File

@@ -0,0 +1,22 @@
package common
import "fmt"
// Paginate creates an additional SQL LIMIT clause for paginating.
func Paginate(page, limit string, maxLimit int) string {
var (
p = Int(page)
l = Int(limit)
)
if p < 1 {
p = 1
}
if l < 1 {
l = 50
}
if l > maxLimit {
l = maxLimit
}
start := uint(p-1) * uint(l)
return fmt.Sprintf(" LIMIT %d,%d ", start, l)
}

View File

@@ -0,0 +1,49 @@
package common
import "testing"
func TestPaginate(t *testing.T) {
type args struct {
page string
limit string
maxLimit int
}
tests := []struct {
name string
args args
want string
}{
{
"1",
args{
"10",
"",
100,
},
" LIMIT 450,50 ",
},
{
"2",
args{
"-5",
"-15",
100,
},
" LIMIT 0,50 ",
},
{
"3",
args{
"2",
"150",
100,
},
" LIMIT 100,100 ",
},
}
for _, tt := range tests {
if got := Paginate(tt.args.page, tt.args.limit, tt.args.maxLimit); got != tt.want {
t.Errorf("%q. Paginate() = %v, want %v", tt.name, got, tt.want)
}
}
}

View File

@@ -0,0 +1,94 @@
package common
import "strings"
// These are the various privileges a token can have.
const (
PrivilegeRead = 1 << iota // used to be to fetch public data, such as user information etc. this is deprecated.
PrivilegeReadConfidential // (eventual) private messages, reports... of self
PrivilegeWrite // change user information, write into confidential stuff...
PrivilegeManageBadges // can change various users' badges.
PrivilegeBetaKeys // can add, remove, upgrade/downgrade, make public beta keys.
PrivilegeManageSettings // maintainance, set registrations, global alerts, bancho settings
PrivilegeViewUserAdvanced // can see user email, and perhaps warnings in the future, basically.
PrivilegeManageUser // can change user email, allowed status, userpage, rank, username...
PrivilegeManageRoles // translates as admin, as they can basically assign roles to anyone, even themselves
PrivilegeManageAPIKeys // admin permission to manage user permission, not only self permissions. Only ever do this if you completely trust the application, because this essentially means to put the entire ripple database in the hands of a (potentially evil?) application.
PrivilegeBlog // can do pretty much anything to the blog, and the documentation.
PrivilegeAPIMeta // can do /meta API calls. basically means they can restart the API server.
PrivilegeBeatmap // rank/unrank beatmaps. also BAT when implemented
)
// Privileges is a bitwise enum of the privileges of an user's API key.
type Privileges uint64
var privilegeString = [...]string{
"Read",
"ReadConfidential",
"Write",
"ManageBadges",
"BetaKeys",
"ManageSettings",
"ViewUserAdvanced",
"ManageUser",
"ManageRoles",
"ManageAPIKeys",
"Blog",
"APIMeta",
"Beatmap",
}
func (p Privileges) String() string {
var pvs []string
for i, v := range privilegeString {
if uint64(p)&uint64(1<<uint(i)) != 0 {
pvs = append(pvs, v)
}
}
return strings.Join(pvs, ", ")
}
var privilegeMustBe = [...]UserPrivileges{
1 << 30, // read is deprecated, and should be given out to no-one.
UserPrivilegeNormal,
UserPrivilegeNormal,
AdminPrivilegeAccessRAP | AdminPrivilegeManageBadges,
AdminPrivilegeAccessRAP | AdminPrivilegeManageBetaKey,
AdminPrivilegeAccessRAP | AdminPrivilegeManageSetting,
AdminPrivilegeAccessRAP,
AdminPrivilegeAccessRAP | AdminPrivilegeManageUsers | AdminPrivilegeBanUsers,
AdminPrivilegeAccessRAP | AdminPrivilegeManageUsers | AdminPrivilegeManagePrivilege,
AdminPrivilegeAccessRAP | AdminPrivilegeManageUsers | AdminPrivilegeManageServer,
AdminPrivilegeChatMod, // temporary?
AdminPrivilegeManageServer,
AdminPrivilegeAccessRAP | AdminPrivilegeManageBeatmap,
}
// CanOnly removes any privilege that the user has requested to have, but cannot have due to their rank.
func (p Privileges) CanOnly(userPrivs UserPrivileges) Privileges {
newPrivilege := 0
for i, v := range privilegeMustBe {
wants := p&1 == 1
can := userPrivs&v == v
if wants && can {
newPrivilege |= 1 << uint(i)
}
p >>= 1
}
return Privileges(newPrivilege)
}
var privilegeMap = map[string]Privileges{
"read_confidential": PrivilegeReadConfidential,
"write": PrivilegeWrite,
}
// OAuthPrivileges returns the equivalent in Privileges of a space-separated
// list of scopes.
func OAuthPrivileges(scopes string) Privileges {
var p Privileges
for _, x := range strings.Split(scopes, " ") {
p |= privilegeMap[x]
}
return p
}

View File

@@ -0,0 +1,33 @@
package common
import (
"math/rand"
"time"
)
const letterBytes = "0123456789abcdef"
const (
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
)
var randSrc = rand.NewSource(time.Now().UnixNano())
// RandomString generates a random string.
func RandomString(n int) string {
b := make([]byte, n)
// A randSrc.Int63() generates 63 random bits, enough for letterIdxMax characters!
for i, cache, remain := n-1, randSrc.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = randSrc.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}
return string(b)
}

View File

@@ -0,0 +1,36 @@
package common
// ResponseBase is the data that is always returned with an API request.
type ResponseBase struct {
Code int `json:"code"`
Message string `json:"message,omitempty"`
}
// GetCode retrieves the response code.
func (r ResponseBase) GetCode() int {
return r.Code
}
// SetCode changes the response code.
func (r *ResponseBase) SetCode(i int) {
r.Code = i
}
// GetMessage retrieves the response message.
func (r ResponseBase) GetMessage() string {
return r.Message
}
// CodeMessager is something that has the Code() and Message() methods.
type CodeMessager interface {
GetMessage() string
GetCode() int
}
// SimpleResponse returns the most basic response.
func SimpleResponse(code int, message string) CodeMessager {
return ResponseBase{
Code: code,
Message: message,
}
}

View File

@@ -0,0 +1,16 @@
package common
import (
"unicode"
)
// SanitiseString removes all control codes from a string.
func SanitiseString(s string) string {
n := make([]rune, 0, len(s))
for _, c := range s {
if c == '\n' || !unicode.Is(unicode.Other, c) {
n = append(n, c)
}
}
return string(n)
}

View File

@@ -0,0 +1,41 @@
package common
import "testing"
const pen = "I trattori di palmizio 나는 펜이있다. 私はリンゴを持っています。" +
"啊! 苹果笔。 у меня есть ручка, Tôi có dứa. අන්නාසි පෑන"
func TestSanitiseString(t *testing.T) {
tests := []struct {
name string
arg string
want string
}{
{
"Normal",
pen,
pen,
},
{
"Arabic (rtl)",
"أناناس",
"أناناس",
},
{
"Null",
"A\x00B",
"AB",
},
}
for _, tt := range tests {
if got := SanitiseString(tt.arg); got != tt.want {
t.Errorf("%q. SanitiseString() = %v, want %v", tt.name, got, tt.want)
}
}
}
func BenchmarkSanitiseString(b *testing.B) {
for i := 0; i < b.N; i++ {
SanitiseString(pen)
}
}

View File

@@ -0,0 +1,52 @@
package common
import "strings"
// SortConfiguration is the configuration of Sort.
type SortConfiguration struct {
Allowed []string // Allowed parameters
Default string
DefaultSorting string // if empty, DESC
Table string
}
// Sort allows the request to modify how the query is sorted.
func Sort(md MethodData, config SortConfiguration) string {
if config.DefaultSorting == "" {
config.DefaultSorting = "DESC"
}
if config.Table != "" {
config.Table += "."
}
var sortBy string
for _, s := range md.Ctx.Request.URI().QueryArgs().PeekMulti("sort") {
sortParts := strings.Split(strings.ToLower(b2s(s)), ",")
if contains(config.Allowed, sortParts[0]) {
if sortBy != "" {
sortBy += ", "
}
sortBy += config.Table + sortParts[0] + " "
if len(sortParts) > 1 && contains([]string{"asc", "desc"}, sortParts[1]) {
sortBy += sortParts[1]
} else {
sortBy += config.DefaultSorting
}
}
}
if sortBy == "" {
sortBy = config.Default
}
if sortBy == "" {
return ""
}
return "ORDER BY " + sortBy
}
func contains(a []string, s string) bool {
for _, el := range a {
if s == el {
return true
}
}
return false
}

View File

@@ -0,0 +1,23 @@
package common
import "fmt"
// Token is an API token.
type Token struct {
ID int
Value string
UserID int
TokenPrivileges Privileges
UserPrivileges UserPrivileges
}
// OnlyUserPublic returns a string containing "(user.privileges & 1 = 1 OR users.id = <userID>)"
// if the user does not have the UserPrivilege AdminManageUsers, and returns "1" otherwise.
func (t Token) OnlyUserPublic(userManagerSeesEverything bool) string {
if userManagerSeesEverything &&
t.UserPrivileges&AdminPrivilegeManageUsers == AdminPrivilegeManageUsers {
return "1"
}
// It's safe to use sprintf directly even if it's a query, because UserID is an int.
return fmt.Sprintf("(users.privileges & 1 = 1 OR users.id = '%d')", t.UserID)
}

View File

@@ -0,0 +1,55 @@
package common
import (
"errors"
"strconv"
"time"
)
// UnixTimestamp is simply a time.Time, but can be used to convert an
// unix timestamp in the database into a native time.Time.
type UnixTimestamp time.Time
// Scan decodes src into an unix timestamp.
func (u *UnixTimestamp) Scan(src interface{}) error {
if u == nil {
return errors.New("rippleapi/common: UnixTimestamp is nil")
}
switch src := src.(type) {
case int64:
*u = UnixTimestamp(time.Unix(src, 0))
case float64:
*u = UnixTimestamp(time.Unix(int64(src), 0))
case string:
return u._string(src)
case []byte:
return u._string(string(src))
case nil:
// Nothing, leave zero value on timestamp
default:
return errors.New("rippleapi/common: unhandleable type")
}
return nil
}
func (u *UnixTimestamp) _string(s string) error {
ts, err := strconv.Atoi(s)
if err != nil {
return err
}
*u = UnixTimestamp(time.Unix(int64(ts), 0))
return nil
}
// MarshalJSON -> time.Time.MarshalJSON
func (u UnixTimestamp) MarshalJSON() ([]byte, error) {
return time.Time(u).MarshalJSON()
}
// UnmarshalJSON -> time.Time.UnmarshalJSON
func (u *UnixTimestamp) UnmarshalJSON(x []byte) error {
t := new(time.Time)
err := t.UnmarshalJSON(x)
*u = UnixTimestamp(*t)
return err
}

View File

@@ -0,0 +1,32 @@
package common
import (
"reflect"
"strings"
)
// UpdateQuery is simply an SQL update query,
// that can be built upon passed parameters.
type UpdateQuery struct {
fields []string
Parameters []interface{}
}
// Add adds a new field with correspective value to UpdateQuery
func (u *UpdateQuery) Add(field string, value interface{}) *UpdateQuery {
val := reflect.ValueOf(value)
if val.Kind() == reflect.Ptr && val.IsNil() {
return u
}
if s, ok := value.(string); ok && s == "" {
return u
}
u.fields = append(u.fields, field+" = ?")
u.Parameters = append(u.Parameters, value)
return u
}
// Fields retrieves the fields joined by a comma.
func (u *UpdateQuery) Fields() string {
return strings.Join(u.fields, ", ")
}

View File

@@ -0,0 +1,68 @@
package common
import "strings"
// user/admin privileges
const (
UserPrivilegePublic UserPrivileges = 1 << iota
UserPrivilegeNormal
UserPrivilegeDonor
AdminPrivilegeAccessRAP
AdminPrivilegeManageUsers
AdminPrivilegeBanUsers
AdminPrivilegeSilenceUsers
AdminPrivilegeWipeUsers
AdminPrivilegeManageBeatmap
AdminPrivilegeManageServer
AdminPrivilegeManageSetting
AdminPrivilegeManageBetaKey
AdminPrivilegeManageReport
AdminPrivilegeManageDocs
AdminPrivilegeManageBadges
AdminPrivilegeViewRAPLogs
AdminPrivilegeManagePrivilege
AdminPrivilegeSendAlerts
AdminPrivilegeChatMod
AdminPrivilegeKickUsers
UserPrivilegePendingVerification
UserPrivilegeTournamentStaff
AdminPrivilegeCaker
)
// UserPrivileges represents a bitwise enum of the privileges of an user.
type UserPrivileges uint64
var userPrivilegeString = [...]string{
"UserPublic",
"UserNormal",
"UserDonor",
"AdminAccessRAP",
"AdminManageUsers",
"AdminBanUsers",
"AdminSilenceUsers",
"AdminWipeUsers",
"AdminManageBeatmap",
"AdminManageServer",
"AdminManageSetting",
"AdminManageBetaKey",
"AdminManageReport",
"AdminManageDocs",
"AdminManageBadges",
"AdminViewRAPLogs",
"AdminManagePrivilege",
"AdminSendAlerts",
"AdminChatMod",
"AdminKickUsers",
"UserPendingVerification",
"UserTournamentStaff",
}
func (p UserPrivileges) String() string {
var pvs []string
for i, v := range userPrivilegeString {
if uint64(p)&uint64(1<<uint(i)) != 0 {
pvs = append(pvs, v)
}
}
return strings.Join(pvs, ", ")
}

View File

@@ -0,0 +1,11 @@
package common
import (
"strings"
)
// SafeUsername makes a string lowercase and replaces all spaces with
// underscores.
func SafeUsername(s string) string {
return strings.Replace(strings.ToLower(s), " ", "_", -1)
}

View File

@@ -0,0 +1,20 @@
package common
import "testing"
func TestSafeUsername(t *testing.T) {
tests := []struct {
name string
arg string
want string
}{
{"noChange", "no_change", "no_change"},
{"toLower", "Change_Me", "change_me"},
{"complete", "La_M a m m a_putt na", "la_m_a_m_m_a_putt_na"},
}
for _, tt := range tests {
if got := SafeUsername(tt.arg); got != tt.want {
t.Errorf("%q. SafeUsername() = %v, want %v", tt.name, got, tt.want)
}
}
}

View File

@@ -0,0 +1,91 @@
package common
// WhereClause is a struct representing a where clause.
// This is made to easily create WHERE clauses from parameters passed from a request.
type WhereClause struct {
Clause string
Params []interface{}
useOr bool
}
// Where adds a new WHERE clause to the WhereClause.
func (w *WhereClause) Where(clause, passedParam string, allowedValues ...string) *WhereClause {
if passedParam == "" {
return w
}
if len(allowedValues) != 0 && !contains(allowedValues, passedParam) {
return w
}
w.addWhere()
w.Clause += clause
w.Params = append(w.Params, passedParam)
return w
}
func (w *WhereClause) addWhere() {
// if string is empty add "WHERE", else add AND
if w.Clause == "" {
w.Clause += "WHERE "
} else {
if w.useOr {
w.Clause += " OR "
return
}
w.Clause += " AND "
}
}
// Or enables using OR instead of AND
func (w *WhereClause) Or() *WhereClause {
w.useOr = true
return w
}
// And enables using AND instead of OR
func (w *WhereClause) And() *WhereClause {
w.useOr = false
return w
}
// In generates an IN clause.
// initial is the initial part, e.g. "users.id".
// Fields are the possible values.
// Sample output: users.id IN ('1', '2', '3')
func (w *WhereClause) In(initial string, fields ...[]byte) *WhereClause {
if len(fields) == 0 {
return w
}
w.addWhere()
w.Clause += initial + " IN (" + generateQuestionMarks(len(fields)) + ")"
fieldsInterfaced := make([]interface{}, len(fields))
for k, f := range fields {
fieldsInterfaced[k] = string(f)
}
w.Params = append(w.Params, fieldsInterfaced...)
return w
}
func generateQuestionMarks(x int) (qm string) {
for i := 0; i < x-1; i++ {
qm += "?, "
}
if x > 0 {
qm += "?"
}
return qm
}
// ClauseSafe returns the clause, always containing something. If w.Clause is
// empty, it returns "WHERE 1".
func (w *WhereClause) ClauseSafe() string {
if w.Clause == "" {
return "WHERE 1"
}
return w.Clause
}
// Where is the same as WhereClause.Where, but creates a new WhereClause.
func Where(clause, passedParam string, allowedValues ...string) *WhereClause {
w := new(WhereClause)
return w.Where(clause, passedParam, allowedValues...)
}

View File

@@ -0,0 +1,97 @@
package common
import (
"reflect"
"testing"
)
func Test_generateQuestionMarks(t *testing.T) {
type args struct {
x int
}
tests := []struct {
name string
args args
wantQm string
}{
{"-1", args{-1}, ""},
{"0", args{0}, ""},
{"1", args{1}, "?"},
{"2", args{2}, "?, ?"},
}
for _, tt := range tests {
if gotQm := generateQuestionMarks(tt.args.x); gotQm != tt.wantQm {
t.Errorf("%q. generateQuestionMarks() = %v, want %v", tt.name, gotQm, tt.wantQm)
}
}
}
func TestWhereClause_In(t *testing.T) {
type args struct {
initial string
fields []string
}
tests := []struct {
name string
fields *WhereClause
args args
want *WhereClause
}{
{
"simple",
&WhereClause{},
args{"users.id", []string{"1", "2", "3"}},
&WhereClause{"WHERE users.id IN (?, ?, ?)", []interface{}{"1", "2", "3"}, false},
},
{
"withExisting",
Where("users.username = ?", "Howl").Where("users.xd > ?", "6"),
args{"users.id", []string{"1"}},
&WhereClause{
"WHERE users.username = ? AND users.xd > ? AND users.id IN (?)",
[]interface{}{"Howl", "6", "1"},
false,
},
},
}
for _, tt := range tests {
w := tt.fields
if got := w.In(tt.args.initial, tt.args.fields...); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q. WhereClause.In() = %v, want %v", tt.name, got, tt.want)
}
}
}
func TestWhere(t *testing.T) {
type args struct {
clause string
passedParam string
allowedValues []string
}
tests := []struct {
name string
args args
want *WhereClause
}{
{
"simple",
args{"users.id = ?", "5", nil},
&WhereClause{"WHERE users.id = ?", []interface{}{"5"}, false},
},
{
"allowed",
args{"users.id = ?", "5", []string{"1", "3", "5"}},
&WhereClause{"WHERE users.id = ?", []interface{}{"5"}, false},
},
{
"notAllowed",
args{"users.id = ?", "5", []string{"0"}},
&WhereClause{},
},
}
for _, tt := range tests {
if got := Where(tt.args.clause, tt.args.passedParam, tt.args.allowedValues...); !reflect.DeepEqual(got, tt.want) {
t.Errorf("%q. Where() = %#v, want %#v", tt.name, got, tt.want)
}
}
}