def test_compute_losses(sketches, node_feature_mapping, edge_feature_mapping):
    sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences,
                                         node_feature_mapping,
                                         edge_feature_mapping,
                                         seed=36)

    batch = [dataset[i] for i in range(10)]
    batch_input = graph_dataset.collate(batch, node_feature_mapping,
                                        edge_feature_mapping)

    feature_dimensions = {
        **node_feature_mapping.feature_dimensions,
        **edge_feature_mapping.feature_dimensions
    }

    model = graph_model.make_graph_model(32, feature_dimensions)
    model_output = model(batch_input)

    losses, _, edge_metrics, node_metrics = graph_model.compute_losses(
        model_output, batch_input, feature_dimensions)
    assert isinstance(losses, dict)

    avg_losses = graph_model.compute_average_losses(
        losses, batch_input['graph_counts'])
    assert isinstance(avg_losses, dict)

    for t in edge_metrics:
        assert edge_metrics[t][0].shape == edge_metrics[t][1].shape

    for t in edge_metrics:
        assert node_metrics[t][0].shape == node_metrics[t][1].shape
def test_collate_some(sketches, node_feature_mapping, edge_feature_mapping):
    sequences = list(map(data_sequence.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences,
                                         node_feature_mapping,
                                         edge_feature_mapping,
                                         seed=42)

    batch = [dataset[i] for i in range(5)]

    batch_info = graph_dataset.collate(batch, node_feature_mapping,
                                       edge_feature_mapping)
    assert 'edge_numerical' in batch_info
    assert batch_info['graph'].node_features is not None
def test_compute_model_no_entity_features(sketches, edge_feature_mapping):
    sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences,
                                         None,
                                         edge_feature_mapping,
                                         seed=36)

    batch = [dataset[i] for i in range(10)]
    batch_input = graph_dataset.collate(batch, None, edge_feature_mapping)

    model = graph_model.make_graph_model(
        32, {**edge_feature_mapping.feature_dimensions},
        readout_entity_features=False)
    result = model(batch_input)
    assert 'node_embedding' in result
def batch_from_example(seq, node_feature_mapping, edge_feature_mapping):
    step_indices = [
        i for i, op in enumerate(seq)
        if i > 0 and not dataset._is_subnode_edge(op)
    ]

    batch_list = []

    for step_idx in step_indices:
        graph = dataset.graph_info_from_sequence(seq[:step_idx],
                                                 node_feature_mapping,
                                                 edge_feature_mapping)
        target = seq[step_idx]
        batch_list.append((graph, target))

    return dataset.collate(batch_list, node_feature_mapping,
                           edge_feature_mapping)
def test_compute_model_subnode(sketches):
    sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
    dataset = graph_dataset.GraphDataset(sequences, seed=36)

    batch = [dataset[i] for i in range(10)]
    batch_input = graph_dataset.collate(batch)

    model = graph_model.make_graph_model(32,
                                         feature_dimensions={},
                                         readout_entity_features=False,
                                         readout_edge_features=False)
    result = model(batch_input)
    assert 'node_embedding' in result

    losses, _, edge_metrics, node_metrics = graph_model.compute_losses(
        result, batch_input, {})
    assert isinstance(losses, dict)

    avg_losses = graph_model.compute_average_losses(
        losses, batch_input['graph_counts'])
    assert isinstance(avg_losses, dict)
def test_collate_empty(node_feature_mapping, edge_feature_mapping):
    result = graph_dataset.collate([], node_feature_mapping,
                                   edge_feature_mapping)
    assert 'edge_numerical' in result