Example #1
0
 def test_layer_from_shape(self, layer_from_shape, get_output_shapes):
     layer = layer_from_shape
     input_shapes = {layer: ((4, 5, 6), )}
     assert get_output_shapes(layer, input_shapes) is input_shapes[layer]
     input_shapes = {None: ((4, 5, 6), )}
     layer.get_output_shapes_for = Mock()
     assert (get_output_shapes(layer, input_shapes) is
             layer.get_output_shapes_for.return_value)
     layer.get_output_shapes_for.assert_called_with((input_shapes[None], ))
Example #2
0
 def test_get_output_shape_with_single_argument_fails(
         self, layers, get_output_shapes):
     l1, l2, l3 = layers
     shp = (4, 5, 6)
     # expected to fail: only gave one shape tuple for two input layers
     with pytest.raises(ValueError):
         output_shape = get_output_shapes(l3, shp)
Example #3
0
 def test_get_output_shape_input_is_a_mapping(self, layers,
                                              get_output_shapes):
     l1, l2, l3 = layers
     input_shapes = {l3: ((4, 5, 6), )}
     # expected: input_shapes[l3]
     assert get_output_shapes(l3, input_shapes) is input_shapes[l3]
     # l3.get_output_shapes_for, l2.get_output_shapes_for
     # should not have been called
     assert l3.get_output_shapes_for.call_count == 0
     assert l2.get_output_shapes_for.call_count == 0
Example #4
0
 def test_get_output_shape_input_is_a_mapping_no_key(
         self, layers, get_output_shapes):
     l1, l2, l3 = layers
     output_shape = get_output_shapes(l3, {})
     # expected: l3.output_shape
     assert output_shape == l3.output_shapes
     # l3.get_output_shapes_for, l2.get_output_shapes_for should
     # not have been called
     assert l3.get_output_shapes_for.call_count == 0
     assert l2.get_output_shapes_for.call_count == 0
Example #5
0
 def test_get_output_shape_with_single_argument(self, layers,
                                                get_output_shapes):
     l1, l2, l3 = layers
     shp = (3, 4, 5)
     output_shape = get_output_shapes(l3, shp)
     # expected: l3.get_output_shape_for(l2.get_output_shape_for(shp))
     assert output_shape is l3.get_output_shapes_for.return_value
     l3.get_output_shapes_for.assert_called_with(
         (l2.get_output_shapes_for.return_value, ))
     l2.get_output_shapes_for.assert_called_with((shp, ))
Example #6
0
 def test_get_output_shape_without_arguments(self, layers,
                                             get_output_shapes):
     l1, l2, l3 = layers
     output_shape = get_output_shapes(l3)
     # expected: l3.output_shape
     assert output_shape is l3.output_shapes
     # l3.get_output_shape_for, l2.get_output_shape_for should not have been
     # called
     assert l3.get_output_shapes_for.call_count == 0
     assert l2.get_output_shapes_for.call_count == 0
Example #7
0
 def test_get_output_shape_input_is_a_mapping_for_input_layer(
         self, layers, get_output_shapes):
     l1, l2, l3 = layers
     shp = ((4, 5, 6), )
     input_shapes = {l1: shp}
     output_shape = get_output_shapes(l3, input_shapes)
     # expected: l3.get_output_shapes_for(l2.get_output_shapes_for(shp))
     assert output_shape is l3.get_output_shapes_for.return_value
     l3.get_output_shapes_for.assert_called_with(
         (l2.get_output_shapes_for.return_value, ))
     l2.get_output_shapes_for.assert_called_with((shp, ))
Example #8
0
 def test_get_output_shape_input_is_a_mapping_for_layer(
         self, layers, get_output_shapes):
     l1, l2, l3 = layers
     shp = (4, 5, 6)
     input_shapes = {l2: shp}
     output_shape = get_output_shapes(l3, input_shapes)
     # expected: l3.get_output_shapes_for(shp)
     assert output_shape is l3.get_output_shapes_for.return_value
     l3.get_output_shapes_for.assert_called_with((shp, ))
     # l2.get_output_shapes_for should not have been called
     assert l2.get_output_shapes_for.call_count == 0
Example #9
0
 def test_get_output_shape_input_is_a_mapping_no_key(
         self, layers, get_output_shapes):
     l1, l2, l3 = layers
     output_shape = get_output_shapes(l3, {})
     # expected: l3.get_output_shapes_for(
     #     [shp, l2[1].get_output_shapes_for(l1[1].shape)])
     # expected: l3.output_shape
     assert output_shape == l3.output_shapes
     # l3.get_output_shape_for, l2[*].get_output_shape_for should not have
     # been called
     assert l3.get_output_shapes_for.call_count == 0
     assert l2[0].get_output_shapes_for.call_count == 0
     assert l2[1].get_output_shapes_for.call_count == 0
Example #10
0
 def test_get_output_shape_input_is_a_mapping_for_layer(
         self, layers, get_output_shapes):
     l1, l2, l3 = layers
     shp = (4, 5, 6)
     input_shapes = {l2[0]: (shp, )}
     output = get_output_shapes(l3, input_shapes)
     # expected: l3.get_output_shapes_for(
     #    [shp, l2[1].get_output_shapes_for(l1[1].shape)])
     assert output == l3.get_output_shapes_for.return_value
     args = (shp, ) + l2[1].get_output_shapes_for.return_value
     l3.get_output_shapes_for.assert_called_with((args, ))
     l2[1].get_output_shapes_for.assert_called_with((l1[1].shape, ))
     # l2[0].get_output_shapes_for should not have been called
     assert l2[0].get_output_shapes_for.call_count == 0
Example #11
0
 def test_get_output_shape_input_is_a_mapping_for_input_layer(
         self, layers, get_output_shapes):
     l1, l2, l3 = layers
     shp = (4, 5, 6)
     input_shapes = {l1[0]: shp}
     output = get_output_shapes(l3, input_shapes)
     # expected: l3.get_output_shapes_for(
     #     [l2[0].get_output_shapes_for(shp),
     #      l2[1].get_output_shapes_for(l1[1].shape)])
     assert output is l3.get_output_shapes_for.return_value
     args = l2[0].get_output_shapes_for.return_value + \
         l2[1].get_output_shapes_for.return_value
     l3.get_output_shapes_for.assert_called_with((args, ))
     l2[0].get_output_shapes_for.assert_called_with((shp, ))
     l2[1].get_output_shapes_for.assert_called_with((l1[1].shape, ))
Example #12
0
 def test_get_output_shape_input_is_a_mapping(self, layer,
                                              get_output_shapes):
     input_shapes = {layer: (4, 5, 6)}
     assert get_output_shapes(layer, input_shapes) == input_shapes[layer]
Example #13
0
 def test_get_output_shape_input_is_tuple(self, layer, get_output_shapes):
     shp = ((4, 5, 6), )
     assert get_output_shapes(layer, shp) == shp
Example #14
0
 def test_get_output_shape_without_arguments(self, layer,
                                             get_output_shapes):
     assert get_output_shapes(layer) == ((3, 2), )