コード例 #1
0
ngf = 64
input_nc = 3
output_nc = 3
fineSize = 128
batchSize = 1
cuda = 1

D_A = Discriminator(input_nc,ndf)
D_B = Discriminator(output_nc,ndf)
G_AB = Generator(input_nc, output_nc, ngf)
G_BA = Generator(output_nc, input_nc, ngf)

G_AB.apply(weights_init)
G_BA.apply(weights_init)

D_A.apply(weights_init)
D_B.apply(weights_init)

if(cuda):
    D_A.cuda()
    D_B.cuda()
    G_AB.cuda()
    G_BA.cuda()

###########   LOSS & OPTIMIZER   ##########
criterionMSE = nn.L1Loss()
criterion = nn.MSELoss()
# chain is used to update two generators simultaneously
optimizerD_A = torch.optim.Adam(D_A.parameters(),lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4)
optimizerD_B = torch.optim.Adam(D_B.parameters(),lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4)
optimizerG = torch.optim.Adam(chain(G_AB.parameters(),G_BA.parameters()),lr=0.0002, betas=(0.5, 0.999))
コード例 #2
0
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


ndf = opt.ndf
ngf = opt.ngf
nc = 3

netD = Discriminator(opt.input_nc, opt.output_nc, ndf)
netG = Generator(opt.input_nc, opt.output_nc, opt.ngf)
if (opt.cuda):
    netD.cuda()
    netG.cuda()

netG.apply(weights_init)
netD.apply(weights_init)
print(netD)
print(netG)

###########   LOSS & OPTIMIZER   ##########
criterion = nn.BCELoss()
criterionL1 = nn.L1Loss()
optimizerD = torch.optim.Adam(netD.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))

###########   GLOBAL VARIABLES   ###########
input_nc = opt.input_nc
コード例 #3
0
def train(**kwargs):
    opt._parse(kwargs)

    id_file_dir = 'ImageSets/Main/trainval_big_64.txt'
    img_dir = 'JPEGImages'
    anno_dir = 'AnnotationsBig'
    large_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir)
    dataloader_large = data_.DataLoader(large_dataset, \
                                        batch_size=1, \
                                        shuffle=True, \
                                        # pin_memory=True,
                                        num_workers=opt.num_workers)

    id_file_dir = 'ImageSets/Main/trainval_pcgan_generated_small.txt'
    img_dir = 'JPEGImagesPCGANGenerated'
    anno_dir = 'AnnotationsPCGANGenerated'

    small_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir)
    dataloader_small = data_.DataLoader(small_dataset, \
                                        batch_size=1, \
                                        shuffle=True, \
                                        # pin_memory=True,
                                        num_workers=opt.num_workers)

    small_test_dataset = SmallImageTestDataset(opt)
    dataloader_small_test = data_.DataLoader(small_test_dataset, \
                                             batch_size=1, \
                                             shuffle=True, \
                                             pin_memory=True,
                                             num_workers=opt.test_num_workers)

    print('{:d} roidb large entries'.format(len(dataloader_large)))
    print('{:d} roidb small entries'.format(len(dataloader_small)))
    print('{:d} roidb small test entries'.format(len(dataloader_small_test)))

    faster_rcnn = FasterRCNNVGG16_GAN()
    faster_rcnn_ = FasterRCNNVGG16()

    print('model construct completed')
    trainer_ = FasterRCNNTrainer(faster_rcnn_).cuda()

    netD = Discriminator()
    netD.apply(weights_init)

    faster_rcnn_.cuda()
    netD.cuda()

    lr = opt.LEARNING_RATE
    params_D = []
    for key, value in dict(netD.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params_D += [{'params': [value], 'lr': lr * 2, \
                              'weight_decay': 0}]
            else:
                params_D += [{'params': [value], 'lr': lr, 'weight_decay': opt.weight_decay}]

    optimizerD = optim.SGD(params_D, momentum=0.9)
    # optimizerG = optim.Adam(faster_rcnn.parameters(), lr=lr, betas=(0.5, 0.999))

    if not opt.gan_load_path:
        trainer_.load(opt.load_path)
        print('load pretrained faster rcnn model from %s' % opt.load_path)

        # optimizer_ = trainer_.optimizer
        state_dict_ = faster_rcnn_.state_dict()
        state_dict = faster_rcnn.state_dict()

        # for k, i in state_dict_.items():
        #     icpu = i.cpu()
        #     b = icpu.data.numpy()
        #     sz = icpu.data.numpy().shape
        #     state_dict[k] = state_dict_[k]
        state_dict.update(state_dict_)
        faster_rcnn.load_state_dict(state_dict)
        faster_rcnn.cuda()

    trainer = FasterRCNNTrainer(faster_rcnn).cuda()

    if opt.gan_load_path:
        trainer.load(opt.gan_load_path, load_optimizer=True)
        print('load pretrained generator model from %s' % opt.gan_load_path)

    if opt.disc_load_path:
        state_dict_d = torch.load(opt.disc_load_path)
        netD.load_state_dict(state_dict_d['model'])
        optimizerD.load_state_dict(state_dict_d['optimizer'])
        print('load pretrained discriminator model from %s' % opt.disc_load_path)

    real_label = 1
    fake_label = 0

    # rpn_loc_loss = []
    # rpn_cls_loss = []
    # roi_loc_loss = []
    # roi_cls_loss = []
    # total_loss = []
    test_map_list = []

    criterion = nn.BCELoss()
    iters_per_epoch = min(len(dataloader_large), len(dataloader_small))
    best_map = 0
    device = torch.device("cuda:2" if (torch.cuda.is_available()) else "cpu")

    for epoch in range(1, opt.gan_epoch + 1):
        trainer.reset_meters()

        loss_temp_G = 0
        loss_temp_D = 0
        if epoch % (opt.lr_decay_step + 1) == 0:
            adjust_learning_rate(trainer.optimizer, opt.LEARNING_RATE_DECAY_GAMMA)
            adjust_learning_rate(optimizerD, opt.LEARNING_RATE_DECAY_GAMMA)
            lr *= opt.LEARNING_RATE_DECAY_GAMMA

        data_iter_large = iter(dataloader_large)
        data_iter_small = iter(dataloader_small)
        for step in tqdm(range(iters_per_epoch)):
            #####(1) Update Perceptual branch + generator(zero mapping)
            ####     Discriminator network: maximize log(D(x))+ log(1-D(G(z)))

            ##### Train with all_real batch
            ##### Format batch
            netD.zero_grad()
            data_large = next(data_iter_large)
            img, bbox_, label_, scale_ = data_large
            scale = at.scalar(scale_)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()

            ##### Forward pass real batch through D
            # faster_rcnn.zero_grad()
            # trainer.optimizer.zero_grad()
            # trainer.optimizer.zero_grad()

            losses, pooled_feat, rois_label, conv1_feat = trainer.train_step_gan(img, bbox, label, scale)

            # if step < 1:
            #     custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/large_orig_%s' % str(epoch))
            #     custom_viz(pooled_feat.cpu().detach(), 'results-gan/features/large_scaled_%s' % str(epoch))

            keep = rois_label != 0
            pooled_feat = pooled_feat[keep]

            real_b_size = pooled_feat.size(0)
            real_labels = torch.full((real_b_size,), real_label, device=device)

            output = netD(pooled_feat.detach()).view(-1)
            # print(output)

            ##### Calculate loss on all-real batch

            errD_real = criterion(output, real_labels)
            errD_real.backward()
            D_x = output.mean().item()

            ##### Train with all_fake batch
            # Generate batch of fake images with G
            data_small = next(data_iter_small)
            img, bbox_, label_, scale_ = data_small
            scale = at.scalar(scale_)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
            trainer.optimizer.zero_grad()

            losses, fake_pooled_feat, rois_label, conv1_feat = trainer.train_step_gan_second(img, bbox, label, scale)

            # if step < 1:
            #     custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/small_orig_%s' % str(epoch))
            #     custom_viz(fake_pooled_feat.cpu().detach(), 'results-gan/features/small_scaled_%s' % str(epoch))

            # select fg rois
            keep = rois_label != 0
            fake_pooled_feat = fake_pooled_feat[keep]
            # print(fake_pooled_feat)
            # print(torch.nonzero(torch.isnan(fake_pooled_feat.view(-1))))

            fake_b_size = fake_pooled_feat.size(0)
            fake_labels = torch.full((fake_b_size,), fake_label, device=device)

            # optimizerD.zero_grad()
            output = netD(fake_pooled_feat.detach()).view(-1)

            # calculate D's loss on the all_fake batch
            errD_fake = criterion(output, fake_labels)
            errD_fake.backward(retain_graph=True)
            D_G_Z1 = output.mean().item()
            # add the gradients from the all-real and all-fake batches
            errD = errD_fake + errD_real
            # Update D
            optimizerD.step()

            ################################################
            #####(2) Update G network: maximize log(D(G(z)))
            ################################################
            faster_rcnn.zero_grad()

            fake_labels.fill_(real_label)

            output = netD(fake_pooled_feat).view(-1)

            # calculate gradients for G
            errG = criterion(output, fake_labels)
            errG += losses.total_loss
            errG.backward()
            D_G_Z2 = output.mean().item()

            clip_gradient(faster_rcnn, 10.)

            trainer.optimizer.step()

            loss_temp_G += errG.item()
            loss_temp_D += errD.item()

            if step % opt.plot_every == 0:
                if step > 0:
                    loss_temp_G /= (opt.plot_every + 1)
                    loss_temp_D /= (opt.plot_every + 1)

                # losses_dict = trainer.get_meter_data()
                #
                # rpn_loc_loss.append(losses_dict['rpn_loc_loss'])
                # roi_loc_loss.append(losses_dict['roi_loc_loss'])
                # rpn_cls_loss.append(losses_dict['rpn_cls_loss'])
                # roi_cls_loss.append(losses_dict['roi_cls_loss'])
                # total_loss.append(losses_dict['total_loss'])
                #
                # save_losses('rpn_loc_loss', rpn_loc_loss, epoch)
                # save_losses('roi_loc_loss', roi_loc_loss, epoch)
                # save_losses('rpn_cls_loss', rpn_cls_loss, epoch)
                # save_losses('total_loss', total_loss, epoch)
                # save_losses('roi_cls_loss', roi_cls_loss, epoch)

                print("[epoch %2d] lossG: %.4f lossD: %.4f, lr: %.2e"
                      % (epoch, loss_temp_G, loss_temp_D, lr))
                print("\t\t\trcnn_cls: %.4f, rcnn_box %.4f"
                      % (losses.roi_cls_loss, losses.roi_loc_loss))

                print("\t\t\trpn_cls: %.4f, rpn_box %.4f"
                      % (losses.rpn_cls_loss, losses.rpn_loc_loss))

                print('\t\t\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (D_x, D_G_Z1, D_G_Z2))
                loss_temp_D = 0
                loss_temp_G = 0

        eval_result = eval(dataloader_small_test, faster_rcnn, test_num=opt.test_num)
        test_map_list.append(eval_result['map'])
        save_map(test_map_list, epoch)

        lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr']
        log_info = 'lr:{}, map:{}'.format(str(lr_),
                                                  str(eval_result['map']))
        print(log_info)

        if eval_result['map'] > best_map:
            best_map = eval_result['map']
            timestr = time.strftime('%m%d%H%M')
            trainer.save(best_map=best_map, save_path='checkpoints-pcgan-generated/gan_fasterrcnn_%s' % timestr)

            save_dict = dict()

            save_dict['model'] = netD.state_dict()

            save_dict['optimizer'] = optimizerD.state_dict()
            save_path = 'checkpoints-pcgan-generated/discriminator_%s' % timestr
            torch.save(save_dict, save_path)