def __init__(self, graph, batch_size=32): """Creates a RNN sequence model for a given Graph instance.""" self.graph = graph self.session = graph.session self.saver = graph.saver self.batch_size = batch_size self._outputs = graph.outputs self._final_state = graph.final_state self._n_examples = graph.n_examples self._predictions = graph.predictions self._probs = graph.prediction_probs self._samples = graph.samples self._loss = 'loss' self._train = 'train' self._count = 'n' self._policy_ent = 'ent_reg' self._step_bc = data_utils.BatchConverter( tuple_keys=['initial_state'], seq_keys=['inputs', 'encoded_context']) self._step_ba = data_utils.BatchAggregator( tuple_keys=[self._final_state], seq_keys=[self._outputs]) self._train_bc = data_utils.BatchConverter( ['initial_state'], 'inputs targets weights context'.split()) self._train_ba = data_utils.BatchAggregator( num_keys=[self._loss, self._policy_ent, self._count])
def __init__(self, graph, batch_size=32): """Creates a RNN seq2seq model for a given Graph object.""" super(RNNSeq2seqModel, self).__init__(graph, batch_size=batch_size) self._en_outputs = graph.en_outputs self._initial_state = graph.initial_state self._en_initial_state = graph.en_initial_state self._encode_bc = data_utils.BatchConverter( tuple_keys=[self._en_initial_state], seq_keys=['context']) self._encode_ba = data_utils.BatchAggregator( tuple_keys=[self._initial_state], seq_keys=[self._en_outputs])
def _predict(self, cell_outputs, predictions_node, temperature=1.0): fetch_list = [predictions_node] feed_dict = {self._outputs: cell_outputs} bc = data_utils.BatchConverter(seq_keys=[self._outputs]) ba = data_utils.BatchAggregator(seq_keys=[predictions_node]) result_dict = self.run_epoch(fetch_list, feed_dict, bc, ba, parameters=dict(temperature=temperature)) outputs = result_dict[predictions_node] return outputs
def compute_step_logprobs(self, inputs, targets, context=None, initial_state=None, parameters=None): weights = data_utils.constant_struct_like(targets, 1.0) feed_dict = dict(initial_state=initial_state, inputs=inputs, targets=targets, weights=weights, context=context) ba = data_utils.BatchAggregator(seq_keys=['step_logprobs']) t1 = time.time() fetch_list = ['step_logprobs'] result_dict = self.run_epoch(fetch_list, feed_dict, self._train_bc, ba, parameters=parameters) t2 = time.time() logprobs = result_dict.get('step_logprobs', []) return logprobs
def compute_probs(self, inputs, targets, context=None, initial_state=None, parameters=None): weights = data_utils.constant_struct_like(targets, 1.0) feed_dict = dict(initial_state=initial_state, inputs=inputs, targets=targets, weights=weights, context=context) ba = data_utils.BatchAggregator(tuple_keys=['sequence_loss']) t1 = time.time() fetch_list = ['sequence_loss'] result_dict = self.run_epoch(fetch_list, feed_dict, self._train_bc, ba, parameters=parameters) t2 = time.time() seq_losses = result_dict.get('sequence_loss', []) probs = [np.exp(-l[0]) for l in seq_losses] return probs