class NTM(object): ''' Performs several operations relevant to the NTM: - builds the computation graph - trains the model - tests the model ''' def __init__(self, mem_size, input_size, output_size, session, num_heads=1, shift_range=3, name="NTM"): ''' Builds the computation graph for the Neural Turing Machine. The tasks from the original paper call for the NTM to take in a sequence of arrays, and produce some output. Let B = batch size, T = sequence length, and L = array length, then a single input sequence is a matrix of size [TxL]. A batch of these input sequences has size [BxTxL]. Arguments: mem_size - Tuple of integers corresponding to the number of storage locations and the dimension of each storage location (in the paper the memory matrix is NxM, mem_size refers to (N, M)). input_size - Integer number of elements in a single input vector (the value L). output_size - Integer number of elements in a single output vector. session - The TensorFlow session object that refers to the current computation graph. num_heads - The integer number of write heads the NTM uses (future feature). shift_range - The integer number of shift values that the read/write heads can perform, which corresponds to the direction and magnitude of the allowable shifts. Shift ranges and corresponding available shift directions/magnitudes: 3 => [-1, 0, 1] 4 => [-2, -1, 0, 1] 5 => [-2, -1, 0, 1, 2] name - A string name for the variable scope, for troubleshooting. ''' self.num_heads = 1 self.sess = session self.S = shift_range self.N, self.M = mem_size self.in_size = input_size self.out_size = output_size num_lstm_units = 100 self.dt=tf.float32 dt = self.dt N = self.N M = self.M S = self.S num_heads = self.num_heads with tf.variable_scope(name): self.feed_in = tf.placeholder(dtype=dt, shape=(None, None, input_size)) self.feed_out = tf.placeholder(dtype=dt, shape=(None, None, output_size)) self.feed_learning_rate = tf.placeholder(dtype=dt, shape=()) batch_size = tf.shape(self.feed_in)[0] seq_length = tf.shape(self.feed_in)[1] head_raw = self.controller(self.feed_in, batch_size, seq_length) self.ntm_cell = NTMCell(mem_size=(N, M), num_shifts=S) write_head, read_head = NTMCell.head_pieces( head_raw, mem_size=(N, M), num_shifts=S, axis=2) self.write_head, self.read_head = \ head_pieces_tuple_to_dict(write_head, read_head) self.ntm_init_state = tuple( [tf.placeholder(dtype=dt, shape=(None, s)) \ for s in self.ntm_cell.state_size]) self.ntm_reads, self.ntm_last_state = tf.nn.dynamic_rnn( cell=self.ntm_cell, initial_state=self.ntm_init_state, inputs=head_raw, dtype=dt) self.w_read = self.ntm_last_state[-2] self.w_write = self.ntm_last_state[-1] ntm_reads_flat = tf.reshape(self.ntm_reads, [-1, M]) L = tf.Variable(tf.random_normal([M, output_size])) b_L = tf.Variable(tf.random_normal([output_size,])) logits_flat = tf.matmul(ntm_reads_flat, L) + b_L targets_flat = tf.reshape(self.feed_out, [-1, output_size]) self.error = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=targets_flat, logits=logits_flat)) self.predictions = tf.sigmoid( tf.reshape(logits_flat, [batch_size, seq_length, output_size])) optimizer = tf.train.RMSPropOptimizer( learning_rate=self.feed_learning_rate, momentum=0.9) grads_and_vars = optimizer.compute_gradients(self.error) capped_grads = [(tf.clip_by_value(grad, -10., 10.), var) \ for grad, var in grads_and_vars] self.train_op = optimizer.apply_gradients(capped_grads) def controller(self, inputs, batch_size, seq_length, num_units=100): ''' Builds a single-layer LSTM controller that manipulates values in the memory matrix and helps produce output. This method should only be utilized by the class. Arguments: inputs - TF tensor containing data that is passed to the controller. batch_size - The number of sequences in a given training batch. seq_length - The length of the sequence being passed to the controller. num_units - The number of units inside of the LSTM controller. ''' N = self.N M = self.M S = self.S dt = self.dt num_heads = self.num_heads self.lstm_cell = tf.contrib.rnn.BasicLSTMCell( num_units=num_units, forget_bias=1.0) self.lstm_init_state = tuple( [tf.placeholder(dtype=dt, shape=(None, s)) for s in self.lstm_cell.state_size]) lstm_init_state = tf.contrib.rnn.LSTMStateTuple( self.lstm_init_state[0], self.lstm_init_state[1]) lstm_out_raw, self.lstm_last_state = tf.nn.dynamic_rnn( cell=self.lstm_cell, initial_state=lstm_init_state, inputs=inputs, dtype=dt) lstm_out = tf.tanh(lstm_out_raw) lstm_out_flat = tf.reshape(lstm_out, [-1, num_units]) # The number of nodes on the controller's output is determined by # 1. the number of allowable shifts # 2. the width of the columns in the memory matrix head_nodes = 4*M+2*S+6 head_W = tf.Variable( tf.random_normal([num_units, num_heads*head_nodes]), name='head_W') head_b_W = tf.Variable( tf.random_normal([num_heads*head_nodes,]), name='head_b_W') head_raw_flat = tf.matmul(lstm_out_flat, head_W) + head_b_W head_raw = tf.reshape(head_raw_flat, [batch_size, seq_length, head_nodes]) return head_raw def train_batch(self, batch_x, batch_y, learning_rate=1e-4): ''' Trains the model on a batch of inputs and their corresponding outputs. Returns the error that was obtained by training the NTM on the input sequence that is provided as an argument. Arguments: batch_x - The batch of input training sequences [BxTxL1]. Note that the first two dimensions (batch size and sequence length) of both batches MUST be the same. Numpy array. batch_y - The batch of output training sequences [BxTxL2]. The output sequences are the desired outputs after the NTM has been presented with the training input, batch_x. Numpy array. Outputs: error - The amount of error (float)produced from this particular training sequence. The error operation is defined in the constructor. ''' lr = learning_rate batch_size = batch_x.shape[0] ntm_init_state = self.ntm_cell.bias_state(batch_size) lstm_init_state = tuple([np.zeros((batch_size, s)) \ for s in self.lstm_cell.state_size]) fetches = [self.error, self.train_op] feeds = { self.feed_in:batch_x, self.feed_out:batch_y, self.feed_learning_rate:lr } for i in range(len(ntm_init_state)): feeds[self.ntm_init_state[i]] = ntm_init_state[i] for i in range(len(lstm_init_state)): feeds[self.lstm_init_state[i]] = lstm_init_state[i] error, _ = self.sess.run(fetches, feeds) return error def run_once(self, test_x): ''' Passes a single input sequence to the NTM, and produces an output according to what it's learned. Returns a tuple of items of interest for troubleshooting purposes (the read/write vectors and output). Arguments: test_x - A batch of input sequences [BxTxL1] that the NTM will use to produce a batch of output sequences [BxTxL2]. Numpy array. Outputs: output_b - A numpy array representing the output of the NTM after being presented with the input batch [BxTxL2]. w_read_b - A numpy array of "read" locations that the NTM used. From the paper, write locations are normalized vectors that allow the NTM to focus on rows of the memory matrix. w_write_b - A numpy array of "write" locations that the NTM used. g_read_b - A numpy array of scalar values indicating whether the NTM used the previous read location or associative recall to determine the read location at each timestep. g_write_b - A numpy array of scalar values indicating whether the NTM used the previous write location or associative recall to determine the write location at each timestep. s_read_b - A numpy array of vectors describing the magnitude and direction of the shifting operation that was applied to the read head. s_write_b - A numpy array of vectors describing the magnitude and direction of the shifting operation that was applied to the write head. ''' batch_size = test_x.shape[0] num_seq = test_x.shape[1] sequences = np.split(test_x, num_seq, axis=1) ntm_init_state = self.ntm_cell.bias_state(batch_size) lstm_init_state = tuple( [np.zeros((batch_size, s)) for s in self.lstm_cell.state_size]) outputs = [] w_read = [] w_write = [] g_read = [] g_write = [] s_read = [] s_write = [] for seq in sequences: fetches = [self.predictions, self.ntm_last_state, self.lstm_last_state, self.read_head, self.write_head] feeds = {self.feed_in: seq} for i in range(len(ntm_init_state)): feeds[self.ntm_init_state[i]] = ntm_init_state[i] for i in range(len(lstm_init_state)): feeds[self.lstm_init_state[i]] = lstm_init_state[i] output, ntm_init_state, lstm_init_state, \ read_head, write_head = self.sess.run(fetches, feeds) outputs.append(output[0].copy()) w_read.append(ntm_init_state[-2][0].copy()) w_write.append(ntm_init_state[-1][0].copy()) g_read.append(read_head['g'][0,0,:].copy()) g_write.append(write_head['g'][0,0,:].copy()) s_read.append(read_head['shift'][0,0,:].copy()) s_write.append(write_head['shift'][0,0,:].copy()) output_b = np.squeeze(np.array(outputs)) w_read_b = np.array(w_read) w_write_b = np.array(w_write) g_read_b = np.array(g_read) g_write_b = np.array(g_write) s_read_b = np.array(s_read) s_write_b = np.array(s_write) #print(output_b.shape) return output_b, w_read_b, w_write_b, g_read_b, \ g_write_b, s_read_b, s_write_b
class NTM(object): def __init__(self, mem_size, input_size, output_size, session, num_heads=1, shift_range=3, name="NTM"): self.num_heads = 1 self.sess = session self.S = shift_range self.N, self.M = mem_size self.in_size = input_size self.out_size = output_size num_lstm_units = 100 self.dt = tf.float32 self.pi = 64 pi = self.pi dt = self.dt N = self.N M = self.M S = self.S num_heads = self.num_heads with tf.variable_scope(name): self.feed_in = tf.placeholder(dtype=dt, shape=(None, None, input_size)) self.feed_out = tf.placeholder(dtype=dt, shape=(None, None, output_size)) self.feed_learning_rate = tf.placeholder(dtype=dt, shape=()) batch_size = tf.shape(self.feed_in)[0] seq_length = tf.shape(self.feed_in)[1] head_raw = self.controller(self.feed_in, batch_size, seq_length) self.ntm_cell = NTMCell(mem_size=(N, M), shift_range=S) self.write_head, self.read_head = NTMCell.head_pieces( head_raw, mem_size=(N, M), shift_range=S, axis=2, style='dict') self.ntm_init_state = tuple( [tf.placeholder(dtype=dt, shape=(None, s)) \ for s in self.ntm_cell.state_size]) self.ntm_reads, self.ntm_last_state = tf.nn.dynamic_rnn( cell=self.ntm_cell, initial_state=self.ntm_init_state, inputs=head_raw, dtype=dt, parallel_iterations=pi) # Started conversion to the multi-head output here, still have # lots to do. self.w_read = self.ntm_last_state[N:N + num_heads] self.w_write = self.ntm_last_state[N + num_heads:N + 2 * num_heads] ntm_reads_flat = [tf.reshape(r, [-1, M]) for r in self.ntm_reads] L = tf.Variable(tf.random_normal([M, output_size])) b_L = tf.Variable(tf.random_normal([ output_size, ])) logits_flat = tf.matmul(ntm_reads_flat, L) + b_L targets_flat = tf.reshape(self.feed_out, [-1, output_size]) self.error = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=targets_flat, logits=logits_flat)) self.predictions = tf.sigmoid( tf.reshape(logits_flat, [batch_size, seq_length, output_size])) optimizer = tf.train.RMSPropOptimizer( learning_rate=self.feed_learning_rate, momentum=0.9) grads_and_vars = optimizer.compute_gradients(self.error) capped_grads = [(tf.clip_by_value(grad, -10., 10.), var) \ for grad, var in grads_and_vars] self.train_op = optimizer.apply_gradients(capped_grads) def controller(self, inputs, batch_size, seq_length, num_units=100): N = self.N M = self.M S = self.S pi = self.pi dt = self.dt num_heads = self.num_heads self.lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=num_units, forget_bias=1.0) self.lstm_init_state = tuple([ tf.placeholder(dtype=dt, shape=(None, s)) for s in self.lstm_cell.state_size ]) lstm_init_state = tf.contrib.rnn.LSTMStateTuple( self.lstm_init_state[0], self.lstm_init_state[1]) lstm_out_raw, self.lstm_last_state = tf.nn.dynamic_rnn( cell=self.lstm_cell, initial_state=lstm_init_state, inputs=inputs, dtype=dt, parallel_iterations=pi) lstm_out = tf.tanh(lstm_out_raw) lstm_out_flat = tf.reshape(lstm_out, [-1, num_units]) head_nodes = 4 * M + 2 * S + 6 head_W = tf.Variable(tf.random_normal( [num_units, num_heads * head_nodes]), name='head_W') head_b_W = tf.Variable(tf.random_normal([ num_heads * head_nodes, ]), name='head_b_W') head_raw_flat = tf.matmul(lstm_out_flat, head_W) + head_b_W head_raw = tf.reshape(head_raw_flat, [batch_size, seq_length, head_nodes]) return head_raw def train_batch(self, batch_x, batch_y, learning_rate=1e-4): lr = learning_rate batch_size = batch_x.shape[0] ntm_init_state = self.ntm_cell.bias_state(batch_size) lstm_init_state = tuple( [np.zeros((batch_size, s)) for s in self.lstm_cell.state_size]) fetches = [self.error, self.train_op] feeds = { self.feed_in: batch_x, self.feed_out: batch_y, self.feed_learning_rate: lr } for i in range(len(ntm_init_state)): feeds[self.ntm_init_state[i]] = ntm_init_state[i] for i in range(len(lstm_init_state)): feeds[self.lstm_init_state[i]] = lstm_init_state[i] error, _ = self.sess.run(fetches, feeds) return error def run_once(self, test_x): batch_size = test_x.shape[0] num_seq = test_x.shape[1] sequences = np.split(test_x, num_seq, axis=1) ntm_init_state = self.ntm_cell.bias_state(batch_size) lstm_init_state = tuple( [np.zeros((batch_size, s)) for s in self.lstm_cell.state_size]) outputs = [] w_read = [] w_write = [] g_read = [] g_write = [] s_read = [] s_write = [] for seq in sequences: fetches = [ self.predictions, self.ntm_last_state, self.lstm_last_state, self.read_head, self.write_head ] feeds = {self.feed_in: seq} for i in range(len(ntm_init_state)): feeds[self.ntm_init_state[i]] = ntm_init_state[i] for i in range(len(lstm_init_state)): feeds[self.lstm_init_state[i]] = lstm_init_state[i] output, ntm_init_state, lstm_init_state, \ read_head, write_head = self.sess.run(fetches, feeds) outputs.append(output[0].copy()) w_read.append(ntm_init_state[-2][0].copy()) w_write.append(ntm_init_state[-1][0].copy()) g_read.append(read_head['g'][0, 0, :].copy()) g_write.append(write_head['g'][0, 0, :].copy()) s_read.append(read_head['shift'][0, 0, :].copy()) s_write.append(write_head['shift'][0, 0, :].copy()) output_b = np.squeeze(np.array(outputs)) w_read_b = np.array(w_read) w_write_b = np.array(w_write) g_read_b = np.array(g_read) g_write_b = np.array(g_write) s_read_b = np.array(s_read) s_write_b = np.array(s_write) #print(output_b.shape) return output_b, w_read_b, w_write_b, g_read_b, \ g_write_b, s_read_b, s_write_b