Extract the query fields in graphql-go

I have been using graphql-go for one of my projects at work, and I was looking for a way to extract the fields requested in the query to optimize the response.

The graphql API I am working on is actually a pass-through layer for another legacy API. Some of the fields are very heavy to query in the underlying API so we query for them only if the user is interested in those fields.

Sounds like the perfect use-case for GraphQL right? The request has a set of fields that the user is interested in. Except, it is not easy to extract the fields during query/mutation that the user is interested in.

I had to debug and inspect the code before I could write code to extract the fields that user wants to query from the response.

The set of fields are hidden behind what is known as an Abstract Syntax Tree, or AST for short. I do not claim to understand AST (or GraphQL internals fully), but here is an excellent post that explains the internals: https://medium.com/@cjoudrey/life-of-a-graphql-query-lexing-parsing-ca7c5045fad8.

For the purposes of the demo, consider the following schema:

type Person {
  id: ID!
  name: String!
  age: Int!
}

type Query {
  person(id: ID!): Person
}

Suppose that it takes an insane amount of time to calculate the age field, so we want to query it only if the user wants it and mentions in the field.

Here is the code for query in graphql-go, for querying the above schema:

package main

import (
	"encoding/json"
	"fmt"
	"net/http"

	"github.com/graphql-go/graphql"
	"github.com/graphql-go/graphql/language/ast"
)

type Person struct {
	id   string
	name string
	age  int
}

var personFields = graphql.Fields{
	"id": &graphql.Field{
		Type: graphql.NewNonNull(graphql.String),
		Resolve: func(p graphql.ResolveParams) (interface{}, error) {
			if person, ok := p.Source.(Person); ok {
				return person.id, nil
			}
			return nil, fmt.Errorf("expected Person")
		},
	},
	"name": &graphql.Field{
		Type: graphql.NewNonNull(graphql.String),
		Resolve: func(p graphql.ResolveParams) (interface{}, error) {
			if person, ok := p.Source.(Person); ok {
				return person.name, nil
			}
			return nil, fmt.Errorf("expected Person")
		},
	},
	"age": &graphql.Field{
		Type: graphql.NewNonNull(graphql.Int),
		Resolve: func(p graphql.ResolveParams) (interface{}, error) {
			if person, ok := p.Source.(Person); ok {
				return person.age, nil
			}
			return nil, fmt.Errorf("expected Person")
		},
	},
}

var personQuery = &graphql.Field{
	Type: graphql.NewObject(graphql.ObjectConfig{
		Name:   "Person",
		Fields: personFields,
	}),
	Args: graphql.FieldConfigArgument{
		"id": &graphql.ArgumentConfig{
			Type: graphql.NewNonNull(graphql.String),
		},
	},
	Resolve: func(p graphql.ResolveParams) (interface{}, error) {
		id, ok := p.Args["id"].(string)
		if !ok {
			return nil, fmt.Errorf("expected string")
		}
		queriedFields := getQueriedFields(&p)
		if queriedFields["age"] {
			// rest of the code here
			// query the database with age, otherwise skip it
		}
		return Person{id: id}, nil
	},
}

func main() {
	schema, _ := graphql.NewSchema(graphql.SchemaConfig{
		Query: graphql.NewObject(graphql.ObjectConfig{
			Name: "Query",
			Fields: graphql.Fields{
				"person": personQuery,
			},
		}),
	})
	handler := func(w http.ResponseWriter, r *http.Request) {
		result := executeQuery(r.URL.Query().Get("query"), schema)
		json.NewEncoder(w).Encode(result)
	}
	http.HandleFunc("/graphql", handler)
	http.ListenAndServe(":8080", nil)
}

func executeQuery(query string, schema graphql.Schema) *graphql.Result {
	result := graphql.Do(graphql.Params{
		Schema:        schema,
		RequestString: query,
	})
	if len(result.Errors) > 0 {
		fmt.Printf("unexpected errors: %v", result.Errors)
	}
	return result
}

The missing method which fetches the queried fields from the graphql.Params object:

// getQueriedFields returns the fields requested in the query/mutation
func getQueriedFields(p *graphql.ResolveParams) map[string]bool {
	fields := map[string]bool{}
	if p.Info.FieldASTs == nil {
		return fields
	}
	for _, fieldAST := range p.Info.FieldASTs {
		if fieldAST.SelectionSet == nil {
			return fields
		}
		for _, sel := range fieldAST.SelectionSet.Selections {
			field, ok := sel.(*ast.Field)
			if !ok {
				return fields
			}
			fields[field.Name.Value] = true
		}
	}
	return fields
}

So by calling getQueriedFields we can get the fields that the user has requested in the query, and then pass those fields list to the backend, either to only select fields for the SQL query, or something else.

This can also technically be applied to a mutation, in fetching the record after the mutation completes, but makes more sense in a query.

#go #golang #graphql