214 lines
4.1 KiB
Go

package gql
import (
"fmt"
"reflect"
"strings"
"github.com/graphql-go/graphql"
)
// TAG is the tag used for json
const TAG = "json"
// BindFields can't take recursive slice type
// e.g
// type Person struct{
// Friends []Person
// }
// it will throw panic stack-overflow
func BindFields(t reflect.Type) graphql.Fields {
fields := make(map[string]*graphql.Field)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// Skip private fields
if field.Tag.Get("private") == "true" {
continue
}
tag := extractTag(field.Tag)
if tag == "-" {
continue
}
fieldType := field.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
var graphType graphql.Output
if fieldType.Kind() == reflect.Struct {
structFields := BindFields(t.Field(i).Type)
if tag == "" {
fields = appendFields(fields, structFields)
continue
} else {
graphType = graphql.NewObject(graphql.ObjectConfig{
Name: t.Name() + "_" + field.Name,
Fields: structFields,
})
}
}
if tag == "" {
continue
}
if graphType == nil {
graphType = getGraphType(fieldType)
}
fields[tag] = &graphql.Field{
Type: graphType,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return extractValue(tag, p.Source), nil
},
}
}
return fields
}
func getGraphType(tipe reflect.Type) graphql.Output {
kind := tipe.Kind()
switch kind {
case reflect.String:
return graphql.String
case reflect.Int:
fallthrough
case reflect.Int8:
fallthrough
case reflect.Int32:
fallthrough
case reflect.Int64:
return graphql.Int
case reflect.Float32:
fallthrough
case reflect.Float64:
return graphql.Float
case reflect.Bool:
return graphql.Boolean
case reflect.Slice:
return getGraphList(tipe)
}
return graphql.String
}
func getGraphList(tipe reflect.Type) *graphql.List {
if tipe.Kind() == reflect.Slice {
switch tipe.Elem().Kind() {
case reflect.Int:
fallthrough
case reflect.Int8:
fallthrough
case reflect.Int32:
fallthrough
case reflect.Int64:
return graphql.NewList(graphql.Int)
case reflect.Bool:
return graphql.NewList(graphql.Boolean)
case reflect.Float32:
fallthrough
case reflect.Float64:
return graphql.NewList(graphql.Float)
case reflect.String:
return graphql.NewList(graphql.String)
}
}
// finally bind object
t := reflect.New(tipe.Elem())
name := strings.Replace(fmt.Sprint(tipe.Elem()), ".", "_", -1)
obj := graphql.NewObject(graphql.ObjectConfig{
Name: name,
Fields: BindFields(t.Elem().Type()),
})
return graphql.NewList(obj)
}
func appendFields(dest, origin graphql.Fields) graphql.Fields {
for key, value := range origin {
dest[key] = value
}
return dest
}
func extractValue(originTag string, obj interface{}) interface{} {
val := reflect.Indirect(reflect.ValueOf(obj))
for j := 0; j < val.NumField(); j++ {
field := val.Type().Field(j)
if field.Type.Kind() == reflect.Struct {
res := extractValue(originTag, val.Field(j).Interface())
if res != nil {
return res
}
}
if originTag == extractTag(field.Tag) {
return reflect.Indirect(val.Field(j)).Interface()
}
}
return nil
}
func extractTag(tag reflect.StructTag) string {
t := tag.Get(TAG)
if t != "" {
t = strings.Split(t, ",")[0]
}
return t
}
// // BindArg is a lazy way of binding args
// func BindArg(obj interface{}, tags ...string) graphql.FieldConfigArgument {
// v := reflect.Indirect(reflect.ValueOf(obj))
// var config = make(graphql.FieldConfigArgument)
// for i := 0; i < v.NumField(); i++ {
// field := v.Type().Field(i)
// mytag := extractTag(field.Tag)
// if inArray(tags, mytag) {
// config[mytag] = &graphql.ArgumentConfig{
// Type: getGraphType(field.Type),
// }
// }
// }
// return config
// }
// func inArray(slice interface{}, item interface{}) bool {
// s := reflect.ValueOf(slice)
// if s.Kind() != reflect.Slice {
// panic("inArray() given a non-slice type")
// }
// for i := 0; i < s.Len(); i++ {
// if reflect.DeepEqual(item, s.Index(i).Interface()) {
// return true
// }
// }
// return false
// }