Esempio n. 1
0
def to_skip_connection_model(model):
    """Return skip_connected model

    Args:
        model: the model from which we get skip_connected model

    Returns:
        The skip_connected model
    """
    graph = Graph(model)
    weighted_layers = list(filter(lambda x: is_conv_layer(x), model.layers))
    index_a = randint(0, len(weighted_layers) - 1)
    index_b = randint(0, len(weighted_layers) - 1)
    if index_a > index_b:
        index_a, index_b = index_b, index_a
    a = weighted_layers[index_a]
    b = weighted_layers[index_b]
    if a.input.shape == b.output.shape:
        return graph.to_add_skip_model(a, b)
    elif random() < 0.5:
        return graph.to_add_skip_model(a, b)
    else:
        return graph.to_concat_skip_model(a, b)
Esempio n. 2
0
def test_legal_graph2():
    graph = Graph(get_pooling_model(), False)
    graph.to_concat_skip_model(2, 6)
    assert legal_graph(graph)
    graph.to_concat_skip_model(2, 6)
    assert not legal_graph(graph)