Skip to content

Commit 5501e43

Browse files
committed
Refactor schema reading in header
1 parent 0f51542 commit 5501e43

File tree

7 files changed

+130
-112
lines changed

7 files changed

+130
-112
lines changed

bulk_insert/bulk_insert.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,19 @@
1010
import module_vars
1111

1212

13-
# For each node input file, validate contents and convert to binary format.
14-
# If any buffer limits have been reached, flush all enqueued inserts to Redis.
15-
def process_entity_csvs(cls, csvs, separator):
16-
for in_csv in csvs:
13+
def parse_schemas(cls, csvs):
14+
schemas = [None] * len(csvs)
15+
for idx, in_csv in enumerate(csvs):
1716
# Build entity descriptor from input CSV
18-
entity = cls(in_csv, separator)
17+
schemas[idx] = cls(in_csv)
18+
return schemas
19+
20+
21+
# For each input file, validate contents and convert to binary format.
22+
# If any buffer limits have been reached, flush all enqueued inserts to Redis.
23+
def process_entities(entities):
24+
for entity in entities:
25+
entity.process_entities()
1926
added_size = entity.binary_size
2027
# Check to see if the addition of this data will exceed the buffer's capacity
2128
if (module_vars.QUERY_BUF.buffer_size + added_size >= module_vars.CONFIGS.max_buffer_size
@@ -42,24 +49,16 @@ def process_entity_csvs(cls, csvs, separator):
4249
@click.option('--max-buffer-size', '-b', default=2048, help='max buffer size in megabytes (default 2048)')
4350
@click.option('--max-token-size', '-t', default=500, help='max size of each token in megabytes (default 500, max 512)')
4451
@click.option('--quote', '-q', default=3, help='the quoting format used in the CSV file. QUOTE_MINIMAL=0,QUOTE_ALL=1,QUOTE_NONNUMERIC=2,QUOTE_NONE=3')
45-
@click.option('--field-types', '-f', default=None, help='json to set explicit types for each field, format {<label>:[<col1 type>, <col2 type> ...]} where type can be 0(null),1(bool),2(numeric),3(string)')
4652
@click.option('--skip-invalid-nodes', '-s', default=False, is_flag=True, help='ignore nodes that use previously defined IDs')
4753
@click.option('--skip-invalid-edges', '-e', default=False, is_flag=True, help='ignore invalid edges, print an error message and continue loading (True), or stop loading after an edge loading failure (False)')
48-
@click.option('--enforce-schema', '-S', default=False, is_flag=True, help='header line introduces property schema')
49-
def bulk_insert(graph, host, port, password, nodes, relations, separator, max_token_count, max_buffer_size, max_token_size, quote, field_types, skip_invalid_nodes, skip_invalid_edges, enforce_schema):
54+
def bulk_insert(graph, host, port, password, nodes, relations, separator, max_token_count, max_buffer_size, max_token_size, quote, skip_invalid_nodes, skip_invalid_edges):
5055
if sys.version_info[0] < 3:
5156
raise Exception("Python 3 is required for the RedisGraph bulk loader.")
5257

53-
if field_types is not None:
54-
try:
55-
module_vars.FIELD_TYPES = json.loads(field_types)
56-
except:
57-
raise Exception("Problem parsing field-types. Use the format {<label>:[<col1 type>, <col2 type> ...]} where type can be 0(null),1(bool),2(numeric),3(string) ")
58-
5958
module_vars.QUOTING = int(quote)
6059

6160
module_vars.TOP_NODE_ID = 0 # reset global ID variable (in case we are calling bulk_insert from unit tests)
62-
module_vars.CONFIGS = Configs(max_token_count, max_buffer_size, max_token_size, skip_invalid_nodes, skip_invalid_edges, enforce_schema)
61+
module_vars.CONFIGS = Configs(max_token_count, max_buffer_size, max_token_size, skip_invalid_nodes, skip_invalid_edges, separator)
6362

6463
start_time = timer()
6564
# Attempt to connect to Redis server
@@ -85,9 +84,9 @@ def bulk_insert(graph, host, port, password, nodes, relations, separator, max_to
8584
print("Graph with name '%s', could not be created, as Redis key '%s' already exists." % (graph, graph))
8685
sys.exit(1)
8786

88-
# If we're enforcing a schema, validate the headers in each file?
89-
if enforce_schema:
90-
pass
87+
# Read the header rows of each input CSV and save its schema.
88+
labels = parse_schemas(Label, nodes)
89+
reltypes = parse_schemas(RelationType, relations)
9190

9291
module_vars.QUERY_BUF = QueryBuffer(graph, client)
9392

@@ -97,10 +96,10 @@ def bulk_insert(graph, host, port, password, nodes, relations, separator, max_to
9796
else:
9897
module_vars.NODE_DICT = None
9998

100-
process_entity_csvs(Label, nodes, separator)
99+
process_entities(labels)
101100

102101
if relations:
103-
process_entity_csvs(RelationType, relations, separator)
102+
process_entities(reltypes)
104103

105104
# Send all remaining tokens to Redis
106105
module_vars.QUERY_BUF.send_buffer()

bulk_insert/configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# User-configurable thresholds for when to send queries to Redis
22
class Configs(object):
3-
def __init__(self, max_token_count, max_buffer_size, max_token_size, skip_invalid_nodes, skip_invalid_edges, enforce_schema):
3+
def __init__(self, max_token_count, max_buffer_size, max_token_size, skip_invalid_nodes, skip_invalid_edges, separator):
44
# Maximum number of tokens per query
55
# 1024 * 1024 is the hard-coded Redis maximum. We'll set a slightly lower limit so
66
# that we can safely ignore tokens that aren't binary strings
@@ -15,4 +15,4 @@ def __init__(self, max_token_count, max_buffer_size, max_token_size, skip_invali
1515
self.skip_invalid_nodes = skip_invalid_nodes
1616
self.skip_invalid_edges = skip_invalid_edges
1717

18-
self.enforce_schema = enforce_schema
18+
self.separator = separator

bulk_insert/entity_file.py

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,151 @@
11
import os
22
import io
33
import csv
4+
import math
45
import struct
56
import module_vars
6-
from exceptions import CSVError
7-
import schema
7+
import configs
8+
from exceptions import CSVError, SchemaError
9+
from schema import Type, convert_schema_type
810

911

1012
# Convert a single CSV property field into a binary stream.
1113
# Supported property types are string, numeric, boolean, and NULL.
1214
# type is either Type.DOUBLE, Type.BOOL or Type.STRING, and explicitly sets the value to this type if possible
13-
def prop_to_binary(prop_val, type):
14-
# All format strings start with an unsigned char to represent our Type enum
15+
def prop_to_binary(prop_val, prop_type):
16+
# All format strings start with an unsigned char to represent our prop_type enum
1517
format_str = "=B"
1618
if prop_val is None:
1719
# An empty field indicates a NULL property
1820
return struct.pack(format_str, Type.NULL)
1921

2022
# If field can be cast to a float, allow it
21-
if type == None or type == Type.DOUBLE:
23+
if prop_type is None or prop_type == Type.DOUBLE:
2224
try:
2325
numeric_prop = float(prop_val)
2426
if not math.isnan(numeric_prop) and not math.isinf(numeric_prop): # Don't accept non-finite values.
2527
return struct.pack(format_str + "d", Type.DOUBLE, numeric_prop)
2628
except:
27-
pass
29+
raise SchemaError("Could not parse '%s' as a double" % prop_val)
2830

29-
if type == None or type == Type.BOOL:
31+
if prop_type is None or prop_type == Type.BOOL:
3032
# If field is 'false' or 'true', it is a boolean
3133
if prop_val.lower() == 'false':
3234
return struct.pack(format_str + '?', Type.BOOL, False)
3335
elif prop_val.lower() == 'true':
3436
return struct.pack(format_str + '?', Type.BOOL, True)
3537

36-
if type == None or type == Type.STRING:
38+
if prop_type is None or prop_type == Type.STRING:
3739
# If we've reached this point, the property is a string
3840
encoded_str = str.encode(prop_val) # struct.pack requires bytes objects as arguments
3941
# Encoding len+1 adds a null terminator to the string
4042
format_str += "%ds" % (len(encoded_str) + 1)
41-
return struct.pack(format_str, schema.Type.STRING, encoded_str)
43+
return struct.pack(format_str, Type.STRING, encoded_str)
4244

45+
if prop_type in (Type.LABEL, Type.TYPE, Type.ID): # TODO tmp, treat as string for testing
46+
encoded_str = str.encode(prop_val) # struct.pack requires bytes objects as arguments
47+
# Encoding len+1 adds a null terminator to the string
48+
format_str += "%ds" % (len(encoded_str) + 1)
49+
return struct.pack(format_str, Type.STRING, encoded_str)
50+
51+
import ipdb
52+
ipdb.set_trace()
4353
# If it hasn't returned by this point, it is trying to set it to a type that it can't adopt
4454
raise Exception("unable to parse [" + prop_val + "] with type ["+repr(type)+"]")
4555

4656

4757
# Superclass for label and relation CSV files
4858
class EntityFile(object):
49-
def __init__(self, filename, separator):
59+
def __init__(self, filename):
5060
# The label or relation type string is the basename of the file
5161
self.entity_str = os.path.splitext(os.path.basename(filename))[0]
5262
# Input file handling
5363
self.infile = io.open(filename, 'rt')
5464
# Initialize CSV reader that ignores leading whitespace in each field
5565
# and does not modify input quote characters
56-
self.reader = csv.reader(self.infile, delimiter=separator, skipinitialspace=True, quoting=module_vars.QUOTING)
57-
58-
self.prop_offset = 0 # Starting index of properties in row
59-
self.prop_count = 0 # Number of properties per entity
66+
self.reader = csv.reader(self.infile, delimiter=module_vars.CONFIGS.separator, skipinitialspace=True, quoting=module_vars.QUOTING)
6067

6168
self.packed_header = b''
6269
self.binary_entities = []
6370
self.binary_size = 0 # size of binary token
71+
72+
# Extract data from header row.
73+
self.convert_header()
74+
6475
self.count_entities() # number of entities/row in file.
76+
next(self.reader) # Skip header for next read.
6577

6678
# Count number of rows in file.
6779
def count_entities(self):
6880
self.entities_count = 0
6981
self.entities_count = sum(1 for line in self.infile)
70-
# discard header row
71-
self.entities_count -= 1
7282
# seek back
7383
self.infile.seek(0)
7484
return self.entities_count
7585

7686
# Simple input validations for each row of a CSV file
77-
def validate_row(self, expected_col_count, row):
87+
def validate_row(self, row):
7888
# Each row should have the same number of fields
79-
if len(row) != expected_col_count:
89+
if len(row) != self.column_count:
8090
raise CSVError("%s:%d Expected %d columns, encountered %d ('%s')"
81-
% (self.infile.name, self.reader.line_num, expected_col_count, len(row), ','.join(row)))
91+
% (self.infile.name, self.reader.line_num, self.column_count, len(row), configs.separator.join(row)))
8292

8393
# If part of a CSV file was sent to Redis, delete the processed entities and update the binary size
8494
def reset_partial_binary(self):
8595
self.binary_entities = []
8696
self.binary_size = len(self.packed_header)
8797

8898
# Convert property keys from a CSV file header into a binary string
89-
def pack_header(self, header):
90-
prop_count = len(header) - self.prop_offset
99+
def pack_header(self):
91100
# String format
92101
entity_bytes = self.entity_str.encode()
93102
fmt = "=%dsI" % (len(entity_bytes) + 1) # Unaligned native, entity name, count of properties
94-
args = [entity_bytes, prop_count]
95-
for p in header[self.prop_offset:]:
96-
prop = p.encode()
103+
args = [entity_bytes, self.prop_count]
104+
for idx in range(self.column_count):
105+
if self.skip_offsets[idx]:
106+
continue
107+
prop = self.column_names[idx].encode()
97108
fmt += "%ds" % (len(prop) + 1) # encode string with a null terminator
98109
args.append(prop)
99110
return struct.pack(fmt, *args)
100111

112+
# Extract column names and types from a header row
113+
def convert_header(self):
114+
header = next(self.reader)
115+
self.column_count = len(header)
116+
self.column_names = [None] * self.column_count # Property names of every column.
117+
self.types = [None] * self.column_count # Value type of every column.
118+
self.skip_offsets = [False] * self.column_count # Whether column at any offset should not be stored as a property.
119+
120+
for idx, field in enumerate(header):
121+
pair = field.split(':')
122+
if len(pair) > 2:
123+
raise CSVError("Field '%s' had %d colons" % field, len(field))
124+
elif len(pair) < 2:
125+
self.types[idx] = convert_schema_type(pair[0].casefold())
126+
self.skip_offsets[idx] = True
127+
if self.types[idx] not in (Type.ID, Type.START_ID, Type.END_ID, Type.IGNORE):
128+
# Any other field should have 2 elements
129+
raise SchemaError("Each property in the header should be a colon-separated pair")
130+
else:
131+
self.column_names[idx] = pair[0]
132+
self.types[idx] = convert_schema_type(pair[1].casefold())
133+
if self.types[idx] in (Type.START_ID, Type.END_ID, Type.IGNORE):
134+
self.skip_offsets[idx] = True
135+
136+
# The number of properties is equal to the number of non-skipped columns.
137+
self.prop_count = self.skip_offsets.count(False)
138+
self.packed_header = self.pack_header()
139+
self.binary_size += len(self.packed_header)
140+
101141
# Convert a list of properties into a binary string
102142
def pack_props(self, line):
103143
props = []
104-
for num, field in enumerate(line[self.prop_offset:]):
105-
field_type_idx = self.prop_offset+num
106-
try:
107-
module_vars.FIELD_TYPES[self.entity_str][field_type_idx]
108-
except:
109-
props.append(prop_to_binary(field, None))
110-
else:
111-
props.append(prop_to_binary(field, module_vars.FIELD_TYPES[self.entity_str][field_type_idx]))
144+
for idx, field in enumerate(line):
145+
if self.skip_offsets[idx]:
146+
continue
147+
if self.column_names[idx]:
148+
props.append(prop_to_binary(field, self.types[idx]))
112149
return b''.join(p for p in props)
113150

114151
def to_binary(self):

bulk_insert/label.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,24 @@
44
from configs import Configs
55
from exceptions import SchemaError
66
import module_vars
7-
import schema
7+
from schema import Type
8+
from schema import convert_schema_type
89

910

1011
# Handler class for processing label csv files.
1112
class Label(EntityFile):
12-
def __init__(self, infile, separator):
13-
super(Label, self).__init__(infile, separator)
14-
expected_col_count = self.process_header()
15-
self.process_entities(expected_col_count)
16-
self.infile.close()
17-
18-
def process_header_schema(self, header):
19-
prop_count = len(header)
20-
self.types = [None] * prop_count
21-
for i, prop in enumerate(header):
22-
pair = prop.split(':')
23-
if len(pair) != 2:
24-
raise SchemaError("Each header entry should be a colon-separated pair")
25-
self.types[i] = schema.convert_schema_type(pair[1].casefold())
26-
27-
def process_header(self):
28-
# Header format:
29-
# node identifier (which may be a property key), then all other property keys
30-
header = next(self.reader)
31-
expected_col_count = len(header)
13+
def __init__(self, infile):
14+
super(Label, self).__init__(infile)
15+
# Verify that exactly one field is labeled ID.
16+
if self.types.count(Type.ID) != 1:
17+
raise SchemaError("Node file '%s' should have exactly one ID column."
18+
% (infile.name))
3219

33-
if module_vars.CONFIGS.enforce_schema:
34-
self.process_header_schema(header)
35-
# If identifier field begins with an underscore, don't add it as a property.
36-
if header[0][0] == '_':
37-
self.prop_offset = 1
38-
self.packed_header = self.pack_header(header)
39-
self.binary_size += len(self.packed_header)
40-
return expected_col_count
41-
42-
def process_entities(self, expected_col_count):
20+
def process_entities(self):
4321
entities_created = 0
4422
with click.progressbar(self.reader, length=self.entities_count, label=self.entity_str) as reader:
4523
for row in reader:
46-
self.validate_row(expected_col_count, row)
24+
self.validate_row(row)
4725
# Add identifier->ID pair to dictionary if we are building relations
4826
if module_vars.NODE_DICT is not None:
4927
if row[0] in module_vars.NODE_DICT:
@@ -69,4 +47,5 @@ def process_entities(self, expected_col_count):
6947
self.binary_size += row_binary_len
7048
self.binary_entities.append(row_binary)
7149
module_vars.QUERY_BUF.labels.append(self.to_binary())
50+
self.infile.close()
7251
print("%d nodes created with label '%s'" % (entities_created, self.entity_str))

bulk_insert/module_vars.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@
66
TOP_NODE_ID = 0 # next ID to assign to a node
77
QUERY_BUF = None # Buffer for query being constructed
88
QUOTING = None
9-
FIELD_TYPES = None

0 commit comments

Comments
 (0)