class TFiLM(Layer): def __init__(self, block_size, **kwargs): self.block_size = block_size super(tFiLM, self).__init__(**kwargs) def make_normalizer(self, x_in): """ Pools to downsample along 'temporal' dimension and then runs LSTM to generate normalization weights. """ x_in_down = (MaxPooling1D(pool_size=self.block_size, padding='valid'))(x_in) x_rnn = self.rnn(x_in_down) return x_rnn def apply_normalizer(self, x_in, x_norm): """ Applies normalization weights by multiplying them into their respective blocks. """ n_blocks = K.shape(x_in)[1] / self.block_size n_filters = K.shape(x_in)[2] # reshape input into blocks x_norm = K.reshape(x_norm, shape=(-1, n_blocks, 1, n_filters)) x_in = K.reshape(x_in, shape=(-1, n_blocks, self.block_size, n_filters)) # multiply x_out = x_norm * x_in # return to original shape x_out = K.reshape(x_out, shape=(-1, n_blocks * self.block_size, n_filters)) return x_out def build(self, input_shape): self.rnn = LSTM(units=input_shape[2], return_sequences=True, trainable=True) self.rnn.build(input_shape) self._trainable_weights = self.rnn.trainable_weights super(tFiLM, self).build(input_shape) # Be sure to call this at the end def call(self, x): assert len(x.shape) == 3, 'Input should be tensor with dimension \ (batch_size, steps, num_features).' assert x.shape[1] % self.block_size == 0, 'Number of steps must be a \ multiple of the block size.' x_norm = self.make_normalizer(x) x = self.apply_normalizer(x, x_norm) return x def compute_output_shape(self, input_shape): return input_shape
class NSE(Layer): ''' Simple Neural Semantic Encoder. ''' def __init__(self, output_dim, input_length=None, composer_activation='linear', return_mode='last_output', weights=None, **kwargs): ''' Arguments: output_dim (int) input_length (int) composer_activation (str): activation used in the MLP return_mode (str): One of last_output, all_outputs, output_and_memory This is analogous to the return_sequences flag in Keras' Recurrent. last_output returns only the last h_t all_outputs returns the whole sequence of h_ts output_and_memory returns the last output and the last memory concatenated (needed if this layer is followed by a MMA-NSE) weights (list): Initial weights ''' self.output_dim = output_dim self.input_dim = output_dim # Equation 2 in the paper makes this assumption. self.initial_weights = weights self.input_spec = [InputSpec(ndim=3)] self.input_length = input_length self.composer_activation = composer_activation super(NSE, self).__init__(**kwargs) self.reader = LSTM(self.output_dim, return_sequences=True, name="{}_reader".format(self.name)) # TODO: Let the writer use parameter dropout and any consume_less mode. # Setting dropout to 0 here to eliminate the need for constants. # Setting consume_less to mem to eliminate need for preprocessing self.writer = LSTM(self.output_dim, dropout_W=0.0, dropout_U=0.0, consume_less="mem", name="{}_writer".format(self.name)) self.composer = Dense(self.output_dim * 2, activation=self.composer_activation, name="{}_composer".format(self.name)) if return_mode not in [ "last_output", "all_outputs", "output_and_memory" ]: raise Exception("Unrecognized return mode: %s" % (return_mode)) self.return_mode = return_mode def get_output_shape_for(self, input_shape): input_length = input_shape[1] if self.return_mode == "last_output": return (input_shape[0], self.output_dim) elif self.return_mode == "all_outputs": return (input_shape[0], input_length, self.output_dim) else: # return_mode is output_and_memory. Output will be concatenated to memory. return (input_shape[0], input_length + 1, self.output_dim) def compute_mask(self, input, mask): if mask is None or self.return_mode == "last_output": return None elif self.return_mode == "all_outputs": return mask # (batch_size, input_length) else: # Return mode is output_and_memory # Mask memory corresponding to all the inputs that are masked, and do not mask the output # (batch_size, input_length + 1) return K.cast(K.concatenate([K.zeros_like(mask[:, :1]), mask]), 'uint8') def get_composer_input_shape(self, input_shape): # Takes concatenation of output and memory summary return (input_shape[0], self.output_dim * 2) def get_reader_input_shape(self, input_shape): return input_shape def build(self, input_shape): self.input_spec = [InputSpec(shape=input_shape)] input_dim = input_shape[-1] assert self.reader.return_sequences, "The reader has to return sequences!" reader_input_shape = self.get_reader_input_shape(input_shape) print >> sys.stderr, "NSE reader input shape:", reader_input_shape writer_input_shape = (input_shape[0], 1, self.output_dim * 2 ) # Will process one timestep at a time print >> sys.stderr, "NSE writer input shape:", writer_input_shape composer_input_shape = self.get_composer_input_shape(input_shape) print >> sys.stderr, "NSE composer input shape:", composer_input_shape self.reader.build(reader_input_shape) self.writer.build(writer_input_shape) self.composer.build(composer_input_shape) # Aggregate weights of individual components for this layer. reader_weights = self.reader.trainable_weights writer_weights = self.writer.trainable_weights composer_weights = self.composer.trainable_weights self.trainable_weights = reader_weights + writer_weights + composer_weights if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights def read(self, nse_input, input_mask=None): ''' This method produces the 'read' output (equation 1 in the paper) for all timesteps and initializes the memory slot mem_0. Input: nse_input (batch_size, input_length, input_dim) Outputs: o (batch_size, input_length, output_dim) flattened_mem_0 (batch_size, input_length * output_dim) While this method simply copies input to mem_0, variants that inherit from this class can do something fancier. ''' input_to_read = nse_input mem_0 = input_to_read flattened_mem_0 = K.batch_flatten(mem_0) o = self.reader.call(input_to_read, input_mask) o_mask = self.reader.compute_mask(input_to_read, input_mask) return o, [flattened_mem_0], o_mask @staticmethod def summarize_memory(o_t, mem_tm1): ''' This method selects the relevant parts of the memory given the read output and summarizes the memory. Implements Equations 2-3 or 8-11 in the paper. ''' # Selecting relevant memory slots, Equation 2 z_t = K.softmax(K.sum(K.expand_dims(o_t, dim=1) * mem_tm1, axis=2)) # (batch_size, input_length) # Summarizing memory, Equation 3 m_rt = K.sum(K.expand_dims(z_t, dim=2) * mem_tm1, axis=1) # (batch_size, output_dim) return z_t, m_rt def compose_memory_and_output(self, output_memory_list): ''' This method takes a list of tensors and applies the composition function on their concatrnation. Implements equation 4 or 12 in the paper. ''' # Composition, Equation 4 c_t = self.composer.call( K.concatenate(output_memory_list)) # (batch_size, output_dim) return c_t def update_memory(self, z_t, h_t, mem_tm1): ''' This method takes the attention vector (z_t), writer output (h_t) and previous timestep's memory (mem_tm1) and updates the memory. Implements equations 6, 14 or 15. ''' tiled_z_t = K.tile( K.expand_dims(z_t), (self.output_dim)) # (batch_size, input_length, output_dim) input_length = K.shape(mem_tm1)[1] # (batch_size, input_length, output_dim) tiled_h_t = K.permute_dimensions( K.tile(K.expand_dims(h_t), (input_length)), (0, 2, 1)) # Updating memory. First term in summation corresponds to selective forgetting and the second term to # selective addition. Equation 6. mem_t = mem_tm1 * ( 1 - tiled_z_t ) + tiled_h_t * tiled_z_t # (batch_size, input_length, output_dim) return mem_t def compose_and_write_step(self, o_t, states): ''' This method is a step function that updates the memory at each time step and produces a new output vector (Equations 2 to 6 in the paper). The memory_state is flattened because K.rnn requires all states to be of the same shape as the output, because it uses the same mask for the output and the states. Inputs: o_t (batch_size, output_dim) states (list[Tensor]) flattened_mem_tm1 (batch_size, input_length * output_dim) writer_h_tm1 (batch_size, output_dim) writer_c_tm1 (batch_size, output_dim) Outputs: h_t (batch_size, output_dim) flattened_mem_t (batch_size, input_length * output_dim) ''' flattened_mem_tm1, writer_h_tm1, writer_c_tm1 = states input_mem_shape = K.shape(flattened_mem_tm1) mem_tm1_shape = (input_mem_shape[0], input_mem_shape[1] / self.output_dim, self.output_dim) mem_tm1 = K.reshape( flattened_mem_tm1, mem_tm1_shape) # (batch_size, input_length, output_dim) z_t, m_rt = self.summarize_memory(o_t, mem_tm1) c_t = self.compose_memory_and_output([o_t, m_rt]) # Collecting the necessary variables to directly call writer's step function. writer_constants = self.writer.get_constants( c_t) # returns dropouts for W and U (all 1s, see init) writer_states = [writer_h_tm1, writer_c_tm1] + writer_constants # Making a call to writer's step function, Equation 5 h_t, [_, writer_c_t] = self.writer.step( c_t, writer_states) # h_t, writer_c_t: (batch_size, output_dim) mem_t = self.update_memory(z_t, h_t, mem_tm1) flattened_mem_t = K.batch_flatten(mem_t) return h_t, [flattened_mem_t, h_t, writer_c_t] def call(self, x, mask=None): # input_shape = (batch_size, input_length, input_dim). This needs to be defined in build. read_output, initial_memory_states, output_mask = self.read(x, mask) initial_write_states = self.writer.get_initial_states( read_output) # h_0 and c_0 of the writer LSTM initial_states = initial_memory_states + initial_write_states # last_output: (batch_size, output_dim) # all_outputs: (batch_size, input_length, output_dim) # last_states: # last_memory_state: (batch_size, input_length, output_dim) # last_output # last_writer_ct last_output, all_outputs, last_states = K.rnn( self.compose_and_write_step, read_output, initial_states, mask=output_mask) last_memory = last_states[0] if self.return_mode == "last_output": return last_output elif self.return_mode == "all_outputs": return all_outputs else: # return mode is output_and_memory expanded_last_output = K.expand_dims( last_output, dim=1) # (batch_size, 1, output_dim) # (batch_size, 1+input_length, output_dim) return K.concatenate([expanded_last_output, last_memory], axis=1) def get_config(self): config = { 'output_dim': self.output_dim, 'input_length': self.input_length, 'composer_activation': self.composer_activation, 'return_mode': self.return_mode } base_config = super(NSE, self).get_config() config.update(base_config) return config
class NSE(Layer): ''' Simple Neural Semantic Encoder. ''' def __init__(self, output_dim, input_length=None, composer_activation='linear', return_mode='last_output', weights=None, **kwargs): ''' Arguments: output_dim (int) input_length (int) composer_activation (str): activation used in the MLP return_mode (str): One of last_output, all_outputs, output_and_memory This is analogous to the return_sequences flag in Keras' Recurrent. last_output returns only the last h_t all_outputs returns the whole sequence of h_ts output_and_memory returns the last output and the last memory concatenated (needed if this layer is followed by a MMA-NSE) weights (list): Initial weights ''' self.output_dim = output_dim self.input_dim = output_dim # Equation 2 in the paper makes this assumption. self.initial_weights = weights self.input_spec = [InputSpec(ndim=3)] self.input_length = input_length self.composer_activation = composer_activation super(NSE, self).__init__(**kwargs) self.reader = LSTM(self.output_dim, dropout_W=0.0, dropout_U=0.0, consume_less="gpu", name="{}_reader".format(self.name)) # TODO: Let the writer use parameter dropout and any consume_less mode. # Setting dropout to 0 here to eliminate the need for constants. # Setting consume_less to gpu to eliminate need for preprocessing self.writer = LSTM(self.output_dim, dropout_W=0.0, dropout_U=0.0, consume_less="gpu", name="{}_writer".format(self.name)) self.composer = Dense(self.output_dim * 2, activation=self.composer_activation, name="{}_composer".format(self.name)) if return_mode not in [ "last_output", "all_outputs", "output_and_memory" ]: raise Exception("Unrecognized return mode: %s" % (return_mode)) print("vj golden NSE.__init__ return_mode is {}".format(return_mode)) self.return_mode = return_mode def get_output_shape_for(self, input_shape): input_length = input_shape[1] if self.return_mode == "last_output": return (input_shape[0], self.output_dim) elif self.return_mode == "all_outputs": return (input_shape[0], input_length, self.output_dim) else: # return_mode is output_and_memory. Output will be concatenated to memory. return (input_shape[0], input_length + 1, self.output_dim) def compute_mask(self, input, mask): if mask is None or self.return_mode == "last_output": return None elif self.return_mode == "all_outputs": return mask # (batch_size, input_length) else: # Return mode is output_and_memory # Mask memory corresponding to all the inputs that are masked, and do not mask the output # (batch_size, input_length + 1) return K.cast(K.concatenate([K.zeros_like(mask[:, :1]), mask]), 'uint8') def get_composer_input_shape(self, input_shape): # Takes concatenation of output and memory summary return (input_shape[0], self.output_dim * 2) def get_reader_input_shape(self, input_shape): return input_shape def build(self, input_shape): self.input_spec = [InputSpec(shape=input_shape)] input_dim = input_shape[-1] reader_input_shape = self.get_reader_input_shape(input_shape) print >> sys.stderr, "NSE reader input shape:", reader_input_shape writer_input_shape = (input_shape[0], 1, self.output_dim * 2 ) # Will process one timestep at a time print >> sys.stderr, "NSE writer input shape:", writer_input_shape composer_input_shape = self.get_composer_input_shape(input_shape) print >> sys.stderr, "NSE composer input shape:", composer_input_shape self.reader.build(reader_input_shape) self.writer.build(writer_input_shape) self.composer.build(composer_input_shape) # Aggregate weights of individual components for this layer. reader_weights = self.reader.trainable_weights writer_weights = self.writer.trainable_weights composer_weights = self.composer.trainable_weights self.trainable_weights = reader_weights + writer_weights + composer_weights if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights def get_initial_states(self, nse_input, input_mask=None): ''' This method produces the 'read' mask for all timesteps and initializes the memory slot mem_0. Input: nse_input (batch_size, input_length, input_dim) Output: list[Tensors]: h_0 (batch_size, output_dim) c_0 (batch_size, output_dim) flattened_mem_0 (batch_size, input_length * output_dim) While this method simply copies input to mem_0, variants that inherit from this class can do something fancier. ''' input_to_read = nse_input mem_0 = input_to_read flattened_mem_0 = K.batch_flatten(mem_0) flattened_mem_0 = TF_PRINT(flattened_mem_0, "get_initial_states.flattened_mem_0", expected_shape=[BATCH, LENGTH * DIM]) initial_states = self.reader.get_initial_states(nse_input) initial_states += [flattened_mem_0] return initial_states @staticmethod def summarize_memory(o_t, mem_tm1): ''' This method selects the relevant parts of the memory given the read output and summarizes the memory. Implements Equations 2-3 or 8-11 in the paper. ''' # Selecting relevant memory slots, Equation 2 z_t = K.softmax(K.sum(K.expand_dims(o_t, dim=1) * mem_tm1, axis=2)) # (batch_size, input_length) z_t = TF_PRINT(z_t, "summarize_memory.z_t", expected_shape=[BATCH, LENGTH]) # Summarizing memory, Equation 3 m_rt = K.sum(K.expand_dims(z_t, dim=2) * mem_tm1, axis=1) # (batch_size, output_dim) m_rt = TF_PRINT(m_rt, "summarize_memory.m_rt", expected_shape=[BATCH, DIM]) return z_t, m_rt def compose_memory_and_output(self, output_memory_list): ''' This method takes a list of tensors and applies the composition function on their concatrnation. Implements equation 4 or 12 in the paper. ''' # Composition, Equation 4 c_t = self.composer.call( K.concatenate(output_memory_list)) # (batch_size, output_dim) c_t = TF_PRINT(c_t, "compose_memory_and_output.c_t", expected_shape=[BATCH, DIM]) return c_t def update_memory(self, z_t, h_t, mem_tm1): ''' This method takes the attention vector (z_t), writer output (h_t) and previous timestep's memory (mem_tm1) and updates the memory. Implements equations 6, 14 or 15. ''' """ The following is written assuming the equations in the paper are implemented as they are written: tiled_z_t_trans = K.tile(K.expand_dims(z_t,1), [1,self.output_dim,1]) # (batch_size, input_length, output_dim) input_length = K.shape(mem_tm1)[1] # (batch_size, input_length, output_dim) # tiled_h_t = K.permute_dimensions(K.tile(K.expand_dims(h_t, -1), [1,input_length]), (0, 2, 1)) tiled_h_t = K.tile(K.expand_dims(h_t, -1), [1,1, input_length]) # Updating memory. First term in summation corresponds to selective forgetting and the second term to # selective addition. Equation 6. mem_t = mem_tm1 * (1 - tiled_z_t_trans) + tiled_h_t * tiled_z_t_trans # (batch_size, input_length, output_dim) """ """ The following code assumes that mem_t is actually the transpose of what is in the paper. Implemented by simply wrapping a K.permute_dimensions(_, (0, 2, 1)) call around the original value. """ tiled_z_t = K.permute_dimensions( K.tile(K.expand_dims(z_t, 1), [1, self.output_dim, 1]), (0, 2, 1)) # (batch_size, input_length, output_dim) input_length = K.shape(mem_tm1)[1] # (batch_size, input_length, output_dim) # tiled_h_t = K.permute_dimensions(K.tile(K.expand_dims(h_t, -1), [1,input_length]), (0, 2, 1)) tiled_h_t = K.permute_dimensions( K.tile(K.expand_dims(h_t, -1), [1, 1, input_length]), (0, 2, 1)) # Updating memory. First term in summation corresponds to selective forgetting and the second term to # selective addition. Equation 6. mem_t = mem_tm1 * ( 1 - tiled_z_t ) + tiled_h_t * tiled_z_t # (batch_size, input_length, output_dim) mem_t = TF_PRINT(mem_t, "update_memory.mem_t", expected_shape=[BATCH, LENGTH, DIM]) return mem_t @staticmethod def split_states(states): # This method is a helper for the step function to split the states into reader states, memory and # awrite states. return states[:2], states[2], states[3:] def step(self, input_t, states): ''' This method is a step function that updates the memory at each time step and produces a new output vector (Equations 1 to 6 in the paper). The memory_state is flattened because K.rnn requires all states to be of the same shape as the output, because it uses the same mask for the output and the states. Inputs: input_t (batch_size, input_dim) states (list[Tensor]) flattened_mem_tm1 (batch_size, input_length * output_dim) writer_h_tm1 (batch_size, output_dim) writer_c_tm1 (batch_size, output_dim) Outputs: h_t (batch_size, output_dim) flattened_mem_t (batch_size, input_length * output_dim) ''' input_t = TF_PRINT(input_t, "step.input_t", expected_shape=[BATCH, DIM]) reader_states, flattened_mem_tm1, writer_states = self.split_states( states) input_mem_shape = K.shape(flattened_mem_tm1) mem_tm1_shape = (input_mem_shape[0], input_mem_shape[1] / self.output_dim, self.output_dim) mem_tm1 = K.reshape( flattened_mem_tm1, mem_tm1_shape) # (batch_size, input_length, output_dim) mem_tm1 = TF_PRINT(mem_tm1, "step.mem_tm1", expected_shape=[BATCH, LENGTH, DIM]) reader_constants = self.reader.get_constants( input_t) # Does not depend on input_t, see init. reader_states = reader_states[:2] + tuple( reader_constants) + reader_states[2:] o_t, [_, reader_c_t] = self.reader.step( input_t, reader_states) # o_t, reader_c_t: (batch_size, output_dim) o_t = TF_PRINT(o_t, "step.o_t", expected_shape=[BATCH, DIM]) reader_c_t = TF_PRINT(reader_c_t, "step.reader_c_t", expected_shape=[BATCH, DIM]) z_t, m_rt = self.summarize_memory(o_t, mem_tm1) c_t = self.compose_memory_and_output([o_t, m_rt]) # Collecting the necessary variables to directly call writer's step function. writer_constants = self.writer.get_constants( c_t) # returns dropouts for W and U (all 1s, see init) writer_states += tuple(writer_constants) # Making a call to writer's step function, Equation 5 h_t, [_, writer_c_t] = self.writer.step( c_t, writer_states) # h_t, writer_c_t: (batch_size, output_dim) h_t = TF_PRINT(h_t, "step.h_t", expected_shape=[BATCH, DIM]) writer_c_t = TF_PRINT(writer_c_t, "step.writer_c_t", expected_shape=[BATCH, DIM]) mem_t = self.update_memory(z_t, h_t, mem_tm1) flattened_mem_t = K.batch_flatten(mem_t) flattened_mem_t = TF_PRINT(flattened_mem_t, "step.flattened_mem_t", expected_shape=[BATCH, LENGTH * DIM]) return h_t, [o_t, reader_c_t, flattened_mem_t, h_t, writer_c_t] def loop(self, x, initial_states, mask): # This is a separate method because Ontoaware variants will have to override this to make a call # to changingdim rnn. last_output, all_outputs, last_states = K.rnn(self.step, x, initial_states, mask=mask) last_output = TF_PRINT(last_output, "loop.last_output") all_outputs = TF_PRINT(all_outputs, "loop.all_outputs") # last_states = TF_PRINT(last_states, "loop.last_states") return last_output, all_outputs, last_states def call(self, x, mask=None): # input_shape = (batch_size, input_length, input_dim). This needs to be defined in build. if mask != None: print("vj golden call.mask ={}. Being set to None.".format(mask)) mask = None initial_read_states = self.get_initial_states(x, mask) fake_writer_input = K.expand_dims(initial_read_states[0], dim=1) # (batch_size, 1, output_dim) fake_writer_input = TF_PRINT(fake_writer_input, "call.fake_writer_input", expected_shape=[BATCH, 1, DIM]) initial_write_states = self.writer.get_initial_states( fake_writer_input) # h_0 and c_0 of the writer LSTM initial_states = initial_read_states + initial_write_states # last_output: (batch_size, output_dim) # all_outputs: (batch_size, input_length, output_dim) # last_states: # last_memory_state: (batch_size, input_length, output_dim) # last_output # last_writer_ct last_output, all_outputs, last_states = self.loop( x, initial_states, mask) last_memory = last_states[0] if self.return_mode == "last_output": return last_output elif self.return_mode == "all_outputs": return all_outputs else: # return mode is output_and_memory expanded_last_output = K.expand_dims( last_output, dim=1) # (batch_size, 1, output_dim) expanded_last_output = TF_PRINT(expanded_last_output, "call.expanded_last_output", expected_size=[BATCH, 1, DIM]) # (batch_size, 1+input_length, output_dim) result = K.concatenate([expanded_last_output, last_memory], axis=1) result = TF_PRINT(result, "call.result", expected_size=[BATCH, 1 + LENGTH, DIM]) return result def get_config(self): config = { 'output_dim': self.output_dim, 'input_length': self.input_length, 'composer_activation': self.composer_activation, 'return_mode': self.return_mode } base_config = super(NSE, self).get_config() config.update(base_config) return config