class TFiLM(Layer):
    def __init__(self, block_size, **kwargs):
        self.block_size = block_size
        super(tFiLM, self).__init__(**kwargs)

    def make_normalizer(self, x_in):
        """ Pools to downsample along 'temporal' dimension and then 
            runs LSTM to generate normalization weights.
        """
        x_in_down = (MaxPooling1D(pool_size=self.block_size,
                                  padding='valid'))(x_in)
        x_rnn = self.rnn(x_in_down)
        return x_rnn

    def apply_normalizer(self, x_in, x_norm):
        """
        Applies normalization weights by multiplying them into their respective blocks.
        """

        n_blocks = K.shape(x_in)[1] / self.block_size
        n_filters = K.shape(x_in)[2]

        # reshape input into blocks
        x_norm = K.reshape(x_norm, shape=(-1, n_blocks, 1, n_filters))
        x_in = K.reshape(x_in,
                         shape=(-1, n_blocks, self.block_size, n_filters))

        # multiply
        x_out = x_norm * x_in

        # return to original shape
        x_out = K.reshape(x_out,
                          shape=(-1, n_blocks * self.block_size, n_filters))

        return x_out

    def build(self, input_shape):
        self.rnn = LSTM(units=input_shape[2],
                        return_sequences=True,
                        trainable=True)
        self.rnn.build(input_shape)
        self._trainable_weights = self.rnn.trainable_weights
        super(tFiLM,
              self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        assert len(x.shape) == 3, 'Input should be tensor with dimension \
                                   (batch_size, steps, num_features).'

        assert x.shape[1] % self.block_size == 0, 'Number of steps must be a \
                                                   multiple of the block size.'

        x_norm = self.make_normalizer(x)
        x = self.apply_normalizer(x, x_norm)
        return x

    def compute_output_shape(self, input_shape):
        return input_shape
Exemple #2
0
class NSE(Layer):
    '''
    Simple Neural Semantic Encoder.
    '''
    def __init__(self,
                 output_dim,
                 input_length=None,
                 composer_activation='linear',
                 return_mode='last_output',
                 weights=None,
                 **kwargs):
        '''
        Arguments:
        output_dim (int)
        input_length (int)
        composer_activation (str): activation used in the MLP
        return_mode (str): One of last_output, all_outputs, output_and_memory
            This is analogous to the return_sequences flag in Keras' Recurrent.
            last_output returns only the last h_t
            all_outputs returns the whole sequence of h_ts
            output_and_memory returns the last output and the last memory concatenated
                (needed if this layer is followed by a MMA-NSE)
        weights (list): Initial weights
        '''
        self.output_dim = output_dim
        self.input_dim = output_dim  # Equation 2 in the paper makes this assumption.
        self.initial_weights = weights
        self.input_spec = [InputSpec(ndim=3)]
        self.input_length = input_length
        self.composer_activation = composer_activation
        super(NSE, self).__init__(**kwargs)
        self.reader = LSTM(self.output_dim,
                           return_sequences=True,
                           name="{}_reader".format(self.name))
        # TODO: Let the writer use parameter dropout and any consume_less mode.
        # Setting dropout to 0 here to eliminate the need for constants.
        # Setting consume_less to mem to eliminate need for preprocessing
        self.writer = LSTM(self.output_dim,
                           dropout_W=0.0,
                           dropout_U=0.0,
                           consume_less="mem",
                           name="{}_writer".format(self.name))
        self.composer = Dense(self.output_dim * 2,
                              activation=self.composer_activation,
                              name="{}_composer".format(self.name))
        if return_mode not in [
                "last_output", "all_outputs", "output_and_memory"
        ]:
            raise Exception("Unrecognized return mode: %s" % (return_mode))
        self.return_mode = return_mode

    def get_output_shape_for(self, input_shape):
        input_length = input_shape[1]
        if self.return_mode == "last_output":
            return (input_shape[0], self.output_dim)
        elif self.return_mode == "all_outputs":
            return (input_shape[0], input_length, self.output_dim)
        else:
            # return_mode is output_and_memory. Output will be concatenated to memory.
            return (input_shape[0], input_length + 1, self.output_dim)

    def compute_mask(self, input, mask):
        if mask is None or self.return_mode == "last_output":
            return None
        elif self.return_mode == "all_outputs":
            return mask  # (batch_size, input_length)
        else:
            # Return mode is output_and_memory
            # Mask memory corresponding to all the inputs that are masked, and do not mask the output
            # (batch_size, input_length + 1)
            return K.cast(K.concatenate([K.zeros_like(mask[:, :1]), mask]),
                          'uint8')

    def get_composer_input_shape(self, input_shape):
        # Takes concatenation of output and memory summary
        return (input_shape[0], self.output_dim * 2)

    def get_reader_input_shape(self, input_shape):
        return input_shape

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        input_dim = input_shape[-1]
        assert self.reader.return_sequences, "The reader has to return sequences!"
        reader_input_shape = self.get_reader_input_shape(input_shape)
        print >> sys.stderr, "NSE reader input shape:", reader_input_shape
        writer_input_shape = (input_shape[0], 1, self.output_dim * 2
                              )  # Will process one timestep at a time
        print >> sys.stderr, "NSE writer input shape:", writer_input_shape
        composer_input_shape = self.get_composer_input_shape(input_shape)
        print >> sys.stderr, "NSE composer input shape:", composer_input_shape
        self.reader.build(reader_input_shape)
        self.writer.build(writer_input_shape)
        self.composer.build(composer_input_shape)

        # Aggregate weights of individual components for this layer.
        reader_weights = self.reader.trainable_weights
        writer_weights = self.writer.trainable_weights
        composer_weights = self.composer.trainable_weights
        self.trainable_weights = reader_weights + writer_weights + composer_weights

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

    def read(self, nse_input, input_mask=None):
        '''
        This method produces the 'read' output (equation 1 in the paper) for all timesteps
        and initializes the memory slot mem_0.

        Input: nse_input (batch_size, input_length, input_dim)
        Outputs:
            o (batch_size, input_length, output_dim)
            flattened_mem_0 (batch_size, input_length * output_dim)
 
        While this method simply copies input to mem_0, variants that inherit from this class can do
        something fancier.
        '''
        input_to_read = nse_input
        mem_0 = input_to_read
        flattened_mem_0 = K.batch_flatten(mem_0)
        o = self.reader.call(input_to_read, input_mask)
        o_mask = self.reader.compute_mask(input_to_read, input_mask)
        return o, [flattened_mem_0], o_mask

    @staticmethod
    def summarize_memory(o_t, mem_tm1):
        '''
        This method selects the relevant parts of the memory given the read output and summarizes the
        memory. Implements Equations 2-3 or 8-11 in the paper.
        '''
        # Selecting relevant memory slots, Equation 2
        z_t = K.softmax(K.sum(K.expand_dims(o_t, dim=1) * mem_tm1,
                              axis=2))  # (batch_size, input_length)
        # Summarizing memory, Equation 3
        m_rt = K.sum(K.expand_dims(z_t, dim=2) * mem_tm1,
                     axis=1)  # (batch_size, output_dim)
        return z_t, m_rt

    def compose_memory_and_output(self, output_memory_list):
        '''
        This method takes a list of tensors and applies the composition function on their concatrnation.
        Implements equation 4 or 12 in the paper.
        '''
        # Composition, Equation 4
        c_t = self.composer.call(
            K.concatenate(output_memory_list))  # (batch_size, output_dim)
        return c_t

    def update_memory(self, z_t, h_t, mem_tm1):
        '''
        This method takes the attention vector (z_t), writer output (h_t) and previous timestep's memory (mem_tm1)
        and updates the memory. Implements equations 6, 14 or 15.
        '''
        tiled_z_t = K.tile(
            K.expand_dims(z_t),
            (self.output_dim))  # (batch_size, input_length, output_dim)
        input_length = K.shape(mem_tm1)[1]
        # (batch_size, input_length, output_dim)
        tiled_h_t = K.permute_dimensions(
            K.tile(K.expand_dims(h_t), (input_length)), (0, 2, 1))
        # Updating memory. First term in summation corresponds to selective forgetting and the second term to
        # selective addition. Equation 6.
        mem_t = mem_tm1 * (
            1 - tiled_z_t
        ) + tiled_h_t * tiled_z_t  # (batch_size, input_length, output_dim)
        return mem_t

    def compose_and_write_step(self, o_t, states):
        '''
        This method is a step function that updates the memory at each time step and produces
        a new output vector (Equations 2 to 6 in the paper).
        The memory_state is flattened because K.rnn requires all states to be of the same shape as the output,
        because it uses the same mask for the output and the states.
        Inputs:
            o_t (batch_size, output_dim)
            states (list[Tensor])
                flattened_mem_tm1 (batch_size, input_length * output_dim)
                writer_h_tm1 (batch_size, output_dim)
                writer_c_tm1 (batch_size, output_dim)

        Outputs:
            h_t (batch_size, output_dim)
            flattened_mem_t (batch_size, input_length * output_dim)
        '''
        flattened_mem_tm1, writer_h_tm1, writer_c_tm1 = states
        input_mem_shape = K.shape(flattened_mem_tm1)
        mem_tm1_shape = (input_mem_shape[0],
                         input_mem_shape[1] / self.output_dim, self.output_dim)
        mem_tm1 = K.reshape(
            flattened_mem_tm1,
            mem_tm1_shape)  # (batch_size, input_length, output_dim)
        z_t, m_rt = self.summarize_memory(o_t, mem_tm1)
        c_t = self.compose_memory_and_output([o_t, m_rt])
        # Collecting the necessary variables to directly call writer's step function.
        writer_constants = self.writer.get_constants(
            c_t)  # returns dropouts for W and U (all 1s, see init)
        writer_states = [writer_h_tm1, writer_c_tm1] + writer_constants
        # Making a call to writer's step function, Equation 5
        h_t, [_, writer_c_t] = self.writer.step(
            c_t, writer_states)  # h_t, writer_c_t: (batch_size, output_dim)
        mem_t = self.update_memory(z_t, h_t, mem_tm1)
        flattened_mem_t = K.batch_flatten(mem_t)
        return h_t, [flattened_mem_t, h_t, writer_c_t]

    def call(self, x, mask=None):
        # input_shape = (batch_size, input_length, input_dim). This needs to be defined in build.
        read_output, initial_memory_states, output_mask = self.read(x, mask)
        initial_write_states = self.writer.get_initial_states(
            read_output)  # h_0 and c_0 of the writer LSTM
        initial_states = initial_memory_states + initial_write_states
        # last_output: (batch_size, output_dim)
        # all_outputs: (batch_size, input_length, output_dim)
        # last_states:
        #       last_memory_state: (batch_size, input_length, output_dim)
        #       last_output
        #       last_writer_ct
        last_output, all_outputs, last_states = K.rnn(
            self.compose_and_write_step,
            read_output,
            initial_states,
            mask=output_mask)
        last_memory = last_states[0]
        if self.return_mode == "last_output":
            return last_output
        elif self.return_mode == "all_outputs":
            return all_outputs
        else:
            # return mode is output_and_memory
            expanded_last_output = K.expand_dims(
                last_output, dim=1)  # (batch_size, 1, output_dim)
            # (batch_size, 1+input_length, output_dim)
            return K.concatenate([expanded_last_output, last_memory], axis=1)

    def get_config(self):
        config = {
            'output_dim': self.output_dim,
            'input_length': self.input_length,
            'composer_activation': self.composer_activation,
            'return_mode': self.return_mode
        }
        base_config = super(NSE, self).get_config()
        config.update(base_config)
        return config
Exemple #3
0
class NSE(Layer):
    '''
    Simple Neural Semantic Encoder.
    '''
    def __init__(self,
                 output_dim,
                 input_length=None,
                 composer_activation='linear',
                 return_mode='last_output',
                 weights=None,
                 **kwargs):
        '''
        Arguments:
        output_dim (int)
        input_length (int)
        composer_activation (str): activation used in the MLP
        return_mode (str): One of last_output, all_outputs, output_and_memory
            This is analogous to the return_sequences flag in Keras' Recurrent.
            last_output returns only the last h_t
            all_outputs returns the whole sequence of h_ts
            output_and_memory returns the last output and the last memory concatenated
                (needed if this layer is followed by a MMA-NSE)
        weights (list): Initial weights
        '''
        self.output_dim = output_dim
        self.input_dim = output_dim  # Equation 2 in the paper makes this assumption.
        self.initial_weights = weights
        self.input_spec = [InputSpec(ndim=3)]
        self.input_length = input_length
        self.composer_activation = composer_activation
        super(NSE, self).__init__(**kwargs)
        self.reader = LSTM(self.output_dim,
                           dropout_W=0.0,
                           dropout_U=0.0,
                           consume_less="gpu",
                           name="{}_reader".format(self.name))
        # TODO: Let the writer use parameter dropout and any consume_less mode.
        # Setting dropout to 0 here to eliminate the need for constants.
        # Setting consume_less to gpu to eliminate need for preprocessing
        self.writer = LSTM(self.output_dim,
                           dropout_W=0.0,
                           dropout_U=0.0,
                           consume_less="gpu",
                           name="{}_writer".format(self.name))
        self.composer = Dense(self.output_dim * 2,
                              activation=self.composer_activation,
                              name="{}_composer".format(self.name))
        if return_mode not in [
                "last_output", "all_outputs", "output_and_memory"
        ]:
            raise Exception("Unrecognized return mode: %s" % (return_mode))
        print("vj golden NSE.__init__ return_mode is {}".format(return_mode))
        self.return_mode = return_mode

    def get_output_shape_for(self, input_shape):
        input_length = input_shape[1]
        if self.return_mode == "last_output":
            return (input_shape[0], self.output_dim)
        elif self.return_mode == "all_outputs":
            return (input_shape[0], input_length, self.output_dim)
        else:
            # return_mode is output_and_memory. Output will be concatenated to memory.
            return (input_shape[0], input_length + 1, self.output_dim)

    def compute_mask(self, input, mask):
        if mask is None or self.return_mode == "last_output":
            return None
        elif self.return_mode == "all_outputs":
            return mask  # (batch_size, input_length)
        else:
            # Return mode is output_and_memory
            # Mask memory corresponding to all the inputs that are masked, and do not mask the output
            # (batch_size, input_length + 1)
            return K.cast(K.concatenate([K.zeros_like(mask[:, :1]), mask]),
                          'uint8')

    def get_composer_input_shape(self, input_shape):
        # Takes concatenation of output and memory summary
        return (input_shape[0], self.output_dim * 2)

    def get_reader_input_shape(self, input_shape):
        return input_shape

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        input_dim = input_shape[-1]
        reader_input_shape = self.get_reader_input_shape(input_shape)
        print >> sys.stderr, "NSE reader input shape:", reader_input_shape
        writer_input_shape = (input_shape[0], 1, self.output_dim * 2
                              )  # Will process one timestep at a time
        print >> sys.stderr, "NSE writer input shape:", writer_input_shape
        composer_input_shape = self.get_composer_input_shape(input_shape)
        print >> sys.stderr, "NSE composer input shape:", composer_input_shape
        self.reader.build(reader_input_shape)
        self.writer.build(writer_input_shape)
        self.composer.build(composer_input_shape)

        # Aggregate weights of individual components for this layer.
        reader_weights = self.reader.trainable_weights
        writer_weights = self.writer.trainable_weights
        composer_weights = self.composer.trainable_weights
        self.trainable_weights = reader_weights + writer_weights + composer_weights

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

    def get_initial_states(self, nse_input, input_mask=None):
        '''
        This method produces the 'read' mask for all timesteps
        and initializes the memory slot mem_0.

        Input: nse_input (batch_size, input_length, input_dim)
        Output: list[Tensors]:
                h_0 (batch_size, output_dim)
                c_0 (batch_size, output_dim)
                flattened_mem_0 (batch_size, input_length * output_dim)
 
        While this method simply copies input to mem_0, variants that inherit from this class can do
        something fancier.
        '''
        input_to_read = nse_input
        mem_0 = input_to_read
        flattened_mem_0 = K.batch_flatten(mem_0)
        flattened_mem_0 = TF_PRINT(flattened_mem_0,
                                   "get_initial_states.flattened_mem_0",
                                   expected_shape=[BATCH, LENGTH * DIM])
        initial_states = self.reader.get_initial_states(nse_input)

        initial_states += [flattened_mem_0]

        return initial_states

    @staticmethod
    def summarize_memory(o_t, mem_tm1):
        '''
        This method selects the relevant parts of the memory given the read output and summarizes the
        memory. Implements Equations 2-3 or 8-11 in the paper.
        '''
        # Selecting relevant memory slots, Equation 2
        z_t = K.softmax(K.sum(K.expand_dims(o_t, dim=1) * mem_tm1,
                              axis=2))  # (batch_size, input_length)
        z_t = TF_PRINT(z_t,
                       "summarize_memory.z_t",
                       expected_shape=[BATCH, LENGTH])

        # Summarizing memory, Equation 3
        m_rt = K.sum(K.expand_dims(z_t, dim=2) * mem_tm1,
                     axis=1)  # (batch_size, output_dim)
        m_rt = TF_PRINT(m_rt,
                        "summarize_memory.m_rt",
                        expected_shape=[BATCH, DIM])
        return z_t, m_rt

    def compose_memory_and_output(self, output_memory_list):
        '''
        This method takes a list of tensors and applies the composition function on their concatrnation.
        Implements equation 4 or 12 in the paper.
        '''
        # Composition, Equation 4
        c_t = self.composer.call(
            K.concatenate(output_memory_list))  # (batch_size, output_dim)
        c_t = TF_PRINT(c_t,
                       "compose_memory_and_output.c_t",
                       expected_shape=[BATCH, DIM])
        return c_t

    def update_memory(self, z_t, h_t, mem_tm1):
        '''
        This method takes the attention vector (z_t), writer output (h_t) and previous timestep's memory (mem_tm1)
        and updates the memory. Implements equations 6, 14 or 15.
        '''
        """ 
        The following is written assuming the equations in the paper are implemented as they are written:
        tiled_z_t_trans = K.tile(K.expand_dims(z_t,1), [1,self.output_dim,1])  # (batch_size, input_length, output_dim)
        input_length = K.shape(mem_tm1)[1]
        # (batch_size, input_length, output_dim)
#        tiled_h_t = K.permute_dimensions(K.tile(K.expand_dims(h_t, -1), [1,input_length]), (0, 2, 1))
        tiled_h_t = K.tile(K.expand_dims(h_t, -1), [1,1, input_length])
# Updating memory. First term in summation corresponds to selective forgetting and the second term to
        # selective addition. Equation 6.
        mem_t = mem_tm1 * (1 - tiled_z_t_trans) + tiled_h_t * tiled_z_t_trans  # (batch_size, input_length, output_dim)
        """
        """ 
        The following code assumes that mem_t is actually the transpose of what is in the paper.
        Implemented by simply wrapping a K.permute_dimensions(_, (0, 2, 1)) call around the original value.
        """
        tiled_z_t = K.permute_dimensions(
            K.tile(K.expand_dims(z_t, 1), [1, self.output_dim, 1]),
            (0, 2, 1))  # (batch_size, input_length, output_dim)
        input_length = K.shape(mem_tm1)[1]
        # (batch_size, input_length, output_dim)
        #        tiled_h_t = K.permute_dimensions(K.tile(K.expand_dims(h_t, -1), [1,input_length]), (0, 2, 1))
        tiled_h_t = K.permute_dimensions(
            K.tile(K.expand_dims(h_t, -1), [1, 1, input_length]), (0, 2, 1))

        # Updating memory. First term in summation corresponds to selective forgetting and the second term to
        # selective addition. Equation 6.
        mem_t = mem_tm1 * (
            1 - tiled_z_t
        ) + tiled_h_t * tiled_z_t  # (batch_size, input_length, output_dim)
        mem_t = TF_PRINT(mem_t,
                         "update_memory.mem_t",
                         expected_shape=[BATCH, LENGTH, DIM])

        return mem_t

    @staticmethod
    def split_states(states):
        # This method is a helper for the step function to split the states into reader states, memory and
        # awrite states.
        return states[:2], states[2], states[3:]

    def step(self, input_t, states):
        '''
        This method is a step function that updates the memory at each time step and produces
        a new output vector (Equations 1 to 6 in the paper).
        The memory_state is flattened because K.rnn requires all states to be of the same shape as the output,
        because it uses the same mask for the output and the states.
        Inputs:
            input_t (batch_size, input_dim)
            states (list[Tensor])
                flattened_mem_tm1 (batch_size, input_length * output_dim)
                writer_h_tm1 (batch_size, output_dim)
                writer_c_tm1 (batch_size, output_dim)

        Outputs:
            h_t (batch_size, output_dim)
            flattened_mem_t (batch_size, input_length * output_dim)
        '''
        input_t = TF_PRINT(input_t,
                           "step.input_t",
                           expected_shape=[BATCH, DIM])

        reader_states, flattened_mem_tm1, writer_states = self.split_states(
            states)
        input_mem_shape = K.shape(flattened_mem_tm1)
        mem_tm1_shape = (input_mem_shape[0],
                         input_mem_shape[1] / self.output_dim, self.output_dim)

        mem_tm1 = K.reshape(
            flattened_mem_tm1,
            mem_tm1_shape)  # (batch_size, input_length, output_dim)
        mem_tm1 = TF_PRINT(mem_tm1,
                           "step.mem_tm1",
                           expected_shape=[BATCH, LENGTH, DIM])

        reader_constants = self.reader.get_constants(
            input_t)  # Does not depend on input_t, see init.
        reader_states = reader_states[:2] + tuple(
            reader_constants) + reader_states[2:]
        o_t, [_, reader_c_t] = self.reader.step(
            input_t,
            reader_states)  # o_t, reader_c_t: (batch_size, output_dim)

        o_t = TF_PRINT(o_t, "step.o_t", expected_shape=[BATCH, DIM])
        reader_c_t = TF_PRINT(reader_c_t,
                              "step.reader_c_t",
                              expected_shape=[BATCH, DIM])

        z_t, m_rt = self.summarize_memory(o_t, mem_tm1)
        c_t = self.compose_memory_and_output([o_t, m_rt])

        # Collecting the necessary variables to directly call writer's step function.
        writer_constants = self.writer.get_constants(
            c_t)  # returns dropouts for W and U (all 1s, see init)
        writer_states += tuple(writer_constants)

        # Making a call to writer's step function, Equation 5
        h_t, [_, writer_c_t] = self.writer.step(
            c_t, writer_states)  # h_t, writer_c_t: (batch_size, output_dim)

        h_t = TF_PRINT(h_t, "step.h_t", expected_shape=[BATCH, DIM])
        writer_c_t = TF_PRINT(writer_c_t,
                              "step.writer_c_t",
                              expected_shape=[BATCH, DIM])

        mem_t = self.update_memory(z_t, h_t, mem_tm1)

        flattened_mem_t = K.batch_flatten(mem_t)
        flattened_mem_t = TF_PRINT(flattened_mem_t,
                                   "step.flattened_mem_t",
                                   expected_shape=[BATCH, LENGTH * DIM])

        return h_t, [o_t, reader_c_t, flattened_mem_t, h_t, writer_c_t]

    def loop(self, x, initial_states, mask):
        # This is a separate method because Ontoaware variants will have to override this to make a call
        # to changingdim rnn.

        last_output, all_outputs, last_states = K.rnn(self.step,
                                                      x,
                                                      initial_states,
                                                      mask=mask)
        last_output = TF_PRINT(last_output, "loop.last_output")
        all_outputs = TF_PRINT(all_outputs, "loop.all_outputs")
        #        last_states = TF_PRINT(last_states, "loop.last_states")
        return last_output, all_outputs, last_states

    def call(self, x, mask=None):
        # input_shape = (batch_size, input_length, input_dim). This needs to be defined in build.
        if mask != None:
            print("vj golden call.mask ={}. Being set to None.".format(mask))
            mask = None
        initial_read_states = self.get_initial_states(x, mask)

        fake_writer_input = K.expand_dims(initial_read_states[0],
                                          dim=1)  # (batch_size, 1, output_dim)
        fake_writer_input = TF_PRINT(fake_writer_input,
                                     "call.fake_writer_input",
                                     expected_shape=[BATCH, 1, DIM])

        initial_write_states = self.writer.get_initial_states(
            fake_writer_input)  # h_0 and c_0 of the writer LSTM
        initial_states = initial_read_states + initial_write_states

        # last_output: (batch_size, output_dim)
        # all_outputs: (batch_size, input_length, output_dim)
        # last_states:
        #       last_memory_state: (batch_size, input_length, output_dim)
        #       last_output
        #       last_writer_ct
        last_output, all_outputs, last_states = self.loop(
            x, initial_states, mask)
        last_memory = last_states[0]

        if self.return_mode == "last_output":
            return last_output
        elif self.return_mode == "all_outputs":
            return all_outputs
        else:
            # return mode is output_and_memory
            expanded_last_output = K.expand_dims(
                last_output, dim=1)  # (batch_size, 1, output_dim)
            expanded_last_output = TF_PRINT(expanded_last_output,
                                            "call.expanded_last_output",
                                            expected_size=[BATCH, 1, DIM])
            # (batch_size, 1+input_length, output_dim)
            result = K.concatenate([expanded_last_output, last_memory], axis=1)
            result = TF_PRINT(result,
                              "call.result",
                              expected_size=[BATCH, 1 + LENGTH, DIM])
            return result

    def get_config(self):
        config = {
            'output_dim': self.output_dim,
            'input_length': self.input_length,
            'composer_activation': self.composer_activation,
            'return_mode': self.return_mode
        }
        base_config = super(NSE, self).get_config()
        config.update(base_config)
        return config