Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(
        description='fine tune GDConvNet on texture databases')
    parser.add_argument('--texture', type=str, help='Path of the dataset.')
    parser.add_argument('--out_dir', type=str, help='Name of sequence.')
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--lr', type=float)

    args = parser.parse_args()

    learning_rate = args.lr
    num_epochs = args.epochs

    args.out_dir = args.out_dir + '/finetune_{}'.format(args.texture)
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)
    result_dir = args.out_dir + '/result'
    ckpt_dir = args.out_dir + '/checkpoint'

    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    # Choose Gpu device
    device_ids = device_id
    device = torch.device(
        "cuda:{}".format(device_id[0]) if torch.cuda.is_available() else "cpu")

    # Build model
    net = Net(nf=144, growth_rate=2, mode=mode)

    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    # multi-GPU
    net = net.to(device)
    net = nn.DataParallel(net, device_ids=device_ids[:1])

    # calculate all trainable parameters in network
    pytorch_total_params = sum(p.numel() for p in net.parameters()
                               if p.requires_grad)
    print("Total_params: {}".format(pytorch_total_params))

    dyntex_dir = '/mnt/storage/home/mt20523/scratch/DynTex'
    syntex_dir = '/mnt/storage/home/mt20523/scratch/SynTex'
    bvitexture_dir = '/mnt/storage/home/mt20523/scratch/BVI-Texture'
    homtex_dir = '/mnt/storage/home/mt20523/scratch/HomTex'
    dataset_dyntex = DBreader_DynTex(dyntex_dir,
                                     args.texture,
                                     random_crop=(256, 256))
    dataset_syntex = DBreader_SynTex(syntex_dir,
                                     args.texture,
                                     random_crop=(256, 256))
    dataset_bvitexture = DBreader_BVItexture(bvitexture_dir,
                                             args.texture,
                                             random_crop=(256, 256))
    sampler = Sampler([dataset_dyntex, dataset_syntex, dataset_bvitexture])

    train_loader = DataLoader(dataset=sampler,
                              batch_size=args.batch_size,
                              shuffle=True)
    test_loader = HomTex(homtex_dir, texture='mixed')

    print(len(train_loader))

    # Load Network weight
    net.load_state_dict(torch.load(model_save_path + 'net_best_weight'),
                        strict=True)
    print("Best weight has been loaded")

    cb_loss = L1_Charbonnier_loss()
    cb_loss = cb_loss.to(device)

    # start training
    for epoch in range(0, num_epochs):
        start_time = time.time()
        adjust_learning_rate(optimizer, epoch)

        # epoch training start
        for batch_id, train_data in enumerate(train_loader):
            # initialize network and optimizer parameter gradients
            net.train()
            net.zero_grad()
            optimizer.zero_grad()

            img1, img2, img4, img5, gt = train_data
            img1 = img1.to(device)
            img2 = img2.to(device)
            img4 = img4.to(device)
            img5 = img5.to(device)
            gt = gt.to(device)

            # forward + backward + optimize
            oup, mid_oup = net(img1, img2, img4, img5)

            oup_loss = cb_loss(oup, gt)
            mid_oup_loss = cb_loss(mid_oup, gt)

            # perceptual_loss = loss_network(oup, gt)
            loss = oup_loss + delta * mid_oup_loss
            loss.backward()
            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #     scaled_loss.backward()
            optimizer.step()

            # print out
            if not (batch_id % 10):
                if batch_id == 0 and epoch != 0:
                    continue
                print('Epoch:{0}, Iteration:{1}'.format(epoch, batch_id))

        # Average PSNR on one epoch train_data
        train_one_epoch_time = time.time() - start_time
        print("Training one epoch costs {}s".format(train_one_epoch_time))

        # use evaluation model during the net evaluating
        if (epoch + 1) % 50 == 0:
            torch.save({
                'epoch': epoch + 1,
                'state_dict': net.state_dict()
            }, ckpt_dir + '/model_epoch' + str(epoch + 1).zfill(3) + '.pth')
            net.eval()
            test_loader.Test(net, epoch + 1, result_dir)
Ejemplo n.º 2
0
enc_t.train()
dec_s.train()
dec_t.train()
dis_s2t.train()
dis_t2s.train()

best_iou = 0
best_iter = 0
for i_iter in range(num_steps):
    print(i_iter)
    sys.stdout.flush()

    enc_shared.train()
    adjust_learning_rate(seg_opt_list,
                         base_lr=learning_rate_seg,
                         i_iter=i_iter,
                         max_iter=num_steps,
                         power=power)
    adjust_learning_rate(dclf_opt_list,
                         base_lr=learning_rate_d,
                         i_iter=i_iter,
                         max_iter=num_steps,
                         power=power)
    adjust_learning_rate(rec_opt_list,
                         base_lr=learning_rate_rec,
                         i_iter=i_iter,
                         max_iter=num_steps,
                         power=power)
    adjust_learning_rate(dis_opt_list,
                         base_lr=learning_rate_dis,
                         i_iter=i_iter,
Ejemplo n.º 3
0
def main():
    # Choose Gpu device
    device_ids = device_id
    device = torch.device(
        "cuda:{}".format(device_id[0]) if torch.cuda.is_available() else "cpu")

    # Build model
    net = Net(nf=144, growth_rate=2, mode=mode)

    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    # multi-GPU
    net = net.to(device)
    net = nn.DataParallel(net, device_ids=device_ids)

    # calculate all trainable parameters in network
    pytorch_total_params = sum(p.numel() for p in net.parameters()
                               if p.requires_grad)
    print("Total_params: {}".format(pytorch_total_params))

    train_data_loader = DataLoader(TrainData(),
                                   batch_size=train_batch_size,
                                   shuffle=True,
                                   num_workers=12)
    val_data_loader_full = DataLoader(ValData(),
                                      batch_size=val_batch_size,
                                      shuffle=False,
                                      num_workers=12)
    val_data_loader_mini = DataLoader(ValData("mini"),
                                      batch_size=val_batch_size,
                                      shuffle=False,
                                      num_workers=12)

    print(len(train_data_loader), len(val_data_loader_full))

    # Load Network weight
    try:
        net.load_state_dict(torch.load(model_save_path + 'net_best_weight'),
                            strict=True)
        print("Best weight has been loaded")
    except:
        print('loading best weight failed')
    # old validation PSNR
    # start_time = time.time()
    pre_val_psnr, pre_val_ssim = validation(net, val_data_loader_mini, device,
                                            False)
    print('old_val_psnr:{0:.2f}, old_val_ssim:{1:.4f}'.format(
        pre_val_psnr, pre_val_ssim))
    # end_time = time.time()
    # print(end_time - start_time)
    pre_val_psnr, pre_val_ssim = 0, 0

    # load the latest model
    initial_epoch = findLastCheckpoint(save_dir=model_save_path)
    if initial_epoch > 0:
        net.load_state_dict(torch.load(model_save_path +
                                       "net_epoch_{}".format(initial_epoch)),
                            strict=False)
        print("resuming by loading epoch {}".format(initial_epoch))

    cb_loss = L1_Charbonnier_loss()
    cb_loss = cb_loss.to(device)

    # start training
    for epoch in range(initial_epoch, num_epochs):
        start_time = time.time()
        adjust_learning_rate(optimizer, epoch)
        current_psnr_list = []
        psnr_list = []

        # epoch training start
        for batch_id, train_data in enumerate(train_data_loader):
            # initialize network and optimizer parameter gradients
            net.train()
            net.zero_grad()
            optimizer.zero_grad()

            img1, img2, img4, img5, gt = train_data
            img1 = img1.to(device)
            img2 = img2.to(device)
            img4 = img4.to(device)
            img5 = img5.to(device)
            gt = gt.to(device)

            # forward + backward + optimize
            oup, mid_oup = net(img1, img2, img4, img5)

            oup_loss = cb_loss(oup, gt)
            mid_oup_loss = cb_loss(mid_oup, gt)

            # perceptual_loss = loss_network(oup, gt)
            loss = oup_loss + delta * mid_oup_loss

            current_psnr_list.extend(to_psnr(oup, gt))
            psnr_list.extend(current_psnr_list[-train_batch_size:])

            loss.backward()
            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #     scaled_loss.backward()
            optimizer.step()

            # print out
            if not (batch_id % 100):
                if batch_id == 0 and epoch != 0:
                    continue
                print(
                    'Epoch:{0}, Iteration:{1}, central_psnr:{2:.2f}, oup_loss:{3:.4f}, mid_oup_loss:{4:.4f}'
                    .format(epoch, batch_id,
                            sum(current_psnr_list) / len(current_psnr_list),
                            oup_loss, mid_oup_loss))
                current_psnr_list = []

        # Average PSNR on one epoch train_data
        train_psnr = sum(psnr_list) / len(psnr_list)
        train_one_epoch_time = time.time() - start_time
        print("Training one epoch costs {}s".format(train_one_epoch_time))

        # use evaluation model during the net evaluating
        if (epoch + 1) % save_freq == 0:
            start_time = time.time()
            oup_pnsr, oup_ssim = validation(net, val_data_loader_full, device)
            val_time = time.time() - start_time
            torch.save(net.state_dict(),
                       model_save_path + "net_epoch_{}".format(epoch + 1))
            print_log(epoch + 1, num_epochs,
                      train_one_epoch_time * save_freq + val_time, train_psnr,
                      oup_pnsr, oup_ssim, 'full')

            if oup_pnsr >= pre_val_psnr:
                torch.save(net.state_dict(),
                           model_save_path + 'net_best_weight')
                pre_val_psnr = oup_pnsr
        else:
            print_log(epoch + 1, num_epochs, train_one_epoch_time, train_psnr,
                      0, 0, 'full')
            torch.save(net.state_dict(),
                       model_save_path + "net_epoch_{}".format(epoch + 1))