netB.cuda()
netB.eval()
for param in netB.parameters():  # freeze netD
    param.requires_grad = False

netG = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                         output_nc=4,
                         n_blocks1=args.n_blocks1,
                         n_blocks2=args.n_blocks2)
netG.apply(conv_init)
netG = nn.DataParallel(netG)
netG.cuda()
torch.backends.cudnn.benchmark = True

netD = MultiscaleDiscriminator(input_nc=3,
                               num_D=1,
                               norm_layer=nn.InstanceNorm2d,
                               ndf=64)
netD.apply(conv_init)
netD = nn.DataParallel(netD)
netD.cuda()

# Loss
l1_loss = alpha_loss()
c_loss = compose_loss()
g_loss = alpha_gradient_loss()
GAN_loss = GANloss()

optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
optimizerD = optim.Adam(netD.parameters(), lr=1e-5)

log_writer = SummaryWriter(tb_dir)
예제 #2
0
    if (args.networkG_type == "global"):
        model_G = GlobalGenerator(input_nc=3, output_nc=3,
                                  ngf=args.n_fmaps).to(device)
    elif (args.networkG_type == "global"):
        model_G = Pix2PixUNetGenerator(
            n_in_channels=3,
            n_out_channels=3,
            n_fmaps=args.n_fmaps,
        ).to(device)
    else:
        raise NotImplementedError('networkG_type %s not implemented' %
                                  args.networkG_type)

    # Discriminator
    model_D = MultiscaleDiscriminator(n_in_channels=3,
                                      n_fmaps=args.n_fmaps,
                                      n_dis=3).to(device)

    if (args.debug):
        print("model_G :\n", model_G)
        print("model_D :\n", model_D)

    # モデルを読み込む
    if not args.load_checkpoints_dir == '' and os.path.exists(
            args.load_checkpoints_dir):
        load_checkpoint(
            model_G, device,
            os.path.join(args.load_checkpoints_dir, "G", "G_final.pth"))
        load_checkpoint(
            model_D, device,
            os.path.join(args.load_checkpoints_dir, "D", "D_final.pth"))
예제 #3
0
    def __init__(self, device=None, jit=True):
        self.device = device
        self.jit = jit
        self.opt = Namespace(
            **{
                'n_blocks1': 7,
                'n_blocks2': 3,
                'batch_size': 1,
                'resolution': 512,
                'name': 'Real_fixed'
            })

        scriptdir = os.path.dirname(os.path.realpath(__file__))
        csv_file = "Video_data_train_processed.csv"
        with open("Video_data_train.csv", "r") as r:
            with open(csv_file, "w") as w:
                w.write(r.read().format(scriptdir=scriptdir))
        data_config_train = {
            'reso': (self.opt.resolution, self.opt.resolution)
        }
        traindata = VideoData(csv_file=csv_file,
                              data_config=data_config_train,
                              transform=None)
        self.train_loader = torch.utils.data.DataLoader(
            traindata,
            batch_size=self.opt.batch_size,
            shuffle=True,
            num_workers=self.opt.batch_size,
            collate_fn=_collate_filter_none)

        netB = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                                 output_nc=4,
                                 n_blocks1=self.opt.n_blocks1,
                                 n_blocks2=self.opt.n_blocks2)
        if self.device == 'cuda':
            netB.cuda()
        netB.eval()
        for param in netB.parameters():  # freeze netB
            param.requires_grad = False
        self.netB = netB

        netG = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                                 output_nc=4,
                                 n_blocks1=self.opt.n_blocks1,
                                 n_blocks2=self.opt.n_blocks2)
        netG.apply(conv_init)
        self.netG = netG

        if self.device == 'cuda':
            self.netG.cuda()
            # TODO(asuhan): is this needed?
            torch.backends.cudnn.benchmark = True

        netD = MultiscaleDiscriminator(input_nc=3,
                                       num_D=1,
                                       norm_layer=nn.InstanceNorm2d,
                                       ndf=64)
        netD.apply(conv_init)
        netD = nn.DataParallel(netD)
        self.netD = netD
        if self.device == 'cuda':
            self.netD.cuda()

        self.l1_loss = alpha_loss()
        self.c_loss = compose_loss()
        self.g_loss = alpha_gradient_loss()
        self.GAN_loss = GANloss()

        self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
        self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5)

        self.log_writer = SummaryWriter(scriptdir)
        self.model_dir = scriptdir

        self._maybe_trace()
예제 #4
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters["lr"]
        self.newsize = hyperparameters["crop_image_height"]
        self.semantic_w = hyperparameters["semantic_w"] > 0
        self.recon_mask = hyperparameters["recon_mask"] == 1
        self.dann_scheduler = None
        self.full_adaptation = hyperparameters["adaptation"][
            "full_adaptation"] == 1
        dim = hyperparameters["gen"]["dim"]
        n_downsample = hyperparameters["gen"]["n_downsample"]
        latent_dim = dim * (2**n_downsample)

        if "domain_adv_w" in hyperparameters.keys():
            self.domain_classif_ab = hyperparameters["domain_adv_w"] > 0
        else:
            self.domain_classif_ab = False

        if hyperparameters["adaptation"]["dfeat_lambda"] > 0:
            self.use_classifier_sr = True
        else:
            self.use_classifier_sr = False

        if hyperparameters["adaptation"]["sem_seg_lambda"] > 0:
            self.train_seg = True
        else:
            self.train_seg = False

        if hyperparameters["adaptation"]["output_classifier_lambda"] > 0:
            self.use_output_classifier_sr = True
        else:
            self.use_output_classifier_sr = False

        self.gen = SpadeGen(hyperparameters["input_dim_a"],
                            hyperparameters["gen"])

        # Note: the "+1" is for the masks
        if hyperparameters["dis"]["type"] == "patchgan":
            print("Using patchgan discrminator...")
            self.dis_a = MultiscaleDiscriminator(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b = MultiscaleDiscriminator(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)

            self.dis_a_masked = MultiscaleDiscriminator(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b_masked = MultiscaleDiscriminator(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        else:
            self.dis_a = MsImageDis(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b = MsImageDis(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)
            self.dis_a_masked = MsImageDis(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b_masked = MsImageDis(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # fix the noise usd in sampling
        display_size = int(hyperparameters["display_size"])
        # Setup the optimizers
        beta1 = hyperparameters["beta1"]
        beta2 = hyperparameters["beta2"]
        dis_params = (list(self.dis_a.parameters()) +
                      list(self.dis_b.parameters()) +
                      list(self.dis_a_masked.parameters()) +
                      list(self.dis_b_masked.parameters()))

        gen_params = list(self.gen.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters["init"]))
        self.dis_a.apply(weights_init("gaussian"))
        self.dis_b.apply(weights_init("gaussian"))
        self.dis_a_masked.apply(weights_init("gaussian"))
        self.dis_b_masked.apply(weights_init("gaussian"))

        # Load VGG model if needed
        if hyperparameters["vgg_w"] > 0:
            self.criterionVGG = VGGLoss()

        # Load semantic segmentation model if needed
        if "semantic_w" in hyperparameters.keys(
        ) and hyperparameters["semantic_w"] > 0:
            self.segmentation_model = load_segmentation_model(
                hyperparameters["semantic_ckpt_path"], 19)
            self.segmentation_model.eval()
            for param in self.segmentation_model.parameters():
                param.requires_grad = False

        # Load domain classifier if needed
        if "domain_adv_w" in hyperparameters.keys(
        ) and hyperparameters["domain_adv_w"] > 0:
            self.domain_classifier_ab = domainClassifier(input_dim=latent_dim,
                                                         dim=256)
            dann_params = list(self.domain_classifier_ab.parameters())
            self.dann_opt = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.domain_classifier_ab.apply(weights_init("gaussian"))
            self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters)

        # Load classifier on features for syn, real adaptation
        if self.use_classifier_sr:
            #! Hardcoded
            self.domain_classifier_sr_b = domainClassifier(
                input_dim=latent_dim, dim=256)
            self.domain_classifier_sr_a = domainClassifier(
                input_dim=latent_dim, dim=256)

            dann_params = list(
                self.domain_classifier_sr_a.parameters()) + list(
                    self.domain_classifier_sr_b.parameters())
            self.classif_opt_sr = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.domain_classifier_sr_a.apply(weights_init("gaussian"))
            self.domain_classifier_sr_b.apply(weights_init("gaussian"))
            self.classif_sr_scheduler = get_scheduler(self.classif_opt_sr,
                                                      hyperparameters)

        if self.use_output_classifier_sr:
            if self.hyperparameters["dis"]["type"] == "patchgan":
                self.output_classifier_sr_a = MultiscaleDiscriminator(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain a,sr
                self.output_classifier_sr_b = MultiscaleDiscriminator(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain b,sr

            else:
                self.output_classifier_sr_a = MsImageDis(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain a,sr
                self.output_classifier_sr_b = MsImageDis(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain b,sr

            dann_params = list(
                self.output_classifier_sr_a.parameters()) + list(
                    self.output_classifier_sr_b.parameters())
            self.output_classif_opt_sr = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.output_classifier_sr_b.apply(weights_init("gaussian"))
            self.output_classifier_sr_a.apply(weights_init("gaussian"))
            self.output_scheduler_sr = get_scheduler(
                self.output_classif_opt_sr, hyperparameters)

        if self.train_seg:
            pretrained = load_segmentation_model(
                hyperparameters["semantic_ckpt_path"], 19)
            last_layer = nn.Conv2d(512, 10, kernel_size=1)
            model = torch.nn.Sequential(
                *list(pretrained.resnet34_8s.children())[7:-1],
                last_layer.cuda())
            self.segmentation_head = model

            for param in self.segmentation_head.parameters():
                param.requires_grad = True

            dann_params = list(self.segmentation_head.parameters())
            self.segmentation_opt = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.scheduler_seg = get_scheduler(self.segmentation_opt,
                                               hyperparameters)
예제 #5
0
def main():
    # CUDA

    # os.environ["CUDA_VISIBLE_DEVICES"]="4"
    # print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])
    print(f'Is CUDA available: {torch.cuda.is_available()}')
    """Parses arguments."""
    parser = argparse.ArgumentParser(
        description='Training Background Matting on Adobe Dataset')
    parser.add_argument('-n',
                        '--name',
                        type=str,
                        help='Name of tensorboard and model saving folders')
    parser.add_argument('-bs', '--batch_size', type=int, help='Batch Size')
    parser.add_argument('-res',
                        '--reso',
                        type=int,
                        help='Input image resolution')
    parser.add_argument('-init_model',
                        '--init_model',
                        type=str,
                        help='Initial model file')

    parser.add_argument('-w',
                        '--workers',
                        type=int,
                        default=None,
                        help='Number of worker to load data')
    parser.add_argument('-ep',
                        '--epochs',
                        type=int,
                        default=15,
                        help='Maximum Epoch')
    parser.add_argument(
        '-n_blocks1',
        '--n_blocks1',
        type=int,
        default=7,
        help='Number of residual blocks after Context Switching')
    parser.add_argument('-n_blocks2',
                        '--n_blocks2',
                        type=int,
                        default=3,
                        help='Number of residual blocks for Fg and alpha each')

    args = parser.parse_args()
    if args.workers is None:
        args.workers = args.batch_size

    ##Directories
    tb_dir = f'tb_summary/{args.name}'
    model_dir = f'models/{args.name}'

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    if not os.path.exists(tb_dir):
        os.makedirs(tb_dir)

    ## Input list
    data_config_train = {
        'reso': (args.reso, args.reso)
    }  # if trimap is true, rcnn is used

    # DATA LOADING
    print('\n[Phase 1] : Data Preparation')

    # Original Data
    traindata = VideoData(
        csv_file='Video_data_train.csv',
        data_config=data_config_train,
        transform=None
    )  # Write a dataloader function that can read the database provided by .csv file

    train_loader = DataLoader(traindata,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              collate_fn=collate_filter_none)

    print('\n[Phase 2] : Initialization')

    netB = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                             output_nc=4,
                             n_blocks1=args.n_blocks1,
                             n_blocks2=args.n_blocks2)
    netB = nn.DataParallel(netB)
    netB.load_state_dict(torch.load(args.init_model))
    netB.cuda()
    netB.eval()
    for param in netB.parameters():  # freeze netB
        param.requires_grad = False

    netG = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                             output_nc=4,
                             n_blocks1=args.n_blocks1,
                             n_blocks2=args.n_blocks2)
    netG.apply(conv_init)
    netG = nn.DataParallel(netG)
    netG.cuda()
    torch.backends.cudnn.benchmark = True

    netD = MultiscaleDiscriminator(input_nc=3,
                                   num_D=1,
                                   norm_layer=nn.InstanceNorm2d,
                                   ndf=64)
    netD.apply(conv_init)
    netD = nn.DataParallel(netD)
    netD.cuda()

    # Loss
    l1_loss = alpha_loss()
    c_loss = compose_loss()
    g_loss = alpha_gradient_loss()
    GAN_loss = GANloss()

    optimizerG = Adam(netG.parameters(), lr=1e-4)
    optimizerD = Adam(netD.parameters(), lr=1e-5)

    log_writer = SummaryWriter(tb_dir)

    print('Starting Training')
    step = 50

    KK = len(train_loader)

    wt = 1
    for epoch in range(0, args.epochs):

        netG.train()
        netD.train()

        lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

        t0 = get_time()

        for i, data in enumerate(train_loader):
            # Initiating
            bg = data['bg'].cuda()
            image = data['image'].cuda()
            seg = data['seg'].cuda()
            multi_fr = data['multi_fr'].cuda()
            seg_gt = data['seg-gt'].cuda()
            back_rnd = data['back-rnd'].cuda()

            mask0 = torch.ones(seg.shape).cuda()

            tr0 = get_time()

            # pseudo-supervision
            alpha_pred_sup, fg_pred_sup = netB(image, bg, seg, multi_fr)
            mask = (alpha_pred_sup > -0.98).type(torch.FloatTensor).cuda()

            mask1 = (seg_gt > 0.95).type(torch.FloatTensor).cuda()

            ## Train Generator

            alpha_pred, fg_pred = netG(image, bg, seg, multi_fr)

            ##pseudo-supervised losses
            al_loss = l1_loss(
                alpha_pred_sup, alpha_pred,
                mask0) + 0.5 * g_loss(alpha_pred_sup, alpha_pred, mask0)
            fg_loss = l1_loss(fg_pred_sup, fg_pred, mask)

            # compose into same background
            comp_loss = c_loss(image, alpha_pred, fg_pred, bg, mask1)

            # randomly permute the background
            perm = torch.LongTensor(np.random.permutation(bg.shape[0]))
            bg_sh = bg[perm, :, :, :]

            al_mask = (alpha_pred > 0.95).type(torch.FloatTensor).cuda()

            # Choose the target background for composition
            # back_rnd: contains separate set of background videos captured
            # bg_sh: contains randomly permuted captured background from the same minibatch
            if np.random.random_sample() > 0.5:
                bg_sh = back_rnd

            image_sh = compose_image_withshift(
                alpha_pred, image * al_mask + fg_pred * (1 - al_mask), bg_sh,
                seg)

            fake_response = netD(image_sh)

            loss_ganG = GAN_loss(fake_response, label_type=True)

            lossG = loss_ganG + wt * (0.05 * comp_loss + 0.05 * al_loss +
                                      0.05 * fg_loss)

            optimizerG.zero_grad()

            lossG.backward()
            optimizerG.step()

            # Train Discriminator

            fake_response = netD(image_sh)
            real_response = netD(image)

            loss_ganD_fake = GAN_loss(fake_response, label_type=False)
            loss_ganD_real = GAN_loss(real_response, label_type=True)

            lossD = (loss_ganD_real + loss_ganD_fake) * 0.5

            # Update discriminator for every 5 generator update
            if i % 5 == 0:
                optimizerD.zero_grad()
                lossD.backward()
                optimizerD.step()

            lG += lossG.data
            lD += lossD.data
            GenL += loss_ganG.data
            DisL_r += loss_ganD_real.data
            DisL_f += loss_ganD_fake.data

            alL += al_loss.data
            fgL += fg_loss.data
            compL += comp_loss.data

            log_writer.add_scalar('Generator Loss', lossG.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Discriminator Loss', lossD.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Generator Loss: Fake', loss_ganG.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Discriminator Loss: Real',
                                  loss_ganD_real.data, epoch * KK + i + 1)
            log_writer.add_scalar('Discriminator Loss: Fake',
                                  loss_ganD_fake.data, epoch * KK + i + 1)

            log_writer.add_scalar('Generator Loss: Alpha', al_loss.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Generator Loss: Fg', fg_loss.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Generator Loss: Comp', comp_loss.data,
                                  epoch * KK + i + 1)

            t1 = get_time()

            elapse += t1 - t0
            elapse_run += t1 - tr0
            t0 = t1

            if i % step == (step - 1):
                print(f'[{epoch + 1}, {i + 1:5d}] '
                      f'Gen-loss: {lG / step:.4f} '
                      f'Disc-loss: {lD / step:.4f} '
                      f'Alpha-loss: {alL / step:.4f} '
                      f'Fg-loss: {fgL / step:.4f} '
                      f'Comp-loss: {compL / step:.4f} '
                      f'Time-all: {elapse / step:.4f} '
                      f'Time-fwbw: {elapse_run / step:.4f}')
                lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

                write_tb_log(image, 'image', log_writer, i)
                write_tb_log(seg, 'seg', log_writer, i)
                write_tb_log(alpha_pred_sup, 'alpha-sup', log_writer, i)
                write_tb_log(alpha_pred, 'alpha_pred', log_writer, i)
                write_tb_log(fg_pred_sup * mask, 'fg-pred-sup', log_writer, i)
                write_tb_log(fg_pred * mask, 'fg_pred', log_writer, i)

                # composition
                alpha_pred = (alpha_pred + 1) / 2
                comp = fg_pred * alpha_pred + (1 - alpha_pred) * bg
                write_tb_log(comp, 'composite-same', log_writer, i)
                write_tb_log(image_sh, 'composite-diff', log_writer, i)

                del comp

            del bg, image, seg, multi_fr, seg_gt, back_rnd
            del mask0, alpha_pred_sup, fg_pred_sup, mask, mask1
            del alpha_pred, fg_pred, al_loss, fg_loss, comp_loss
            del bg_sh, image_sh, fake_response, real_response
            del lossG, lossD, loss_ganD_real, loss_ganD_fake, loss_ganG

        if epoch % 2 == 0:
            ep = epoch + 1
            torch.save(netG.state_dict(), f'{model_dir}/netG_epoch_{ep}.pth')
            torch.save(optimizerG.state_dict(),
                       f'{model_dir}/optimG_epoch_{ep}.pth')
            torch.save(netD.state_dict(), f'{model_dir}/netD_epoch_{ep}.pth')
            torch.save(optimizerD.state_dict(),
                       f'{model_dir}/optimD_epoch_{ep}.pth')

            # Change weight every 2 epoch to put more stress on discriminator weight and less on pseudo-supervision
            wt = wt / 2
예제 #6
0
def main():
    parser = argparse.ArgumentParser(description='chainer pip2pixHD')
    parser.add_argument('--batchsize', '-b', type=int, default=1)
    parser.add_argument('--epoch', '-e', type=int, default=200)
    parser.add_argument('--gpu', '-g', type=int, default=-1)
    parser.add_argument('--dataset',
                        '-i',
                        default="/mnt/sakuradata10-striped/gao/cityscapes")
    parser.add_argument(
        '--out',
        '-o',
        default='/mnt/sakuradata10-striped/gao/results/pix2pixHD')
    parser.add_argument('--resume', '-r', default='')
    parser.add_argument('--snapshot_interval', type=int, default=10000)
    parser.add_argument('--display_interval', type=int, default=10)
    parser.add_argument('--size', type=int, default=256)
    parser.add_argument('--no_one_hot', action='store_false')
    parser.add_argument('--ins_norm', action='store_true')
    parser.add_argument('--vis_num', type=int, default=4)
    parser.add_argument('--vis_interval', type=int, default=100)
    parser.add_argument('--model_num', '-n', default='')
    parser.add_argument('--generator',
                        '-G',
                        default='Global',
                        choices=['Global', 'Local'])
    parser.add_argument('--fix_global_num_epochs', type=int, default=10)
    parser.add_argument('--global_model_path', default='')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    size = [args.size, args.size * 2]

    if args.generator == 'Global':
        gen = GlobalGenerator(ins_norm=args.ins_norm, input_size=size)
    else:
        gen = LocalEnhancer(args.global_model_path,
                            ins_norm=args.ins_norm,
                            input_size=size)
    dis = MultiscaleDiscriminator()
    if args.model_num:
        chainer.serializers.load_npz(
            os.path.join(args.out, 'gen_iter_' + args.model_num + '.npz'), gen)
        # chainer.serializers.load_npz(os.path.join(args.out, 'gen_dis_iter_' + args.model_num + '.npz'), dis)

    train = Pix2PixHDDataset(root=args.dataset,
                             one_hot=args.no_one_hot,
                             size=size)
    test = Pix2PixHDDataset(root=args.dataset,
                            one_hot=args.no_one_hot,
                            size=size,
                            test=True)
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test,
                                                 args.batchsize,
                                                 shuffle=False)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()

    # Setup optimizer parameters.
    opt = optimizers.Adam(alpha=0.0002)
    opt.setup(gen)
    opt_d = optimizers.Adam(alpha=0.0002)
    opt_d.setup(dis)

    # Set up a trainer
    updater = Updater(models=(gen, dis),
                      iterator={
                          'main': train_iter,
                          'test': test_iter
                      },
                      optimizer={
                          'gen': opt,
                          'dis': opt_d
                      },
                      device=args.gpu,
                      size=size,
                      fix_global_num_epochs=args.fix_global_num_epochs)

    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
    if args.resume:
        chainer.serializers.load_npz(
            os.path.join(args.out, 'snapshot_iter_' + args.resume + '.npz'),
            trainer)

    snapshot_interval = (args.snapshot_interval, 'iteration')
    vis_interval = (args.vis_interval, 'iteration')
    trainer.extend(extensions.dump_graph('gen/loss_GAN'))
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        gen, 'gen_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(
        extensions.LogReport(trigger=(args.display_interval, 'iteration'), ))
    report = [
        'epoch', 'iteration', 'gen/loss_GAN', 'gen/loss_FM', 'dis/loss_GAN'
    ]
    trainer.extend(extensions.PrintReport(report))
    trainer.extend(
        extensions.ProgressBar(update_interval=args.display_interval))
    trainer.extend(train.visualizer(n=args.vis_num, one_hot=args.no_one_hot),
                   trigger=vis_interval)

    trainer.run()

    # Save the trained model
    chainer.serializers.save_npz(os.path.join(args.out, 'model_final'), gen)
    chainer.serializers.save_npz(os.path.join(args.out, 'optimizer_final'),
                                 opt)
예제 #7
0
    def __init__(self, device=None, jit=True):
        self.device = device
        self.jit = jit
        self.opt = Namespace(
            **{
                'n_blocks1': 7,
                'n_blocks2': 3,
                'batch_size': 1,
                'resolution': 512,
                'name': 'Real_fixed'
            })

        data_config_train = {
            'reso': (self.opt.resolution, self.opt.resolution)
        }
        traindata = VideoData(csv_file='Video_data_train.csv',
                              data_config=data_config_train,
                              transform=None)
        self.train_loader = torch.utils.data.DataLoader(
            traindata,
            batch_size=self.opt.batch_size,
            shuffle=True,
            num_workers=self.opt.batch_size,
            collate_fn=_collate_filter_none)

        netB = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                                 output_nc=4,
                                 n_blocks1=self.opt.n_blocks1,
                                 n_blocks2=self.opt.n_blocks2)
        if self.device == 'cuda':
            netB.cuda()
        netB.eval()
        for param in netB.parameters():  # freeze netB
            param.requires_grad = False
        self.netB = netB

        netG = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                                 output_nc=4,
                                 n_blocks1=self.opt.n_blocks1,
                                 n_blocks2=self.opt.n_blocks2)
        netG.apply(conv_init)
        self.netG = netG

        if self.device == 'cuda':
            self.netG.cuda()
            # TODO(asuhan): is this needed?
            torch.backends.cudnn.benchmark = True

        netD = MultiscaleDiscriminator(input_nc=3,
                                       num_D=1,
                                       norm_layer=nn.InstanceNorm2d,
                                       ndf=64)
        netD.apply(conv_init)
        netD = nn.DataParallel(netD)
        self.netD = netD
        if self.device == 'cuda':
            self.netD.cuda()

        self.l1_loss = alpha_loss()
        self.c_loss = compose_loss()
        self.g_loss = alpha_gradient_loss()
        self.GAN_loss = GANloss()

        self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
        self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5)

        tb_dir = '/home/circleci/project/benchmark/models/Background-Matting/TB_Summary/' + self.opt.name
        if not os.path.exists(tb_dir):
            os.makedirs(tb_dir)
        self.log_writer = SummaryWriter(tb_dir)
        self.model_dir = '/home/circleci/project/benchmark/models/Background-Matting/Models/' + self.opt.name
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        self._maybe_trace()