コード例 #1
0
ファイル: test_sync_batchnorm.py プロジェクト: GAIMJKP/GAN-1
    def testSyncBatchNormSyncEval(self):
        bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
        sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])

        bn.cuda()
        sync_bn.cuda()

        self._checkBatchNormResult(bn,
                                   sync_bn,
                                   torch.rand(16, 10),
                                   False,
                                   cuda=True)
コード例 #2
0
ファイル: test_sync_batchnorm.py プロジェクト: GAIMJKP/GAN-1
    def testSyncBatchNorm2DSyncTrain(self):
        bn = nn.BatchNorm2d(10)
        sync_bn = SynchronizedBatchNorm2d(10)
        sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])

        bn.cuda()
        sync_bn.cuda()

        self._checkBatchNormResult(bn,
                                   sync_bn,
                                   torch.rand(16, 10, 16, 16),
                                   True,
                                   cuda=True)
コード例 #3
0
def main(args):
    # parameters
    img_size = args.img_size
    z_dim = 128
    lamb_obj = 1.0
    lamb_img = 0.1
    num_classes = 184 if args.dataset == 'coco' else 179
    num_obj = 8 if args.dataset == 'coco' else 31

    args.out_path = os.path.join(args.out_path, args.dataset,
                                 str(args.img_size))

    num_gpus = torch.cuda.device_count()
    num_workers = 2
    if num_gpus > 1:
        parallel = True
        args.batch_size = args.batch_size * num_gpus
        num_workers = num_workers * num_gpus
    else:
        parallel = False

    # data loader
    train_data = get_dataset(args.dataset, img_size)

    dataloader = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             drop_last=True,
                                             shuffle=True,
                                             num_workers=8)

    # Load model
    device = torch.device('cuda')
    netG = context_aware_generator(num_classes=num_classes,
                                   output_dim=3).to(device)
    netD = CombineDiscriminator128(num_classes=num_classes).to(device)

    parallel = True
    if parallel:
        netG = DataParallelWithCallback(netG)
        netD = nn.DataParallel(netD)

    g_lr, d_lr = args.g_lr, args.d_lr
    gen_parameters = []
    for key, value in dict(netG.named_parameters()).items():
        if value.requires_grad:
            if 'mapping' in key:
                gen_parameters += [{'params': [value], 'lr': g_lr * 0.1}]
            else:
                gen_parameters += [{'params': [value], 'lr': g_lr}]

    g_optimizer = torch.optim.Adam(gen_parameters, betas=(0, 0.999))

    dis_parameters = []
    for key, value in dict(netD.named_parameters()).items():
        if value.requires_grad:
            dis_parameters += [{'params': [value], 'lr': d_lr}]
    d_optimizer = torch.optim.Adam(dis_parameters, betas=(0, 0.999))

    # make dirs
    if not os.path.exists(args.out_path):
        os.makedirs(args.out_path)
    if not os.path.exists(os.path.join(args.out_path, 'model/')):
        os.makedirs(os.path.join(args.out_path, 'model/'))
    writer = SummaryWriter(os.path.join(args.out_path, 'log'))

    logger = setup_logger("lostGAN", args.out_path, 0)
    logger.info(netG)
    logger.info(netD)

    start_time = time.time()
    vgg_loss = VGGLoss()
    vgg_loss = nn.DataParallel(vgg_loss)
    l1_loss = nn.DataParallel(nn.L1Loss())
    for epoch in range(args.total_epoch):
        netG.train()
        netD.train()

        for idx, data in enumerate(tqdm(dataloader)):
            real_images, label, bbox = data
            # print(real_images.shape)
            # print(label.shape)
            # print(bbox.shape)
            real_images, label, bbox = real_images.to(device), label.long().to(
                device).unsqueeze(-1), bbox.float()

            # update D network
            netD.zero_grad()
            real_images, label = real_images.to(device), label.long().to(
                device)
            d_out_real, d_out_robj = netD(real_images, bbox, label)
            d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
            d_loss_robj = torch.nn.ReLU()(1.0 - d_out_robj).mean()

            z = torch.randn(real_images.size(0), num_obj, z_dim).to(device)
            fake_images = netG(z, bbox, y=label.squeeze(dim=-1))
            d_out_fake, d_out_fobj = netD(fake_images.detach(), bbox, label)
            d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
            d_loss_fobj = torch.nn.ReLU()(1.0 + d_out_fobj).mean()

            d_loss = lamb_obj * (d_loss_robj + d_loss_fobj) + lamb_img * (
                d_loss_real + d_loss_fake)
            d_loss.backward()
            d_optimizer.step()

            # update G network
            if (idx % 1) == 0:
                netG.zero_grad()
                g_out_fake, g_out_obj = netD(fake_images, bbox, label)
                g_loss_fake = -g_out_fake.mean()
                g_loss_obj = -g_out_obj.mean()

                pixel_loss = l1_loss(fake_images, real_images).mean()
                feat_loss = vgg_loss(fake_images, real_images).mean()

                g_loss = g_loss_obj * lamb_obj + g_loss_fake * lamb_img + pixel_loss + feat_loss
                g_loss.backward()
                g_optimizer.step()

            if (idx + 1) % 10 == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                logger.info("Time Elapsed: [{}]".format(elapsed))
                logger.info(
                    "Step[{}/{}], d_out_real: {:.4f}, d_out_fake: {:.4f}, g_out_fake: {:.4f} "
                    .format(epoch + 1, idx + 1, d_loss_real.item(),
                            d_loss_fake.item(), g_loss_fake.item()))
                logger.info(
                    "             d_obj_real: {:.4f}, d_obj_fake: {:.4f}, g_obj_fake: {:.4f} "
                    .format(d_loss_robj.item(), d_loss_fobj.item(),
                            g_loss_obj.item()))
                logger.info(
                    "             pixel_loss: {:.4f}, feat_loss: {:.4f}".
                    format(pixel_loss.item(), feat_loss.item()))

                writer.add_image(
                    "real images",
                    make_grid(real_images.cpu().data * 0.5 + 0.5, nrow=4),
                    epoch * len(dataloader) + idx + 1)
                writer.add_image(
                    "fake images",
                    make_grid(fake_images.cpu().data * 0.5 + 0.5, nrow=4),
                    epoch * len(dataloader) + idx + 1)

                writer.add_scalars(
                    "D_loss_real", {
                        "real": d_loss_real.item(),
                        "robj": d_loss_robj.item(),
                        "loss": d_loss.item()
                    })
                writer.add_scalars("D_loss_fake", {
                    "fake": d_loss_fake.item(),
                    "fobj": d_loss_fobj.item()
                })
                writer.add_scalars(
                    "G_loss", {
                        "fake": g_loss_fake.item(),
                        "obj": g_loss_obj.item(),
                        "loss": g_loss.item()
                    })

        # save model
        if (epoch + 1) % 5 == 0:
            torch.save(
                netG.state_dict(),
                os.path.join(args.out_path, 'model/',
                             'G_%d.pth' % (epoch + 1)))
コード例 #4
0
def main(args):
    # parameters
    img_size = 128
    z_dim = 128
    lamb_obj = 1.0
    lamb_app = 1.0
    lamb_img = 0.1
    num_classes = 184 if args.dataset == 'coco' else 179
    num_obj = 8 if args.dataset == 'coco' else 31

    # args.out_path = os.path.join(args.out_path, args.dataset + '_1gpu', str(img_size))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids

    # data loader
    train_data = get_dataset(args.dataset, img_size)
    print(f"{args.dataset.title()} datasets with {len(train_data)} samples has been created!")

    num_gpus = torch.cuda.device_count()
    num_workers = 20
    if num_gpus > 1:
        parallel = True
        args.batch_size = args.batch_size * num_gpus
        num_workers = num_workers * num_gpus
    else:
        parallel = False

    print("{} GPUs, {} workers are used".format(num_gpus, num_workers))
    dataloader = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        drop_last=True, shuffle=True, num_workers=num_workers)

    # Load model
    device = torch.device('cuda')
    netG = ResnetGenerator128(num_classes=num_classes, output_dim=3).to(device)
    netD = CombineDiscriminator128_app(num_classes=num_classes).to(device)

    # if os.path.isfile(args.checkpoint):
    #     state_dict = torch.load(args.checkpoint)

    # parallel = True
    if parallel:
        netG = DataParallelWithCallback(netG)
        netD = nn.DataParallel(netD)

    g_lr, d_lr = args.g_lr, args.d_lr
    gen_parameters = []
    for key, value in dict(netG.named_parameters()).items():
        if value.requires_grad:
            if 'mapping' in key:
                gen_parameters += [{'params': [value], 'lr': g_lr * 0.1}]
            else:
                gen_parameters += [{'params': [value], 'lr': g_lr}]

    g_optimizer = torch.optim.Adam(gen_parameters, betas=(0, 0.999))

    dis_parameters = []
    for key, value in dict(netD.named_parameters()).items():
        if value.requires_grad:
            dis_parameters += [{'params': [value], 'lr': d_lr}]
    d_optimizer = torch.optim.Adam(dis_parameters, betas=(0, 0.999))

    # make dirs
    if not os.path.exists(args.out_path):
        os.makedirs(args.out_path)
    if not os.path.exists(os.path.join(args.out_path, 'model/')):
        os.makedirs(os.path.join(args.out_path, 'model/'))
    writer = SummaryWriter(os.path.join(args.out_path, 'log'))
    # writer = None
    logger = setup_logger("lostGAN", args.out_path, 0)
    for arg in vars(args):
        logger.info("%15s : %-15s" % (arg, str(getattr(args, arg))))

    logger.info(netG)
    logger.info(netD)

    start_time = time.time()
    vgg_loss = VGGLoss()
    vgg_loss = nn.DataParallel(vgg_loss)
    l1_loss = nn.DataParallel(nn.L1Loss())
    for epoch in range(args.total_epoch):
        netG.train()
        netD.train()
        print("Epoch {}/{}".format(epoch, args.total_epoch))
        for idx, data in enumerate(tqdm(dataloader)):
            real_images, label, bbox = data
            real_images, label, bbox = real_images.to(device), label.long().to(device).unsqueeze(-1), bbox.float()

            # update D network
            netD.zero_grad()
            real_images, label = real_images.to(device), label.long().to(device)
            d_out_real, d_out_robj, d_out_robj_app = netD(real_images, bbox, label)
            d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
            d_loss_robj = torch.nn.ReLU()(1.0 - d_out_robj).mean()
            d_loss_robj_app = torch.nn.ReLU()(1.0 - d_out_robj_app).mean()
            # print(d_loss_robj)
            # print(d_loss_robj_app)

            z = torch.randn(real_images.size(0), num_obj, z_dim).to(device)
            fake_images = netG(z, bbox, y=label.squeeze(dim=-1))
            d_out_fake, d_out_fobj, d_out_fobj_app = netD(fake_images.detach(), bbox, label)
            d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
            d_loss_fobj = torch.nn.ReLU()(1.0 + d_out_fobj).mean()
            d_loss_fobj_app = torch.nn.ReLU()(1.0 + d_out_fobj_app).mean()

            d_loss = lamb_obj * (d_loss_robj + d_loss_fobj) + lamb_img * (d_loss_real + d_loss_fake) + lamb_app * (d_loss_robj_app + d_loss_fobj_app)
            d_loss.backward()
            d_optimizer.step()

            # update G network
            if (idx % 1) == 0:
                netG.zero_grad()
                g_out_fake, g_out_obj, g_out_obj_app = netD(fake_images, bbox, label)
                g_loss_fake = - g_out_fake.mean()
                g_loss_obj = - g_out_obj.mean()
                g_loss_obj_app = - g_out_obj_app.mean()

                pixel_loss = l1_loss(fake_images, real_images).mean()
                feat_loss = vgg_loss(fake_images, real_images).mean()

                g_loss = g_loss_obj * lamb_obj + g_loss_fake * lamb_img + pixel_loss + feat_loss + lamb_app * g_loss_obj_app
                g_loss.backward()
                g_optimizer.step()
            # print("d_loss_real={:.3f}, d_loss_robj={:.3f}, d_loss_robj_app={:.3f}".format(d_loss_real.item(), d_loss_robj.item(), d_loss_robj_app.item()))
            # print("d_loss_fake={:.3f}, d_loss_fobj={:.3f}, d_loss_fobj_app={:.3f}".format(d_loss_fake.item(), d_loss_fobj.item(), d_loss_fobj_app.item()))
            # print("g_loss_fake={:.3f}, g_loss_obj={:.3f}, g_loss_obj_app={:.3f}".format(g_loss_fake.item(), g_loss_obj.item(), g_loss_obj_app.item()))
            if (idx + 1) % 100 == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                logger.info("Time Elapsed: [{}]".format(elapsed))
                logger.info("Step[{}/{}], d_out_real: {:.4f}, d_out_fake: {:.4f}, g_out_fake: {:.4f} ".format(epoch + 1,
                                                                                                              idx + 1,
                                                                                                              d_loss_real.item(),
                                                                                                              d_loss_fake.item(),
                                                                                                              g_loss_fake.item()))
                logger.info("             d_obj_real: {:.4f}, d_obj_fake: {:.4f}, g_obj_fake: {:.4f} ".format(
                    d_loss_robj.item(),
                    d_loss_fobj.item(),
                    g_loss_obj.item()))
                logger.info("             d_obj_real_app: {:.4f}, d_obj_fake_app: {:.4f}, g_obj_fake_app: {:.4f} ".format(
                    d_loss_robj_app.item(),
                    d_loss_fobj_app.item(),
                    g_loss_obj_app.item()))

                logger.info("             pixel_loss: {:.4f}, feat_loss: {:.4f}".format(pixel_loss.item(), feat_loss.item()))
                if writer is not None:
                    writer.add_image("real images", make_grid(real_images.cpu().data * 0.5 + 0.5, nrow=4), epoch * len(dataloader) + idx + 1)
                    writer.add_image("fake images", make_grid(fake_images.cpu().data * 0.5 + 0.5, nrow=4), epoch * len(dataloader) + idx + 1)

                    writer.add_scalars("D_loss_real", {"real": d_loss_real.item(),
                                                       "robj": d_loss_robj.item(),
                                                       "robj_app": d_loss_robj_app.item(),
                                                       "loss": d_loss.item()})
                    writer.add_scalars("D_loss_fake", {"fake": d_loss_fake.item(),
                                                       "fobj": d_loss_fobj.item(),
                                                       "fobj_app": d_loss_fobj_app.item()})
                    writer.add_scalars("G_loss", {"fake": g_loss_fake.item(),
                                                  "obj_app": g_loss_obj_app.item(),
                                                  "obj": g_loss_obj.item(),
                                                  "loss": g_loss.item()})

        # save model
        if (epoch + 1) % 5 == 0:
            torch.save(netG.state_dict(), os.path.join(args.out_path, 'model/', 'G_%d.pth' % (epoch + 1)))
            torch.save(netD.state_dict(), os.path.join(args.out_path, 'model/', 'D_%d.pth' % (epoch + 1)))
コード例 #5
0
ファイル: train.py プロジェクト: kroniidvul/LostGANs-mnist
def main(args):
    # parameters
    img_size = 128
    z_dim = 128
    lamb_obj = 1.1  # 0.5, 1, 1, 0
    lamb_img = 0.1  # 1, 1, 0.5, 1
    num_classes = 184 if args.dataset == 'coco' else 10  # 179
    num_obj = 8 if args.dataset == 'coco' else 9  # 31

    # data loader
    train_data = get_dataset(args.dataset, img_size)

    dataloader = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             drop_last=True,
                                             shuffle=True,
                                             num_workers=8)

    # Load model
    netG = ResnetGenerator128(num_classes=num_classes, output_dim=3).cuda()
    netD = CombineDiscriminator128(num_classes=num_classes).cuda()

    parallel = True
    if parallel:
        netG = DataParallelWithCallback(netG)
        netD = nn.DataParallel(netD)

    g_lr, d_lr = args.g_lr, args.d_lr
    gen_parameters = []
    for key, value in dict(netG.named_parameters()).items():
        if value.requires_grad:
            if 'mapping' in key:
                gen_parameters += [{'params': [value], 'lr': g_lr * 0.1}]
            else:
                gen_parameters += [{'params': [value], 'lr': g_lr}]

    g_optimizer = torch.optim.Adam(gen_parameters, betas=(0, 0.999))

    dis_parameters = []
    for key, value in dict(netD.named_parameters()).items():
        if value.requires_grad:
            dis_parameters += [{'params': [value], 'lr': d_lr}]
    d_optimizer = torch.optim.Adam(dis_parameters, betas=(0, 0.999))

    # get most recent commit
    commit_obj = get_recent_commit()
    current_time = time.strftime("%H_%M_%dd_%mm", time.localtime())

    # make dirs
    if not os.path.exists(args.out_path):
        os.makedirs(args.out_path)
    if not os.path.exists(os.path.join(args.out_path, 'model/')):
        os.makedirs(os.path.join(args.out_path, 'model/'))
    writer = SummaryWriter(os.path.join(args.out_path, 'log'))

    logger = setup_logger("lostGAN",
                          args.out_path,
                          0,
                          filename=current_time + '_log.txt')
    logger.info('Commit Tag: ' + commit_obj)
    logger.info('Time: ', time.localtime())
    logger.info('No rotations. No resizing. Ob=1.1, im=0.1. 3x3 grid.')
    logger.info(netG)
    logger.info(netD)

    # labelencoder = LabelEncoder()
    # enc = OneHotEncoder()

    total_steps = len(dataloader)
    start_time = time.time()
    for epoch in range(args.total_epoch):
        netG.train()
        netD.train()
        print('epoch ', epoch)

        for idx, data in enumerate(dataloader):
            real_images, label, bbox = data
            real_images, label, bbox = real_images.cuda(), \
                                       label.long().cuda().unsqueeze(-1), \
                                       bbox.float()  # , \
            # rotation.float()  # .unsqueeze(-1)
            # rotation = labelencoder.fit_transform(rotation.view(-1)).reshape(-1, 16)
            # rotation = torch.from_numpy(rotation).float().cuda()
            # rotation = enc.fit_transform(rotation).toarray()
            # rotation = torch.from_numpy(rotation).float().cuda()  # [bs*16, 4]
            # update D network
            netD.zero_grad()
            real_images, label = real_images.float().cuda(), label.long().cuda(
            )
            d_out_real, d_out_robj = netD(real_images, bbox, label)
            # rotation=rotation)  # d_out_robj: [460, 1]

            d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
            d_loss_robj = torch.nn.ReLU()(1.0 - d_out_robj).mean()
            # d_loss_rot = torch.nn.ReLU()(1.0 - d_out_rot).mean()

            z = torch.randn(real_images.size(0), num_obj, z_dim).cuda()

            fake_images = netG(z, bbox, y=label.squeeze(dim=-1))  # rm rot

            d_out_fake, d_out_fobj = netD(fake_images.detach(), bbox,
                                          label)  # rm rot

            d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
            d_loss_fobj = torch.nn.ReLU()(1.0 + d_out_fobj).mean()
            # d_loss_frot = torch.nn.ReLU()(1.0 + d_out_frot).mean()

            d_loss = lamb_obj * (d_loss_robj + d_loss_fobj) + \
                     lamb_img * (d_loss_real + d_loss_fake)

            d_loss.backward()
            d_optimizer.step()

            # update G network
            if (idx % 1) == 0:
                netG.zero_grad()
                g_out_fake, g_out_obj = netD(fake_images, bbox,
                                             label)  # rm rot
                g_loss_fake = -g_out_fake.mean()
                g_loss_obj = -g_out_obj.mean()
                # g_loss_rot = - g_out_rot.mean()

                g_loss = g_loss_obj * lamb_obj + g_loss_fake * lamb_img
                g_loss.backward()
                g_optimizer.step()

            if (idx + 1) % 100 == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                logger.info("Time Elapsed: [{}]".format(elapsed))
                logger.info(
                    "Epoch[{}/{}], Step[{}/{}], "
                    "d_out_real: {:.4f}, d_out_fake: {:.4f}, g_out_fake: {:.4f} "
                    .format(epoch + 1, args.total_epoch, idx + 1, total_steps,
                            d_loss_real.item(), d_loss_fake.item(),
                            g_loss_fake.item()))

                logger.info(
                    "d_obj_real: {:.4f}, d_obj_fake: {:.4f}, g_obj_fake: {:.4f} "
                    .format(d_loss_robj.item(), d_loss_fobj.item(),
                            g_loss_obj.item()))

                writer.add_image(
                    "real images",
                    make_grid(real_images.cpu().data * 0.5 + 0.5, nrow=4),
                    epoch * len(dataloader) + idx + 1)
                writer.add_image(
                    "fake images",
                    make_grid(fake_images.cpu().data * 0.5 + 0.5, nrow=4),
                    epoch * len(dataloader) + idx + 1)
                writer.add_scalar('DLoss', d_loss,
                                  epoch * len(dataloader) + idx + 1)
                writer.add_scalar('DLoss/real_images', d_loss_real,
                                  epoch * len(dataloader) + idx + 1)
                writer.add_scalar('DLoss/fake_images', d_loss_fake,
                                  epoch * len(dataloader) + idx + 1)
                writer.add_scalar('DLoss/real_objects', d_loss_robj,
                                  epoch * len(dataloader) + idx + 1)
                writer.add_scalar('DLoss/fake_objects', d_loss_fobj,
                                  epoch * len(dataloader) + idx + 1)

                writer.add_scalar('GLoss', g_loss,
                                  epoch * len(dataloader) + idx + 1)
                writer.add_scalar('GLoss/fake_images',
                                  epoch * len(dataloader) + idx + 1)
                writer.add_scalar('GLoss/fake_objects',
                                  epoch * len(dataloader) + idx + 1)

        # save model
        if (epoch + 1) % 5 == 0:
            torch.save(
                netG.state_dict(),
                os.path.join(args.out_path, 'model/',
                             'G_%d.pth' % (epoch + 1)))
コード例 #6
0
def main(args):
    # parameters
    img_size = 128
    z_dim = 128
    lamb_obj = 1.0
    lamb_app = 1.0
    lamb_img = 0.1
    num_classes = 184 if args.dataset == 'coco' else 179
    num_obj = 8 if args.dataset == 'coco' else 31

    args.out_path = os.path.join(args.out_path, args.dataset, str(img_size))

    num_gpus = torch.cuda.device_count()
    num_workers = 2
    if num_gpus > 1:
        parallel = True
        args.batch_size = args.batch_size * num_gpus
        num_workers = num_workers * num_gpus
    else:
        parallel = False

    # data loader
    train_data = get_dataset(args.dataset, img_size)

    dataloader = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             drop_last=True,
                                             shuffle=True,
                                             num_workers=num_workers)

    # Load model
    device = torch.device('cuda')
    netG = ResnetGenerator128_context(num_classes=num_classes,
                                      output_dim=3).to(device)
    netD = CombineDiscriminator128_app(num_classes=num_classes).to(device)

    if args.checkpoint_epoch is not None:
        load_G = args.out_path + '/model/G_{}.pth'.format(
            args.checkpoint_epoch)
        load_D = args.out_path + '/model/D_{}.pth'.format(
            args.checkpoint_epoch)
        if not os.path.isfile(load_G) or not os.path.isfile(load_D):
            return

        # load generator
        state_dict = torch.load(load_G)

        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`nvidia
            new_state_dict[name] = v

        model_dict = netG.state_dict()
        pretrained_dict = {
            k: v
            for k, v in new_state_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        netG.load_state_dict(model_dict)
        netG.cuda()

        # load discriminator
        state_dict = torch.load(load_D)

        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`nvidia
            new_state_dict[name] = v

        model_dict = netD.state_dict()
        pretrained_dict = {
            k: v
            for k, v in new_state_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        netD.load_state_dict(model_dict)
        netD.cuda()
    else:
        args.checkpoint_epoch = 0

    # parallel = False
    if parallel:
        netG = DataParallelWithCallback(netG)
        netD = nn.DataParallel(netD)

    g_lr, d_lr = args.g_lr, args.d_lr
    gen_parameters = []
    for key, value in dict(netG.named_parameters()).items():
        if value.requires_grad:
            if 'mapping' in key:
                gen_parameters += [{'params': [value], 'lr': g_lr * 0.1}]
            else:
                gen_parameters += [{'params': [value], 'lr': g_lr}]

    g_optimizer = torch.optim.Adam(gen_parameters, betas=(0, 0.999))

    dis_parameters = []
    for key, value in dict(netD.named_parameters()).items():
        if value.requires_grad:
            dis_parameters += [{'params': [value], 'lr': d_lr}]
    d_optimizer = torch.optim.Adam(dis_parameters, betas=(0, 0.999))

    # make dirs
    if not os.path.exists(args.out_path):
        os.makedirs(args.out_path)
    if not os.path.exists(os.path.join(args.out_path, 'model/')):
        os.makedirs(os.path.join(args.out_path, 'model/'))
    writer = SummaryWriter(os.path.join(args.out_path, 'log'))

    logger = setup_logger("lostGAN", args.out_path, 0)
    logger.info(netG)
    logger.info(netD)

    start_time = time.time()
    vgg_loss = VGGLoss()
    vgg_loss = nn.DataParallel(vgg_loss)
    l1_loss = nn.DataParallel(nn.L1Loss())
    for epoch in range(args.checkpoint_epoch, args.total_epoch):
        netG.train()
        netD.train()
        print("Epoch {}/{}".format(epoch, args.total_epoch))
        for idx, data in enumerate(tqdm(dataloader)):
            real_images, label, bbox = data
            # print(real_images.shape)
            # print(label.shape)
            # print(bbox.shape)
            real_images, label, bbox = real_images.to(device), label.long().to(
                device).unsqueeze(-1), bbox.float()

            # update D network
            netD.zero_grad()
            real_images, label = real_images.to(device), label.long().to(
                device)
            d_out_real, d_out_robj, d_out_robj_app = netD(
                real_images, bbox, label)
            d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
            d_loss_robj = torch.nn.ReLU()(1.0 - d_out_robj).mean()
            d_loss_robj_app = torch.nn.ReLU()(1.0 - d_out_robj_app).mean()
            # print(d_loss_robj)
            # print(d_loss_robj_app)

            z = torch.randn(real_images.size(0), num_obj, z_dim).to(device)
            fake_images = netG(z, bbox, y=label.squeeze(dim=-1))
            d_out_fake, d_out_fobj, d_out_fobj_app = netD(
                fake_images.detach(), bbox, label)
            d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
            d_loss_fobj = torch.nn.ReLU()(1.0 + d_out_fobj).mean()
            d_loss_fobj_app = torch.nn.ReLU()(1.0 + d_out_fobj_app).mean()

            d_loss = lamb_obj * (d_loss_robj + d_loss_fobj) + lamb_img * (
                d_loss_real + d_loss_fake) + lamb_app * (d_loss_robj_app +
                                                         d_loss_fobj_app)
            d_loss.backward()
            d_optimizer.step()

            # update G network
            if (idx % 1) == 0:
                netG.zero_grad()
                g_out_fake, g_out_obj, g_out_obj_app = netD(
                    fake_images, bbox, label)
                g_loss_fake = -g_out_fake.mean()
                g_loss_obj = -g_out_obj.mean()
                g_loss_obj_app = -g_out_obj_app.mean()

                pixel_loss = l1_loss(fake_images, real_images).mean()
                feat_loss = vgg_loss(fake_images, real_images).mean()

                g_loss = g_loss_obj * lamb_obj + g_loss_fake * lamb_img + pixel_loss + feat_loss + lamb_app * g_loss_obj_app
                g_loss.backward()
                g_optimizer.step()

            if (idx + 1) % 500 == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                logger.info("Time Elapsed: [{}]".format(elapsed))
                logger.info(
                    "Step[{}/{}], d_out_real: {:.4f}, d_out_fake: {:.4f}, g_out_fake: {:.4f} "
                    .format(epoch + 1, idx + 1, d_loss_real.item(),
                            d_loss_fake.item(), g_loss_fake.item()))
                logger.info(
                    "             d_obj_real: {:.4f}, d_obj_fake: {:.4f}, g_obj_fake: {:.4f} "
                    .format(d_loss_robj.item(), d_loss_fobj.item(),
                            g_loss_obj.item()))
                logger.info(
                    "             d_obj_real_app: {:.4f}, d_obj_fake_app: {:.4f}, g_obj_fake_app: {:.4f} "
                    .format(d_loss_robj_app.item(), d_loss_fobj_app.item(),
                            g_loss_obj_app.item()))

                logger.info(
                    "             pixel_loss: {:.4f}, feat_loss: {:.4f}".
                    format(pixel_loss.item(), feat_loss.item()))

                writer.add_image(
                    "real images",
                    make_grid(real_images.cpu().data * 0.5 + 0.5, nrow=4),
                    epoch * len(dataloader) + idx + 1)
                writer.add_image(
                    "fake images",
                    make_grid(fake_images.cpu().data * 0.5 + 0.5, nrow=4),
                    epoch * len(dataloader) + idx + 1)

        # save model
        if (epoch + 1) % 5 == 0:
            torch.save(
                netG.state_dict(),
                os.path.join(args.out_path, 'model/',
                             'G_%d.pth' % (epoch + 1)))
            torch.save(
                netD.state_dict(),
                os.path.join(args.out_path, 'model/',
                             'D_%d.pth' % (epoch + 1)))