def __call__(self, inputs, initial_state=None, initial_readout=None, ground_truth=None, **kwargs):
     req_num_inputs = 1 + self.num_states
     inputs = _to_list(inputs)
     inputs = inputs[:]
     if len(inputs) == 1:
         if initial_state is not None:
             if type(initial_state) is list:
                 inputs += initial_state
             else:
                 inputs.append(initial_state)
         else:
             if self.readout:
                 initial_state = self._get_optional_input_placeholder('initial_state', self.num_states - 1)
             else:
                 initial_state = self._get_optional_input_placeholder('initial_state', self.num_states)
             inputs += _to_list(initial_state)
         if self.readout:
             if initial_readout is None:
                 initial_readout = self._get_optional_input_placeholder('initial_readout')
             inputs.append(initial_readout)
         if self.teacher_force:
             req_num_inputs += 1
             if ground_truth is None:
                 ground_truth = self._get_optional_input_placeholder('ground_truth')
             inputs.append(ground_truth)
     assert len(inputs) == req_num_inputs, "Required " + str(req_num_inputs) + " inputs, received " + str(len(inputs)) + "."
     with K.name_scope(self.name):
         if not self.built:
             self.build(K.int_shape(inputs[0]))
             if self._initial_weights is not None:
                 self.set_weights(self._initial_weights)
                 del self._initial_weights
                 self._initial_weights = None
         previous_mask = _collect_previous_mask(inputs[:1])
         user_kwargs = kwargs.copy()
         if not _is_all_none(previous_mask):
             if 'mask' in inspect.getargspec(self.call).args:
                 if 'mask' not in kwargs:
                     kwargs['mask'] = previous_mask
         input_shape = _collect_input_shape(inputs)
         output = self.call(inputs, **kwargs)
         output_mask = self.compute_mask(inputs[0], previous_mask)
         output_shape = self.compute_output_shape(input_shape[0])
         self._add_inbound_node(input_tensors=inputs, output_tensors=output,
                                input_masks=previous_mask, output_masks=output_mask,
                                input_shapes=input_shape, output_shapes=output_shape,
                                arguments=user_kwargs)
         if hasattr(self, 'activity_regularizer') and self.activity_regularizer is not None:
             regularization_losses = [self.activity_regularizer(x) for x in _to_list(output)]
             self.add_loss(regularization_losses, _to_list(inputs))
     return output
Exemple #2
0
    def call(self, inputs, states, constants, training=None):
        """Complete attentive cell transformation.
        """
        attended = to_list(constants, allow_tuple=True)
        # NOTE: `K.rnn` will pass constants as a tuple and `_collect_previous_mask`
        # returns `None` if passed a tuple of tensors, hence `to_list` above!
        # We also make `attended` and `attended_mask` always lists for uniformity:
        attended_mask = to_list(_collect_previous_mask(attended))
        cell_states = states[:self._num_wrapped_states]
        attention_states = states[self._num_wrapped_states:]

        if self.attend_after:
            call = self._call_attend_after
        else:
            call = self._call_attend_before

        return call(inputs=inputs,
                    cell_states=cell_states,
                    attended=attended,
                    attention_states=attention_states,
                    attended_mask=attended_mask,
                    training=training)
    def call(self, inputs, states, constants, training=None):
        """Complete attentive cell transformation.
        """
        attended = constants
        attended_mask = _collect_previous_mask(attended)
        # attended and mask are always lists for uniformity:
        if not isinstance(attended_mask, list):
            attended_mask = [attended_mask]
        cell_states = states[:self._num_wrapped_states]
        attention_states = states[self._num_wrapped_states:]

        if self.attend_after:
            call = self._call_attend_after
        else:
            call = self._call_attend_before

        return call(inputs=inputs,
                    cell_states=cell_states,
                    attended=attended,
                    attention_states=attention_states,
                    attended_mask=attended_mask,
                    training=training)
Exemple #4
0
    def __call__(self, inputs, **kwargs):
        """
        Overriding the Layer's __call__ method for graph wrappers. 
        The overriding is mostly based on the __call__ code from 
        'keras/engine/base_layer.py' for the moment, with some changes 
        to handle graphs. 
        """
        # If arguments are 'keras-style' arguments, let keras to the job
        if K.is_tensor(inputs) or (isinstance(inputs, list)
                                   and not isinstance(inputs, GraphWrapper)):
            output = super(GraphLayer, self).__call__(inputs, **kwargs)
        else:
            if isinstance(inputs,
                          list) and not isinstance(inputs, GraphWrapper):
                inputs = inputs[:]

            with K.name_scope(self.name):
                # Build the layer
                if not self.built:
                    # Check the comatibilty of inputs for each inputs (some can
                    # be GraphWrapper)
                    if isinstance(inputs, list) and not isinstance(
                            inputs, GraphWrapper):
                        for i in inputs:
                            self.assert_input_compatibility(i)

                    # Collect input shapes to build layer.
                    input_shapes = []

                    if isinstance(inputs, GraphWrapper):
                        inputs = [inputs]

                    for x_elem in inputs:
                        if hasattr(x_elem, '_keras_shape'):
                            # For a GraphWrapper, _keras_shape is a GraphShape
                            # object
                            input_shapes.append(x_elem._keras_shape)
                        elif hasattr(K, 'int_shape'):
                            input_shapes.append(K.int_shape(x_elem))
                        else:
                            raise ValueError(
                                'You tried to call layer "' + self.name +
                                '". This layer has no information'
                                ' about its expected input shape, '
                                'and thus cannot be built. '
                                'You can build it manually via: '
                                '`layer.build(batch_input_shape)`')

                    self.build(unpack_singleton(input_shapes))
                    self._built = True

                    # Load weights that were specified at layer instantiation.
                    if self._initial_weights is not None:
                        self.set_weights(self._initial_weights)

                # Raise exceptions in case the input is not compatible
                # with the input_spec set at build time.
                if isinstance(inputs,
                              list) and not isinstance(inputs, GraphWrapper):
                    for i in inputs:
                        self.assert_input_compatibility(i)

                # Handle mask propagation.
                previous_mask = _collect_previous_mask(inputs)
                user_kwargs = copy.copy(kwargs)
                if not is_all_none(previous_mask):
                    # The previous layer generated a mask.
                    if has_arg(self.call, 'mask'):
                        if 'mask' not in kwargs:
                            # If mask is explicitly passed to __call__,
                            # we should override the default mask.
                            kwargs['mask'] = previous_mask
                # Handle automatic shape inference (only useful for Theano).
                input_shape = _collect_input_shape(inputs)

                # Actually call the layer,
                # collecting output(s), mask(s), and shape(s).
                # Note that inpputs can hold graph wrappers now
                output = self.call(unpack_singleton(inputs), **kwargs)
                output_mask = self.compute_mask(inputs, previous_mask)

                # If the layer returns tensors from its inputs, unmodified,
                # we copy them to avoid loss of tensor metadata.
                # output_ls = to_list(output)
                if isinstance(output, GraphWrapper):
                    output_ls = [output]
                # Unpack wrappers
                inputs_ls = list(
                    chain.from_iterable(
                        [i for i in inputs if isinstance(i, GraphWrapper)]))
                inputs_ls += [
                    i for i in inputs if not isinstance(i, GraphWrapper)
                ]

                # Unpack adjacency and nodes
                inpadj_ls = list(
                    chain.from_iterable([
                        to_list(i.adjacency) for i in inputs
                        if isinstance(i, GraphWrapper)
                    ]))
                inpnod_ls = to_list(
                    [i.nodes for i in inputs if isinstance(i, GraphWrapper)])

                output_ls_copy = []
                adj_ls = []
                for x in output_ls:
                    if K.is_tensor(x) and x in inputs_ls:
                        x = K.identity(x)
                    # Apply adjacency-wise and node-wise identity
                    elif isinstance(x, GraphWrapper):
                        # Store changed or copy of unchanged adjacency matrices
                        adj_ls.clear()
                        for adj in to_list(x.adjacency):
                            if adj in inpadj_ls:
                                adj_ls.append(K.identity(adj))
                            else:
                                adj_ls.append(adj)
                        # Assign to output graph
                        x.adjacency = unpack_singleton(adj_ls)
                        # Store unchanged nodes
                        if x.nodes in inpnod_ls:
                            x.nodes = K.identity(x.nodes)
                    output_ls_copy.append(x)

                output = unpack_singleton(output_ls_copy)

                # Inferring the output shape is only relevant for Theano.
                if all([s is not None for s in to_list(input_shape)]):
                    output_shape = self.compute_output_shape(input_shape)
                else:
                    if isinstance(input_shape, list):
                        output_shape = [None for _ in input_shape]
                    else:
                        output_shape = None

                if (not isinstance(output_mask, (list, tuple))
                        and len(output_ls) > 1):
                    # Augment the mask to match the length of the output.
                    output_mask = [output_mask] * len(output_ls)

                # Add an inbound node to the layer, so that it keeps track
                # of the call and of all new variables created during the call.
                # This also updates the layer history of the output tensor(s).
                # If the input tensor(s) had not previous Keras history,
                # this does nothing.
                self._add_inbound_node(input_tensors=inputs_ls,
                                       output_tensors=unpack_singleton(output),
                                       input_masks=previous_mask,
                                       output_masks=output_mask,
                                       input_shapes=input_shape,
                                       output_shapes=output_shape,
                                       arguments=user_kwargs)

                # Apply activity regularizer if any:
                if (hasattr(self, 'activity_regularizer')
                        and self.activity_regularizer is not None):
                    with K.name_scope('activity_regularizer'):
                        regularization_losses = [
                            self.activity_regularizer(x)
                            for x in to_list(output)
                        ]
                    self.add_loss(regularization_losses,
                                  inputs=to_list(inputs))
        return output