Example #1
0
    def __init__(self, graph, batch_size=32):
        """Creates a RNN sequence model for a given Graph instance."""
        self.graph = graph
        self.session = graph.session
        self.saver = graph.saver
        self.batch_size = batch_size
        self._outputs = graph.outputs
        self._final_state = graph.final_state
        self._n_examples = graph.n_examples
        self._predictions = graph.predictions
        self._probs = graph.prediction_probs
        self._samples = graph.samples

        self._loss = 'loss'
        self._train = 'train'
        self._count = 'n'
        self._policy_ent = 'ent_reg'

        self._step_bc = data_utils.BatchConverter(
            tuple_keys=['initial_state'],
            seq_keys=['inputs', 'encoded_context'])
        self._step_ba = data_utils.BatchAggregator(
            tuple_keys=[self._final_state], seq_keys=[self._outputs])

        self._train_bc = data_utils.BatchConverter(
            ['initial_state'], 'inputs targets weights context'.split())
        self._train_ba = data_utils.BatchAggregator(
            num_keys=[self._loss, self._policy_ent, self._count])
Example #2
0
 def __init__(self, graph, batch_size=32):
     """Creates a RNN seq2seq model for a given Graph object."""
     super(RNNSeq2seqModel, self).__init__(graph, batch_size=batch_size)
     self._en_outputs = graph.en_outputs
     self._initial_state = graph.initial_state
     self._en_initial_state = graph.en_initial_state
     self._encode_bc = data_utils.BatchConverter(
         tuple_keys=[self._en_initial_state], seq_keys=['context'])
     self._encode_ba = data_utils.BatchAggregator(
         tuple_keys=[self._initial_state], seq_keys=[self._en_outputs])
Example #3
0
    def _predict(self, cell_outputs, predictions_node, temperature=1.0):
        fetch_list = [predictions_node]
        feed_dict = {self._outputs: cell_outputs}

        bc = data_utils.BatchConverter(seq_keys=[self._outputs])
        ba = data_utils.BatchAggregator(seq_keys=[predictions_node])

        result_dict = self.run_epoch(fetch_list,
                                     feed_dict,
                                     bc,
                                     ba,
                                     parameters=dict(temperature=temperature))
        outputs = result_dict[predictions_node]
        return outputs
Example #4
0
 def compute_step_logprobs(self,
                           inputs,
                           targets,
                           context=None,
                           initial_state=None,
                           parameters=None):
     weights = data_utils.constant_struct_like(targets, 1.0)
     feed_dict = dict(initial_state=initial_state,
                      inputs=inputs,
                      targets=targets,
                      weights=weights,
                      context=context)
     ba = data_utils.BatchAggregator(seq_keys=['step_logprobs'])
     t1 = time.time()
     fetch_list = ['step_logprobs']
     result_dict = self.run_epoch(fetch_list,
                                  feed_dict,
                                  self._train_bc,
                                  ba,
                                  parameters=parameters)
     t2 = time.time()
     logprobs = result_dict.get('step_logprobs', [])
     return logprobs
Example #5
0
 def compute_probs(self,
                   inputs,
                   targets,
                   context=None,
                   initial_state=None,
                   parameters=None):
     weights = data_utils.constant_struct_like(targets, 1.0)
     feed_dict = dict(initial_state=initial_state,
                      inputs=inputs,
                      targets=targets,
                      weights=weights,
                      context=context)
     ba = data_utils.BatchAggregator(tuple_keys=['sequence_loss'])
     t1 = time.time()
     fetch_list = ['sequence_loss']
     result_dict = self.run_epoch(fetch_list,
                                  feed_dict,
                                  self._train_bc,
                                  ba,
                                  parameters=parameters)
     t2 = time.time()
     seq_losses = result_dict.get('sequence_loss', [])
     probs = [np.exp(-l[0]) for l in seq_losses]
     return probs