Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable users to load schemas from GraphQL files. #215

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions graphql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
# Parse and operate on GraphQL language source files.
from .language.base import ( # no import order
Source,
FileSource,
get_location,
# Parse
parse,
Expand Down Expand Up @@ -223,6 +224,7 @@
"BREAK",
"ParallelVisitor",
"Source",
"FileSource",
"TypeInfoVisitor",
"get_location",
"parse",
Expand Down
1 change: 1 addition & 0 deletions graphql/language/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,7 @@ def __repr__(self):
"name={self.name!r}"
", arguments={self.arguments!r}"
", type={self.type!r}"
", directives={self.directives!r}"
")"
).format(self=self)

Expand Down
3 changes: 2 additions & 1 deletion graphql/language/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .location import get_location
from .parser import parse, parse_value
from .printer import print_ast
from .source import Source
from .source import FileSource, Source
from .visitor import BREAK, ParallelVisitor, TypeInfoVisitor, visit

__all__ = [
Expand All @@ -12,6 +12,7 @@
"parse_value",
"print_ast",
"Source",
"FileSource",
"BREAK",
"ParallelVisitor",
"TypeInfoVisitor",
Expand Down
45 changes: 44 additions & 1 deletion graphql/language/source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["Source"]
import os

__all__ = ["Source", "FileSource"]

class Source(object):
__slots__ = "body", "name"
Expand All @@ -15,3 +16,45 @@ def __eq__(self, other):
and self.body == other.body
and self.name == other.name
)

class FileSource(Source):
__slots__ = "body", "name"

def __init__(self, *args, **kwargs):
"""Create a Source using the specified GraphQL files' contents."""
name = kwargs.get("name", "GraphQL")

# From the specified list of paths, first identify all files. Then, load
# their contents into a single, newline delimited string.
file_contents = []
file_paths = self.__get_file_paths__(args)
for fp in file_paths:
with open(fp) as f:
file_contents.append(f.read())
body = '\n'.join(file_contents)

super(FileSource, self).__init__(body, name)

def __get_file_paths__(self, paths):
"""Get the paths to all files in the given list of paths. This means
filtering out invalid paths and recursively walking a given directory
path to gather the paths of all files that it contains."""
all_file_paths = []

# Filter out invalid paths.
valid_paths = [p for p in paths if os.path.exists(p)]

# Add all paths pointing to a file to all_file_paths.
all_file_paths += [p for p in valid_paths if os.path.isfile(p)]

# For each path referring to a directory, walk that directory's structure
# recursively, and add its constituent files' paths to all_file_paths.
all_file_paths += [
os.path.join(dir_name, file_name)
for p in valid_paths
if os.path.isdir(p)
for dir_name, _, files_in_dir in os.walk(p)
for file_name in files_in_dir
]

return all_file_paths
4 changes: 4 additions & 0 deletions graphql/language/tests/graphql_schemas/models/Person.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
type Person {
name: ID!
age: Int
}
9 changes: 9 additions & 0 deletions graphql/language/tests/graphql_schemas/models/Skill.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
type Skill {
name: ID!
level: Int
possessors: [Person!]
}

extend type Person {
skills: [Skill!]
}
8 changes: 8 additions & 0 deletions graphql/language/tests/graphql_schemas/schema.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
type Query {
person(name: ID!): Person
skill(name: ID!): Skill
}

schema {
query: Query
}
136 changes: 133 additions & 3 deletions graphql/language/tests/test_schema_parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from pytest import raises

from graphql import Source, parse
from graphql import FileSource, Source, parse
from graphql.error import GraphQLSyntaxError
from graphql.language import ast
from graphql.language.parser import Loc

import os

from pytest import raises

from typing import Callable


Expand Down Expand Up @@ -567,6 +570,133 @@ def test_parses_simple_input_object():
assert doc == expected


def test_parses_schema_files():
test_graphql_schemas_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "graphql_schemas")
doc = parse(FileSource(test_graphql_schemas_dir))
expected = ast.Document(
definitions=[
ast.ObjectTypeDefinition(
name=ast.Name(value="Query"),
interfaces=[],
fields=[
ast.FieldDefinition(
name=ast.Name(value="person"),
arguments=[
ast.InputValueDefinition(
name=ast.Name(value="name"),
type=ast.NonNullType(
type=ast.NamedType(name=ast.Name(value="ID"))
),
default_value=None,
directives=[]
)
],
type=ast.NamedType(name=ast.Name(value="Person")),
directives=[]
),
ast.FieldDefinition(
name=ast.Name(value="skill"),
arguments=[
ast.InputValueDefinition(
name=ast.Name(value="name"),
type=ast.NonNullType(
type=ast.NamedType(name=ast.Name(value="ID"))
),
default_value=None,
directives=[]
)
],
type=ast.NamedType(name=ast.Name(value="Skill")),
directives=[]
)
],
directives=[]
),
ast.SchemaDefinition(
operation_types=[
ast.OperationTypeDefinition(
operation="query",
type=ast.NamedType(name=ast.Name(value="Query"))
)
],
directives=[]
),
ast.ObjectTypeDefinition(
name=ast.Name(value="Person"),
interfaces=[],
fields=[
ast.FieldDefinition(
name=ast.Name(value="name"),
arguments=[],
type=ast.NonNullType(
type=ast.NamedType(name=ast.Name(value="ID"))
),
directives=[]
),
ast.FieldDefinition(
name=ast.Name(value="age"),
arguments=[],
type=ast.NamedType(name=ast.Name(value="Int")),
directives=[]
)
],
directives=[]
),
ast.ObjectTypeDefinition(
name=ast.Name(value="Skill"),
interfaces=[],
fields=[
ast.FieldDefinition(
name=ast.Name(value="name"),
arguments=[],
type=ast.NonNullType(
type=ast.NamedType(name=ast.Name(value="ID"))
),
directives=[]
),
ast.FieldDefinition(
name=ast.Name(value="level"),
arguments=[],
type=ast.NamedType(name=ast.Name(value="Int")),
directives=[]
),
ast.FieldDefinition(
name=ast.Name(value="possessors"),
arguments=[],
type=ast.ListType(
type=ast.NonNullType(
type=ast.NamedType(name=ast.Name(value="Person"))
)
),
directives=[]
)
],
directives=[]
),
ast.TypeExtensionDefinition(
definition=ast.ObjectTypeDefinition(
name=ast.Name(value="Person"),
interfaces=[],
fields=[
ast.FieldDefinition(
name=ast.Name(value="skills"),
arguments=[],
type=ast.ListType(
type=ast.NonNullType(
type=ast.NamedType(name=ast.Name(value="Skill"))
)
),
directives=[]
)
],
directives=[]
)
)
]
)
assert doc == expected


def test_parsing_simple_input_object_with_args_should_fail():
# type: () -> None
body = """
Expand Down