Exemplo n.º 1
0
def test_compute_losses(sketches, node_feature_mapping, edge_feature_mapping):
    sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences,
                                         node_feature_mapping,
                                         edge_feature_mapping,
                                         seed=36)

    batch = [dataset[i] for i in range(10)]
    batch_input = graph_dataset.collate(batch, node_feature_mapping,
                                        edge_feature_mapping)

    feature_dimensions = {
        **node_feature_mapping.feature_dimensions,
        **edge_feature_mapping.feature_dimensions
    }

    model = graph_model.make_graph_model(32, feature_dimensions)
    model_output = model(batch_input)

    losses, _, edge_metrics, node_metrics = graph_model.compute_losses(
        model_output, batch_input, feature_dimensions)
    assert isinstance(losses, dict)

    avg_losses = graph_model.compute_average_losses(
        losses, batch_input['graph_counts'])
    assert isinstance(avg_losses, dict)

    for t in edge_metrics:
        assert edge_metrics[t][0].shape == edge_metrics[t][1].shape

    for t in edge_metrics:
        assert node_metrics[t][0].shape == node_metrics[t][1].shape
Exemplo n.º 2
0
def make_model_with_arguments(feature_dimensions, args):
    return graph_model.make_graph_model(
        args['hidden_size'],
        feature_dimensions,
        readout_entity_features=not args.get('disable_entity_features', False)
        or args.get('force_entity_categorical_features', False),
        readout_edge_features=not args.get('disable_edge_features', False),
        readin_entity_features=False
        if args.get('disable_readin_entity', False) else None,
        readin_edge_features=False
        if args.get('disable_readin_edge', False) else None)
Exemplo n.º 3
0
def test_compute_model_no_entity_features(sketches, edge_feature_mapping):
    sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences,
                                         None,
                                         edge_feature_mapping,
                                         seed=36)

    batch = [dataset[i] for i in range(10)]
    batch_input = graph_dataset.collate(batch, None, edge_feature_mapping)

    model = graph_model.make_graph_model(
        32, {**edge_feature_mapping.feature_dimensions},
        readout_entity_features=False)
    result = model(batch_input)
    assert 'node_embedding' in result
Exemplo n.º 4
0
def test_compute_model_subnode(sketches):
    sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences, seed=36)

    batch = [dataset[i] for i in range(10)]
    batch_input = graph_dataset.collate(batch)

    model = graph_model.make_graph_model(32,
                                         feature_dimensions={},
                                         readout_entity_features=False,
                                         readout_edge_features=False)
    result = model(batch_input)
    assert 'node_embedding' in result

    losses, _, edge_metrics, node_metrics = graph_model.compute_losses(
        result, batch_input, {})
    assert isinstance(losses, dict)

    avg_losses = graph_model.compute_average_losses(
        losses, batch_input['graph_counts'])
    assert isinstance(avg_losses, dict)