Example #1
0
    def construct(self, step, print_step):
        '''construct'''
        self.last_crd = self.crd
        if step == 0:
            res = self.neighbor_list_update_init(
                self.atom_numbers_in_grid_bucket, self.bucket, self.crd,
                self.box_length, self.grid_N, self.grid_length_inverse,
                self.atom_in_grid_serial, self.old_crd,
                self.crd_to_uint_crd_cof, self.uint_crd, self.pointer,
                self.nl_atom_numbers, self.nl_atom_serial,
                self.uint_dr_to_dr_cof, self.excluded_list_start,
                self.excluded_list, self.excluded_numbers,
                self.need_refresh_flag, self.refresh_count)
            self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
            self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
            self.uint_dr_to_dr_cof = F.depend(self.uint_dr_to_dr_cof, res)
            self.old_crd = F.depend(self.old_crd, res)
            self.atom_numbers_in_grid_bucket = F.depend(
                self.atom_numbers_in_grid_bucket, res)
            self.bucket = F.depend(self.bucket, res)
            self.atom_in_grid_serial = F.depend(self.atom_in_grid_serial, res)
            self.pointer = F.depend(self.pointer, res)
            uint_crd = F.depend(self.uint_crd, res)

            force = self.Simulation_Caculate_Force(uint_crd,
                                                   self.uint_dr_to_dr_cof,
                                                   self.nl_atom_numbers,
                                                   self.nl_atom_serial)
            bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, nb14_cf_energy_sum, \
            lj_energy_sum, ee_ene, total_energy = self.Simulation_Caculate_Energy(uint_crd, self.uint_dr_to_dr_cof)
            temperature = self.Simulation_Temperature()
            self.rand_state = self.setup_random_state()
            self.velocity, self.crd, _ = self.Simulation_MDIterationLeapFrog_Liujian(
                self.mass_inverse, self.sqrt_mass, self.crd, force,
                self.rand_state, self.random_force)

            res = self.neighbor_list_update(
                self.atom_numbers_in_grid_bucket, self.bucket, self.crd,
                self.box_length, self.grid_N, self.grid_length_inverse,
                self.atom_in_grid_serial, self.old_crd,
                self.crd_to_uint_crd_cof, self.uint_crd, self.pointer,
                self.nl_atom_numbers, self.nl_atom_serial,
                self.uint_dr_to_dr_cof, self.excluded_list_start,
                self.excluded_list, self.excluded_numbers,
                self.need_refresh_flag, self.refresh_count)
            self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
            self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
        else:
            uint_crd = self.Simulation_Beforce_Caculate_Force()
            force = self.Simulation_Caculate_Force(uint_crd,
                                                   self.uint_dr_to_dr_cof,
                                                   self.nl_atom_numbers,
                                                   self.nl_atom_serial)
            if print_step == 0:
                bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, nb14_cf_energy_sum, \
                lj_energy_sum, ee_ene, total_energy = self.Simulation_Caculate_Energy(
                    uint_crd, self.uint_dr_to_dr_cof)
            else:
                bond_energy_sum = self.zero_fp_tensor
                angle_energy_sum = self.zero_fp_tensor
                dihedral_energy_sum = self.zero_fp_tensor
                nb14_lj_energy_sum = self.zero_fp_tensor
                nb14_cf_energy_sum = self.zero_fp_tensor
                lj_energy_sum = self.zero_fp_tensor
                ee_ene = self.zero_fp_tensor
                total_energy = self.zero_fp_tensor
            temperature = self.Simulation_Temperature()
            self.velocity, self.crd, _ = self.Simulation_MDIterationLeapFrog_Liujian(
                self.mass_inverse, self.sqrt_mass, self.crd, force,
                self.rand_state, self.random_force)
            res = self.neighbor_list_update(
                self.atom_numbers_in_grid_bucket, self.bucket, self.crd,
                self.box_length, self.grid_N, self.grid_length_inverse,
                self.atom_in_grid_serial, self.old_crd,
                self.crd_to_uint_crd_cof, self.uint_crd, self.pointer,
                self.nl_atom_numbers, self.nl_atom_serial,
                self.uint_dr_to_dr_cof, self.excluded_list_start,
                self.excluded_list, self.excluded_numbers,
                self.need_refresh_flag, self.refresh_count)
            self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
            self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
        return temperature, total_energy, bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, \
               nb14_cf_energy_sum, lj_energy_sum, ee_ene, res
    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  next_sentence_labels,
                  masked_lm_positions,
                  masked_lm_ids,
                  masked_lm_weights,
                  sens=None):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(input_ids, input_mask, token_type_id,
                            next_sentence_labels, masked_lm_positions,
                            masked_lm_ids, masked_lm_weights)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens

        # update accumulation parameters
        is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
        self.local_step = self.select(is_accu_step, self.local_step + self.one,
                                      self.one)
        self.loss = self.select(is_accu_step, self.loss + loss, loss)
        mean_loss = self.loss / self.local_step
        is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)

        # alloc status and clear should be right before gradoperation
        init = self.alloc_status()
        self.clear_before_grad(init)
        grads = self.grad(self.network,
                          weights)(input_ids, input_mask, token_type_id,
                                   next_sentence_labels, masked_lm_positions,
                                   masked_lm_ids, masked_lm_weights,
                                   self.cast(scaling_sens, mstype.float32))

        accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
        mean_loss = F.depend(mean_loss, accu_succ)

        self.get_status(init)
        flag_sum = self.reduce_sum(init, (0, ))
        overflow = self.less_equal(self.base, flag_sum)
        overflow = self.logical_or(
            self.not_equal(self.accu_overflow, self.zero), overflow)
        accu_overflow = self.select(overflow, self.one, self.zero)
        self.accu_overflow = self.select(is_accu_step, accu_overflow,
                                         self.zero)

        if is_accu_step:
            succ = False
        else:
            # apply grad reducer on grads
            grads = self.grad_reducer(self.accu_grads)
            scaling = scaling_sens * self.degree * self.accumulation_steps
            grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
            if self.enable_global_norm:
                grads = ClipByGlobalNorm()(grads)
            else:
                grads = self.hyper_map(
                    F.partial(clip_grad, GRADIENT_CLIP_TYPE,
                              GRADIENT_CLIP_VALUE), grads)
            accu_overflow = self.overflow_reducer(accu_overflow)
            F.control_depend(grads, accu_overflow)
            overflow = self.less_equal(self.base, accu_overflow)
            accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
            overflow = F.depend(overflow, accu_succ)
            overflow = self.reshape(overflow, (()))
            if sens is None:
                overflow = self.loss_scaling_manager(self.loss_scale, overflow)
            if overflow:
                succ = False
            else:
                succ = self.optimizer(grads)

        ret = (mean_loss, overflow, scaling_sens)
        return F.depend(ret, succ)
Example #3
0
 def construct(self):
     weights = self.weights
     loss = self.network()
     sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(sens)
     return F.depend(loss, self.optimizer(grads))
Example #4
0
 def construct(self, x):
     x = F.depend(x, self.assign(self.s1, x + self.s1))
     self.s1 = self.sub(self.s1, x)
     self.s2 = self.sub(self.s2, x)
     return x
    def construct(self,
                  source_eos_ids,
                  source_eos_mask,
                  target_sos_ids,
                  target_sos_mask,
                  target_eos_ids,
                  target_eos_mask,
                  sens=None):
        """
        Construct network.

        Args:
            source_eos_ids (Tensor): Source sentence.
            source_eos_mask (Tensor): Source padding mask.
            target_sos_ids (Tensor): Target sentence.
            target_sos_mask (Tensor): Target padding mask.
            target_eos_ids (Tensor): Prediction sentence.
            target_eos_mask (Tensor): Prediction padding mask.
            sens (Tensor): Loss sen.

        Returns:
            Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.
        """
        source_ids = source_eos_ids
        source_mask = source_eos_mask
        target_ids = target_sos_ids
        target_mask = target_sos_mask
        label_ids = target_eos_ids
        label_weights = target_eos_mask

        weights = self.weights
        loss = self.network(source_ids,
                            source_mask,
                            target_ids,
                            target_mask,
                            label_ids,
                            label_weights)
        # Alloc status.
        init = self.alloc_status()
        # Clear overflow buffer.
        self.clear_before_grad(init)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens
        grads = self.grad(self.network, weights)(source_ids,
                                                 source_mask,
                                                 target_ids,
                                                 target_mask,
                                                 label_ids,
                                                 label_weights,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))

        grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
        grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
        if self.reducer_flag:
            # Apply grad reducer on grads.
            grads = self.grad_reducer(grads)
        self.get_status(init)
        flag_sum = self.reduce_sum(init, (0,))

        if self.is_distributed:
            # Sum overflow flag over devices.
            flag_reduce = self.all_reduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)

        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if overflow:
            succ = False
        else:
            succ = self.optimizer(grads)

        ret = (loss, cond, scaling_sens)
        return F.depend(ret, succ)
Example #6
0
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
    """Apply sgd optimizer to the weight parameter using Tensor."""
    success = True
    success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
    return success
Example #7
0
def _tensor_run_opt_ext(opt, weight_decay, scale, momentum, learning_rate, gradient, weight, moment):
    """Apply momentum optimizer to the weight parameter using Tensor."""
    success = F.depend(True, opt(weight_decay, scale, weight, moment, learning_rate, gradient, momentum))
    return success
Example #8
0
def _tensor_run_opt(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum):
    """Apply proximal_ada_grad optimizer to the weight parameter."""
    success = True
    success = F.depend(success, opt(weight, accum, learning_rate, l1, l2, gradient))
    return success
Example #9
0
    def construct(self,
                  source_eos_ids,
                  source_eos_mask,
                  target_sos_ids,
                  target_sos_mask,
                  target_eos_ids,
                  target_eos_mask,
                  sens=None):
        """Defines the computation performed."""
        source_ids = source_eos_ids
        source_mask = source_eos_mask
        target_ids = target_sos_ids
        target_mask = target_sos_mask
        label_ids = target_eos_ids
        label_weights = target_eos_mask

        weights = self.weights
        loss = self.network(source_ids,
                            source_mask,
                            target_ids,
                            target_mask,
                            label_ids,
                            label_weights)
        # alloc status
        init = self.alloc_status()
        # clear overflow buffer
        self.clear_before_grad(init)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens
        grads = self.grad(self.network, weights)(source_ids,
                                                 source_mask,
                                                 target_ids,
                                                 target_mask,
                                                 label_ids,
                                                 label_weights,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))

        grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
        grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
        if self.reducer_flag:
            # apply grad reducer on grads
            grads = self.grad_reducer(grads)
        self.get_status(init)
        flag_sum = self.reduce_sum(init, (0,))

        if self.is_distributed:
            # sum overflow flag over devices
            flag_reduce = self.allreduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)

        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if overflow:
            succ = False
        else:
            succ = self.optimizer(grads)

        ret = (loss, cond, scaling_sens)
        return F.depend(ret, succ)
Example #10
0
def _tensor_run_opt(opt, learning_rate, weight, accum, gradient):
    """Apply ada_grad optimizer to the weight parameter."""
    success = True
    success = F.depend(success, opt(weight, accum, learning_rate, gradient))
    return success
Example #11
0
 def construct(self, data, sens):
     weights = self.weights
     loss = self.network(data)
     grads = self.grad(self.network, weights)(data, sens)
     return F.depend(loss, self.optimizer(grads))
Example #12
0
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v,
                   gradient, decay_flag, optim_filter):
    """
    Update parameters.

    Args:
        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
        lr (Tensor): Learning rate.
        overflow (Tensor): Whether overflow occurs.
        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
        param (Tensor): Parameters.
        m (Tensor): m value of parameters.
        v (Tensor): v value of parameters.
        gradient (Tensor): Gradient of parameters.
        decay_flag (bool): Applies weight decay or not.
        optim_filter (bool): Applies parameter update or not.

    Returns:
        Tensor, the new value of v after updating.
    """
    if optim_filter:
        op_mul = P.Mul()
        op_square = P.Square()
        op_sqrt = P.Sqrt()
        op_cast = P.Cast()
        op_reshape = P.Reshape()
        op_shape = P.Shape()
        op_select = P.Select()

        param_fp32 = op_cast(param, mstype.float32)
        m_fp32 = op_cast(m, mstype.float32)
        v_fp32 = op_cast(v, mstype.float32)
        gradient_fp32 = op_cast(gradient, mstype.float32)

        cond = op_cast(
            F.fill(mstype.int32, op_shape(m_fp32), 1) *
            op_reshape(overflow, (())), mstype.bool_)
        next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))

        next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))

        update = next_m / (eps + op_sqrt(next_v))
        if decay_flag:
            update = op_mul(weight_decay, param_fp32) + update

        update_with_lr = op_mul(lr, update)
        zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
        next_param = param_fp32 - op_select(
            cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))

        next_param = F.depend(
            next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
        next_param = F.depend(next_param,
                              F.assign(m, op_cast(next_m, F.dtype(m))))
        next_param = F.depend(next_param,
                              F.assign(v, op_cast(next_v, F.dtype(v))))

        return op_cast(next_param, F.dtype(param))
    return gradient
Example #13
0
    def construct(self, x, y):
        add_res = self.add(x, y)
        F.depend(add_res, F.assign(self.param, add_res))

        return add_res
Example #14
0
    def construct(self, gradients):
        params = self.params
        moments = self.moments
        if self.thor:
            matrix_A_allreduce = ()
            matrix_G_allreduce = ()
            matrix_A_max_allreduce = ()
            matrix_G_max_allreduce = ()
            for i in range(54):
                g = gradients[i * 3]
                matrix_A = self.matrix_A[i]
                matrix_G = self.matrix_G[i]
                A_max = self.A_inv_max[i]
                G_max = self.G_inv_max[i]
                matrix_A = F.depend(matrix_A, g)
                matrix_G = F.depend(matrix_G, g)
                A_max = F.depend(A_max, g)
                G_max = F.depend(G_max, g)
                matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
                matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
                matrix_A_max_allreduce = matrix_A_max_allreduce + (A_max,)
                matrix_G_max_allreduce = matrix_G_max_allreduce + (G_max,)
            matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
            matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
            matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
            matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)
            new_grads = ()
            for i in range(54):
                g = gradients[i * 3]
                temp_a = matrix_A_allreduce[i]
                temp_g = matrix_G_allreduce[i]
                temp_a = self.cast(temp_a, mstype.float32)
                temp_g = self.cast(temp_g, mstype.float32)
                matrix_A_inv_max = self.log(matrix_A_max_allreduce[i])
                matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
                matrix_A_inv_max = self.exp(matrix_A_inv_max)
                temp_a = self.mul(temp_a, matrix_A_inv_max)
                matrix_G_inv_max = self.log(matrix_G_max_allreduce[i])
                matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
                matrix_G_inv_max = self.exp(matrix_G_inv_max)
                temp_g = self.mul(temp_g, matrix_G_inv_max)
                temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
                temp_max = self.mul(temp_max, self.feature_map[i])
                temp_a = self.cast(temp_a, mstype.float16)
                temp_g = self.cast(temp_g, mstype.float16)
                if i == 53:
                    g = self.cube_matmul_left_fc(temp_g, g)
                    g = self.cube_matmul_right_fc(g, temp_a, temp_max)
                else:
                    g = self.cube_matmul_left(temp_g, g)
                    g = self.cube_matmul_right_mul(g, temp_a, temp_max)
                fake_A = self.assign(self.matrix_A[i], temp_a)
                fake_G = self.assign(self.matrix_G[i], temp_g)
                fake_max = self.assign(self.matrix_max_inv[i], temp_max)
                g = F.depend(g, fake_A)
                g = F.depend(g, fake_G)
                g = F.depend(g, fake_max)
                if i == 53:
                    new_grads = new_grads + (g,)
                else:
                    new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
            gradients = new_grads
        else:
            new_grads = ()
            for i in range(54):
                g = gradients[i * 3]
                matrix_A = self.matrix_A[i]
                matrix_G = self.matrix_G[i]
                matrix_max = self.matrix_max_inv[i]
                matrix_A = F.depend(matrix_A, g)
                matrix_G = F.depend(matrix_G, g)
                matrix_max = F.depend(matrix_max, g)
                if i == 53:
                    g = self.cube_matmul_left_fc(matrix_G, g)
                    g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
                    new_grads = new_grads + (g,)
                else:
                    g = self.cube_matmul_left(matrix_G, g)
                    g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
                    new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
            gradients = new_grads

        if self.weight_decay > 0:
            gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
                                       params, gradients)
        gradients = self.scale_grad(gradients)
        lr = self.get_lr()
        success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
        return success
Example #15
0
def _tensor_run_opt(opt, learning_rate, momentum, gradient, weight, moment):
    """Apply momentum optimizer to the weight parameter."""
    success = True
    success = F.depend(success,
                       opt(weight, moment, learning_rate, gradient, momentum))
    return success
Example #16
0
def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
    """Apply sparse proximal_ada_grad optimizer to the weight parameter."""
    success = True
    success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values, gradient.indices))
    return success
Example #17
0
    def construct(self, data, label, sens=None):
        """
        construct a compute flow.
        """
        init = False
        if not self.gpu_target:
            # init overflow buffer
            init = self.alloc_status()
            # clear overflow buffer
            self.clear_status(init)

        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens

        # DP clip
        weights = self.weights
        record_datas = self._split(data)
        record_labels = self._split(label)
        # first index
        loss = self.network(record_datas[0], record_labels[0])
        scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss))
        record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled)
        record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
        grads = record_grad
        total_loss = loss
        for i in range(1, self._micro_batches):
            loss = self.network(record_datas[i], record_labels[i])
            scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss))
            record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled)
            record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
            grads = self._tuple_add(grads, record_grad)
            total_loss = P.TensorAdd()(total_loss, loss)
        loss = P.Div()(total_loss, self._micro_float)

        if self._mech is not None:
            grad_noise = self._hyper_map(self._mech, grads)
            grads = self._tuple_add(grads, grad_noise)
            grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads)

        grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
        # apply grad reducer on grads
        grads = self.grad_reducer(grads)
        # get the overflow buffer
        if not self.gpu_target:
            self.get_status(init)
            # sum overflow buffer elements, 0:not overflow , >0:overflow
            flag_sum = self.reduce_sum(init, (0,))
        else:
            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
            flag_sum = self.addn(flag_sum)
            # convert flag_sum to scalar
            flag_sum = self.reshape(flag_sum, (()))
        if self.is_distributed:
            # sum overflow flag over devices
            flag_reduce = self.allreduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        # if there is no overflow, do optimize
        if overflow:
            opt = False
        else:
            opt = self.optimizer(grads)
        ret = (loss, cond, scaling_sens)
        return F.depend(ret, opt)
Example #18
0
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking,
                         use_nesterov, target, beta1_power, beta2_power, beta1,
                         beta2, eps, lr, gradient, param, m, v, ps_parameter,
                         cache_enable):
    """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
    success = True
    indices = gradient.indices
    values = gradient.values
    if ps_parameter and not cache_enable:
        op_shape = P.Shape()
        shapes = (op_shape(param), op_shape(m), op_shape(v),
                  op_shape(beta1_power), op_shape(beta2_power), op_shape(lr),
                  op_shape(beta1), op_shape(beta2), op_shape(eps),
                  op_shape(values), op_shape(indices))
        success = F.depend(
            success,
            pull(
                push((beta1_power, beta2_power, lr, beta1, beta2, eps, values,
                      indices), shapes), param))
        return success

    if not target:
        success = F.depend(
            success,
            sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
                       eps, values, indices))
    else:
        op_mul = P.Mul()
        op_square = P.Square()
        op_sqrt = P.Sqrt()
        scatter_add = P.ScatterAdd(use_locking)

        success = F.depend(success, F.assign(m, op_mul(beta1, m)))
        success = F.depend(success, F.assign(v, op_mul(beta2, v)))

        grad_indices = gradient.indices
        grad_value = gradient.values

        next_m = scatter_add(
            m, grad_indices,
            op_mul(F.tuple_to_array((1.0, )) - beta1, grad_value))

        next_v = scatter_add(
            v, grad_indices,
            op_mul(F.tuple_to_array((1.0, )) - beta2, op_square(grad_value)))

        if use_nesterov:
            m_temp = next_m * _scaler_ten
            F.assign(m, op_mul(beta1, next_m))
            div_value = scatter_add(
                m, op_mul(grad_indices, _scaler_one),
                op_mul(F.tuple_to_array((1.0, )) - beta1, grad_value))
            param_update = div_value / (op_sqrt(next_v) + eps)
            F.assign(m, m_temp / _scaler_ten)
        else:
            param_update = next_m / (op_sqrt(next_v) + eps)

        lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
        next_param = param - lr_t * param_update

        success = F.depend(success, F.assign(param, next_param))
        success = F.depend(success, F.assign(m, next_m))
        success = F.depend(success, F.assign(v, next_v))

    return success
Example #19
0
def _centered_rmsprop_opt_(opt, decay, epsilon, momentum, learning_rate, weight, mg, ms, mom, grad):
    """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate."""
    success = True
    success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
    return success
Example #20
0
    def construct(self,
                  data,
                  coord_mask,
                  conf_pos_mask,
                  conf_neg_mask,
                  cls_mask,
                  t_coord,
                  t_conf,
                  t_cls,
                  gt_list,
                  coord_mask_1,
                  conf_pos_mask_1,
                  conf_neg_mask_1,
                  cls_mask_1,
                  t_coord_1,
                  t_conf_1,
                  t_cls_1,
                  gt_list_1,
                  coord_mask_2,
                  conf_pos_mask_2,
                  conf_neg_mask_2,
                  cls_mask_2,
                  t_coord_2,
                  t_conf_2,
                  t_cls_2,
                  gt_list_2,
                  sens=None):
        '''construct'''

        weights = self.weights
        loss = self.network(data, coord_mask, conf_pos_mask, conf_neg_mask,
                            cls_mask, t_coord, t_conf, t_cls, gt_list,
                            coord_mask_1, conf_pos_mask_1, conf_neg_mask_1,
                            cls_mask_1, t_coord_1, t_conf_1, t_cls_1,
                            gt_list_1, coord_mask_2, conf_pos_mask_2,
                            conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2,
                            t_cls_2, gt_list_2)

        # init overflow buffer
        init = self.alloc_status()

        # clear overflow buffer
        self.clear_status(init)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens

        grads = self.grad(self.network, weights)(
            data, coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord,
            t_conf, t_cls, gt_list, coord_mask_1, conf_pos_mask_1,
            conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1, t_cls_1,
            gt_list_1, coord_mask_2, conf_pos_mask_2, conf_neg_mask_2,
            cls_mask_2, t_coord_2, t_conf_2, t_cls_2, gt_list_2,
            F.cast(scaling_sens, F.dtype(loss)))

        grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
        if self.reducer_flag:
            grads = self.grad_reducer(grads)

        # get the overflow buffer
        self.get_status(init)

        # sum overflow buffer elements, 0:not overflow , >0:overflow
        flag_sum = self.reduce_sum(init, (0, ))
        if self.is_distributed:
            # sum overflow flag over devices
            flag_reduce = self.allreduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)

        opt = self.optimizer(grads)

        ret = (loss, cond, scaling_sens)
        return F.depend(ret, opt)
Example #21
0
 def construct(self, gradients):
     params = self.params
     moments = self.moments
     gradients = self.scale_grad(gradients)
     new_grads = ()
     if self.thor:
         matrix_A_allreduce = ()
         matrix_G_allreduce = ()
         for i in range(54):
             g = gradients[i * 3]
             matrix_A = self.matrix_A[i]
             matrix_G = self.matrix_G[i]
             matrix_A = F.depend(matrix_A, g)
             matrix_G = F.depend(matrix_G, g)
             matrix_A = self.mul(matrix_A, self.feature_map_new[i])
             matrix_G = self.mul(matrix_G, self.feature_map_new[i])
             matrix_A_allreduce = matrix_A_allreduce + (matrix_A, )
             matrix_G_allreduce = matrix_G_allreduce + (matrix_G, )
         matrix_A_allreduce = self.grad_reducer_thorA(matrix_A_allreduce)
         matrix_G_allreduce = self.grad_reducer_thorG(matrix_G_allreduce)
         for i in range(54):
             g = gradients[i * 3]
             g_shape = self.shape(g)
             g = self.reshape(g, (g_shape[0], -1))
             matrix_A = matrix_A_allreduce[i]
             matrix_G = matrix_G_allreduce[i]
             g = self.update_gradient(matrix_G, g, matrix_A)
             fake_A = self.assign(self.matrix_A[i], matrix_A)
             fake_G = self.assign(self.matrix_G[i], matrix_G)
             g = F.depend(g, fake_A)
             g = F.depend(g, fake_G)
             if i == 53:
                 new_grads = new_grads + (g, )
             else:
                 g = self.reshape(g, g_shape)
                 new_grads = new_grads + (g, gradients[i * 3 + 1],
                                          gradients[i * 3 + 2])
     else:
         for i in range(54):
             g = gradients[i * 3]
             g_shape = self.shape(g)
             g = self.reshape(g, (g_shape[0], -1))
             matrix_A = self.matrix_A[i]
             matrix_G = self.matrix_G[i]
             matrix_A = F.depend(matrix_A, g)
             matrix_G = F.depend(matrix_G, g)
             g = self.update_gradient(matrix_G, g, matrix_A)
             if i == 53:
                 new_grads = new_grads + (g, )
             else:
                 g = self.reshape(g, g_shape)
                 new_grads = new_grads + (g, gradients[i * 3 + 1],
                                          gradients[i * 3 + 2])
     gradients = new_grads
     if self.weight_decay > 0:
         gradients = self.hyper_map(
             F.partial(apply_decay, self.weight_decay), self.decay_flags,
             params, gradients)
     lr = self.get_lr()
     success = self.hyper_map(
         F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients,
         params, moments)
     return success
def _clear_grad_sum(grad_sum, zero):
    """Apply zero to clear grad_sum."""
    success = True
    success = F.depend(success, F.assign(grad_sum, zero))
    return success
def _update_accu_grads(accu_grad, grad):
    succ = True
    return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
Example #24
0
    def construct(self, gradients):
        """construct of THOR"""
        params = self.params
        moments = self.moments
        encoder_layers_num = 16
        if self.thor:
            new_grads = ()
            # process embedding layer
            for em_idx in range(3):
                g = gradients[em_idx]
                matrix_idx = em_idx
                temp_a_ori = self.matrix_A[matrix_idx]
                temp_g = self.matrix_G[matrix_idx]
                temp_a_ori = F.depend(temp_a_ori, g)
                temp_g = F.depend(temp_g, g)
                temp_a = self.expand(temp_a_ori, 1)
                temp_a = self.cast(temp_a, mstype.float16)
                temp_g = self.cast(temp_g, mstype.float16)
                g = self.cast(g, mstype.float16)
                g = self.mul(temp_a, g)
                g = self.matmul(g, temp_g)
                g = self.cast(g, mstype.float32)
                fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori)
                fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
                g = F.depend(g, fake_A)
                g = F.depend(g, fake_G)
                new_grads = new_grads + (g, )
            # process bert_embedding_postprocessor.layernorm
            grad_idx = 3
            beta_grad = gradients[grad_idx]
            gamma_grad = gradients[grad_idx + 1]
            normalizer = self.batch_size
            normalizer = self.cast(normalizer, mstype.float32)
            damping_step = self.gather(self.damping, self.cov_step, 0)
            damping_step = self.cast(damping_step, mstype.float32)
            self.cov_step = self.cov_step + self.one
            damping = self.sqrt(damping_step)
            beta = self.square(beta_grad)
            beta_cov = self.mul(beta, 1.0 / normalizer)
            beta_cov = beta_cov + damping
            beta_inv = self.inv(beta_cov)
            gamma = self.square(gamma_grad)
            gamma_cov = self.mul(gamma, 1.0 / normalizer)
            gamma_cov = gamma_cov + damping
            gamma_inv = self.inv(gamma_cov)
            beta = self.mul(beta_inv, beta_grad)
            gamma = self.mul(gamma_inv, gamma_grad)
            new_grads = new_grads + (beta, gamma)

            for i in range(self.num_hidden_layers):
                encoder_begin_idx = encoder_layers_num * i + 5
                for j in range(0, encoder_layers_num, 2):
                    grad_idx = encoder_begin_idx + j
                    if j in (8, 14):
                        # process layernorm layer
                        beta_grad = gradients[grad_idx]
                        gamma_grad = gradients[grad_idx + 1]
                        normalizer = self.batch_size
                        normalizer = self.cast(normalizer, mstype.float32)
                        beta = self.square(beta_grad)
                        beta_cov = self.mul(beta, 1.0 / normalizer)
                        beta_cov = beta_cov + damping
                        beta_inv = self.inv(beta_cov)
                        gamma = self.square(gamma_grad)
                        gamma_cov = self.mul(gamma, 1.0 / normalizer)
                        gamma_cov = gamma_cov + damping
                        gamma_inv = self.inv(gamma_cov)
                        beta = self.mul(beta_inv, beta_grad)
                        gamma = self.mul(gamma_inv, gamma_grad)
                        new_grads = new_grads + (beta, gamma)
                    else:
                        g = gradients[grad_idx]
                        offset_idx = 0
                        if j in (0, 2, 4, 6):
                            offset_idx = j // 2
                        elif j in (10, 12):
                            offset_idx = j // 2 - 1
                        matrix_idx = 6 * i + offset_idx + 3
                        temp_a = self.matrix_A[matrix_idx]
                        temp_g = self.matrix_G[matrix_idx]
                        temp_a = F.depend(temp_a, g)
                        temp_g = F.depend(temp_g, g)
                        temp_a = self.cast(temp_a, mstype.float16)
                        temp_g = self.cast(temp_g, mstype.float16)
                        g = self.cast(g, mstype.float16)
                        g = self.matmul(temp_g, g)
                        g = self.matmul(g, temp_a)
                        g = self.cast(g, mstype.float32)
                        fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
                        fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
                        g = F.depend(g, fake_A)
                        g = F.depend(g, fake_G)
                        new_grads = new_grads + (g, )
                        new_grads = new_grads + (gradients[grad_idx + 1], )

            # process pooler layer
            pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
            matrix_idx = self.num_hidden_layers * 6 + 3
            g = gradients[pooler_layer_idx]
            pooler_bias = gradients[pooler_layer_idx + 1]
            temp_a = self.matrix_A[matrix_idx]
            temp_g = self.matrix_G[matrix_idx]
            temp_a = F.depend(temp_a, g)
            temp_g = F.depend(temp_g, g)
            temp_a = self.cast(temp_a, mstype.float16)
            temp_g = self.cast(temp_g, mstype.float16)
            g = self.cast(g, mstype.float16)
            g = self.matmul(temp_g, g)
            g = self.matmul(g, temp_a)
            g = self.cast(g, mstype.float32)
            fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
            fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
            g = F.depend(g, fake_A)
            g = F.depend(g, fake_G)
            new_grads = new_grads + (g, pooler_bias)

            # cls1 fully connect layer for masked language model(mlm)
            mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
            matrix_idx = self.num_hidden_layers * 6 + 4
            g = gradients[mlm_fc_idx]
            mlm_bias = gradients[mlm_fc_idx + 1]
            temp_a = self.matrix_A[matrix_idx]
            temp_g = self.matrix_G[matrix_idx]
            temp_a = F.depend(temp_a, g)
            temp_g = F.depend(temp_g, g)
            temp_a = self.cast(temp_a, mstype.float16)
            temp_g = self.cast(temp_g, mstype.float16)
            g = self.cast(g, mstype.float16)
            g = self.matmul(temp_g, g)
            g = self.matmul(g, temp_a)
            g = self.cast(g, mstype.float32)
            # add bert.cls1.output_bias grad
            fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
            fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
            g = F.depend(g, fake_A)
            g = F.depend(g, fake_G)
            new_grads = new_grads + (gradients[mlm_fc_idx - 1], )
            new_grads = new_grads + (g, mlm_bias)
            # add bert.cls1.layernorm grad
            begin_idx = mlm_fc_idx + 2
            end_idx = mlm_fc_idx + 4
            new_grads = new_grads + gradients[begin_idx:end_idx]

            lenth = len(gradients)
            new_grads = new_grads + gradients[lenth - 2:lenth]
            gradients = new_grads
            gradients = self.grad_reducer_g(gradients)
        else:
            new_grads = ()
            # process embedding layer
            for em_idx in range(3):
                g = gradients[em_idx]
                matrix_idx = em_idx
                temp_a = self.matrix_A[matrix_idx]
                temp_g = self.matrix_G[matrix_idx]
                temp_a = F.depend(temp_a, g)
                temp_g = F.depend(temp_g, g)
                temp_a = self.expand(temp_a, 1)
                temp_a = self.cast(temp_a, mstype.float16)
                temp_g = self.cast(temp_g, mstype.float16)
                g = self.cast(g, mstype.float16)
                g = self.mul(temp_a, g)
                g = self.matmul(g, temp_g)
                g = self.cast(g, mstype.float32)
                new_grads = new_grads + (g, )
            # process bert_embedding_postprocessor.layernorm
            grad_idx = 3
            beta_grad = gradients[grad_idx]
            gamma_grad = gradients[grad_idx + 1]
            normalizer = self.batch_size
            normalizer = self.cast(normalizer, mstype.float32)
            damping_step = self.gather(self.damping, self.cov_step, 0)
            damping_step = self.cast(damping_step, mstype.float32)
            self.cov_step = self.cov_step + self.one
            damping = self.sqrt(damping_step)
            beta = self.square(beta_grad)
            beta_cov = self.mul(beta, 1.0 / normalizer)
            beta_cov = beta_cov + damping
            beta_inv = self.inv(beta_cov)
            gamma = self.square(gamma_grad)
            gamma_cov = self.mul(gamma, 1.0 / normalizer)
            gamma_cov = gamma_cov + damping
            gamma_inv = self.inv(gamma_cov)
            beta = self.mul(beta_inv, beta_grad)
            gamma = self.mul(gamma_inv, gamma_grad)
            new_grads = new_grads + (beta, gamma)

            for i in range(self.num_hidden_layers):
                encoder_begin_idx = encoder_layers_num * i + 5
                for j in range(0, encoder_layers_num, 2):
                    grad_idx = encoder_begin_idx + j
                    if j in (8, 14):
                        # process layernorm layer
                        beta_grad = gradients[grad_idx]
                        gamma_grad = gradients[grad_idx + 1]
                        normalizer = self.batch_size
                        normalizer = self.cast(normalizer, mstype.float32)
                        beta = self.square(beta_grad)
                        beta_cov = self.mul(beta, 1.0 / normalizer)
                        beta_cov = beta_cov + damping
                        beta_inv = self.inv(beta_cov)
                        gamma = self.square(gamma_grad)
                        gamma_cov = self.mul(gamma, 1.0 / normalizer)
                        gamma_cov = gamma_cov + damping
                        gamma_inv = self.inv(gamma_cov)
                        beta = self.mul(beta_inv, beta_grad)
                        gamma = self.mul(gamma_inv, gamma_grad)
                        new_grads = new_grads + (beta, gamma)
                    else:
                        g = gradients[grad_idx]
                        offset_idx = 0
                        if j in (0, 2, 4, 6):
                            offset_idx = j // 2
                        elif j in (10, 12):
                            offset_idx = j // 2 - 1
                        matrix_idx = 6 * i + offset_idx + 3
                        temp_a = self.matrix_A[matrix_idx]
                        temp_g = self.matrix_G[matrix_idx]
                        temp_a = F.depend(temp_a, g)
                        temp_g = F.depend(temp_g, g)
                        temp_a = self.cast(temp_a, mstype.float16)
                        temp_g = self.cast(temp_g, mstype.float16)
                        g = self.cast(g, mstype.float16)
                        g = self.matmul(temp_g, g)
                        g = self.matmul(g, temp_a)
                        g = self.cast(g, mstype.float32)
                        new_grads = new_grads + (g, )
                        new_grads = new_grads + (gradients[grad_idx + 1], )

            # process pooler layer
            pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
            matrix_idx = self.num_hidden_layers * 6 + 3
            g = gradients[pooler_layer_idx]
            pooler_bias = gradients[pooler_layer_idx + 1]
            temp_a = self.matrix_A[matrix_idx]
            temp_g = self.matrix_G[matrix_idx]
            temp_a = F.depend(temp_a, g)
            temp_g = F.depend(temp_g, g)
            temp_a = self.cast(temp_a, mstype.float16)
            temp_g = self.cast(temp_g, mstype.float16)
            g = self.cast(g, mstype.float16)
            g = self.matmul(temp_g, g)
            g = self.matmul(g, temp_a)
            g = self.cast(g, mstype.float32)
            new_grads = new_grads + (g, pooler_bias)

            # cls1 fully connect layer for masked language model(mlm)
            mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
            matrix_idx = self.num_hidden_layers * 6 + 4
            g = gradients[mlm_fc_idx]
            mlm_bias = gradients[mlm_fc_idx + 1]
            temp_a = self.matrix_A[matrix_idx]
            temp_g = self.matrix_G[matrix_idx]
            temp_a = F.depend(temp_a, g)
            temp_g = F.depend(temp_g, g)
            temp_a = self.cast(temp_a, mstype.float16)
            temp_g = self.cast(temp_g, mstype.float16)
            g = self.cast(g, mstype.float16)
            g = self.matmul(temp_g, g)
            g = self.matmul(g, temp_a)
            g = self.cast(g, mstype.float32)
            # add bert.cls1.output_bias grad
            new_grads = new_grads + (gradients[mlm_fc_idx - 1], )
            new_grads = new_grads + (g, mlm_bias)
            # add bert.cls1.layernorm grad
            begin_idx = mlm_fc_idx + 2
            end_idx = mlm_fc_idx + 4
            new_grads = new_grads + gradients[begin_idx:end_idx]

            lenth = len(gradients)
            new_grads = new_grads + gradients[lenth - 2:lenth]
            gradients = new_grads
            gradients = self.grad_reducer_g(gradients)

        if self.weight_decay > 0:
            gradients = self.hyper_map(
                F.partial(apply_decay, self.weight_decay), self.decay_flags,
                params, gradients)
        gradients = self.scale_grad(gradients)
        lr = self.get_lr()
        success = self.hyper_map(
            F.partial(momentum_opt, self.opt, lr, self.momentum), gradients,
            params, moments)
        return success
def _reset_accu_grads(accu_grad):
    succ = True
    return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
Example #26
0
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step,
                   param, m, v, gradient, decay_flag):
    """
    Update parameters.

    Args:
        beta1 (Tensor): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
        beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
        lr (Tensor): Learning rate.
        weight_decay_tensor (Tensor): Weight decay. Should be equal to or greater than 0.
        global_step (Tensor): Global step.
        param (Tensor): Parameters.
        m (Tensor): m value of parameters.
        v (Tensor): v value of parameters.
        gradient (Tensor): Gradient of parameters.
        decay_flag (bool): Specifies whether param update with weight decay.

    Returns:
        Tensor, the new value of v after updating.
    """
    op_mul = P.Mul()
    op_sqrt = P.Sqrt()
    op_rsqrt = P.Rsqrt()
    op_square = P.Square()
    op_cast = P.Cast()
    op_reshape = P.Reshape()
    op_shape = P.Shape()
    op_pow = P.Pow()
    op_norm = layer.Norm()
    op_select = P.Select()
    op_greater = P.Greater()
    op_fill = P.Fill()
    op_dtype = P.DType()

    param = op_cast(param, mstype.float32)
    m = op_cast(m, mstype.float32)
    v = op_cast(v, mstype.float32)
    gradient = op_cast(gradient, mstype.float32)

    next_m = op_mul(beta1, m) + op_mul(
        op_cast(num_one, mstype.float32) - beta1, gradient)

    next_v = op_mul(beta2, v) + op_mul(
        op_cast(num_one, mstype.float32) - beta2, op_square(gradient))

    next_mm = next_m / (op_cast(num_one, mstype.float32) - op_pow(
        beta1, op_cast(global_step + num_one, mstype.float32)))
    next_vv = next_v / (op_cast(num_one, mstype.float32) - op_pow(
        beta2, op_cast(global_step + num_one, mstype.float32)))
    w_norm = op_norm(param)
    g_norm = op_norm(gradient)

    g_norm_hat = op_norm(
        op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param)
    zeros = F.zeros_like_tensor(w_norm)
    ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
    trust_ratio = op_select(
        op_greater(w_norm, zeros),
        op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), ones)
    tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
    trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
    update = next_mm / (op_sqrt(next_vv) + eps)

    if decay_flag:
        update = update + op_mul(weight_decay_tensor, param)

    update_with_lr = op_mul(op_mul(trust_ratio, lr), update)

    next_param = param - op_reshape(update_with_lr, op_shape(param))

    next_v = F.depend(next_v, F.assign(param, next_param))
    next_v = F.depend(next_v, F.assign(m, next_m))
    next_v = F.depend(next_v, F.assign(v, next_v))

    return next_v
Example #27
0
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
    """Apply AdamOffload optimizer to the weight parameter using Tensor."""
    success = True
    delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
    success = F.depend(success, F.assign_add(param, delat_param))
    return success