Exemplo n.º 1
0
    def build_gradient_merge_and_update(self, global_step, losses,
                                        device_grads):
        fetches = {}

        apply_gradient_devices = self.devices
        gradient_state = device_grads

        training_ops = []

        # gradient_state is the merged gradient.
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            apply_gradient_devices, gradient_state = self.variable_mgr.preprocess_device_grads(
                device_grads, self.params.independent_replica)

        for d, device in enumerate(apply_gradient_devices):
            with tf.device(device):
                if self.mode != tf.contrib.learn.ModeKeys.INFER:
                    average_loss = (losses[d]
                                    if self.params.independent_replica else
                                    tf.reduce_sum(losses))
                if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                    avg_grads = self.variable_mgr.get_gradients_to_apply(
                        d, gradient_state)
            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                # add gradient clipping moved to add_forward_pass_and_gradients
                self.learning_rate = tf.constant(self.params.learning_rate)
                opt = tf.train.AdamOptimizer(self.learning_rate)
                loss_scale_params = variable_mgr_util.AutoLossScaleParams(
                    enable_auto_loss_scale=False,
                    loss_scale=None,
                    loss_scale_normal_steps=None,
                    inc_loss_scale_every_n=1000,
                    is_chief=True,
                )
                # append optimizer operators into the graph
                self.variable_mgr.append_apply_gradients_ops(
                    gradient_state, opt, avg_grads, training_ops,
                    loss_scale_params)
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            fetches["train_op"] = tf.group(training_ops)
        if self.mode != tf.contrib.learn.ModeKeys.INFER:
            fetches["average_loss"] = (average_loss
                                       if self.params.independent_replica else
                                       average_loss /
                                       tf.to_float(self.batch_size))
        return fetches
    def testAppendGradientsWithLossScaleForNonChiefWorker(self):
        v = tf.Variable(0)
        training_ops = []
        get_apply_gradients_ops_func = lambda: [tf.assign(v, v + 1)]
        loss_scale_params = variable_mgr_util.AutoLossScaleParams(
            enable_auto_loss_scale=True,
            loss_scale=tf.Variable(4),
            loss_scale_normal_steps=tf.Variable(10),
            inc_loss_scale_every_n=10,
            is_chief=False)  # Non-chief
        variable_mgr_util.append_gradients_with_loss_scale(
            training_ops,
            get_apply_gradients_ops_func,
            loss_scale_params,
            grad_has_inf_nan=False)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(training_ops)
            self.assertEqual(sess.run(v), 1)
            self.assertEqual(sess.run(loss_scale_params.loss_scale), 4)
            self.assertEqual(
                sess.run(loss_scale_params.loss_scale_normal_steps), 10)
    def testAppendGradientsWithLossScaleWithoutNan(self):
        v = tf.Variable(0)
        training_ops = []
        get_apply_gradients_ops_func = lambda: [tf.assign(v, v + 1)]
        loss_scale_params = variable_mgr_util.AutoLossScaleParams(
            enable_auto_loss_scale=True,
            loss_scale=tf.Variable(4, dtype=tf.float32),
            loss_scale_normal_steps=tf.Variable(10),
            inc_loss_scale_every_n=10,
            is_chief=True)
        variable_mgr_util.append_gradients_with_loss_scale(
            training_ops,
            get_apply_gradients_ops_func,
            loss_scale_params,
            grad_has_inf_nan=tf.constant(False))

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(training_ops)
            self.assertEqual(sess.run(v), 1)
            self.assertEqual(sess.run(loss_scale_params.loss_scale), 8)
            self.assertEqual(
                sess.run(loss_scale_params.loss_scale_normal_steps), 0)