def predict_on_batch(self, sess, inputs_batch, targets_batch=None): """ Make predictions for the provided batch of data Args: sess: tf.Session() input_batch: np.ndarray of shape (n_samples, #TODO) Returns: e predictions: np.ndarray of shape (n_samples, max_length_y) """ inputs_batch_padded, _ = padded_batch_lr(inputs_batch, self.config.max_length_x, self.config.voc) length_inputs_batch = np.asarray([min(self.config.max_length_x,len(item))\ for item in inputs_batch]) if targets_batch is None: feed = self.create_feed_dict(inputs_batch_padded, length_inputs_batch) else: decoder_batch_padded, _ = padded_batch_lr(targets_batch, self.config.max_length_y, self.config.voc, option='decoder_inputs') targets_batch_padded, mask_batch = padded_batch_lr( targets_batch, self.config.max_length_y, self.config.voc, option='decoder_targets') length_decoder_batch = np.asarray([min(self.config.max_length_y, len(item)+1)\ for item in targets_batch]) feed = self.create_feed_dict(inputs_batch_padded, length_inputs_batch, mask_batch, length_decoder_batch, decoder_batch_padded, targets_batch_padded) preds, dev_loss, dev_acc, dev_loss_summ, dev_acc_summ = sess.run( [ self.infer_pred, self.dev_loss, self.dev_accuracy, self.dev_loss_summary, self.dev_acc_summary ], feed_dict=feed) preds = np.argmax(preds, 2) return preds, dev_loss, dev_acc, dev_loss_summ, dev_acc_summ
def train_on_batch(self, sess, inputs_batch, targets_batch): """ Perform one step of gradient descent on the provided batch of data. This version also returns the norm of gradients. """ inputs_batch_padded, _ = padded_batch_lr(inputs_batch, self.config.max_length_x, self.config.voc) length_inputs_batch = np.asarray([min(self.config.max_length_x,len(item))\ for item in inputs_batch]) if targets_batch is None: feed = self.create_feed_dict(inputs_batch_padded, length_inputs_batch, self.config.pdrop) else: decoder_batch_padded, _ = padded_batch_lr(targets_batch, self.config.max_length_y, self.config.voc, option='decoder_inputs') targets_batch_padded, mask_batch = padded_batch_lr( targets_batch, self.config.max_length_y, self.config.voc, option='decoder_targets') length_decoder_batch = np.asarray([min(self.config.max_length_y, len(item)+1)\ for item in targets_batch]) feed = self.create_feed_dict(inputs_batch_padded, length_inputs_batch, mask_batch, length_decoder_batch, decoder_batch_padded, targets_batch_padded, self.config.pdrop) _, preds, loss, acc, loss_summ, acc_summ = sess.run([ self.train_op, self.train_pred, self.loss, self.accuracy, self.loss_summary, self.acc_summary ], feed_dict=feed) preds = np.argmax(preds, 2) return preds, loss, acc, loss_summ, acc_summ