Example #1
0
  def __init__(self, graph, batch_size=32, en_maxlen=None, maxlen=None):
    """Creates a RNN sequence model for a given Graph instance."""
    self.graph = graph
    self.session = graph.session
    self.saver = graph.saver
    self.batch_size = batch_size
    self._outputs = graph.outputs
    self._final_state = graph.final_state
    self._n_examples = graph.n_examples
    self._predictions = graph.predictions
    self._probs = graph.prediction_probs
    self._samples = graph.samples
    self.en_maxlen = en_maxlen
    self.maxlen = maxlen
    self.meta_learn = graph.meta_learn

    self._loss = 'loss'
    self._train = 'train'
    self._meta_train = 'meta_train'
    self._count = 'n'
    self._policy_ent = 'ent_reg'

    self._step_bc = data_utils.BatchConverter(
        tuple_keys=['initial_state'], seq_keys=['inputs', 'encoded_context'])
    self._step_ba = data_utils.BatchAggregator(
        tuple_keys=[self._final_state], seq_keys=[self._outputs])

    self._train_bc = data_utils.BatchConverter(['initial_state'],
                                               'inputs targets context'.split())
    self._train_ba = data_utils.BatchAggregator(
        num_keys=[self._loss, self._policy_ent, self._count])
Example #2
0
 def __init__(self, graph, batch_size=32, en_maxlen=None, maxlen=None):
   """Creates a RNN seq2seq model for a given Graph object."""
   super(RNNSeq2seqModel, self).__init__(
       graph, batch_size=batch_size, en_maxlen=en_maxlen, maxlen=maxlen)
   self._en_outputs = graph.en_outputs
   self._initial_state = graph.initial_state
   self._en_initial_state = graph.en_initial_state
   self._encode_bc = data_utils.BatchConverter(
       tuple_keys=[self._en_initial_state], seq_keys=['context'])
   self._encode_ba = data_utils.BatchAggregator(
       tuple_keys=[self._initial_state], seq_keys=[self._en_outputs])
Example #3
0
  def _predict(self, cell_outputs, predictions_node, temperature=1.0):
    fetch_list = [predictions_node]
    feed_dict = {self._outputs: cell_outputs}

    bc = data_utils.BatchConverter(seq_keys=[self._outputs], maxlen=self.maxlen)
    ba = data_utils.BatchAggregator(seq_keys=[predictions_node])

    result_dict = self.run_epoch(
        fetch_list, feed_dict, bc, ba, parameters=dict(temperature=temperature))
    outputs = result_dict[predictions_node]
    return outputs
Example #4
0
 def compute_step_logprobs(self,
                           inputs,
                           targets,
                           context=None,
                           initial_state=None,
                           parameters=None):
   feed_dict = dict(
       initial_state=initial_state,
       inputs=inputs,
       targets=targets,
       context=context)
   ba = data_utils.BatchAggregator(seq_keys=['step_logprobs'])
   fetch_list = ['step_logprobs']
   result_dict = self.run_epoch(
       fetch_list, feed_dict, self._train_bc, ba, parameters=parameters)
   logprobs = result_dict.get('step_logprobs', [])
   return logprobs
Example #5
0
 def compute_probs(self,
                   inputs,
                   targets,
                   context=None,
                   initial_state=None,
                   parameters=None):
   feed_dict = dict(
       initial_state=initial_state,
       inputs=inputs,
       targets=targets,
       context=context)
   ba = data_utils.BatchAggregator(tuple_keys=['sequence_probs'])
   fetch_list = ['sequence_probs']
   result_dict = self.run_epoch(
       fetch_list, feed_dict, self._train_bc, ba, parameters=parameters)
   probs = [l[0] for l in result_dict.get('sequence_probs', [])]
   return probs
Example #6
0
 def compute_scores(self,
                    inputs,
                    targets,
                    context=None,
                    initial_state=None,
                    parameters=None):
   """Computes the scores for the attn based score function."""
   feed_dict = dict(
       initial_state=initial_state,
       inputs=inputs,
       targets=targets,
       context=context)
   ba = data_utils.BatchAggregator(keep_keys=['scores'])
   fetch_list = ['scores']
   result_dict = self.run_epoch(
       fetch_list, feed_dict, self._train_bc, ba, parameters=parameters)
   scores = result_dict.get('scores', [])
   return scores