Ejemplo n.º 1
0
def test_compute_losses(sketches, node_feature_mapping, edge_feature_mapping):
    sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences,
                                         node_feature_mapping,
                                         edge_feature_mapping,
                                         seed=36)

    batch = [dataset[i] for i in range(10)]
    batch_input = graph_dataset.collate(batch, node_feature_mapping,
                                        edge_feature_mapping)

    feature_dimensions = {
        **node_feature_mapping.feature_dimensions,
        **edge_feature_mapping.feature_dimensions
    }

    model = graph_model.make_graph_model(32, feature_dimensions)
    model_output = model(batch_input)

    losses, _, edge_metrics, node_metrics = graph_model.compute_losses(
        model_output, batch_input, feature_dimensions)
    assert isinstance(losses, dict)

    avg_losses = graph_model.compute_average_losses(
        losses, batch_input['graph_counts'])
    assert isinstance(avg_losses, dict)

    for t in edge_metrics:
        assert edge_metrics[t][0].shape == edge_metrics[t][1].shape

    for t in edge_metrics:
        assert node_metrics[t][0].shape == node_metrics[t][1].shape
Ejemplo n.º 2
0
def test_dataset(sketches, node_feature_mapping, edge_feature_mapping):
    sequences = list(map(data_sequence.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences,
                                         node_feature_mapping,
                                         edge_feature_mapping,
                                         seed=12)

    graph, target = dataset[0]
    assert graph_dataset.TargetType.EdgeDistance in graph.sparse_edge_features
    assert graph.node_features is not None
    assert graph.sparse_node_features is not None
Ejemplo n.º 3
0
def test_collate_some(sketches, node_feature_mapping, edge_feature_mapping):
    sequences = list(map(data_sequence.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences,
                                         node_feature_mapping,
                                         edge_feature_mapping,
                                         seed=42)

    batch = [dataset[i] for i in range(5)]

    batch_info = graph_dataset.collate(batch, node_feature_mapping,
                                       edge_feature_mapping)
    assert 'edge_numerical' in batch_info
    assert batch_info['graph'].node_features is not None
Ejemplo n.º 4
0
def load_dataset_and_weights_with_mapping(dataset_file,
                                          node_feature_mapping,
                                          edge_feature_mapping,
                                          seed=None):
    data = flat_array.load_dictionary_flat(np.load(dataset_file,
                                                   mmap_mode='r'))
    seqs = data['sequences']
    seqs.share_memory_()

    ds = dataset.GraphDataset(seqs, node_feature_mapping, edge_feature_mapping,
                              seed)

    return ds, data['sequence_lengths']
Ejemplo n.º 5
0
def test_compute_model_no_entity_features(sketches, edge_feature_mapping):
    sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences,
                                         None,
                                         edge_feature_mapping,
                                         seed=36)

    batch = [dataset[i] for i in range(10)]
    batch_input = graph_dataset.collate(batch, None, edge_feature_mapping)

    model = graph_model.make_graph_model(
        32, {**edge_feature_mapping.feature_dimensions},
        readout_entity_features=False)
    result = model(batch_input)
    assert 'node_embedding' in result
Ejemplo n.º 6
0
def load_dataset_and_weights(dataset_file,
                             auxiliary_file,
                             quantization,
                             seed=None,
                             entity_features=True,
                             edge_features=True,
                             force_entity_categorical_features=False):
    data = load_sequences_and_mappings(dataset_file, auxiliary_file,
                                       quantization, entity_features,
                                       edge_features)

    if data['entity_feature_mapping'] is None and force_entity_categorical_features:
        # Create an entity mapping which only computes the categorical features (i.e. isConstruction and clockwise)
        data['entity_feature_mapping'] = dataset.EntityFeatureMapping()

    return dataset.GraphDataset(data['sequences'],
                                data['entity_feature_mapping'],
                                data['edge_feature_mapping'],
                                seed=seed), data['weights']
Ejemplo n.º 7
0
def test_compute_model_subnode(sketches):
    sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences, seed=36)

    batch = [dataset[i] for i in range(10)]
    batch_input = graph_dataset.collate(batch)

    model = graph_model.make_graph_model(32,
                                         feature_dimensions={},
                                         readout_entity_features=False,
                                         readout_edge_features=False)
    result = model(batch_input)
    assert 'node_embedding' in result

    losses, _, edge_metrics, node_metrics = graph_model.compute_losses(
        result, batch_input, {})
    assert isinstance(losses, dict)

    avg_losses = graph_model.compute_average_losses(
        losses, batch_input['graph_counts'])
    assert isinstance(avg_losses, dict)