def get_training_op(self, graph, loss): """Get training op for applying gradients to variables. Subclasses that need to do anything fancy with gradients should override this method. Returns: A training op. """ with graph.as_default(): opt = model_ops.optimizer(self.optimizer, self.learning_rate, self.momentum) return opt.minimize(loss, name='train')
def get_task_training_op(self, graph, losses, task): """Get training op for applying gradients to variables. Subclasses that need to do anything fancy with gradients should override this method. Parameters ---------- graph: tf.Graph Graph for this op losses: dict Dictionary mapping task to losses Returns ------- A training op. """ with graph.as_default(): task_loss = losses[task] task_root = "task%d_ops" % task task_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, task_root) opt = model_ops.optimizer(self.optimizer, self.learning_rate, self.momentum) return opt.minimize(task_loss, name='train', var_list=task_vars)