def __init__(self, graph, batch_size=32, en_maxlen=None, maxlen=None): """Creates a RNN sequence model for a given Graph instance.""" self.graph = graph self.session = graph.session self.saver = graph.saver self.batch_size = batch_size self._outputs = graph.outputs self._final_state = graph.final_state self._n_examples = graph.n_examples self._predictions = graph.predictions self._probs = graph.prediction_probs self._samples = graph.samples self.en_maxlen = en_maxlen self.maxlen = maxlen self.meta_learn = graph.meta_learn self._loss = 'loss' self._train = 'train' self._meta_train = 'meta_train' self._count = 'n' self._policy_ent = 'ent_reg' self._step_bc = data_utils.BatchConverter( tuple_keys=['initial_state'], seq_keys=['inputs', 'encoded_context']) self._step_ba = data_utils.BatchAggregator( tuple_keys=[self._final_state], seq_keys=[self._outputs]) self._train_bc = data_utils.BatchConverter(['initial_state'], 'inputs targets context'.split()) self._train_ba = data_utils.BatchAggregator( num_keys=[self._loss, self._policy_ent, self._count])
def __init__(self, graph, batch_size=32, en_maxlen=None, maxlen=None): """Creates a RNN seq2seq model for a given Graph object.""" super(RNNSeq2seqModel, self).__init__( graph, batch_size=batch_size, en_maxlen=en_maxlen, maxlen=maxlen) self._en_outputs = graph.en_outputs self._initial_state = graph.initial_state self._en_initial_state = graph.en_initial_state self._encode_bc = data_utils.BatchConverter( tuple_keys=[self._en_initial_state], seq_keys=['context']) self._encode_ba = data_utils.BatchAggregator( tuple_keys=[self._initial_state], seq_keys=[self._en_outputs])
def _predict(self, cell_outputs, predictions_node, temperature=1.0): fetch_list = [predictions_node] feed_dict = {self._outputs: cell_outputs} bc = data_utils.BatchConverter(seq_keys=[self._outputs], maxlen=self.maxlen) ba = data_utils.BatchAggregator(seq_keys=[predictions_node]) result_dict = self.run_epoch( fetch_list, feed_dict, bc, ba, parameters=dict(temperature=temperature)) outputs = result_dict[predictions_node] return outputs
def compute_step_logprobs(self, inputs, targets, context=None, initial_state=None, parameters=None): feed_dict = dict( initial_state=initial_state, inputs=inputs, targets=targets, context=context) ba = data_utils.BatchAggregator(seq_keys=['step_logprobs']) fetch_list = ['step_logprobs'] result_dict = self.run_epoch( fetch_list, feed_dict, self._train_bc, ba, parameters=parameters) logprobs = result_dict.get('step_logprobs', []) return logprobs
def compute_probs(self, inputs, targets, context=None, initial_state=None, parameters=None): feed_dict = dict( initial_state=initial_state, inputs=inputs, targets=targets, context=context) ba = data_utils.BatchAggregator(tuple_keys=['sequence_probs']) fetch_list = ['sequence_probs'] result_dict = self.run_epoch( fetch_list, feed_dict, self._train_bc, ba, parameters=parameters) probs = [l[0] for l in result_dict.get('sequence_probs', [])] return probs
def compute_scores(self, inputs, targets, context=None, initial_state=None, parameters=None): """Computes the scores for the attn based score function.""" feed_dict = dict( initial_state=initial_state, inputs=inputs, targets=targets, context=context) ba = data_utils.BatchAggregator(keep_keys=['scores']) fetch_list = ['scores'] result_dict = self.run_epoch( fetch_list, feed_dict, self._train_bc, ba, parameters=parameters) scores = result_dict.get('scores', []) return scores