def test_layer_from_shape(self, layer_from_shape, get_output_shapes): layer = layer_from_shape input_shapes = {layer: ((4, 5, 6), )} assert get_output_shapes(layer, input_shapes) is input_shapes[layer] input_shapes = {None: ((4, 5, 6), )} layer.get_output_shapes_for = Mock() assert (get_output_shapes(layer, input_shapes) is layer.get_output_shapes_for.return_value) layer.get_output_shapes_for.assert_called_with((input_shapes[None], ))
def test_get_output_shape_with_single_argument_fails( self, layers, get_output_shapes): l1, l2, l3 = layers shp = (4, 5, 6) # expected to fail: only gave one shape tuple for two input layers with pytest.raises(ValueError): output_shape = get_output_shapes(l3, shp)
def test_get_output_shape_input_is_a_mapping(self, layers, get_output_shapes): l1, l2, l3 = layers input_shapes = {l3: ((4, 5, 6), )} # expected: input_shapes[l3] assert get_output_shapes(l3, input_shapes) is input_shapes[l3] # l3.get_output_shapes_for, l2.get_output_shapes_for # should not have been called assert l3.get_output_shapes_for.call_count == 0 assert l2.get_output_shapes_for.call_count == 0
def test_get_output_shape_input_is_a_mapping_no_key( self, layers, get_output_shapes): l1, l2, l3 = layers output_shape = get_output_shapes(l3, {}) # expected: l3.output_shape assert output_shape == l3.output_shapes # l3.get_output_shapes_for, l2.get_output_shapes_for should # not have been called assert l3.get_output_shapes_for.call_count == 0 assert l2.get_output_shapes_for.call_count == 0
def test_get_output_shape_with_single_argument(self, layers, get_output_shapes): l1, l2, l3 = layers shp = (3, 4, 5) output_shape = get_output_shapes(l3, shp) # expected: l3.get_output_shape_for(l2.get_output_shape_for(shp)) assert output_shape is l3.get_output_shapes_for.return_value l3.get_output_shapes_for.assert_called_with( (l2.get_output_shapes_for.return_value, )) l2.get_output_shapes_for.assert_called_with((shp, ))
def test_get_output_shape_without_arguments(self, layers, get_output_shapes): l1, l2, l3 = layers output_shape = get_output_shapes(l3) # expected: l3.output_shape assert output_shape is l3.output_shapes # l3.get_output_shape_for, l2.get_output_shape_for should not have been # called assert l3.get_output_shapes_for.call_count == 0 assert l2.get_output_shapes_for.call_count == 0
def test_get_output_shape_input_is_a_mapping_for_input_layer( self, layers, get_output_shapes): l1, l2, l3 = layers shp = ((4, 5, 6), ) input_shapes = {l1: shp} output_shape = get_output_shapes(l3, input_shapes) # expected: l3.get_output_shapes_for(l2.get_output_shapes_for(shp)) assert output_shape is l3.get_output_shapes_for.return_value l3.get_output_shapes_for.assert_called_with( (l2.get_output_shapes_for.return_value, )) l2.get_output_shapes_for.assert_called_with((shp, ))
def test_get_output_shape_input_is_a_mapping_for_layer( self, layers, get_output_shapes): l1, l2, l3 = layers shp = (4, 5, 6) input_shapes = {l2: shp} output_shape = get_output_shapes(l3, input_shapes) # expected: l3.get_output_shapes_for(shp) assert output_shape is l3.get_output_shapes_for.return_value l3.get_output_shapes_for.assert_called_with((shp, )) # l2.get_output_shapes_for should not have been called assert l2.get_output_shapes_for.call_count == 0
def test_get_output_shape_input_is_a_mapping_no_key( self, layers, get_output_shapes): l1, l2, l3 = layers output_shape = get_output_shapes(l3, {}) # expected: l3.get_output_shapes_for( # [shp, l2[1].get_output_shapes_for(l1[1].shape)]) # expected: l3.output_shape assert output_shape == l3.output_shapes # l3.get_output_shape_for, l2[*].get_output_shape_for should not have # been called assert l3.get_output_shapes_for.call_count == 0 assert l2[0].get_output_shapes_for.call_count == 0 assert l2[1].get_output_shapes_for.call_count == 0
def test_get_output_shape_input_is_a_mapping_for_layer( self, layers, get_output_shapes): l1, l2, l3 = layers shp = (4, 5, 6) input_shapes = {l2[0]: (shp, )} output = get_output_shapes(l3, input_shapes) # expected: l3.get_output_shapes_for( # [shp, l2[1].get_output_shapes_for(l1[1].shape)]) assert output == l3.get_output_shapes_for.return_value args = (shp, ) + l2[1].get_output_shapes_for.return_value l3.get_output_shapes_for.assert_called_with((args, )) l2[1].get_output_shapes_for.assert_called_with((l1[1].shape, )) # l2[0].get_output_shapes_for should not have been called assert l2[0].get_output_shapes_for.call_count == 0
def test_get_output_shape_input_is_a_mapping_for_input_layer( self, layers, get_output_shapes): l1, l2, l3 = layers shp = (4, 5, 6) input_shapes = {l1[0]: shp} output = get_output_shapes(l3, input_shapes) # expected: l3.get_output_shapes_for( # [l2[0].get_output_shapes_for(shp), # l2[1].get_output_shapes_for(l1[1].shape)]) assert output is l3.get_output_shapes_for.return_value args = l2[0].get_output_shapes_for.return_value + \ l2[1].get_output_shapes_for.return_value l3.get_output_shapes_for.assert_called_with((args, )) l2[0].get_output_shapes_for.assert_called_with((shp, )) l2[1].get_output_shapes_for.assert_called_with((l1[1].shape, ))
def test_get_output_shape_input_is_a_mapping(self, layer, get_output_shapes): input_shapes = {layer: (4, 5, 6)} assert get_output_shapes(layer, input_shapes) == input_shapes[layer]
def test_get_output_shape_input_is_tuple(self, layer, get_output_shapes): shp = ((4, 5, 6), ) assert get_output_shapes(layer, shp) == shp
def test_get_output_shape_without_arguments(self, layer, get_output_shapes): assert get_output_shapes(layer) == ((3, 2), )