def load_sequences_and_mappings(dataset_file, auxiliary_file, quantization, entity_features=True, edge_features=True): data = flat_array.load_dictionary_flat(np.load(dataset_file, mmap_mode='r')) if auxiliary_file is None: root, _ = os.path.splitext(dataset_file) auxiliary_file = root + '.stats.pkl.gz' if entity_features or edge_features: with gzip.open(auxiliary_file, 'rb') as f: auxiliary_dict = pickle.load(f) if entity_features: entity_feature_mapping = dataset.EntityFeatureMapping( auxiliary_dict['node']) else: entity_feature_mapping = None seqs = data['sequences'] weights = data['sequence_lengths'] if edge_features: if isinstance(quantization['angle'], dataset.QuantizationMap): angle_map = quantization['angle'] else: angle_map = dataset.QuantizationMap.from_counter( auxiliary_dict['edge']['angle'], quantization['angle']) if isinstance(quantization['length'], dataset.QuantizationMap): length_map = quantization['length'] else: length_map = dataset.QuantizationMap.from_counter( auxiliary_dict['edge']['length'], quantization['length']) edge_feature_mapping = dataset.EdgeFeatureMapping( angle_map, length_map) else: edge_feature_mapping = None return { 'sequences': seqs.share_memory_(), 'entity_feature_mapping': entity_feature_mapping, 'edge_feature_mapping': edge_feature_mapping, 'weights': weights }
def load_dataset_and_weights(dataset_file, auxiliary_file, quantization, seed=None, entity_features=True, edge_features=True, force_entity_categorical_features=False): data = load_sequences_and_mappings(dataset_file, auxiliary_file, quantization, entity_features, edge_features) if data['entity_feature_mapping'] is None and force_entity_categorical_features: # Create an entity mapping which only computes the categorical features (i.e. isConstruction and clockwise) data['entity_feature_mapping'] = dataset.EntityFeatureMapping() return dataset.GraphDataset(data['sequences'], data['entity_feature_mapping'], data['edge_feature_mapping'], seed=seed), data['weights']
def load_saved_model(model_state_path): state = torch.load(model_state_path, map_location='cpu') args_path = os.path.join(os.path.dirname(model_state_path), 'args.txt') with open(args_path) as fh: model_args = json.load(fh) if state['node_feature_mapping'] is not None: node_feature_mapping = dataset.EntityFeatureMapping() node_feature_mapping.load_state_dict(state['node_feature_mapping']) else: node_feature_mapping = None if state['edge_feature_mapping'] is not None: edge_feature_mapping = dataset.EdgeFeatureMapping() edge_feature_mapping.load_state_dict(state['edge_feature_mapping']) else: edge_feature_mapping = None feature_dimensions = { **node_feature_mapping.feature_dimensions, **edge_feature_mapping.feature_dimensions } model = make_model_with_arguments(feature_dimensions, model_args) ## Remove "module." from beginning of keys new_state_dict = {} for key in state['model']: if key.startswith('module.'): new_state_dict[key[7:]] = state['model'][key] else: new_state_dict[key] = state['model'][key] model.load_state_dict(new_state_dict) return model, node_feature_mapping, edge_feature_mapping
def node_feature_mapping(): with gzip.open('tests/testdata/sg_t16.stats.pkl.gz', 'rb') as f: mapping = pickle.load(f) return graph_dataset.EntityFeatureMapping(mapping['node'])