Пример #1
0
 def construct(self, x_in, z_rec, stage):
     # 通过生成器生成图像
     x_rec_list = self.generator(z_rec)
     # 计算Loss
     g_rec = nn.MSELoss()(x_rec_list[-1], x_in)
     # calculate rmse for each scale
     rmse_list = [1.0]
     for rmseidx in range(1, stage + 1):
         rmse = mindspore.ops.Sqrt(nn.MSELoss()(
             x_rec_list[rmseidx], self.x_in_list[rmseidx]))
         rmse_list.append(rmse)
     pad = mindspore.ops.Pad(((0, 0), (0, 0), (5, 5), (5, 5)))
     z_list = [pad(
         rmse_list[z_idx] *
         Tensor(np.random.randn(self.args.batch_size, 3,
                                self.args.size_list[z_idx],
                                self.args.size_list[z_idx]).astype(np.float32))
     )for z_idx in range(stage + 1)]
     x_fake_list = self.generator(z_list)
     g_fake_logit = self.discriminator(x_fake_list[-1])
     ones = mindspore.ops.OnesLike()(g_fake_logit)
     g_loss = self.lossCreator(x_in, g_fake_logit, g_rec, ones)
     # 计算反向梯度
     grad = ops.GradOperation(get_by_list=True)(self.lossCell, self.weight)
     grads_g = grad(x_in, g_fake_logit, g_rec, ones)
     return g_loss, rmse_list, z_list, ops.depend(g_loss, self.opts(grads_g))
Пример #2
0
 def __init__(self, teacher_config, teacher_ckpt, student_config, is_training, use_one_hot_embeddings=False,
              is_att_fit=True, is_rep_fit=True):
     super(BertNetworkWithLoss_gd, self).__init__()
     # load teacher model
     self.teacher = BertModel(teacher_config, False, use_one_hot_embeddings)
     param_dict = load_checkpoint(teacher_ckpt)
     new_param_dict = {}
     for key, value in param_dict.items():
         new_key = re.sub('^bert.bert.', 'teacher.', key)
         new_param_dict[new_key] = value
     load_param_into_net(self.teacher, new_param_dict)
     # no_grad
     self.teacher.set_train(False)
     params = self.teacher.trainable_params()
     for param in params:
         param.requires_grad = False
     # student model
     self.bert = TinyBertModel(student_config, is_training, use_one_hot_embeddings)
     self.cast = P.Cast()
     self.fit_dense = nn.Dense(student_config.hidden_size,
                               teacher_config.hidden_size).to_float(teacher_config.compute_type)
     self.teacher_layers_num = teacher_config.num_hidden_layers
     self.student_layers_num = student_config.num_hidden_layers
     self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num)
     self.is_att_fit = is_att_fit
     self.is_rep_fit = is_rep_fit
     self.loss_mse = nn.MSELoss()
     self.select = P.Select()
     self.zeroslike = P.ZerosLike()
     self.dtype = teacher_config.dtype
Пример #3
0
 def __init__(self, net_config):
     super(CenterNetMultiPoseLossCell, self).__init__()
     self.network = GatherMultiPoseFeatureCell(net_config)
     self.reduce_sum = ops.ReduceSum()
     self.crit = FocalLoss()
     self.crit_hm_hp = nn.MSELoss() if net_config.mse_loss else self.crit
     self.crit_kp = RegWeightedL1Loss(
     ) if not net_config.dense_hp else nn.L1Loss(reduction='sum')
     self.crit_reg = RegLoss(net_config.reg_loss)
     self.hm_weight = net_config.hm_weight
     self.hm_hp_weight = net_config.hm_hp_weight
     self.hp_weight = net_config.hp_weight
     self.wh_weight = net_config.wh_weight
     self.off_weight = net_config.off_weight
     self.hm_hp = net_config.hm_hp
     self.dense_hp = net_config.dense_hp
     self.reg_offset = net_config.reg_offset
     self.reg_hp_offset = net_config.reg_hp_offset
     self.hm_hp_ind = 3 if self.hm_hp else 2
     self.reg_ind = self.hm_hp_ind + 1 if self.reg_offset else self.hm_hp_ind
     self.reg_hp_ind = self.reg_ind + 1 if self.reg_hp_offset else self.reg_ind
     # just used for check
     self.print = ops.Print()
     self.concat = ops.Concat(axis=1)
     self.reshape = ops.Reshape()
Пример #4
0
 def __init__(self, use_target_weight):
     super(JointsMSELoss, self).__init__()
     self.criterion = nn.MSELoss(reduction='mean')
     self.use_target_weight = use_target_weight
     self.reshape = P.Reshape()
     self.squeeze = P.Squeeze(1)
     self.mul = P.Mul()
Пример #5
0
 def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
     super(BertReg, self).__init__()
     self.bert = BertRegressionModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
     self.loss = nn.MSELoss()
     self.is_training = is_training
     self.sigmoid = P.Sigmoid()
     self.cast = P.Cast()
     self.mul = P.Mul()
Пример #6
0
def test_amp_o0_loss():
    inputs = Tensor(np.ones([16, 16]).astype(np.float32))
    label = Tensor(np.zeros([16, 16]).astype(np.float32))
    net = NetNoLoss(16, 16)
    loss = nn.MSELoss()
    optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
    train_network = amp.build_train_network(net, optimizer, loss)
    output = train_network(inputs, label)
Пример #7
0
    def __init__(self, teacher_config, teacher_ckpt, student_config, student_ckpt,
                 is_training, task_type, num_labels, use_one_hot_embeddings=False,
                 is_predistill=True, is_att_fit=True, is_rep_fit=True,
                 temperature=1.0, dropout_prob=0.1):
        super(BertNetworkWithLoss_td, self).__init__()
        # load teacher model
        self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob,
                                    use_one_hot_embeddings, "teacher")
        param_dict = load_checkpoint(teacher_ckpt)
        new_param_dict = {}
        for key, value in param_dict.items():
            new_key = re.sub('^bert.', 'teacher.', key)
            new_param_dict[new_key] = value
        load_param_into_net(self.teacher, new_param_dict)

        # no_grad
        self.teacher.set_train(False)
        params = self.teacher.trainable_params()
        for param in params:
            param.requires_grad = False
        # load student model
        self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
                                 use_one_hot_embeddings, "student")
        param_dict = load_checkpoint(student_ckpt)
        if is_predistill:
            new_param_dict = {}
            for key, value in param_dict.items():
                new_key = re.sub('tinybert_', 'bert_', 'bert.' + key)
                new_param_dict[new_key] = value
            load_param_into_net(self.bert, new_param_dict)
        else:
            new_param_dict = {}
            for key, value in param_dict.items():
                new_key = re.sub('tinybert_', 'bert_', key)
                new_param_dict[new_key] = value
            load_param_into_net(self.bert, new_param_dict)
        self.cast = P.Cast()
        self.fit_dense = nn.Dense(student_config.hidden_size,
                                  teacher_config.hidden_size).to_float(teacher_config.compute_type)
        self.teacher_layers_num = teacher_config.num_hidden_layers
        self.student_layers_num = student_config.num_hidden_layers
        self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num)
        self.is_predistill = is_predistill
        self.is_att_fit = is_att_fit
        self.is_rep_fit = is_rep_fit
        self.task_type = task_type
        self.temperature = temperature
        self.loss_mse = nn.MSELoss()
        self.select = P.Select()
        self.zeroslike = P.ZerosLike()
        self.dtype = student_config.dtype
        self.num_labels = num_labels
        self.dtype = teacher_config.dtype
        self.soft_cross_entropy = SoftCrossEntropy()
Пример #8
0
    def __init__(self):
        super(CriterionsFaceQA, self).__init__()
        self.gatherv2 = P.GatherV2()
        self.squeeze = P.Squeeze(axis=1)
        self.shape = P.Shape()
        self.reshape = P.Reshape()

        self.euler_label_list = Tensor([0, 1, 2], dtype=mstype.int32)
        self.mse_loss = nn.MSELoss(reduction='sum')

        self.kp_label_list = Tensor([3, 4, 5, 6, 7], dtype=mstype.int32)
        self.kps_loss = CEWithIgnoreIndex3D()
Пример #9
0
 def __init__(self, mode="lsgan", reduction='mean'):
     super(GANLoss, self).__init__()
     self.loss = None
     self.ones = ops.OnesLike()
     if mode == "lsgan":
         self.loss = nn.MSELoss(reduction)
     elif mode == "vanilla":
         self.loss = BCEWithLogits(reduction)
     else:
         raise NotImplementedError(
             f'GANLoss {mode} not recognized, we support lsgan and vanilla.'
         )
Пример #10
0
 def __init__(self, **kwargs):
     for key, value in kwargs.items():
         setattr(self, key, value)
     self.policy_net = DQN(self.state_space_dim, 256, self.action_space_dim)
     self.target_net = DQN(self.state_space_dim, 256, self.action_space_dim)
     self.optimizer = nn.RMSProp(self.policy_net.trainable_params(),
                                 learning_rate=self.lr)
     loss_fn = nn.MSELoss()
     loss_q_net = WithLossCell(self.policy_net, loss_fn)
     self.policy_net_train = nn.TrainOneStepCell(loss_q_net, self.optimizer)
     self.policy_net_train.set_train(mode=True)
     self.buffer = []
     self.steps = 0
Пример #11
0
def test_compile_model_train_O2():
    dataset_types = (np.float32, np.float32)
    dataset_shapes = ((16, 16), (16, 16))

    dataset = MindDataSet(dataset_types, dataset_shapes)

    net = NetNoLoss(16, 16)
    loss = nn.MSELoss()
    optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)

    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
    model.train(2, dataset, dataset_sink_mode=False)
    with pytest.raises(ValueError):
        # not actual run, the metrics step will fail, check if compile ok.
        model.eval(dataset)
Пример #12
0
    def construct(self, x_in):
        x_in.requires_grad = True
        x_fake_list = self.G(self.z_list)

        d_fake_logit = self.D(x_fake_list[-1])
        d_real_logit = self.D(x_in)

        ones = mindspore.ops.OnesLike()(d_real_logit)
        zeros = mindspore.ops.ZerosLike()(d_fake_logit)

        # 根据所选择GAN类型不同,而选择不一样的算法
        if self.args.gantype == 'wgangp':
            # wgan gp
            d_fake = mindspore.ops.ReduceMean(d_fake_logit, (2, 3))
            d_real = -mindspore.ops.ReduceMean(d_real_logit, (2, 3))
            d_gp = compute_grad_gp_wgan(
                self.D, x_in, x_fake_list[-1], self.args.gpu)
            d_loss = d_real + d_fake + 0.1 * d_gp
        elif self.args.gantype == 'zerogp':
            # zero centered GP
            d_fake = mindspore.ops.BinaryCrossEntropy(
                reduction='none')(d_fake_logit, zeros, None).mean()
            d_real = mindspore.ops.BinaryCrossEntropy(
                reduction='none')(d_real_logit, ones, None).mean()
            d_gp = compute_grad_gp(
                d_real_logit.mean((2, 3)), x_in)
            d_loss = d_real + d_fake + 10.0 * d_gp
        elif self.args.gantype == 'lsgan':
            # lsgan
            d_fake = nn.MSELoss()(mindspore.ops.ReduceMean(
                d_fake_logit, (2, 3)), zeros)
            d_real = nn.MSELoss()(mindspore.ops.ReduceMean(
                d_real_logit, (2, 3)), 0.9 * ones)
            d_loss = d_real + d_fake

        return d_loss, None, None
Пример #13
0
 def __init__(self, net_config):
     super(CenterNetLossCell, self).__init__()
     self.network = GatherDetectionFeatureCell(net_config)
     self.net_config = net_config
     self.reduce_sum = ops.ReduceSum()
     self.Sigmoid = Sigmoid()
     self.FocalLoss = FocalLoss()
     self.crit = nn.MSELoss() if net_config.mse_loss else self.FocalLoss
     self.crit_reg = RegLoss(net_config.reg_loss)
     self.crit_wh = RegLoss(net_config.reg_loss)
     self.num_stacks = net_config.num_stacks
     self.wh_weight = net_config.wh_weight
     self.hm_weight = net_config.hm_weight
     self.off_weight = net_config.off_weight
     self.reg_offset = net_config.reg_offset
     self.not_enable_mse_loss = not net_config.mse_loss
     self.Print = ops.Print()
Пример #14
0
 def construct(self, x_in, g_fake_logit, g_rec, ones):
     # 根据所选择GAN类型不同,而选择不一样的算法
     if self.args.gantype == 'wgangp':
         # wgan gp
         g_fake = -mindspore.ops.ReduceMean()(g_fake_logit, (2, 3))
         g_loss = g_fake + 10.0 * g_rec
     elif self.args.gantype == 'zerogp':
         # zero centered GP
         g_fake = mindspore.ops.BinaryCrossEntropy(
             reduction='none')(g_fake_logit, ones, None).mean()
         g_loss = g_fake + 100.0 * g_rec
     elif self.args.gantype == 'lsgan':
         # lsgan
         g_fake = nn.MSELoss()(mindspore.ops.ReduceMean(
             g_fake_logit, (2, 3)), 0.9 * ones)
         g_loss = g_fake + 50.0 * g_rec
     return g_loss
Пример #15
0
    def __init__(self, config: Dict, model_path: str = None) -> None:
        super().__init__(config)

        self.num_labels = self.config.num_labels
        self.bert = MSBertModel(self.config)
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.classifier = nn.Dense(
            self.config.hidden_size,
            self.num_labels,
            weight_init=TruncatedNormal(self.config.initializer_range),
        )

        if self.num_labels == 1:
            self.loss_fct = nn.MSELoss()
        else:
            self.loss_fct = nn.SoftmaxCrossEntropyWithLogits(
                sparse=True, reduction="mean")

        # self.acc_fct = tf.keras.metrics.SparseCategoricalAccuracy()
        if model_path is not None:
            self._load_weights(model_path)
Пример #16
0
def test_compile_model_train_O2_parallel():
    dataset_types = (np.float32, np.float32)
    dataset_shapes = ((16, 16), (16, 16))
    context.set_auto_parallel_context(global_rank=0,
                                      device_num=8,
                                      mirror_mean=True,
                                      parameter_broadcast=True,
                                      parallel_mode=ParallelMode.DATA_PARALLEL)

    dataset = MindDataSet(dataset_types, dataset_shapes)

    net = NetNoLoss(16, 16)
    loss = nn.MSELoss()
    optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)

    init()

    model = Model(net,
                  loss_fn=loss,
                  optimizer=optimizer,
                  metrics={"acc"},
                  amp_level="O2")
    model.train(2, dataset, dataset_sink_mode=False)
Пример #17
0
from mindspore.nn.optim import Momentum
from mindspore.train.model import Model

log = logging.getLogger("test")
log.setLevel(level=logging.ERROR)


class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.conv(x)
        x = self.relu(x)
        out = self.flatten(x)
        return out


loss = nn.MSELoss()


def test_build():
    input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]))
    input_label = Tensor(np.random.randint(0, 10, [1, 10]))
    net = Net()
    opt = Momentum(net.get_parameters(), learning_rate=0.1, momentum=0.9)
    model = Model(net, loss_fn=loss, optimizer=opt, metrics=None)
Пример #18
0
def test_MSELoss():
    loss = nn.MSELoss()
    input_data = Tensor(np.array([[1, 2, 3], [2, 3, 2]]).astype(np.float32))
    target_data = Tensor(np.array([[0, 0, 5], [1, 2, 3]]).astype(np.float32))
    loss(input_data, target_data)
Пример #19
0
 def __init__(self, in_features, out_features):
     super(Net, self).__init__()
     self.dense = nn.Dense(in_features, out_features)
     self.loss = nn.MSELoss()
Пример #20
0
 def __init__(self):
     super(VaeGanLoss, self).__init__()
     self.zeros = P.ZerosLike()
     self.mse = nn.MSELoss(reduction='sum')
     self.elbo = ELBO(latent_prior='Normal', output_prior='Normal')
Пример #21
0
    def __init__(self,
                 teacher_config,
                 teacher_ckpt,
                 student_config,
                 student_ckpt,
                 is_training,
                 task_type,
                 num_labels,
                 use_one_hot_embeddings=False,
                 temperature=1.0,
                 dropout_prob=0.1):
        super(BertNetworkWithLoss, self).__init__()
        # load teacher model
        self.teacher = BertModelCLS(teacher_config, False, num_labels,
                                    dropout_prob, use_one_hot_embeddings,
                                    "teacher")
        param_dict = load_checkpoint(teacher_ckpt)
        new_param_dict = {}
        for key, value in param_dict.items():
            new_key = 'teacher.' + key
            new_param_dict[new_key] = value
        load_param_into_net(self.teacher, new_param_dict)

        # no_grad
        self.teacher.set_train(False)
        params = self.teacher.trainable_params()
        for param in params:
            param.requires_grad = False
        # load student model
        self.bert = BertModelCLS(student_config, is_training, num_labels,
                                 dropout_prob, use_one_hot_embeddings,
                                 "student")
        param_dict = load_checkpoint(student_ckpt)
        new_param_dict = {}
        for key, value in param_dict.items():
            new_key = 'bert.' + key
            new_param_dict[new_key] = value
        load_param_into_net(self.bert, new_param_dict)
        self.cast = P.Cast()
        self.teacher_layers_num = teacher_config.num_hidden_layers
        self.student_layers_num = student_config.num_hidden_layers
        self.layers_per_block = int(self.teacher_layers_num /
                                    self.student_layers_num)
        self.is_att_fit = student_config.is_att_fit
        self.is_rep_fit = student_config.is_rep_fit
        self.is_lgt_fit = student_config.is_lgt_fit
        self.task_type = task_type
        self.temperature = temperature
        self.loss_mse = nn.MSELoss()
        self.lgt_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True,
                                                        reduction='mean')
        self.select = P.Select()
        self.zeroslike = P.ZerosLike()
        self.dtype = student_config.dtype
        self.num_labels = num_labels
        self.soft_cross_entropy = SoftmaxCrossEntropy()
        self.compute_type = student_config.compute_type
        self.embedding_bits = student_config.embedding_bits
        self.weight_bits = student_config.weight_bits
        self.weight_clip_value = student_config.weight_clip_value
        self.reshape = P.Reshape()
Пример #22
0
    parser.add_argument('--device_target',
                        type=str,
                        default='GPU',
                        choices=("GPU"),
                        help="Device target, support GPU.")
    args, _ = parser.parse_known_args()

    if args.device_target == "GPU":
        context.set_context(mode=context.GRAPH_MODE,
                            device_target=args.device_target,
                            save_graphs=False)
    else:
        raise ValueError("Unsupported device target.")

    eval_ds = create_eval_dataset(args.dataset_path)

    net = SRCNN()
    lr = Tensor(config.lr, ms.float32)
    opt = nn.Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07)
    loss = nn.MSELoss(reduction='mean')
    param_dict = load_checkpoint(args.checkpoint_path)
    load_param_into_net(net, param_dict)
    net.set_train(False)
    model = Model(net,
                  loss_fn=loss,
                  optimizer=opt,
                  metrics={'PSNR': SRCNNpsnr()})

    res = model.eval(eval_ds, dataset_sink_mode=False)
    print("result ", res)
Пример #23
0
"""lenet_train_export."""

import sys
import numpy as np
from train_utils import save_inout, train_wrap
from official.cv.lenet.src.lenet import LeNet5
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export

context.set_context(mode=context.PYNATIVE_MODE,
                    device_target="GPU",
                    save_graphs=False)

n = LeNet5()
loss_fn = nn.MSELoss()
optimizer = nn.Adam(n.trainable_params(),
                    learning_rate=1e-2,
                    beta1=0.5,
                    beta2=0.7,
                    eps=1e-2,
                    use_locking=True,
                    use_nesterov=False,
                    weight_decay=0.0,
                    loss_scale=0.3)
net = train_wrap(n, loss_fn, optimizer)

x = Tensor(np.random.randn(32, 1, 32, 32), mstype.float32)
label = Tensor(np.zeros([32, 10]).astype(np.float32))
export(net, x, label, file_name="mindir/lenet_train", file_format='MINDIR')
Пример #24
0
 def __init__(self):
     super(VaeGanLoss, self).__init__()
     self.zeros = P.ZerosLike()
     self.mse = nn.MSELoss(reduction='sum')
Пример #25
0
def trainSinGAN(data_loader, networks, opts, stage, args, additional):
    # avg meter
    d_losses = AverageMeter()
    g_losses = AverageMeter()

    # set nets
    D = networks[0]
    G = networks[1]
    # set opts
    d_opt = opts['d_opt']
    g_opt = opts['g_opt']
    # switch to train mode
    D.train()
    G.train()
    # summary writer
    # writer = additional[0]
    train_it = iter(data_loader)
    # total_iter = 2000 * (args.num_scale - stage + 1)
    # decay_lr = 1600 * (args.num_scale - stage + 1)
    total_iter = 2000
    decay_lr = 1600

    d_iter = 3
    g_iter = 3

    t_train = trange(0, total_iter, initial=0, total=total_iter)

    z_rec = additional['z_rec']

    for z_idx in range(len(z_rec)):
        z_rec[z_idx] = z_rec[z_idx]

    x_in = next(train_it)

    x_in = x_in
    x_org = x_in
    x_in = F.interpolate(x_in, (args.size_list[stage], args.size_list[stage]),
                         mode='bilinear',
                         align_corners=True)
    vutils.save_image(x_in.detach(),
                      os.path.join(args.res_dir,
                                   'ORGTRAIN_{}.png'.format(stage)),
                      nrow=1,
                      normalize=True)

    x_in_list = [x_in]
    for xidx in range(1, stage + 1):
        x_tmp = F.interpolate(x_org,
                              (args.size_list[xidx], args.size_list[xidx]),
                              mode='bilinear',
                              align_corners=True)
        x_in_list.append(x_tmp)

    for i in t_train:
        if i == decay_lr:
            for param_group in d_opt.param_groups:
                param_group['lr'] *= 0.1
                print("DISCRIMINATOR LEARNING RATE UPDATE TO :",
                      param_group['lr'])
            for param_group in g_opt.param_groups:
                param_group['lr'] *= 0.1
                print("GENERATOR LEARNING RATE UPDATE TO :", param_group['lr'])

        for _ in range(g_iter):
            # MSP: 也不知道有没有什么用
            # g_opt.zero_grad()

            x_rec_list = G(z_rec)

            g_rec = nn.MSELoss()(x_rec_list[-1], x_in)
            # calculate rmse for each scale
            rmse_list = [1.0]
            for rmseidx in range(1, stage + 1):
                rmse = mindspore.ops.Sqrt(nn.MSELoss()(x_rec_list[rmseidx],
                                                       x_in_list[rmseidx]))
                rmse_list.append(rmse)

            z_list = [
                mindspore.ops.Pad(
                    rmse_list[z_idx] * mindspore.ops.StandardNormal(
                        args.batch_size, 3, args.size_list[z_idx],
                        args.size_list[z_idx]), [5, 5, 5, 5],
                    value=0) for z_idx in range(stage + 1)
            ]

            x_fake_list = G(z_list)

            g_fake_logit = D(x_fake_list[-1])

            ones = mindspore.ops.OnesLike(g_fake_logit).cuda(args.gpu)

            if args.gantype == 'wgangp':
                # wgan gp
                g_fake = -mindspore.ops.ReduceMean(g_fake_logit, (2, 3))
                g_loss = g_fake + 10.0 * g_rec
            elif args.gantype == 'zerogp':
                # zero centered GP
                g_fake = mindspore.ops.BinaryCrossEntropy(
                    g_fake_logit, ones, reduction='none').mean()
                g_loss = g_fake + 100.0 * g_rec

            elif args.gantype == 'lsgan':
                # lsgan
                g_fake = nn.MSELoss()(mindspore.ops.ReduceMean(
                    g_fake_logit, (2, 3)), 0.9 * ones)
                g_loss = g_fake + 50.0 * g_rec

            g_loss.backward()
            g_opt.step()

            g_losses.update(g_loss.item(), x_in.size(0))

        # Update discriminator
        for _ in range(d_iter):
            x_in.requires_grad = True

            d_opt.zero_grad()
            x_fake_list = G(z_list)

            d_fake_logit = D(x_fake_list[-1].detach())
            d_real_logit = D(x_in)

            ones = mindspore.ops.OnesLike(d_real_logit).cuda(args.gpu)
            zeros = mindspore.ops.ZerosLike(d_fake_logit).cuda(args.gpu)

            if args.gantype == 'wgangp':
                # wgan gp
                d_fake = mindspore.ops.ReduceMean(d_fake_logit, (2, 3))
                d_real = -mindspore.ops.ReduceMean(d_real_logit, (2, 3))
                d_gp = compute_grad_gp_wgan(D, x_in, x_fake_list[-1], args.gpu)
                d_loss = d_real + d_fake + 0.1 * d_gp
            elif args.gantype == 'zerogp':
                # zero centered GP
                # d_fake = F.binary_cross_entropy_with_logits(torch.mean(d_fake_logit, (2, 3)), zeros)
                d_fake = mindspore.ops.BinaryCrossEntropy(
                    d_fake_logit, zeros, reduction='none').mean()
                # d_real = F.binary_cross_entropy_with_logits(torch.mean(d_real_logit, (2, 3)), ones)
                d_real = mindspore.ops.BinaryCrossEntropy(
                    d_real_logit, ones, reduction='none').mean()
                d_gp = compute_grad_gp(
                    mindspore.ops.ReduceMean(d_real_logit, (2, 3)), x_in)
                d_loss = d_real + d_fake + 10.0 * d_gp

            elif args.gantype == 'lsgan':
                # lsgan
                d_fake = nn.MSELoss()(mindspore.ops.ReduceMean(
                    d_fake_logit, (2, 3)), zeros)
                d_real = nn.MSELoss()(mindspore.ops.ReduceMean(
                    d_real_logit, (2, 3)), 0.9 * ones)
                d_loss = d_real + d_fake

            d_loss.backward()
            d_opt.step()

            d_losses.update(d_loss.item(), x_in.size(0))

        t_train.set_description(
            'Stage: [{}/{}] Avg Loss: D[{d_losses.avg:.3f}] G[{g_losses.avg:.3f}] RMSE[{rmse:.3f}]'
            .format(stage,
                    args.num_scale,
                    d_losses=d_losses,
                    g_losses=g_losses,
                    rmse=rmse_list[-1]))