def test_compute_losses(sketches, node_feature_mapping, edge_feature_mapping): sequences = list(map(graph_dataset.sketch_to_sequence, sketches)) dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=36) batch = [dataset[i] for i in range(10)] batch_input = graph_dataset.collate(batch, node_feature_mapping, edge_feature_mapping) feature_dimensions = { **node_feature_mapping.feature_dimensions, **edge_feature_mapping.feature_dimensions } model = graph_model.make_graph_model(32, feature_dimensions) model_output = model(batch_input) losses, _, edge_metrics, node_metrics = graph_model.compute_losses( model_output, batch_input, feature_dimensions) assert isinstance(losses, dict) avg_losses = graph_model.compute_average_losses( losses, batch_input['graph_counts']) assert isinstance(avg_losses, dict) for t in edge_metrics: assert edge_metrics[t][0].shape == edge_metrics[t][1].shape for t in edge_metrics: assert node_metrics[t][0].shape == node_metrics[t][1].shape
def test_dataset(sketches, node_feature_mapping, edge_feature_mapping): sequences = list(map(data_sequence.sketch_to_sequence, sketches)) dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=12) graph, target = dataset[0] assert graph_dataset.TargetType.EdgeDistance in graph.sparse_edge_features assert graph.node_features is not None assert graph.sparse_node_features is not None
def test_collate_some(sketches, node_feature_mapping, edge_feature_mapping): sequences = list(map(data_sequence.sketch_to_sequence, sketches)) dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=42) batch = [dataset[i] for i in range(5)] batch_info = graph_dataset.collate(batch, node_feature_mapping, edge_feature_mapping) assert 'edge_numerical' in batch_info assert batch_info['graph'].node_features is not None
def load_dataset_and_weights_with_mapping(dataset_file, node_feature_mapping, edge_feature_mapping, seed=None): data = flat_array.load_dictionary_flat(np.load(dataset_file, mmap_mode='r')) seqs = data['sequences'] seqs.share_memory_() ds = dataset.GraphDataset(seqs, node_feature_mapping, edge_feature_mapping, seed) return ds, data['sequence_lengths']
def test_compute_model_no_entity_features(sketches, edge_feature_mapping): sequences = list(map(graph_dataset.sketch_to_sequence, sketches)) dataset = graph_dataset.GraphDataset(sequences, None, edge_feature_mapping, seed=36) batch = [dataset[i] for i in range(10)] batch_input = graph_dataset.collate(batch, None, edge_feature_mapping) model = graph_model.make_graph_model( 32, {**edge_feature_mapping.feature_dimensions}, readout_entity_features=False) result = model(batch_input) assert 'node_embedding' in result
def load_dataset_and_weights(dataset_file, auxiliary_file, quantization, seed=None, entity_features=True, edge_features=True, force_entity_categorical_features=False): data = load_sequences_and_mappings(dataset_file, auxiliary_file, quantization, entity_features, edge_features) if data['entity_feature_mapping'] is None and force_entity_categorical_features: # Create an entity mapping which only computes the categorical features (i.e. isConstruction and clockwise) data['entity_feature_mapping'] = dataset.EntityFeatureMapping() return dataset.GraphDataset(data['sequences'], data['entity_feature_mapping'], data['edge_feature_mapping'], seed=seed), data['weights']
def test_compute_model_subnode(sketches): sequences = list(map(graph_dataset.sketch_to_sequence, sketches)) dataset = graph_dataset.GraphDataset(sequences, seed=36) batch = [dataset[i] for i in range(10)] batch_input = graph_dataset.collate(batch) model = graph_model.make_graph_model(32, feature_dimensions={}, readout_entity_features=False, readout_edge_features=False) result = model(batch_input) assert 'node_embedding' in result losses, _, edge_metrics, node_metrics = graph_model.compute_losses( result, batch_input, {}) assert isinstance(losses, dict) avg_losses = graph_model.compute_average_losses( losses, batch_input['graph_counts']) assert isinstance(avg_losses, dict)