def _make_train_function(self): self._assert_compiled() if self.train_function is None: logging.info("=>Creating training function...") inputs = self._feed_inputs + self._feed_targets if self.uses_learning_phase: inputs += [F.learning_phase()] with ops.name_scope('training'): with ops.name_scope(self.optimizer.__class__.__name__): if not hasattr(self.optimizer, 'get_updates'): self.optimizer = Optimizer( optimizer=self.optimizer, global_step=training_util.get_global_step()) # extra updates (e.g. slim.batch_norm) update_ops = fops.get_collection(fops.GraphKeys.UPDATE_OPS) training_updates = self.optimizer.get_updates( params=list(self.trainable_weights), loss=self.loss) self.train_function = Function( inputs=inputs, outputs=[self.loss] + self.metric_tensors, updates=training_updates + update_ops, name='train_function', hooks=self.train_hooks, **self._function_kwargs) logging.info("=>Finish creating training function...")
def _make_predict_function(self): self._assert_compiled() if self.predict_function is None: logging.info("=>Creating predict function...") inputs = self._feed_inputs if self.uses_learning_phase: inputs += [F.learning_phase()] with ops.name_scope('predict'): self.predict_function = Function( inputs=inputs, outputs=self.outputs, hooks=self._predict_hooks, name='predict_function', **self._function_kwargs) logging.info("=>Finish creating predict function...")
def _make_eval_function(self): self._assert_compiled() if self.eval_function is None: logging.info("=>Creating evaluation function...") inputs = self._feed_inputs + self._feed_targets if self.uses_learning_phase: inputs += [F.learning_phase()] with ops.name_scope('evaluation'): self.eval_function = Function( inputs=inputs, outputs=[self.loss] + self.metric_tensors, name='eval_function', hooks=self.val_hooks, **self._function_kwargs) logging.info("=>Finish creating evaluation function...")