def __init__(self, shape, name=None): super(Input, self).__init__(name=name) if isinstance(shape, tf.TensorShape): shape = tf_utils.shape_to_tuple(shape) self.shape = as_tuple(shape)
def fail_if_shape_invalid(self, input_shapes): n_input_layers = len(input_shapes) try: gate_shape = input_shapes[self.gate_index] except IndexError: raise LayerConnectionError( "Invalid index for gating layer. Number of input " "layers: {}. Gating layer index: {}" "".format(n_input_layers, self.gate_index)) other_shapes = exclude_index(input_shapes, self.gate_index) if gate_shape and len(gate_shape) != 2: raise LayerConnectionError( "Output from the gating network should be 2-dimensional. " "Output shape from gating layer: {!r}" "".format(gate_shape)) n_expected_networks = gate_shape[-1] # Note: -1 from all layers in order to exclude gating layer if n_expected_networks != (n_input_layers - 1): raise LayerConnectionError( "Gating layer can work only for combining only {} networks, " "got {} networks instead." "".format(n_expected_networks, (n_input_layers - 1))) for shape in other_shapes: if not shape.is_compatible_with(other_shapes[0]): raise LayerConnectionError( "Output layer that has to be merged expect to " "have the same shapes. Shapes: {!r}" "".format(tf_utils.shape_to_tuple(other_shapes)))
def __reduce__(self): parameters = self.get_params(with_network=False) # We only need to know placeholders shape # in order to be able to reconstruct it parameters['target'] = tf_utils.shape_to_tuple( parameters['target'].shape) args = (self.network, parameters) return (self.__class__, args)
def check_if_networks_compatible(networks): input_shapes = [] output_shapes = [] for i, network in enumerate(networks): input_shapes.append(network.input_shape) output_shapes.append(network.output_shape) for shape in input_shapes: if not shape.is_compatible_with(input_shapes[0]): raise ValueError( "Networks have incompatible input shapes. Shapes: {}" "".format(tf_utils.shape_to_tuple(input_shapes))) for shape in output_shapes: if not shape.is_compatible_with(output_shapes[0]): raise ValueError( "Networks have incompatible output shapes. Shapes: {}" "".format(tf_utils.shape_to_tuple(output_shapes)))
def get_output_shape(self, *input_shapes): input_shapes = [tf.TensorShape(shape) for shape in input_shapes] first_shape = input_shapes[0] if len(input_shapes) < 2: raise LayerConnectionError( "Layer `{}` expected multiple inputs. Input shapes: {}" "".format(self.name, tf_utils.shape_to_tuple(input_shapes))) if any(shape.ndims is None for shape in input_shapes): return tf.TensorShape(None) for shape in input_shapes: if not shape.is_compatible_with(first_shape): formatted_shapes = tf_utils.shape_to_tuple(input_shapes) raise LayerConnectionError( "Input shapes to the `{}` layer have incompatible shapes. " "Input shapes: {}, Layer: {}" "".format(self.name, formatted_shapes, self)) return first_shape
def targets(self): placeholders = [] for layer in self.output_layers: placeholder = tf.placeholder( tf.float32, shape=tf_utils.shape_to_tuple(layer.output_shape), name="placeholder/target/{}".format(layer.name), ) placeholders.append(placeholder) return make_one_if_possible(placeholders)