示例#1
0
    def __init__(self, graph, batch_size=32):
        """Creates a RNN sequence model for a given Graph instance."""
        self.graph = graph
        self.session = graph.session
        self.saver = graph.saver
        self.batch_size = batch_size
        self._outputs = graph.outputs
        self._final_state = graph.final_state
        self._n_examples = graph.n_examples
        self._predictions = graph.predictions
        self._probs = graph.prediction_probs
        self._samples = graph.samples

        self._loss = 'loss'
        self._train = 'train'
        self._count = 'n'
        self._policy_ent = 'ent_reg'

        self._step_bc = data_utils.BatchConverter(
            tuple_keys=['initial_state'],
            seq_keys=['inputs', 'encoded_context'])
        self._step_ba = data_utils.BatchAggregator(
            tuple_keys=[self._final_state], seq_keys=[self._outputs])

        self._train_bc = data_utils.BatchConverter(
            ['initial_state'], 'inputs targets weights context'.split())
        self._train_ba = data_utils.BatchAggregator(
            num_keys=[self._loss, self._policy_ent, self._count])
示例#2
0
 def __init__(self, graph, batch_size=32):
     super(MemorySeq2seqModel, self).__init__(graph, batch_size=batch_size)
     self.max_n_valid_indices = graph.config['core_config'][
         'max_n_valid_indices']
     self.n_mem = graph.config['core_config']['n_mem']
     self.hidden_size = graph.config['core_config']['hidden_size']
     self.value_embedding_size = graph.config['core_config'][
         'value_embedding_size']
     self._encode_bc = data_utils.BatchConverter(
         seq_keys=['en_inputs', 'en_input_features'],
         tuple_keys=[
             'en_initial_state', 'n_constants', 'constant_spans',
             'constant_value_embeddings'
         ],
         preprocess_fn=self._preprocess)
     self._step_bc = data_utils.BatchConverter(
         tuple_keys=['initial_state'],
         seq_keys=['encoded_context'],
         preprocess_fn=self._preprocess)
     self._train_bc = data_utils.BatchConverter(
         tuple_keys=[
             'n_constants', 'constant_spans', 'constant_value_embeddings'
         ],
         seq_keys=['targets', 'weights', 'en_inputs', 'en_input_features'],
         preprocess_fn=self._preprocess)
示例#3
0
 def __init__(self, graph, batch_size=32):
     """Creates a RNN seq2seq model for a given Graph object."""
     super(RNNSeq2seqModel, self).__init__(graph, batch_size=batch_size)
     self._en_outputs = graph.en_outputs
     self._initial_state = graph.initial_state
     self._en_initial_state = graph.en_initial_state
     self._encode_bc = data_utils.BatchConverter(
         tuple_keys=[self._en_initial_state], seq_keys=['context'])
     self._encode_ba = data_utils.BatchAggregator(
         tuple_keys=[self._initial_state], seq_keys=[self._en_outputs])
示例#4
0
    def _predict(self, cell_outputs, predictions_node, temperature=1.0):
        fetch_list = [predictions_node]
        feed_dict = {self._outputs: cell_outputs}

        bc = data_utils.BatchConverter(seq_keys=[self._outputs])
        ba = data_utils.BatchAggregator(seq_keys=[predictions_node])

        result_dict = self.run_epoch(fetch_list,
                                     feed_dict,
                                     bc,
                                     ba,
                                     parameters=dict(temperature=temperature))
        outputs = result_dict[predictions_node]
        return outputs