Esempio n. 1
0
 def __init__(self, reduction="mean"):
     super(G_Loss, self).__init__(reduction)
     self.sig = SigmoidCrossEntropyWithLogits()
     self.l1_loss = nn.L1Loss()
     self.ones = ops.OnesLike()
     self.LAMBDA_GAN = args.LAMBDA_GAN
     self.LAMBDA_L1 = args.LAMBDA_L1
Esempio n. 2
0
 def __init__(self, net_config):
     super(CenterNetMultiPoseLossCell, self).__init__()
     self.network = GatherMultiPoseFeatureCell(net_config)
     self.reduce_sum = ops.ReduceSum()
     self.crit = FocalLoss()
     self.crit_hm_hp = nn.MSELoss() if net_config.mse_loss else self.crit
     self.crit_kp = RegWeightedL1Loss(
     ) if not net_config.dense_hp else nn.L1Loss(reduction='sum')
     self.crit_reg = RegLoss(net_config.reg_loss)
     self.hm_weight = net_config.hm_weight
     self.hm_hp_weight = net_config.hm_hp_weight
     self.hp_weight = net_config.hp_weight
     self.wh_weight = net_config.wh_weight
     self.off_weight = net_config.off_weight
     self.hm_hp = net_config.hm_hp
     self.dense_hp = net_config.dense_hp
     self.reg_offset = net_config.reg_offset
     self.reg_hp_offset = net_config.reg_hp_offset
     self.hm_hp_ind = 3 if self.hm_hp else 2
     self.reg_ind = self.hm_hp_ind + 1 if self.reg_offset else self.hm_hp_ind
     self.reg_hp_ind = self.reg_ind + 1 if self.reg_hp_offset else self.reg_ind
     # just used for check
     self.print = ops.Print()
     self.concat = ops.Concat(axis=1)
     self.reshape = ops.Reshape()
Esempio n. 3
0
 def __init__(self, net, contrast_w=0, neg_num=0):
     super(NetWithCSDLossCell, self).__init__()
     self.net = net
     self.neg_num = neg_num
     self.l1_loss = nn.L1Loss()
     self.contrast_loss = ContrastLoss()
     self.contrast_w = contrast_w
Esempio n. 4
0
 def __init__(self, args, D_A, D_B):
     super(DiscriminatorLoss, self).__init__()
     self.D_A = D_A
     self.D_B = D_B
     self.false = Tensor(False, mstype.bool_)
     self.true = Tensor(True, mstype.bool_)
     self.dis_loss = GANLoss(args.gan_mode)
     self.rec_loss = nn.L1Loss("mean")
Esempio n. 5
0
 def __init__(self):
     super(SmoothL1LossNewCMask, self).__init__()
     self.transpose = P.Transpose()
     self.smooth_l1_loss = nn.L1Loss(
         reduction='sum')  # or use nn.SmoothL1Loss()
     self.shape = P.Shape()
     self.expand_dims = P.ExpandDims()
     self.sum = P.ReduceSum()
     self.cast = P.Cast()
Esempio n. 6
0
 def __init__(self, mode='l1'):
     super(RegLoss, self).__init__()
     self.reduce_sum = ops.ReduceSum()
     self.cast = ops.Cast()
     self.expand_dims = ops.ExpandDims()
     self.reshape = ops.Reshape()
     self.gather_feature = TransposeGatherFeature()
     if mode == 'l1':
         self.loss = nn.L1Loss(reduction='sum')
     elif mode == 'sl1':
         self.loss = nn.SmoothL1Loss()
     else:
         self.loss = None
Esempio n. 7
0
 def __init__(self, args, loader, my_model):
     self.args = args
     self.scale = args.scale
     self.trainloader = loader
     self.model = my_model
     self.model.set_train()
     self.criterion = nn.L1Loss()
     self.con_loss = SupConLoss()
     self.optimizer = nn.Adam(self.model.trainable_params(),
                              learning_rate=args.lr,
                              loss_scale=1024.0)
     self.train_net = MyTrain(self.model,
                              self.criterion,
                              self.con_loss,
                              use_con=args.con_loss)
     self.bp = MyTrainOneStepCell(self.train_net, self.optimizer, 1024.0)
Esempio n. 8
0
 def __init__(self, args, generator, D_A, D_B):
     super(GeneratorLoss, self).__init__()
     self.lambda_A = args.lambda_A
     self.lambda_B = args.lambda_B
     self.lambda_idt = args.lambda_idt
     self.use_identity = args.lambda_idt > 0
     self.dis_loss = GANLoss(args.gan_mode)
     self.rec_loss = nn.L1Loss("mean")
     self.generator = generator
     self.D_A = D_A
     self.D_B = D_B
     self.true = Tensor(True, mstype.bool_)
     self.kd = args.kd
     if self.kd:
         self.GT_A = get_generator(args, True)
         load_teacher_ckpt(self.GT_A, args.GT_A_ckpt, "GT_A", "G_A")
         self.GT_B = get_generator(args, True)
         load_teacher_ckpt(self.GT_B, args.GT_B_ckpt, "GT_B", "G_B")
         self.GT_A.set_train(True)
         self.GT_B.set_train(True)
Esempio n. 9
0
def test_L1Loss():
    loss = nn.L1Loss()
    input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32))
    target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32))
    loss(input_data, target_data)
Esempio n. 10
0
 def __init__(self):
     super(RegWeightedL1Loss, self).__init__()
     self.reduce_sum = ops.ReduceSum()
     self.gather_feature = TransposeGatherFeature()
     self.cast = ops.Cast()
     self.l1_loss = nn.L1Loss(reduction='sum')
Esempio n. 11
0
def test_L1Loss():
    loss = nn.L1Loss()
    input_data = Tensor(np.array([1, 2, 3]))
    target_data = Tensor(np.array([1, 2, 2]))
    with pytest.raises(NotImplementedError):
        loss.construct(input_data, target_data)
Esempio n. 12
0
 def __init__(self, net):
     super(NetWithLossCell, self).__init__()
     self.net = net
     self.l1_loss = nn.L1Loss()
Esempio n. 13
0
 def __init__(self):
     super(ContrastLoss, self).__init__()
     self.vgg = Vgg19()
     self.l1 = nn.L1Loss()
     self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
Esempio n. 14
0
def train():
    """train"""
    set_seed(1)
    device_id = int(os.getenv('DEVICE_ID', '0'))
    rank_id = int(os.getenv('RANK_ID', '0'))
    device_num = int(os.getenv('RANK_SIZE', '1'))
    # context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False, device_id=device_id)
    context.set_context(mode=context.PYNATIVE_MODE,
                        device_target="GPU",
                        save_graphs=False,
                        device_id=device_id)

    if device_num > 1:
        init()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            device_num=device_num,
            global_rank=device_id,
            gradients_mean=True)
    if args.modelArts_mode:
        import moxing as mox
        local_data_url = '/cache/data'
        mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)

    train_dataset = DIV2K(args,
                          name=args.data_train,
                          train=True,
                          benchmark=False)
    train_dataset.set_scale(args.task_id)
    print(len(train_dataset))
    train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"],
                                           num_shards=device_num,
                                           shard_id=rank_id,
                                           shuffle=True)
    train_de_dataset = train_de_dataset.batch(args.batch_size,
                                              drop_remainder=True)

    eval_dataset = SRData(args,
                          name=args.data_test,
                          train=False,
                          benchmark=True)
    print(len(eval_dataset))
    eval_ds = ds.GeneratorDataset(eval_dataset, ['LR', 'HR'], shuffle=False)
    eval_ds = eval_ds.batch(1, drop_remainder=True)

    # net_m = RCAN(args)
    net_m = EDSR(args)
    print("Init net weights successfully")

    if args.ckpt_path:
        param_dict = load_checkpoint(args.pth_path)
        load_param_into_net(net_m, param_dict)
        print("Load net weight successfully")
    step_size = train_de_dataset.get_dataset_size()
    lr = []
    for i in range(0, args.epochs):
        cur_lr = args.lr / (2**((i + 1) // 200))
        lr.extend([cur_lr] * step_size)
    opt = nn.Adam(net_m.trainable_params(),
                  learning_rate=lr,
                  loss_scale=args.loss_scale)
    loss = nn.L1Loss()
    loss_scale_manager = DynamicLossScaleManager(init_loss_scale=args.init_loss_scale, \
             scale_factor=2, scale_window=1000)

    eval_net = net_m
    model = Model(net_m,
                  loss_fn=loss,
                  optimizer=opt,
                  loss_scale_manager=loss_scale_manager)

    time_cb = TimeMonitor(data_size=step_size)
    loss_cb = LossMonitor()
    metrics = {
        "psnr": PSNR(rgb_range=args.rgb_range, shave=True),
    }
    eval_cb = EvalCallBack(eval_net,
                           eval_ds,
                           args.test_every,
                           step_size / args.batch_size,
                           metrics=metrics,
                           rank_id=rank_id)
    cb = [time_cb, loss_cb, eval_cb]
    config_ck = CheckpointConfig(
        save_checkpoint_steps=args.ckpt_save_interval * step_size,
        keep_checkpoint_max=args.ckpt_save_max)
    ckpt_cb = ModelCheckpoint(prefix=args.filename,
                              directory=args.ckpt_save_path,
                              config=config_ck)
    if device_id == 0:
        cb += [ckpt_cb]
    model.train(args.epochs,
                train_de_dataset,
                callbacks=cb,
                dataset_sink_mode=True)