class DataLoaderSchema(RelatedConfigMixin): """Describes the model schema Properties: - we allow classes that contain also dictionaries -> leaf can be an - array - scalar - custom dictionary (recursive data-type) - SpecialType (say ArrayRanges, BatchArrayRanges, which will effectively be a dicitonary of scalars) """ inputs = NestedMappingField(ArraySchema, keyword="shape", key="name") targets = NestedMappingField(ArraySchema, keyword="shape", key="name", required=False) metadata = NestedMappingField(MetadataStruct, keyword="doc", key="name", required=False) def compatible_with_batch(self, batch, verbose=True): """Validate if the batch of data complies with the schema Checks preformed: - nested structure is the same (i.e. dictionary names, list length etc) - array shapes are compatible - returned obj classess are compatible # Arguments batch: a batch of data returned by one iteraton of dataloader's batch_iter nested dictionary verbose: verbose error logging if things don't match # Returns bool: True only if everyhing is ok """ def print_msg(msg): if verbose: print(msg) # check the individual names if not isinstance(batch, dict): print("not isinstance(batch, dict)") return False # contains only the three specified fields if not set(batch.keys()).issubset({"inputs", "targets", "metadata"}): print( 'not set(batch.keys()).issubset({"inputs", "targets", "metadata"})' ) return False # Inputs check def compatible_nestedmapping(batch, descr, cls, verbose=True): """Recursive function of checks shapes match, batch-dim matches """ # we expect a numpy array/special class, a list or a dictionary # Special case for the metadat if isinstance(descr, cls): return descr.compatible_with_batch(batch, verbose=verbose) elif isinstance(batch, collections.Mapping) and isinstance( descr, collections.Mapping): if not set(batch.keys()) == set(descr.keys()): print_msg("The dictionary keys don't match:") print_msg("batch: {0}".format(batch.keys())) print_msg("descr: {0}".format(descr.keys())) return False return all([ compatible_nestedmapping(batch[key], descr[key], cls, verbose) for key in batch ]) elif isinstance(batch, collections.Sequence) and isinstance( descr, collections.Sequence): if not len(batch) == len(descr): print_msg("Lengths dont match:") print_msg("len(batch): {0}".format(len(batch))) print_msg("len(descr): {0}".format(len(descr))) return False return all([ compatible_nestedmapping(batch[i], descr[i], cls, verbose) for i in range(len(batch)) ]) print_msg("Invalid types:") print_msg("type(batch): {0}".format(type(batch))) print_msg("type(descr): {0}".format(type(descr))) return False # inputs needs to be present allways if "inputs" not in batch: print_msg('not "inputs" in batch') return False if not compatible_nestedmapping(batch["inputs"], self.inputs, ArraySchema, verbose): return False if "targets" in batch and not \ (len(batch["targets"]) == 0): # unspecified if self.targets is None: # targets need to be specified if we want to use them print_msg('self.targets is None') return False if not compatible_nestedmapping(batch["targets"], self.targets, ArraySchema, verbose): return False # metadata needs to be present if it is defined in the description if self.metadata is not None: if "metadata" not in batch: print_msg('not "metadata" in batch') return False if not compatible_nestedmapping(batch["metadata"], self.metadata, MetadataStruct, verbose): return False else: if "metadata" in batch: print_msg('"metadata" in batch') return False return True
class ModelSchema(RelatedConfigMixin): """Describes the model schema """ # can be a dictionary, list or a single array inputs = NestedMappingField(ArraySchema, keyword="shape", key="name") targets = NestedMappingField(ArraySchema, keyword="shape", key="name") def compatible_with_schema(self, dataloader_schema, verbose=True): """Check the compatibility: model.schema <-> dataloader.output_schema Checks preformed: - nested structure is the same (i.e. dictionary names, list length etc) - array shapes are compatible - returned obj classess are compatible # Arguments dataloader_schema: a dataloader_schema of data returned by one iteraton of dataloader's dataloader_schema_iter nested dictionary verbose: verbose error logging if things don't match # Returns bool: True only if everyhing is ok """ def print_msg(msg): if verbose: print(msg) # Inputs check def compatible_nestedmapping(dschema, descr, cls, verbose=True): """Recursive function of checks shapes match, dschema-dim matches """ if isinstance(descr, cls): # Recursion stop return descr.compatible_with_schema(dschema, name_self="Model", name_schema="Dataloader", verbose=verbose) elif isinstance(dschema, collections.Mapping) and isinstance( descr, collections.Mapping): if not set(descr.keys()).issubset(set(dschema.keys())): print_msg( "Dataloader doesn't provide all the fields required by the model:" ) print_msg("dataloader fields: {0}".format(dschema.keys())) print_msg("model fields: {0}".format(descr.keys())) return False return all([ compatible_nestedmapping(dschema[key], descr[key], cls, verbose) for key in descr ]) elif isinstance(dschema, collections.Sequence) and isinstance( descr, collections.Sequence): if not len(descr) <= len(dschema): print_msg( "Dataloader doesn't provide all the fields required by the model:" ) print_msg("len(dataloader): {0}".format(len(dschema))) print_msg("len(model): {0}".format(len(descr))) return False return all([ compatible_nestedmapping(dschema[i], descr[i], cls, verbose) for i in range(len(descr)) ]) elif isinstance(dschema, collections.Mapping) and isinstance( descr, collections.Sequence): if not len(descr) <= len(dschema): print_msg( "Dataloader doesn't provide all the fields required by the model:" ) print_msg("len(dataloader): {0}".format(len(dschema))) print_msg("len(model): {0}".format(len(descr))) return False compatible = [] for i in range(len(descr)): if descr[i].name in dschema: compatible.append( compatible_nestedmapping(dschema[descr[i].name], descr[i], cls, verbose)) else: print_msg( "Model array name: {0} not found in dataloader keys: {1}" .format(descr[i].name, list(dschema.keys()))) return False return all(compatible) print_msg("Invalid types:") print_msg("type(Dataloader schema): {0}".format(type(dschema))) print_msg("type(Model schema): {0}".format(type(descr))) return False if not compatible_nestedmapping(dataloader_schema.inputs, self.inputs, ArraySchema, verbose): return False if (isinstance(dataloader_schema.targets, ArraySchema) or len(dataloader_schema.targets) > 0 ) and not compatible_nestedmapping( dataloader_schema.targets, self.targets, ArraySchema, verbose): return False return True