コード例 #1
0
        netG.zero_grad()
        label.data.fill_(real_label)
        output = netD(fake_AB)
        errGAN = criterion(output, label)
        errL1 = criterionL1(fake_B, real_B)
        errG = errGAN + opt.lamb * errL1

        errG.backward()

        optimizerG.step()

        ########### Logging ##########
        if (i % 50 == 0):
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_L1: %.4f' %
                  (epoch, opt.niter, i, len(train_loader), errD.data[0],
                   errGAN.data[0], errL1.data[0]))
    print('Time: %.4f' % (time.time() - nowtime))

    ########## Visualize #########
    if (epoch % 1 == 0):
        f_B = fake_B.cpu().data.numpy()
        for n, pic in enumerate(f_B[0]):
            misc.imsave('%s/%d_%d.png' % (opt.outf, epoch, n), pic)
    if (epoch % 10 == 0):
        print('save model:', epoch)
        torch.save(netG.state_dict(), '%s/netG_1d.pth' % (opt.outf))
        torch.save(netD.state_dict(), '%s/netD_1d.pth' % (opt.outf))

torch.save(netG.state_dict(), '%s/netG_1d.pth' % (opt.outf))
torch.save(netD.state_dict(), '%s/netD_1d.pth' % (opt.outf))
コード例 #2
0
            seq_l2_loss_e = seq_l2_loss_e + l2_loss.item()
            seq_l1_loss_e = seq_l1_loss_e + l1_loss.item()
            seq_preceptual_e = seq_preceptual_e + preceptual.item()

            print(
                "===> Epoch[{}]({}/{}): d_fake: {:.4f}, d_real: {:.4f}, g: {:.4f}, image_recon: {:.4f}"
                .format(epoch, i, len(trainloader), D_loss_fake.item(),
                        D_loss_real.item(), G_loss.item(),
                        img_recon_loss.item()))
            # break

        if epoch % opt.cp_freq == opt.cp_freq - 1:
            torch.save(model.state_dict(),
                       test_folder + '/model_epoch_' + str(epoch) + '.pth')
            torch.save(discriminator.state_dict(),
                       test_folder + '/D_epoch_' + str(epoch) + '.pth')

        ####testing
        model.eval()

        skip = len(testloader) // 9  # save images every skip iters

        with torch.no_grad():
            for i, batch in enumerate(testloader):

                if i == 0:
                    image, depth = batch['image'], batch['depth']

                    real_image, real_depth = image.cuda(), depth.cuda()
                    rand_image, rand_depth = image.cuda(), depth.cuda()
コード例 #3
0
ファイル: main.py プロジェクト: jungwon-choi/WGAN-pytorch
def main(args):
    #===========================================================================
    # Set the file name format
    FILE_NAME_FORMAT = "{0}_{1}_{2:d}_{3:d}_{4:d}_{5:f}{6}".format(
        args.model, args.dataset, args.epochs, args.obj_step, args.batch_size,
        args.lr, args.flag)

    # Set the results file path
    RESULT_FILE_NAME = FILE_NAME_FORMAT + '_results.pkl'
    RESULT_FILE_PATH = os.path.join(RESULTS_PATH, RESULT_FILE_NAME)
    # Set the checkpoint file path
    CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '.ckpt'
    CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, CHECKPOINT_FILE_NAME)
    BEST_CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '_best.ckpt'
    BEST_CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH,
                                             BEST_CHECKPOINT_FILE_NAME)

    # Set the random seed same for reproducibility
    random.seed(190811)
    torch.manual_seed(190811)
    torch.cuda.manual_seed_all(190811)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Step1 ====================================================================
    # Load dataset
    if args.dataset == 'CelebA':
        dataloader = CelebA_Dataloader()
    else:
        assert False, "Please select the proper dataset."

    train_loader = dataloader.get_train_loader(batch_size=args.batch_size,
                                               num_workers=args.num_workers)
    print('==> DataLoader ready.')

    # Step2 ====================================================================
    # Make the model
    if args.model in ['WGAN', 'DCGAN']:
        generator = Generator(BN=True)
        discriminator = Discriminator(BN=True)
    elif args.model in ['WGAN_noBN', 'DCGAN_noBN']:
        generator = Generator(BN=False)
        discriminator = Discriminator(BN=False)
    else:
        assert False, "Please select the proper model."

    # Check DataParallel available
    if torch.cuda.device_count() > 1:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

    # Check CUDA available
    if torch.cuda.is_available():
        generator.cuda()
        discriminator.cuda()
    print('==> Model ready.')

    # Step3 ====================================================================
    # Set loss function and optimizer
    if args.model in ['DCGAN', 'DCGAN_noBN']:
        criterion = nn.BCELoss()
    else:
        criterion = None
    optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=args.lr)
    optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=args.lr)
    step_counter = StepCounter(args.obj_step)
    print('==> Criterion and optimizer ready.')

    # Step4 ====================================================================
    # Train and validate the model
    start_epoch = 0
    best_metric = float("inf")
    validate_noise = torch.randn(args.batch_size, 100, 1, 1)

    # Initialize the result lists
    train_loss_G = []
    train_loss_D = []
    train_distance = []

    if args.resume:
        assert os.path.exists(CHECKPOINT_FILE_PATH), 'No checkpoint file!'
        checkpoint = torch.load(CHECKPOINT_FILE_PATH)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        start_epoch = checkpoint['epoch']
        step_counter.current_step = checkpoint['current_step']
        train_loss_G = checkpoint['train_loss_G']
        train_loss_D = checkpoint['train_loss_D']
        train_distance = checkpoint['train_distance']
        best_metric = checkpoint['best_metric']

    # Save the training information
    result_data = {}
    result_data['model'] = args.model
    result_data['dataset'] = args.dataset
    result_data['target_epoch'] = args.epochs
    result_data['batch_size'] = args.batch_size

    # Check the directory of the file path
    if not os.path.exists(os.path.dirname(RESULT_FILE_PATH)):
        os.makedirs(os.path.dirname(RESULT_FILE_PATH))
    if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH)):
        os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH))

    print('==> Train ready.')

    # Validate before training (step 0)
    val(generator, validate_noise, step_counter, FILE_NAME_FORMAT)

    for epoch in range(args.epochs):
        # strat after the checkpoint epoch
        if epoch < start_epoch:
            continue
        print("\n[Epoch: {:3d}/{:3d}]".format(epoch + 1, args.epochs))
        epoch_time = time.time()
        #=======================================================================
        # train the model (+ validate the model)
        tloss_G, tloss_D, tdist = train(generator, discriminator, train_loader,
                                        criterion, optimizer_G, optimizer_D,
                                        args.clipping, args.num_critic,
                                        step_counter, validate_noise,
                                        FILE_NAME_FORMAT)
        train_loss_G.extend(tloss_G)
        train_loss_D.extend(tloss_D)
        train_distance.extend(tdist)
        #=======================================================================
        current = time.time()

        # Calculate average loss
        avg_loss_G = sum(tloss_G) / len(tloss_G)
        avg_loss_D = sum(tloss_D) / len(tloss_D)
        avg_distance = sum(tdist) / len(tdist)

        # Save the current result
        result_data['current_epoch'] = epoch
        result_data['train_loss_G'] = train_loss_G
        result_data['train_loss_D'] = train_loss_D
        result_data['train_distance'] = train_distance

        # Save result_data as pkl file
        with open(RESULT_FILE_PATH, 'wb') as pkl_file:
            pickle.dump(result_data,
                        pkl_file,
                        protocol=pickle.HIGHEST_PROTOCOL)

        # Save the best checkpoint
        # if avg_distance < best_metric:
        #     best_metric = avg_distance
        #     torch.save({
        #         'epoch': epoch+1,
        #         'generator_state_dict': generator.state_dict(),
        #         'discriminator_state_dict': discriminator.state_dict(),
        #         'optimizer_G_state_dict': optimizer_G.state_dict(),
        #         'optimizer_D_state_dict': optimizer_D.state_dict(),
        #         'current_step': step_counter.current_step,
        #         'best_metric': best_metric,
        #         }, BEST_CHECKPOINT_FILE_PATH)

        # Save the current checkpoint
        torch.save(
            {
                'epoch': epoch + 1,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'current_step': step_counter.current_step,
                'train_loss_G': train_loss_G,
                'train_loss_D': train_loss_D,
                'train_distance': train_distance,
                'best_metric': best_metric,
            }, CHECKPOINT_FILE_PATH)

        # Print the information on the console
        print("model                : {}".format(args.model))
        print("dataset              : {}".format(args.dataset))
        print("batch_size           : {}".format(args.batch_size))
        print("current step         : {:d}".format(step_counter.current_step))
        print("current lrate        : {:f}".format(args.lr))
        print("gen/disc loss        : {:f}/{:f}".format(
            avg_loss_G, avg_loss_D))
        print("distance metric      : {:f}".format(avg_distance))
        print("epoch time           : {0:.3f} sec".format(current -
                                                          epoch_time))
        print("Current elapsed time : {0:.3f} sec".format(current - start))

        # If iteration step has been satisfied
        if step_counter.exit_signal:
            break

    print('==> Train done.')

    print(' '.join(['Results have been saved at', RESULT_FILE_PATH]))
    print(' '.join(['Checkpoints have been saved at', CHECKPOINT_FILE_PATH]))
コード例 #4
0
    # reconstruction loss
    l_rec_ABA = criterionMSE(ABA, real_A)
    l_rec_BAB = criterionMSE(BAB, real_B)

    errGAN = l_BA + l_AB
    errMSE = l_rec_ABA + l_rec_BAB
    errG = errGAN + errMSE
    errG.backward()

    optimizerG.step()

    ###########   Logging   ############
    if (iteration % opt.log_step):
        print('[%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_MSE: %.4f' %
              (iteration, opt.niter, errD.data[0], errGAN.data[0],
               errMSE.data[0]))
    ########## Visualize #########
    if (iteration % 1000 == 0):
        test(iteration)

    if iteration % opt.save_step == 0:
        torch.save(G_AB.state_dict(),
                   '{}/G_AB_{}.pth'.format(opt.outf, iteration))
        torch.save(G_BA.state_dict(),
                   '{}/G_BA_{}.pth'.format(opt.outf, iteration))
        torch.save(D_A.state_dict(),
                   '{}/D_A_{}.pth'.format(opt.outf, iteration))
        torch.save(D_B.state_dict(),
                   '{}/D_B_{}.pth'.format(opt.outf, iteration))
コード例 #5
0
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()

        errD = errD_fake + errD_real
        optimizerD.step()

        ########### fGx ###########
        netG.zero_grad()
        label.data.fill_(real_label)
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        ########### Logging #########
        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f ' %
              (epoch, opt.niter, i, len(loader), errD.data[0], errG.data[0]))

        ########## Visualize #########
        if (i % 50 == 0):
            vutils.save_image(fake.data,
                              '%s/fake_samples_epoch_%03d.png' %
                              (opt.outf, epoch),
                              normalize=True)

            torch.save(netG.state_dict(),
                       '%s/netG_epoch_%03d.pth' % (opt.outf, epoch))
            torch.save(netD.state_dict(),
                       '%s/netD_epoch_%03d.pth' % (opt.outf, epoch))
コード例 #6
0
def train(**kwargs):
    opt._parse(kwargs)

    id_file_dir = 'ImageSets/Main/trainval_big_64.txt'
    img_dir = 'JPEGImages'
    anno_dir = 'AnnotationsBig'
    large_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir)
    dataloader_large = data_.DataLoader(large_dataset, \
                                        batch_size=1, \
                                        shuffle=True, \
                                        # pin_memory=True,
                                        num_workers=opt.num_workers)

    id_file_dir = 'ImageSets/Main/trainval_pcgan_generated_small.txt'
    img_dir = 'JPEGImagesPCGANGenerated'
    anno_dir = 'AnnotationsPCGANGenerated'

    small_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir)
    dataloader_small = data_.DataLoader(small_dataset, \
                                        batch_size=1, \
                                        shuffle=True, \
                                        # pin_memory=True,
                                        num_workers=opt.num_workers)

    small_test_dataset = SmallImageTestDataset(opt)
    dataloader_small_test = data_.DataLoader(small_test_dataset, \
                                             batch_size=1, \
                                             shuffle=True, \
                                             pin_memory=True,
                                             num_workers=opt.test_num_workers)

    print('{:d} roidb large entries'.format(len(dataloader_large)))
    print('{:d} roidb small entries'.format(len(dataloader_small)))
    print('{:d} roidb small test entries'.format(len(dataloader_small_test)))

    faster_rcnn = FasterRCNNVGG16_GAN()
    faster_rcnn_ = FasterRCNNVGG16()

    print('model construct completed')
    trainer_ = FasterRCNNTrainer(faster_rcnn_).cuda()

    netD = Discriminator()
    netD.apply(weights_init)

    faster_rcnn_.cuda()
    netD.cuda()

    lr = opt.LEARNING_RATE
    params_D = []
    for key, value in dict(netD.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params_D += [{'params': [value], 'lr': lr * 2, \
                              'weight_decay': 0}]
            else:
                params_D += [{'params': [value], 'lr': lr, 'weight_decay': opt.weight_decay}]

    optimizerD = optim.SGD(params_D, momentum=0.9)
    # optimizerG = optim.Adam(faster_rcnn.parameters(), lr=lr, betas=(0.5, 0.999))

    if not opt.gan_load_path:
        trainer_.load(opt.load_path)
        print('load pretrained faster rcnn model from %s' % opt.load_path)

        # optimizer_ = trainer_.optimizer
        state_dict_ = faster_rcnn_.state_dict()
        state_dict = faster_rcnn.state_dict()

        # for k, i in state_dict_.items():
        #     icpu = i.cpu()
        #     b = icpu.data.numpy()
        #     sz = icpu.data.numpy().shape
        #     state_dict[k] = state_dict_[k]
        state_dict.update(state_dict_)
        faster_rcnn.load_state_dict(state_dict)
        faster_rcnn.cuda()

    trainer = FasterRCNNTrainer(faster_rcnn).cuda()

    if opt.gan_load_path:
        trainer.load(opt.gan_load_path, load_optimizer=True)
        print('load pretrained generator model from %s' % opt.gan_load_path)

    if opt.disc_load_path:
        state_dict_d = torch.load(opt.disc_load_path)
        netD.load_state_dict(state_dict_d['model'])
        optimizerD.load_state_dict(state_dict_d['optimizer'])
        print('load pretrained discriminator model from %s' % opt.disc_load_path)

    real_label = 1
    fake_label = 0

    # rpn_loc_loss = []
    # rpn_cls_loss = []
    # roi_loc_loss = []
    # roi_cls_loss = []
    # total_loss = []
    test_map_list = []

    criterion = nn.BCELoss()
    iters_per_epoch = min(len(dataloader_large), len(dataloader_small))
    best_map = 0
    device = torch.device("cuda:2" if (torch.cuda.is_available()) else "cpu")

    for epoch in range(1, opt.gan_epoch + 1):
        trainer.reset_meters()

        loss_temp_G = 0
        loss_temp_D = 0
        if epoch % (opt.lr_decay_step + 1) == 0:
            adjust_learning_rate(trainer.optimizer, opt.LEARNING_RATE_DECAY_GAMMA)
            adjust_learning_rate(optimizerD, opt.LEARNING_RATE_DECAY_GAMMA)
            lr *= opt.LEARNING_RATE_DECAY_GAMMA

        data_iter_large = iter(dataloader_large)
        data_iter_small = iter(dataloader_small)
        for step in tqdm(range(iters_per_epoch)):
            #####(1) Update Perceptual branch + generator(zero mapping)
            ####     Discriminator network: maximize log(D(x))+ log(1-D(G(z)))

            ##### Train with all_real batch
            ##### Format batch
            netD.zero_grad()
            data_large = next(data_iter_large)
            img, bbox_, label_, scale_ = data_large
            scale = at.scalar(scale_)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()

            ##### Forward pass real batch through D
            # faster_rcnn.zero_grad()
            # trainer.optimizer.zero_grad()
            # trainer.optimizer.zero_grad()

            losses, pooled_feat, rois_label, conv1_feat = trainer.train_step_gan(img, bbox, label, scale)

            # if step < 1:
            #     custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/large_orig_%s' % str(epoch))
            #     custom_viz(pooled_feat.cpu().detach(), 'results-gan/features/large_scaled_%s' % str(epoch))

            keep = rois_label != 0
            pooled_feat = pooled_feat[keep]

            real_b_size = pooled_feat.size(0)
            real_labels = torch.full((real_b_size,), real_label, device=device)

            output = netD(pooled_feat.detach()).view(-1)
            # print(output)

            ##### Calculate loss on all-real batch

            errD_real = criterion(output, real_labels)
            errD_real.backward()
            D_x = output.mean().item()

            ##### Train with all_fake batch
            # Generate batch of fake images with G
            data_small = next(data_iter_small)
            img, bbox_, label_, scale_ = data_small
            scale = at.scalar(scale_)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
            trainer.optimizer.zero_grad()

            losses, fake_pooled_feat, rois_label, conv1_feat = trainer.train_step_gan_second(img, bbox, label, scale)

            # if step < 1:
            #     custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/small_orig_%s' % str(epoch))
            #     custom_viz(fake_pooled_feat.cpu().detach(), 'results-gan/features/small_scaled_%s' % str(epoch))

            # select fg rois
            keep = rois_label != 0
            fake_pooled_feat = fake_pooled_feat[keep]
            # print(fake_pooled_feat)
            # print(torch.nonzero(torch.isnan(fake_pooled_feat.view(-1))))

            fake_b_size = fake_pooled_feat.size(0)
            fake_labels = torch.full((fake_b_size,), fake_label, device=device)

            # optimizerD.zero_grad()
            output = netD(fake_pooled_feat.detach()).view(-1)

            # calculate D's loss on the all_fake batch
            errD_fake = criterion(output, fake_labels)
            errD_fake.backward(retain_graph=True)
            D_G_Z1 = output.mean().item()
            # add the gradients from the all-real and all-fake batches
            errD = errD_fake + errD_real
            # Update D
            optimizerD.step()

            ################################################
            #####(2) Update G network: maximize log(D(G(z)))
            ################################################
            faster_rcnn.zero_grad()

            fake_labels.fill_(real_label)

            output = netD(fake_pooled_feat).view(-1)

            # calculate gradients for G
            errG = criterion(output, fake_labels)
            errG += losses.total_loss
            errG.backward()
            D_G_Z2 = output.mean().item()

            clip_gradient(faster_rcnn, 10.)

            trainer.optimizer.step()

            loss_temp_G += errG.item()
            loss_temp_D += errD.item()

            if step % opt.plot_every == 0:
                if step > 0:
                    loss_temp_G /= (opt.plot_every + 1)
                    loss_temp_D /= (opt.plot_every + 1)

                # losses_dict = trainer.get_meter_data()
                #
                # rpn_loc_loss.append(losses_dict['rpn_loc_loss'])
                # roi_loc_loss.append(losses_dict['roi_loc_loss'])
                # rpn_cls_loss.append(losses_dict['rpn_cls_loss'])
                # roi_cls_loss.append(losses_dict['roi_cls_loss'])
                # total_loss.append(losses_dict['total_loss'])
                #
                # save_losses('rpn_loc_loss', rpn_loc_loss, epoch)
                # save_losses('roi_loc_loss', roi_loc_loss, epoch)
                # save_losses('rpn_cls_loss', rpn_cls_loss, epoch)
                # save_losses('total_loss', total_loss, epoch)
                # save_losses('roi_cls_loss', roi_cls_loss, epoch)

                print("[epoch %2d] lossG: %.4f lossD: %.4f, lr: %.2e"
                      % (epoch, loss_temp_G, loss_temp_D, lr))
                print("\t\t\trcnn_cls: %.4f, rcnn_box %.4f"
                      % (losses.roi_cls_loss, losses.roi_loc_loss))

                print("\t\t\trpn_cls: %.4f, rpn_box %.4f"
                      % (losses.rpn_cls_loss, losses.rpn_loc_loss))

                print('\t\t\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (D_x, D_G_Z1, D_G_Z2))
                loss_temp_D = 0
                loss_temp_G = 0

        eval_result = eval(dataloader_small_test, faster_rcnn, test_num=opt.test_num)
        test_map_list.append(eval_result['map'])
        save_map(test_map_list, epoch)

        lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr']
        log_info = 'lr:{}, map:{}'.format(str(lr_),
                                                  str(eval_result['map']))
        print(log_info)

        if eval_result['map'] > best_map:
            best_map = eval_result['map']
            timestr = time.strftime('%m%d%H%M')
            trainer.save(best_map=best_map, save_path='checkpoints-pcgan-generated/gan_fasterrcnn_%s' % timestr)

            save_dict = dict()

            save_dict['model'] = netD.state_dict()

            save_dict['optimizer'] = optimizerD.state_dict()
            save_path = 'checkpoints-pcgan-generated/discriminator_%s' % timestr
            torch.save(save_dict, save_path)