コード例 #1
0
 def _make_train_function(self):
     self._assert_compiled()
     if self.train_function is None:
         logging.info("=>Creating training function...")
         inputs = self._feed_inputs + self._feed_targets
         if self.uses_learning_phase:
             inputs += [F.learning_phase()]
         with ops.name_scope('training'):
             with ops.name_scope(self.optimizer.__class__.__name__):
                 if not hasattr(self.optimizer, 'get_updates'):
                     self.optimizer = Optimizer(
                         optimizer=self.optimizer,
                         global_step=training_util.get_global_step())
                 # extra updates (e.g. slim.batch_norm)
                 update_ops = fops.get_collection(fops.GraphKeys.UPDATE_OPS)
                 training_updates = self.optimizer.get_updates(
                     params=list(self.trainable_weights), loss=self.loss)
             self.train_function = Function(
                 inputs=inputs,
                 outputs=[self.loss] + self.metric_tensors,
                 updates=training_updates + update_ops,
                 name='train_function',
                 hooks=self.train_hooks,
                 **self._function_kwargs)
         logging.info("=>Finish creating training function...")
コード例 #2
0
    def forward(self, inputs, training=None):
        if training is None:
            training = F.learning_phase()

        def dropped_inputs():
            return nn.dropout(
                inputs,
                noise_shape=self._get_noise_shape(inputs),
                seed=self.seed,
                rate=self.rate)
        outputs = F.smart_cond(training,
                               dropped_inputs,
                               lambda: array_ops.identity(inputs))
        return outputs
コード例 #3
0
 def _make_predict_function(self):
     self._assert_compiled()
     if self.predict_function is None:
         logging.info("=>Creating predict function...")
         inputs = self._feed_inputs
         if self.uses_learning_phase:
             inputs += [F.learning_phase()]
         with ops.name_scope('predict'):
             self.predict_function = Function(
                 inputs=inputs,
                 outputs=self.outputs,
                 hooks=self._predict_hooks,
                 name='predict_function',
                 **self._function_kwargs)
         logging.info("=>Finish creating predict function...")
コード例 #4
0
 def _make_eval_function(self):
     self._assert_compiled()
     if self.eval_function is None:
         logging.info("=>Creating evaluation function...")
         inputs = self._feed_inputs + self._feed_targets
         if self.uses_learning_phase:
             inputs += [F.learning_phase()]
         with ops.name_scope('evaluation'):
             self.eval_function = Function(
                 inputs=inputs,
                 outputs=[self.loss] + self.metric_tensors,
                 name='eval_function',
                 hooks=self.val_hooks,
                 **self._function_kwargs)
         logging.info("=>Finish creating evaluation function...")