Example #1
0
class Trainer:
    def __init__(self, config, data_loader):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.num_epoch = config.num_epoch
        self.epoch = config.epoch
        self.image_size = config.image_size
        self.data_loader = data_loader
        self.checkpoint_dir = config.checkpoint_dir
        self.batch_size = config.batch_size
        self.sample_dir = config.sample_dir
        self.nf = config.nf
        self.scale_factor = config.scale_factor

        if config.is_perceptual_oriented:
            self.lr = config.p_lr
            self.content_loss_factor = config.p_content_loss_factor
            self.perceptual_loss_factor = config.p_perceptual_loss_factor
            self.adversarial_loss_factor = config.p_adversarial_loss_factor
            self.decay_batch_size = config.p_decay_batch_size
        else:
            self.lr = config.g_lr
            self.content_loss_factor = config.g_content_loss_factor
            self.perceptual_loss_factor = config.g_perceptual_loss_factor
            self.adversarial_loss_factor = config.g_adversarial_loss_factor
            self.decay_batch_size = config.g_decay_batch_size

        self.build_model()
        self.optimizer_generator = Adam(self.generator.parameters(), lr=self.lr, betas=(config.b1, config.b2),
                                        weight_decay=config.weight_decay)
        self.optimizer_discriminator = Adam(self.discriminator.parameters(), lr=self.lr, betas=(config.b1, config.b2),
                                            weight_decay=config.weight_decay)

        self.lr_scheduler_generator = torch.optim.lr_scheduler.StepLR(self.optimizer_generator, self.decay_batch_size)
        self.lr_scheduler_discriminator = torch.optim.lr_scheduler.StepLR(self.optimizer_discriminator, self.decay_batch_size)

    def train(self):
        total_step = len(self.data_loader)
        adversarial_criterion = nn.BCEWithLogitsLoss().to(self.device)
        content_criterion = nn.L1Loss().to(self.device)
        perception_criterion = PerceptualLoss().to(self.device)
        self.generator.train()
        self.discriminator.train()

        for epoch in range(self.epoch, self.num_epoch):
            if not os.path.exists(os.path.join(self.sample_dir, str(epoch))):
                os.makedirs(os.path.join(self.sample_dir, str(epoch)))

            for step, image in enumerate(self.data_loader):
                low_resolution = image['lr'].to(self.device)
                high_resolution = image['hr'].to(self.device)

                real_labels = torch.ones((high_resolution.size(0), 1)).to(self.device)
                fake_labels = torch.zeros((high_resolution.size(0), 1)).to(self.device)

                ##########################
                #   training generator   #
                ##########################
                self.optimizer_generator.zero_grad()
                fake_high_resolution = self.generator(low_resolution)

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution)
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(discriminator_rf, fake_labels)
                adversarial_loss_fr = adversarial_criterion(discriminator_fr, real_labels)
                adversarial_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2

                perceptual_loss = perception_criterion(high_resolution, fake_high_resolution)
                content_loss = content_criterion(fake_high_resolution, high_resolution)

                generator_loss = adversarial_loss * self.adversarial_loss_factor + \
                                 perceptual_loss * self.perceptual_loss_factor + \
                                 content_loss * self.content_loss_factor

                generator_loss.backward()
                self.optimizer_generator.step()

                ##########################
                # training discriminator #
                ##########################

                self.optimizer_discriminator.zero_grad()

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution.detach())
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(discriminator_rf, real_labels)
                adversarial_loss_fr = adversarial_criterion(discriminator_fr, fake_labels)
                discriminator_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2

                discriminator_loss.backward()
                self.optimizer_discriminator.step()

                self.lr_scheduler_generator.step()
                self.lr_scheduler_discriminator.step()
                if step % 1000 == 0:
                    print(f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] "
                          f"[D loss {discriminator_loss.item():.4f}] [G loss {generator_loss.item():.4f}] "
                          f"[adversarial loss {adversarial_loss.item() * self.adversarial_loss_factor:.4f}]"
                          f"[perceptual loss {perceptual_loss.item() * self.perceptual_loss_factor:.4f}]"
                          f"[content loss {content_loss.item() * self.content_loss_factor:.4f}]"
                          f"")
                    if step % 5000 == 0:
                        result = torch.cat((high_resolution, fake_high_resolution), 2)
                        save_image(result, os.path.join(self.sample_dir, str(epoch), f"SR_{step}.png"))

            torch.save(self.generator.state_dict(), os.path.join(self.checkpoint_dir, f"generator_{epoch}.pth"))
            torch.save(self.discriminator.state_dict(), os.path.join(self.checkpoint_dir, f"discriminator_{epoch}.pth"))

    def build_model(self):
        self.generator = ESRGAN(3, 3, 64, scale_factor=self.scale_factor).to(self.device)
        self.discriminator = Discriminator().to(self.device)
        self.load_model()

    def load_model(self):
        print(f"[*] Load model from {self.checkpoint_dir}")
        if not os.path.exists(self.checkpoint_dir):
            self.makedirs = os.makedirs(self.checkpoint_dir)

        if not os.listdir(self.checkpoint_dir):
            print(f"[!] No checkpoint in {self.checkpoint_dir}")
            return

        generator = glob(os.path.join(self.checkpoint_dir, f'generator_{self.epoch - 1}.pth'))
        discriminator = glob(os.path.join(self.checkpoint_dir, f'discriminator_{self.epoch - 1}.pth'))

        if not generator:
            print(f"[!] No checkpoint in epoch {self.epoch - 1}")
            return

        self.generator.load_state_dict(torch.load(generator[0]))
        self.discriminator.load_state_dict(torch.load(discriminator[0]))
Example #2
0
def main(args):
    #===========================================================================
    # Set the file name format
    FILE_NAME_FORMAT = "{0}_{1}_{2:d}_{3:d}_{4:d}_{5:f}{6}".format(
        args.model, args.dataset, args.epochs, args.obj_step, args.batch_size,
        args.lr, args.flag)

    # Set the results file path
    RESULT_FILE_NAME = FILE_NAME_FORMAT + '_results.pkl'
    RESULT_FILE_PATH = os.path.join(RESULTS_PATH, RESULT_FILE_NAME)
    # Set the checkpoint file path
    CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '.ckpt'
    CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, CHECKPOINT_FILE_NAME)
    BEST_CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '_best.ckpt'
    BEST_CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH,
                                             BEST_CHECKPOINT_FILE_NAME)

    # Set the random seed same for reproducibility
    random.seed(190811)
    torch.manual_seed(190811)
    torch.cuda.manual_seed_all(190811)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Step1 ====================================================================
    # Load dataset
    if args.dataset == 'CelebA':
        dataloader = CelebA_Dataloader()
    else:
        assert False, "Please select the proper dataset."

    train_loader = dataloader.get_train_loader(batch_size=args.batch_size,
                                               num_workers=args.num_workers)
    print('==> DataLoader ready.')

    # Step2 ====================================================================
    # Make the model
    if args.model in ['WGAN', 'DCGAN']:
        generator = Generator(BN=True)
        discriminator = Discriminator(BN=True)
    elif args.model in ['WGAN_noBN', 'DCGAN_noBN']:
        generator = Generator(BN=False)
        discriminator = Discriminator(BN=False)
    else:
        assert False, "Please select the proper model."

    # Check DataParallel available
    if torch.cuda.device_count() > 1:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

    # Check CUDA available
    if torch.cuda.is_available():
        generator.cuda()
        discriminator.cuda()
    print('==> Model ready.')

    # Step3 ====================================================================
    # Set loss function and optimizer
    if args.model in ['DCGAN', 'DCGAN_noBN']:
        criterion = nn.BCELoss()
    else:
        criterion = None
    optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=args.lr)
    optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=args.lr)
    step_counter = StepCounter(args.obj_step)
    print('==> Criterion and optimizer ready.')

    # Step4 ====================================================================
    # Train and validate the model
    start_epoch = 0
    best_metric = float("inf")
    validate_noise = torch.randn(args.batch_size, 100, 1, 1)

    # Initialize the result lists
    train_loss_G = []
    train_loss_D = []
    train_distance = []

    if args.resume:
        assert os.path.exists(CHECKPOINT_FILE_PATH), 'No checkpoint file!'
        checkpoint = torch.load(CHECKPOINT_FILE_PATH)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        start_epoch = checkpoint['epoch']
        step_counter.current_step = checkpoint['current_step']
        train_loss_G = checkpoint['train_loss_G']
        train_loss_D = checkpoint['train_loss_D']
        train_distance = checkpoint['train_distance']
        best_metric = checkpoint['best_metric']

    # Save the training information
    result_data = {}
    result_data['model'] = args.model
    result_data['dataset'] = args.dataset
    result_data['target_epoch'] = args.epochs
    result_data['batch_size'] = args.batch_size

    # Check the directory of the file path
    if not os.path.exists(os.path.dirname(RESULT_FILE_PATH)):
        os.makedirs(os.path.dirname(RESULT_FILE_PATH))
    if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH)):
        os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH))

    print('==> Train ready.')

    # Validate before training (step 0)
    val(generator, validate_noise, step_counter, FILE_NAME_FORMAT)

    for epoch in range(args.epochs):
        # strat after the checkpoint epoch
        if epoch < start_epoch:
            continue
        print("\n[Epoch: {:3d}/{:3d}]".format(epoch + 1, args.epochs))
        epoch_time = time.time()
        #=======================================================================
        # train the model (+ validate the model)
        tloss_G, tloss_D, tdist = train(generator, discriminator, train_loader,
                                        criterion, optimizer_G, optimizer_D,
                                        args.clipping, args.num_critic,
                                        step_counter, validate_noise,
                                        FILE_NAME_FORMAT)
        train_loss_G.extend(tloss_G)
        train_loss_D.extend(tloss_D)
        train_distance.extend(tdist)
        #=======================================================================
        current = time.time()

        # Calculate average loss
        avg_loss_G = sum(tloss_G) / len(tloss_G)
        avg_loss_D = sum(tloss_D) / len(tloss_D)
        avg_distance = sum(tdist) / len(tdist)

        # Save the current result
        result_data['current_epoch'] = epoch
        result_data['train_loss_G'] = train_loss_G
        result_data['train_loss_D'] = train_loss_D
        result_data['train_distance'] = train_distance

        # Save result_data as pkl file
        with open(RESULT_FILE_PATH, 'wb') as pkl_file:
            pickle.dump(result_data,
                        pkl_file,
                        protocol=pickle.HIGHEST_PROTOCOL)

        # Save the best checkpoint
        # if avg_distance < best_metric:
        #     best_metric = avg_distance
        #     torch.save({
        #         'epoch': epoch+1,
        #         'generator_state_dict': generator.state_dict(),
        #         'discriminator_state_dict': discriminator.state_dict(),
        #         'optimizer_G_state_dict': optimizer_G.state_dict(),
        #         'optimizer_D_state_dict': optimizer_D.state_dict(),
        #         'current_step': step_counter.current_step,
        #         'best_metric': best_metric,
        #         }, BEST_CHECKPOINT_FILE_PATH)

        # Save the current checkpoint
        torch.save(
            {
                'epoch': epoch + 1,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'current_step': step_counter.current_step,
                'train_loss_G': train_loss_G,
                'train_loss_D': train_loss_D,
                'train_distance': train_distance,
                'best_metric': best_metric,
            }, CHECKPOINT_FILE_PATH)

        # Print the information on the console
        print("model                : {}".format(args.model))
        print("dataset              : {}".format(args.dataset))
        print("batch_size           : {}".format(args.batch_size))
        print("current step         : {:d}".format(step_counter.current_step))
        print("current lrate        : {:f}".format(args.lr))
        print("gen/disc loss        : {:f}/{:f}".format(
            avg_loss_G, avg_loss_D))
        print("distance metric      : {:f}".format(avg_distance))
        print("epoch time           : {0:.3f} sec".format(current -
                                                          epoch_time))
        print("Current elapsed time : {0:.3f} sec".format(current - start))

        # If iteration step has been satisfied
        if step_counter.exit_signal:
            break

    print('==> Train done.')

    print(' '.join(['Results have been saved at', RESULT_FILE_PATH]))
    print(' '.join(['Checkpoints have been saved at', CHECKPOINT_FILE_PATH]))
def train(**kwargs):
    opt._parse(kwargs)

    id_file_dir = 'ImageSets/Main/trainval_big_64.txt'
    img_dir = 'JPEGImages'
    anno_dir = 'AnnotationsBig'
    large_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir)
    dataloader_large = data_.DataLoader(large_dataset, \
                                        batch_size=1, \
                                        shuffle=True, \
                                        # pin_memory=True,
                                        num_workers=opt.num_workers)

    id_file_dir = 'ImageSets/Main/trainval_pcgan_generated_small.txt'
    img_dir = 'JPEGImagesPCGANGenerated'
    anno_dir = 'AnnotationsPCGANGenerated'

    small_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir)
    dataloader_small = data_.DataLoader(small_dataset, \
                                        batch_size=1, \
                                        shuffle=True, \
                                        # pin_memory=True,
                                        num_workers=opt.num_workers)

    small_test_dataset = SmallImageTestDataset(opt)
    dataloader_small_test = data_.DataLoader(small_test_dataset, \
                                             batch_size=1, \
                                             shuffle=True, \
                                             pin_memory=True,
                                             num_workers=opt.test_num_workers)

    print('{:d} roidb large entries'.format(len(dataloader_large)))
    print('{:d} roidb small entries'.format(len(dataloader_small)))
    print('{:d} roidb small test entries'.format(len(dataloader_small_test)))

    faster_rcnn = FasterRCNNVGG16_GAN()
    faster_rcnn_ = FasterRCNNVGG16()

    print('model construct completed')
    trainer_ = FasterRCNNTrainer(faster_rcnn_).cuda()

    netD = Discriminator()
    netD.apply(weights_init)

    faster_rcnn_.cuda()
    netD.cuda()

    lr = opt.LEARNING_RATE
    params_D = []
    for key, value in dict(netD.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params_D += [{'params': [value], 'lr': lr * 2, \
                              'weight_decay': 0}]
            else:
                params_D += [{'params': [value], 'lr': lr, 'weight_decay': opt.weight_decay}]

    optimizerD = optim.SGD(params_D, momentum=0.9)
    # optimizerG = optim.Adam(faster_rcnn.parameters(), lr=lr, betas=(0.5, 0.999))

    if not opt.gan_load_path:
        trainer_.load(opt.load_path)
        print('load pretrained faster rcnn model from %s' % opt.load_path)

        # optimizer_ = trainer_.optimizer
        state_dict_ = faster_rcnn_.state_dict()
        state_dict = faster_rcnn.state_dict()

        # for k, i in state_dict_.items():
        #     icpu = i.cpu()
        #     b = icpu.data.numpy()
        #     sz = icpu.data.numpy().shape
        #     state_dict[k] = state_dict_[k]
        state_dict.update(state_dict_)
        faster_rcnn.load_state_dict(state_dict)
        faster_rcnn.cuda()

    trainer = FasterRCNNTrainer(faster_rcnn).cuda()

    if opt.gan_load_path:
        trainer.load(opt.gan_load_path, load_optimizer=True)
        print('load pretrained generator model from %s' % opt.gan_load_path)

    if opt.disc_load_path:
        state_dict_d = torch.load(opt.disc_load_path)
        netD.load_state_dict(state_dict_d['model'])
        optimizerD.load_state_dict(state_dict_d['optimizer'])
        print('load pretrained discriminator model from %s' % opt.disc_load_path)

    real_label = 1
    fake_label = 0

    # rpn_loc_loss = []
    # rpn_cls_loss = []
    # roi_loc_loss = []
    # roi_cls_loss = []
    # total_loss = []
    test_map_list = []

    criterion = nn.BCELoss()
    iters_per_epoch = min(len(dataloader_large), len(dataloader_small))
    best_map = 0
    device = torch.device("cuda:2" if (torch.cuda.is_available()) else "cpu")

    for epoch in range(1, opt.gan_epoch + 1):
        trainer.reset_meters()

        loss_temp_G = 0
        loss_temp_D = 0
        if epoch % (opt.lr_decay_step + 1) == 0:
            adjust_learning_rate(trainer.optimizer, opt.LEARNING_RATE_DECAY_GAMMA)
            adjust_learning_rate(optimizerD, opt.LEARNING_RATE_DECAY_GAMMA)
            lr *= opt.LEARNING_RATE_DECAY_GAMMA

        data_iter_large = iter(dataloader_large)
        data_iter_small = iter(dataloader_small)
        for step in tqdm(range(iters_per_epoch)):
            #####(1) Update Perceptual branch + generator(zero mapping)
            ####     Discriminator network: maximize log(D(x))+ log(1-D(G(z)))

            ##### Train with all_real batch
            ##### Format batch
            netD.zero_grad()
            data_large = next(data_iter_large)
            img, bbox_, label_, scale_ = data_large
            scale = at.scalar(scale_)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()

            ##### Forward pass real batch through D
            # faster_rcnn.zero_grad()
            # trainer.optimizer.zero_grad()
            # trainer.optimizer.zero_grad()

            losses, pooled_feat, rois_label, conv1_feat = trainer.train_step_gan(img, bbox, label, scale)

            # if step < 1:
            #     custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/large_orig_%s' % str(epoch))
            #     custom_viz(pooled_feat.cpu().detach(), 'results-gan/features/large_scaled_%s' % str(epoch))

            keep = rois_label != 0
            pooled_feat = pooled_feat[keep]

            real_b_size = pooled_feat.size(0)
            real_labels = torch.full((real_b_size,), real_label, device=device)

            output = netD(pooled_feat.detach()).view(-1)
            # print(output)

            ##### Calculate loss on all-real batch

            errD_real = criterion(output, real_labels)
            errD_real.backward()
            D_x = output.mean().item()

            ##### Train with all_fake batch
            # Generate batch of fake images with G
            data_small = next(data_iter_small)
            img, bbox_, label_, scale_ = data_small
            scale = at.scalar(scale_)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
            trainer.optimizer.zero_grad()

            losses, fake_pooled_feat, rois_label, conv1_feat = trainer.train_step_gan_second(img, bbox, label, scale)

            # if step < 1:
            #     custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/small_orig_%s' % str(epoch))
            #     custom_viz(fake_pooled_feat.cpu().detach(), 'results-gan/features/small_scaled_%s' % str(epoch))

            # select fg rois
            keep = rois_label != 0
            fake_pooled_feat = fake_pooled_feat[keep]
            # print(fake_pooled_feat)
            # print(torch.nonzero(torch.isnan(fake_pooled_feat.view(-1))))

            fake_b_size = fake_pooled_feat.size(0)
            fake_labels = torch.full((fake_b_size,), fake_label, device=device)

            # optimizerD.zero_grad()
            output = netD(fake_pooled_feat.detach()).view(-1)

            # calculate D's loss on the all_fake batch
            errD_fake = criterion(output, fake_labels)
            errD_fake.backward(retain_graph=True)
            D_G_Z1 = output.mean().item()
            # add the gradients from the all-real and all-fake batches
            errD = errD_fake + errD_real
            # Update D
            optimizerD.step()

            ################################################
            #####(2) Update G network: maximize log(D(G(z)))
            ################################################
            faster_rcnn.zero_grad()

            fake_labels.fill_(real_label)

            output = netD(fake_pooled_feat).view(-1)

            # calculate gradients for G
            errG = criterion(output, fake_labels)
            errG += losses.total_loss
            errG.backward()
            D_G_Z2 = output.mean().item()

            clip_gradient(faster_rcnn, 10.)

            trainer.optimizer.step()

            loss_temp_G += errG.item()
            loss_temp_D += errD.item()

            if step % opt.plot_every == 0:
                if step > 0:
                    loss_temp_G /= (opt.plot_every + 1)
                    loss_temp_D /= (opt.plot_every + 1)

                # losses_dict = trainer.get_meter_data()
                #
                # rpn_loc_loss.append(losses_dict['rpn_loc_loss'])
                # roi_loc_loss.append(losses_dict['roi_loc_loss'])
                # rpn_cls_loss.append(losses_dict['rpn_cls_loss'])
                # roi_cls_loss.append(losses_dict['roi_cls_loss'])
                # total_loss.append(losses_dict['total_loss'])
                #
                # save_losses('rpn_loc_loss', rpn_loc_loss, epoch)
                # save_losses('roi_loc_loss', roi_loc_loss, epoch)
                # save_losses('rpn_cls_loss', rpn_cls_loss, epoch)
                # save_losses('total_loss', total_loss, epoch)
                # save_losses('roi_cls_loss', roi_cls_loss, epoch)

                print("[epoch %2d] lossG: %.4f lossD: %.4f, lr: %.2e"
                      % (epoch, loss_temp_G, loss_temp_D, lr))
                print("\t\t\trcnn_cls: %.4f, rcnn_box %.4f"
                      % (losses.roi_cls_loss, losses.roi_loc_loss))

                print("\t\t\trpn_cls: %.4f, rpn_box %.4f"
                      % (losses.rpn_cls_loss, losses.rpn_loc_loss))

                print('\t\t\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (D_x, D_G_Z1, D_G_Z2))
                loss_temp_D = 0
                loss_temp_G = 0

        eval_result = eval(dataloader_small_test, faster_rcnn, test_num=opt.test_num)
        test_map_list.append(eval_result['map'])
        save_map(test_map_list, epoch)

        lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr']
        log_info = 'lr:{}, map:{}'.format(str(lr_),
                                                  str(eval_result['map']))
        print(log_info)

        if eval_result['map'] > best_map:
            best_map = eval_result['map']
            timestr = time.strftime('%m%d%H%M')
            trainer.save(best_map=best_map, save_path='checkpoints-pcgan-generated/gan_fasterrcnn_%s' % timestr)

            save_dict = dict()

            save_dict['model'] = netD.state_dict()

            save_dict['optimizer'] = optimizerD.state_dict()
            save_path = 'checkpoints-pcgan-generated/discriminator_%s' % timestr
            torch.save(save_dict, save_path)