Beispiel #1
0
    def __init__(self, fix_im, **kwargs):
        super(SSIMRegularization, self).__init__(fix_im)

        if 'window_size' in kwargs:
            self.ssim_instance = ssim.SSIM(window_size=kwargs['window_size'])
        else:
            self.ssim_instance = ssim.SSIM()

        manual_gpu = kwargs.get('manual_gpu', None)
        if manual_gpu is not None:
            self.use_gpu = manual_gpu
        else:
            self.use_gpu = utils.use_gpu()
def SSIM(output, target):
    ssim = pytorch_ssim.SSIM(window_size=11)
    total_ssim = 0.
    n_frames = target.shape[1]
    for f in range(n_frames):
        total_ssim += ssim(output[:, f], target[:, f])
    return total_ssim / n_frames
Beispiel #3
0
    def __init__(self, recon_loss_name):

        if recon_loss_name == "L1":
            self.recon_loss_func = lambda x, y: torch.mean(
                torch.sum(torch.abs(x - y), dim=(1, 2, 3)), dim=0)

        elif recon_loss_name == "MSE":
            self.recon_loss_func = lambda x, y: torch.mean(
                torch.sum(torch.abs(x - y)**2, dim=(1, 2, 3)), dim=0)

        elif recon_loss_name == "BCE":
            raise NotImplementedError

        elif recon_loss_name == "SSIM":
            ssim = pytorch_ssim.SSIM(window_size=3).cuda()
            self.recon_loss_func = lambda x, y: torch.mean(
                torch.sum(1 - ssim(x, y), dim=(1, 2, 3)), dim=0)

        elif recon_loss_name == "custom":
            raise NotImplementedError
Beispiel #4
0
def main(args):
    # use gpu
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    cur_device = torch.device('cuda:{}'.format(args.gpu))
    if args.loss == 'bayes':
        root = '/home/datamining/Datasets/CrowdCounting/sha_bayes_512/'
        train_path = root + 'train/'
        test_path = root + 'test/'
    elif args.bn:
        root = '/home/datamining/Datasets/CrowdCounting/sha_512_a/'
        train_path = root + 'train/'
        test_path = root + 'test/'
    else:
        if args.dataset == 'sha':
            root = '/home/datamining/Datasets/CrowdCounting/shanghaitech/part_A_final/'
            train_path = root + 'train_data/images'
            test_path = root + 'test_data/images/'
        elif args.dataset == 'shb':
            root = '/home/datamining/Datasets/CrowdCounting/shb_1024_f15/'
            train_path = root + 'train/'
            test_path = root + 'test/'
        elif args.dataset == 'qnrf':
            root = '/home/datamining/Datasets/CrowdCounting/qnrf_1024_a/'
            train_path = root + 'train/'
            test_path = root + 'test/'

    downsample_ratio = args.downsample
    train_loader, test_loader, train_img_paths, test_img_paths = get_loader(
        train_path, test_path, downsample_ratio, args)

    model_dict = {
        'VGG16_13': M_CSRNet,
        'DefCcNet': DefCcNet,
        'Res50_back3': Res50,
        'InceptionV3': Inception3CC,
        'CAN': CANNet
    }
    model_name = args.model
    dataset_name = args.dataset
    net = model_dict[model_name](downsample=args.downsample,
                                 bn=args.bn > 0,
                                 objective=args.objective,
                                 sp=(args.sp > 0),
                                 se=(args.se > 0),
                                 NL=args.nl)
    net.cuda()
    if args.bn > 0:
        save_name = '{}_{}_{}_bn{}_ps{}_{}'.format(model_name, dataset_name,
                                                   str(int(args.bn)),
                                                   str(args.crop_size),
                                                   args.loss)
    else:
        save_name = '{}_d{}{}{}{}{}_{}_{}_cr{}_{}{}{}{}{}{}'.format(
            model_name, str(args.downsample), '_sp' if args.sp else '',
            '_se' if args.se else '',
            '_' + args.nl if args.nl != 'relu' else '',
            '_vp' if args.val_patch else '', dataset_name, args.crop_mode,
            str(args.crop_scale), args.loss, '_wu' if args.warm_up else '',
            '_cl' if args.curriculum == 'W' else '', '_v' +
            str(int(args.value_factor)) if args.value_factor != 1 else '',
            '_amp' + str(args.amp_k) if args.objective == 'dmp+amp' else '',
            '_bg' if args.use_bg else '')
    save_path = "/home/datamining/Models/CrowdCounting/" + save_name + ".pth"
    logger = get_logger('logs/' + save_name + '.txt')
    for k, v in args.__dict__.items():  # save args
        logger.info("{}: {}".format(k, v))
    if os.path.exists(save_path) and args.resume:
        net.load_state_dict(torch.load(save_path))
        print('{} loaded!'.format(save_path))

    value_factor = args.value_factor
    freq = 100

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.decay)
    elif args.optimizer == 'SGD':
        # not converage
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=args.lr,
                                    momentum=0.95,
                                    weight_decay=args.decay)

    if args.loss == 'bayes':
        bayes_criterion = Bay_Loss(True, cur_device)
        post_prob = Post_Prob(sigma=8.0,
                              c_size=args.crop_size,
                              stride=1,
                              background_ratio=0.15,
                              use_background=True,
                              device=cur_device)
    else:
        mse_criterion = nn.MSELoss().cuda()

    if args.scheduler == 'plt':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.9,
                                                   patience=10,
                                                   verbose=True)
    elif args.scheduler == 'cos':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=50,
                                                   eta_min=0)
    elif args.scheduler == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8)
    elif args.scheduler == 'exp':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
    elif args.scheduler == 'cyclic' and args.optimizer == 'SGD':
        scheduler = lr_scheduler.CyclicLR(
            optimizer,
            base_lr=args.lr * 0.01,
            max_lr=args.lr,
            step_size_up=25,
        )
    elif args.scheduler == 'None':
        scheduler = None
    else:
        print('scheduler name error!')

    if args.val_patch:
        best_mae, best_rmse = val_patch(net, test_loader, value_factor)
    elif args.loss == 'bayes':
        best_mae, best_rmse = val_bayes(net, test_loader, value_factor)
    else:
        best_mae, best_rmse = val(net, test_loader, value_factor)
    if args.scheduler == 'plt':
        scheduler.step(best_mae)
    ssim_loss = pytorch_ssim.SSIM(window_size=11)
    for epoch in range(args.epochs):
        if args.crop_mode == 'curriculum':
            # every 20%, change the dataset
            if (epoch + 1) % (args.epochs // 5) == 0:
                print('change dataset')
                single_dataset = RawDataset(
                    train_img_paths, transform, args.crop_mode,
                    downsample_ratio, args.crop_scale,
                    (epoch + 1.0 + args.epochs // 5) / args.epochs)
                train_loader = torch.utils.data.DataLoader(single_dataset,
                                                           shuffle=True,
                                                           batch_size=1,
                                                           num_workers=8)

        train_loss = 0.0
        if args.loss == 'bayes':
            epoch_mae = AverageMeter()
            epoch_mse = AverageMeter()
        net.train()
        if args.warm_up and epoch < args.warm_up_steps:
            linear_warm_up_lr(optimizer, epoch, args.warm_up_steps, args.lr)
        for it, data in enumerate(train_loader):
            if args.loss == 'bayes':
                inputs, points, targets, st_sizes = data
                img = inputs.to(cur_device)
                st_sizes = st_sizes.to(cur_device)
                gd_count = np.array([len(p) for p in points], dtype=np.float32)
                points = [p.to(cur_device) for p in points]
                targets = [t.to(cur_device) for t in targets]
            else:
                img, target, _, amp_gt = data
                img = img.cuda()
                target = value_factor * target.float().unsqueeze(1).cuda()
                amp_gt = amp_gt.cuda()
            #print(img.shape)
            optimizer.zero_grad()

            #print(target.shape)
            if args.objective == 'dmp+amp':
                output, amp = net(img)
                output = output * amp
            else:
                output = net(img)

            if args.curriculum == 'W':
                delta = (output - target)**2
                k_w = 2e-3 * args.value_factor * args.downsample**2
                b_w = 5e-3 * args.value_factor * args.downsample**2
                T = torch.ones_like(target,
                                    dtype=torch.float32) * epoch * k_w + b_w
                W = T / torch.max(T, output)
                delta = delta * W
                mse_loss = torch.mean(delta)
            else:
                mse_loss = mse_criterion(output, target)

            if args.loss == 'mse+lc':
                loss = mse_loss + 1e2 * cal_lc_loss(output,
                                                    target) * args.downsample
            elif args.loss == 'ssim':
                loss = 1 - ssim_loss(output, target)
            elif args.loss == 'mse+ssim':
                loss = 100 * mse_loss + 1e-2 * (1 - ssim_loss(output, target))
            elif args.loss == 'mse+la':
                loss = mse_loss + cal_spatial_abstraction_loss(output, target)
            elif args.loss == 'la':
                loss = cal_spatial_abstraction_loss(output, target)
            elif args.loss == 'ms-ssim':
                #to do
                pass
            elif args.loss == 'adversial':
                # to do
                pass
            elif args.loss == 'bayes':
                prob_list = post_prob(points, st_sizes)
                loss = bayes_criterion(prob_list, targets, output)
            else:
                loss = mse_loss

            # add the cross entropy loss for attention map
            if args.objective == 'dmp+amp':
                cross_entropy = (amp_gt * torch.log(amp) +
                                 (1 - amp_gt) * torch.log(1 - amp)) * -1
                cross_entropy_loss = torch.mean(cross_entropy)
                loss = loss + cross_entropy_loss * args.amp_k

            loss.backward()
            optimizer.step()
            data_loss = loss.item()
            train_loss += data_loss
            if args.loss == 'bayes':
                N = inputs.size(0)
                pre_count = torch.sum(output.view(N, -1),
                                      dim=1).detach().cpu().numpy()
                res = pre_count - gd_count
                epoch_mse.update(np.mean(res * res), N)
                epoch_mae.update(np.mean(abs(res)), N)

            if args.loss != 'bayes' and it % freq == 0:
                print(
                    '[ep:{}], [it:{}], [loss:{:.8f}], [output:{:.2f}, target:{:.2f}]'
                    .format(epoch + 1, it, data_loss, output[0].sum().item(),
                            target[0].sum().item()))
        if args.val_patch:
            mae, rmse = val_patch(net, test_loader, value_factor)
        elif args.loss == 'bayes':
            mae, rmse = val_bayes(net, test_loader, value_factor)
        else:
            mae, rmse = val(net, test_loader, value_factor)
        if not (args.warm_up and epoch < args.warm_up_steps):
            if args.scheduler == 'plt':
                scheduler.step(best_mae)
            elif args.scheduler != 'None':
                scheduler.step()

        if mae + 0.1 * rmse < best_mae + 0.1 * best_rmse:
            best_mae, best_rmse = mae, rmse
            torch.save(net.state_dict(), save_path)

        if args.loss == 'bayes':
            logger.info(
                '{} Epoch {}/{} Loss:{:.8f},MAE:{:.2f},RMSE:{:.2f} lr:{:.8f}, [CUR]:{mae:.1f}, {rmse:.1f}, [Best]:{b_mae:.1f}, {b_rmse:.1f}'
                .format(model_name,
                        epoch + 1,
                        args.epochs,
                        train_loss / len(train_loader),
                        epoch_mae.get_avg(),
                        np.sqrt(epoch_mse.get_avg()),
                        optimizer.param_groups[0]['lr'],
                        mae=mae,
                        rmse=rmse,
                        b_mae=best_mae,
                        b_rmse=best_rmse))
        else:
            logger.info(
                '{} Epoch {}/{} Loss:{:.8f}, lr:{:.8f}, [CUR]:{mae:.1f}, {rmse:.1f}, [Best]:{b_mae:.1f}, {b_rmse:.1f}'
                .format(model_name,
                        epoch + 1,
                        args.epochs,
                        train_loss / len(train_loader),
                        optimizer.param_groups[0]['lr'],
                        mae=mae,
                        rmse=rmse,
                        b_mae=best_mae,
                        b_rmse=best_rmse))
Beispiel #5
0
print('Preparing data done.')

# net
print('==> Building model..')
net = network.UNet(channels=args.channels)
net = net.to(device)

writer = SummaryWriter('runs/eventcamera_experiment_' + str(args.channels) +
                       ('_fixed' if args.fixed else ''))
print('Building model done.')

test_output_image = np.zeros((len(test_label), 180, 240), dtype='float')
if not os.path.exists('result'):
    os.mkdir('result')

criterion = pytorch_ssim.SSIM(window_size=11)
optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=0.001)

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('checkpoint/ckpt' + str(args.channels) +
                            ('_fixed' if args.fixed else '') + '.pth')
    net.load_state_dict(checkpoint['net_params'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    best_psnr = checkpoint['psnr']
    best_ssim = checkpoint['ssim']
    start_epoch = checkpoint['epoch']
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[20, 35],
Beispiel #6
0
def train(args):
    writer = SummaryWriter(comment=args.writer)
    os.makedirs(args.checkpoint_save_path, exist_ok=True)

    argsDict = args.__dict__
    for k, v in argsDict.items():
        writer.add_text('hyperparameter', '{} : {}'.format(str(k), str(v)))

    print_freq = args.print_freq
    test_freq = 1
    global device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)

    input_shape = (3, args.img_height, args.img_width)
    net = dispnetcorr(args.maxdisp)
    G_AB = GeneratorResNet(input_shape, 2)
    G_BA = GeneratorResNet(input_shape, 2)
    D_A = Discriminator(3)
    D_B = Discriminator(3)

    if args.load_checkpoints:
        if args.load_from_mgpus_model:
            if args.load_dispnet_path:
                net = load_multi_gpu_checkpoint(net, args.load_dispnet_path,
                                                'model')
            else:
                net.apply(weights_init_normal)
            G_AB = load_multi_gpu_checkpoint(G_AB, args.load_gan_path, 'G_AB')
            G_BA = load_multi_gpu_checkpoint(G_BA, args.load_gan_path, 'G_BA')
            D_A = load_multi_gpu_checkpoint(D_A, args.load_gan_path, 'D_A')
            D_B = load_multi_gpu_checkpoint(D_B, args.load_gan_path, 'D_B')
        else:
            if args.load_dispnet_path:
                net = load_checkpoint(net, args.load_checkpoint_path, device)
            else:
                net.apply(weights_init_normal)
            G_AB = load_checkpoint(G_AB, args.load_gan_path, 'G_AB')
            G_BA = load_checkpoint(G_BA, args.load_gan_path, 'G_BA')
            D_A = load_checkpoint(D_A, args.load_gan_path, 'D_A')
            D_B = load_checkpoint(D_B, args.load_gan_path, 'D_B')
    else:
        net.apply(weights_init_normal)
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

    # optimizer = optim.SGD(params, momentum=0.9)
    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr_rate,
                           betas=(0.9, 0.999))
    optimizer_G = optim.Adam(itertools.chain(G_AB.parameters(),
                                             G_BA.parameters()),
                             lr=args.lr_gan,
                             betas=(0.5, 0.999))
    optimizer_D_A = optim.Adam(D_A.parameters(),
                               lr=args.lr_gan,
                               betas=(0.5, 0.999))
    optimizer_D_B = optim.Adam(D_B.parameters(),
                               lr=args.lr_gan,
                               betas=(0.5, 0.999))

    if args.use_multi_gpu:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net, device_ids=list(range(args.use_multi_gpu)))
        G_AB = nn.DataParallel(G_AB,
                               device_ids=list(range(args.use_multi_gpu)))
        G_BA = nn.DataParallel(G_BA,
                               device_ids=list(range(args.use_multi_gpu)))
        D_A = nn.DataParallel(D_A, device_ids=list(range(args.use_multi_gpu)))
        D_B = nn.DataParallel(D_B, device_ids=list(range(args.use_multi_gpu)))

    net.to(device)
    G_AB.to(device)
    G_BA.to(device)
    D_A.to(device)
    D_B.to(device)

    criterion_GAN = torch.nn.MSELoss().cuda()
    criterion_identity = torch.nn.L1Loss().cuda()
    ssim_loss = pytorch_ssim.SSIM()

    # data loader
    if args.source_dataset == 'driving':
        dataset = ImageDataset(height=args.img_height, width=args.img_width)
    elif args.source_dataset == 'synthia':
        dataset = ImageDataset2(height=args.img_height, width=args.img_width)
    else:
        raise "No suportive dataset"
    trainloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)
    valdataset = ValJointImageDataset()
    valloader = torch.utils.data.DataLoader(valdataset,
                                            batch_size=args.test_batch_size,
                                            shuffle=False,
                                            num_workers=1)

    train_loss_meter = AverageMeter()
    val_loss_meter = AverageMeter()

    ## debug only
    #with torch.no_grad():
    #    l1_test_loss, out_val = val(valloader, net, G_AB, None, writer, epoch=0, board_save=True)
    #    val_loss_meter.update(l1_test_loss)
    #    print('Val epoch[{}/{}] loss: {}'.format(0, args.total_epochs, l1_test_loss))

    print('begin training...')
    best_val_d1 = 1.
    best_val_epe = 100.
    for epoch in range(args.total_epochs):
        #net.train()
        #G_AB.train()

        n_iter = 0
        running_loss = 0.
        t = time.time()
        # custom lr decay, or warm-up
        lr = args.lr_rate
        if epoch >= int(args.lrepochs.split(':')[0]):
            lr = lr / int(args.lrepochs.split(':')[1])
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        for i, batch in enumerate(trainloader):
            n_iter += 1
            leftA = batch['leftA'].to(device)
            rightA = batch['rightA'].to(device)
            leftB = batch['leftB'].to(device)
            rightB = batch['rightB'].to(device)
            dispA = batch['dispA'].unsqueeze(1).float().to(device)
            dispB = batch['dispB'].to(device)
            out_shape = (leftA.size(0), 1, args.img_height // 16,
                         args.img_width // 16)
            valid = torch.cuda.FloatTensor(np.ones(out_shape))
            fake = torch.cuda.FloatTensor(np.zeros(out_shape))

            if i % args.train_ratio_gan == 0:
                # train generators
                G_AB.train()
                G_BA.train()
                net.eval()
                optimizer_G.zero_grad()

                # Identity loss
                loss_id_A = (criterion_identity(G_BA(leftA), leftA) +
                             criterion_identity(G_BA(rightA), rightA)) / 2
                loss_id_B = (criterion_identity(G_AB(leftB), leftB) +
                             criterion_identity(G_AB(rightB), rightB)) / 2
                loss_id = (loss_id_A + loss_id_B) / 2

                if args.lambda_warp_inv:
                    fake_leftB, fake_leftB_feats = G_AB(leftA,
                                                        extract_feat=True)
                    fake_leftA, fake_leftA_feats = G_BA(leftB,
                                                        extract_feat=True)
                else:
                    fake_leftB = G_AB(leftA)
                    fake_leftA = G_BA(leftB)
                if args.lambda_warp:
                    fake_rightB, fake_rightB_feats = G_AB(rightA,
                                                          extract_feat=True)
                    fake_rightA, fake_rightA_feats = G_BA(rightB,
                                                          extract_feat=True)
                else:
                    fake_rightB = G_AB(rightA)
                    fake_rightA = G_BA(rightB)
                loss_GAN_AB = criterion_GAN(D_B(fake_leftB), valid)
                loss_GAN_BA = criterion_GAN(D_A(fake_leftA), valid)
                loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

                if args.lambda_warp_inv:
                    rec_leftA, rec_leftA_feats = G_BA(fake_leftB,
                                                      extract_feat=True)
                else:
                    rec_leftA = G_BA(fake_leftB)
                if args.lambda_warp:
                    rec_rightA, rec_rightA_feats = G_BA(fake_rightB,
                                                        extract_feat=True)
                else:
                    rec_rightA = G_BA(fake_rightB)
                rec_leftB = G_AB(fake_leftA)
                rec_rightB = G_AB(fake_rightA)
                loss_cycle_A = (criterion_identity(rec_leftA, leftA) +
                                criterion_identity(rec_rightA, rightA)) / 2
                loss_ssim_A = 1. - (ssim_loss(rec_leftA, leftA) +
                                    ssim_loss(rec_rightA, rightA)) / 2
                loss_cycle_B = (criterion_identity(rec_leftB, leftB) +
                                criterion_identity(rec_rightB, rightB)) / 2
                loss_ssim_B = 1. - (ssim_loss(rec_leftB, leftB) +
                                    ssim_loss(rec_rightB, rightB)) / 2
                loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
                loss_ssim = (loss_ssim_A + loss_ssim_B) / 2

                # mode seeking loss
                if args.lambda_ms:
                    loss_ms = G_AB(leftA, zx=True, zx_relax=True).mean()
                else:
                    loss_ms = 0

                # warping loss
                if args.lambda_warp_inv:
                    fake_leftB_warp, loss_warp_inv_feat1 = G_AB(
                        rightA, -dispA, True,
                        [x.detach() for x in fake_leftB_feats])
                    rec_leftA_warp, loss_warp_inv_feat2 = G_BA(
                        fake_rightB, -dispA, True,
                        [x.detach() for x in rec_leftA_feats])
                    loss_warp_inv1 = warp_loss(
                        [(G_BA(fake_leftB_warp[0]), fake_leftB_warp[1])],
                        [leftA],
                        weights=[1])
                    loss_warp_inv2 = warp_loss([rec_leftA_warp], [leftA],
                                               weights=[1])
                    loss_warp_inv = loss_warp_inv1 + loss_warp_inv2 + loss_warp_inv_feat1.mean(
                    ) + loss_warp_inv_feat2.mean()
                else:
                    loss_warp_inv = 0

                if args.lambda_warp:
                    fake_rightB_warp, loss_warp_feat1 = G_AB(
                        leftA, dispA, True,
                        [x.detach() for x in fake_rightB_feats])
                    rec_rightA_warp, loss_warp_feat2 = G_BA(
                        fake_leftB, dispA, True,
                        [x.detach() for x in rec_rightA_feats])
                    loss_warp1 = warp_loss(
                        [(G_BA(fake_rightB_warp[0]), fake_rightB_warp[1])],
                        [rightA],
                        weights=[1])
                    loss_warp2 = warp_loss([rec_rightA_warp], [rightA],
                                           weights=[1])
                    loss_warp = loss_warp1 + loss_warp2 + loss_warp_feat1.mean(
                    ) + loss_warp_feat2.mean()
                else:
                    loss_warp = 0

                # corr loss
                if args.lambda_corr:
                    corrB = net(leftB, rightB, extract_feat=True)
                    corrB1 = net(leftB, rec_rightB, extract_feat=True)
                    corrB2 = net(rec_leftB, rightB, extract_feat=True)
                    corrB3 = net(rec_leftB, rec_rightB, extract_feat=True)
                    loss_corr = (criterion_identity(corrB1, corrB) +
                                 criterion_identity(corrB2, corrB) +
                                 criterion_identity(corrB3, corrB)) / 3
                else:
                    loss_corr = 0.

                lambda_ms = args.lambda_ms * (args.total_epochs -
                                              epoch) / args.total_epochs
                loss_G = loss_GAN + args.lambda_cycle*(args.alpha_ssim*loss_ssim+(1-args.alpha_ssim)*loss_cycle) + args.lambda_id*loss_id \
                       + args.lambda_warp*loss_warp + args.lambda_warp_inv*loss_warp_inv + args.lambda_corr*loss_corr + lambda_ms*loss_ms
                loss_G.backward()
                optimizer_G.step()

                # train discriminators. A: real, B: syn
                optimizer_D_A.zero_grad()
                loss_real_A = criterion_GAN(D_A(leftA), valid)
                fake_leftA.detach_()
                loss_fake_A = criterion_GAN(D_A(fake_leftA), fake)
                loss_D_A = (loss_real_A + loss_fake_A) / 2
                loss_D_A.backward()
                optimizer_D_A.step()

                optimizer_D_B.zero_grad()
                #loss_real_B = criterion_GAN(D_B(torch.cat([syn_left_img, syn_right_img], 0)), valid)
                #fake_syn_left.detach_()
                #fake_syn_right.detach_()
                #loss_fake_B = criterion_GAN(D_B(torch.cat([fake_syn_left, fake_syn_right], 0)), fake)
                loss_real_B = criterion_GAN(D_B(leftB), valid)
                fake_leftB.detach_()
                loss_fake_B = criterion_GAN(D_B(fake_leftB), fake)
                loss_D_B = (loss_real_B + loss_fake_B) / 2
                loss_D_B.backward()
                optimizer_D_B.step()

            # train disp net
            net.train()
            G_AB.eval()
            G_BA.eval()
            optimizer.zero_grad()
            disp_ests = net(G_AB(leftA), G_AB.forward(rightA))
            mask = (dispA < args.maxdisp) & (dispA > 0)
            loss0 = model_loss0(disp_ests, dispA, mask)

            if args.lambda_disp_warp_inv:
                disp_warp = [-disp_ests[i] for i in range(3)]
                loss_disp_warp_inv = G_BA(
                    rightB, disp_warp, True,
                    [x.detach() for x in fake_leftA_feats])
                loss_disp_warp_inv = loss_disp_warp_inv.mean()
            else:
                loss_disp_warp_inv = 0

            if args.lambda_disp_warp:
                disp_warp = [disp_ests[i] for i in range(3)]
                loss_disp_warp = G_BA(leftB, disp_warp, True,
                                      [x.detach() for x in fake_rightA_feats])
                loss_disp_warp = loss_disp_warp.mean()
            else:
                loss_disp_warp = 0

            loss = loss0 + args.lambda_disp_warp * loss_disp_warp + args.lambda_disp_warp_inv * loss_disp_warp_inv
            loss.backward()
            optimizer.step()

            if i % print_freq == print_freq - 1:
                print('epoch[{}/{}]  step[{}/{}]  loss: {}'.format(
                    epoch, args.total_epochs, i, len(trainloader),
                    loss.item()))
                train_loss_meter.update(running_loss / print_freq)
                #writer.add_scalar('loss/trainloss avg_meter', train_loss_meter.val, train_loss_meter.count * print_freq)
                writer.add_scalar('loss/loss_disp', loss0,
                                  train_loss_meter.count * print_freq)
                writer.add_scalar('loss/loss_disp_warp', loss_disp_warp,
                                  train_loss_meter.count * print_freq)
                writer.add_scalar('loss/loss_disp_warp_inv',
                                  loss_disp_warp_inv,
                                  train_loss_meter.count * print_freq)

                if i % args.train_ratio_gan == 0:
                    writer.add_scalar('loss/loss_G', loss_G,
                                      train_loss_meter.count * print_freq)
                    writer.add_scalar('loss/loss_gan', loss_GAN,
                                      train_loss_meter.count * print_freq)
                    writer.add_scalar('loss/loss_cycle', loss_cycle,
                                      train_loss_meter.count * print_freq)
                    writer.add_scalar('loss/loss_id', loss_id,
                                      train_loss_meter.count * print_freq)
                    writer.add_scalar('loss/loss_warp', loss_warp,
                                      train_loss_meter.count * print_freq)
                    writer.add_scalar('loss/loss_warp_inv', loss_warp_inv,
                                      train_loss_meter.count * print_freq)
                    writer.add_scalar('loss/loss_corr', loss_corr,
                                      train_loss_meter.count * print_freq)
                    writer.add_scalar('loss/loss_ms', loss_ms,
                                      train_loss_meter.count * print_freq)
                    writer.add_scalar('loss/loss_D_A', loss_D_A,
                                      train_loss_meter.count * print_freq)
                    writer.add_scalar('loss/loss_D_B', loss_D_B,
                                      train_loss_meter.count * print_freq)

                    imgA_visual = vutils.make_grid(leftA[:4, :, :, :],
                                                   nrow=1,
                                                   normalize=True,
                                                   scale_each=True)
                    fakeB_visual = vutils.make_grid(fake_leftB[:4, :, :, :],
                                                    nrow=1,
                                                    normalize=True,
                                                    scale_each=True)
                    recA_visual = vutils.make_grid(rec_leftA[:4, :, :, :],
                                                   nrow=1,
                                                   normalize=True,
                                                   scale_each=True)
                    rightA_visual = vutils.make_grid(rightA[:4, :, :, :],
                                                     nrow=1,
                                                     normalize=True,
                                                     scale_each=True)
                    fakeB_R_visual = vutils.make_grid(fake_rightB[:4, :, :, :],
                                                      nrow=1,
                                                      normalize=True,
                                                      scale_each=True)
                    recA_R_visual = vutils.make_grid(rec_rightA[:4, :, :, :],
                                                     nrow=1,
                                                     normalize=True,
                                                     scale_each=True)

                    imgB_visual = vutils.make_grid(leftB[:4, :, :, :],
                                                   nrow=1,
                                                   normalize=True,
                                                   scale_each=True)
                    fakeA_visual = vutils.make_grid(fake_leftA[:4, :, :, :],
                                                    nrow=1,
                                                    normalize=True,
                                                    scale_each=True)
                    recB_visual = vutils.make_grid(rec_leftB[:4, :, :, :],
                                                   nrow=1,
                                                   normalize=True,
                                                   scale_each=True)
                    rightB_visual = vutils.make_grid(rightB[:4, :, :, :],
                                                     nrow=1,
                                                     normalize=True,
                                                     scale_each=True)
                    fakeA_R_visual = vutils.make_grid(fake_rightA[:4, :, :, :],
                                                      nrow=1,
                                                      normalize=True,
                                                      scale_each=True)
                    recB_R_visual = vutils.make_grid(rec_rightB[:4, :, :, :],
                                                     nrow=1,
                                                     normalize=True,
                                                     scale_each=True)

                    writer.add_image('ABA_L/imgA', imgA_visual, i)
                    writer.add_image('ABA_L/fakeB', fakeB_visual, i)
                    writer.add_image('ABA_L/recA', recA_visual, i)
                    writer.add_image('ABA_R/imgA', rightA_visual, i)
                    writer.add_image('ABA_R/fakeB', fakeB_R_visual, i)
                    writer.add_image('ABA_R/recA', recA_R_visual, i)
                    writer.add_image('BAB_L/imgB', imgB_visual, i)
                    writer.add_image('BAB_L/fakeA', fakeA_visual, i)
                    writer.add_image('BAB_L/recB', recB_visual, i)
                    writer.add_image('BAB_R/imgB', rightB_visual, i)
                    writer.add_image('BAB_R/fakeA', fakeA_R_visual, i)
                    writer.add_image('BAB_R/recB', recB_R_visual, i)

                if args.lambda_warp_inv:
                    recA_warp_visual = vutils.make_grid(
                        rec_leftA_warp[0][:4, :, :, :],
                        nrow=1,
                        normalize=True,
                        scale_each=True)
                    fakeB_warp_visual = vutils.make_grid(
                        fake_leftB_warp[0][:4, :, :, :],
                        nrow=1,
                        normalize=True,
                        scale_each=True)
                    writer.add_image('warp/recA_L_warp', recA_warp_visual, i)
                    writer.add_image('warp/fakeB_L_warp', fakeB_warp_visual, i)
                if args.lambda_warp:
                    writer.add_image('warp/recA_R_warp', recA_warp_R_visual, i)
                    writer.add_image('warp/fakeB_R_warp', fakeB_warp_R_visual,
                                     i)
                    recA_warp_R_visual = vutils.make_grid(
                        rec_rightA_warp[0][:4, :, :, :],
                        nrow=1,
                        normalize=True,
                        scale_each=True)
                    fakeB_warp_R_visual = vutils.make_grid(
                        fake_rightB_warp[0][:4, :, :, :],
                        nrow=1,
                        normalize=True,
                        scale_each=True)

        with torch.no_grad():
            EPE, D1 = val(valloader, net, writer, epoch=epoch, board_save=True)

        t1 = time.time()
        print('epoch:{}, D1:{:.6f}, EPE:{:.6f}, cost time:{} '.format(
            epoch, D1, EPE, t1 - t))

        if (epoch % args.save_interval
                == 0) or D1 < best_val_d1 or EPE < best_val_epe:
            best_val_d1 = D1
            best_val_epe = EPE
            torch.save(
                {
                    'epoch': epoch,
                    'G_AB': G_AB.state_dict(),
                    'G_BA': G_BA.state_dict(),
                    'D_A': D_A.state_dict(),
                    'D_B': D_B.state_dict(),
                    'model': net.state_dict(),
                    'optimizer_DA_state_dict': optimizer_D_A.state_dict(),
                    'optimizer_DB_state_dict': optimizer_D_B.state_dict(),
                    'optimizer_G_state_dict': optimizer_G.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, args.checkpoint_save_path + '/ep' + str(epoch) +
                '_D1_{:.4f}_EPE{:.4f}'.format(D1, EPE) + '.pth.rar')