def check_if_connections_compatible(connections): input_shapes = [] output_shapes = [] for i, connection in enumerate(connections): input_shapes.append(connection.input_shape) output_shapes.append(connection.output_shape) if not all_equal(input_shapes): raise ValueError("Networks have different input shapes: {}" "".format(input_shapes)) if not all_equal(output_shapes): raise ValueError("Networks have different output shapes: {}" "".format(output_shapes))
def validate(self, input_shapes): n_input_layers = len(input_shapes) gating_layer_index = self.gating_layer_index try: gating_layer_shape = input_shapes[gating_layer_index] except IndexError: raise LayerConnectionError( "Invalid index for gating layer. Number of input " "layers: {}. Gating layer index: {}" "".format(n_input_layers, gating_layer_index)) other_layers_shape = exclude_index(input_shapes, gating_layer_index) if len(gating_layer_shape) != 1: raise LayerConnectionError( "Output from the gating network should be vector. Output " "shape from gating layer: {!r}".format(gating_layer_shape)) n_gating_weights = gating_layer_shape[0] # Note: -1 from all layers in order to exclude gating layer if n_gating_weights != (n_input_layers - 1): raise LayerConnectionError( "Gating layer can work only for combining only {} networks, " "got {} networks instead." "".format(n_gating_weights, (n_input_layers - 1))) if not all_equal(other_layers_shape): raise LayerConnectionError( "Output layer that has to be merged expect to have the " "same shapes. Shapes: {!r}".format(other_layers_shape))
def test_all_equal(self): self.assertTrue(all_equal([1] * 10)) self.assertTrue(all_equal([(1, 5)] * 10)) self.assertTrue(all_equal([0.1] * 2)) self.assertTrue(all_equal([5])) self.assertFalse(all_equal([1, 2, 3, 4, 5])) self.assertFalse(all_equal([2, 2, 2, 2, 1])) self.assertFalse(all_equal([5, 5 - 1e-8]))
def test_all_equal_exception(self): with self.assertRaises(ValueError): all_equal([])