Beispiel #1
0
# sigmoid=nn.Sigmoid()
# discriminator=models.Inception3(num_classes=1,aux_logits=False)
# discriminator=discriminator.to(device)
# discriminator.apply(weights_init)

#snresnet64

discriminator = SNResNetProjectionDiscriminator(num_classes=0)
discriminator = discriminator.to(device)

#losses

ssim = pytorch_ssim.SSIM()
bce_loss = nn.BCELoss()
mse_loss = nn.MSELoss()
gen_criterion = L.GenLoss("hinge", False)
dis_criterion = L.DisLoss("hinge", False)

optimizerG = optim.Adam(dncnn.parameters(), lr=learning_rate)
optimizerD = optim.Adam(discriminator.parameters(), lr=3 * learning_rate)

network_loss = []
ssim_ = []
mse_ = []

#loading checkpoint
if resume:
    print('loading params')
    start_epoch = params.start
    path = model_dir + '/' + str(start_epoch -
                                 1) + '.pth.tar'  #load the last epoch params
def main():
    args = get_args()
    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    def _rescale(img):
        return img * 2.0 - 1.0

    def _noise_adder(img):
        return torch.empty_like(img, dtype=img.dtype).uniform_(0.0,
                                                               1 / 128.0) + img

    # dataset
    train_dataset = datasets.ImageFolder(
        os.path.join(args.data_root, 'train'),
        transforms.Compose([
            transforms.ToTensor(),
            _rescale,
            _noise_adder,
        ]))
    train_loader = iter(
        data.DataLoader(train_dataset,
                        args.batch_size,
                        sampler=InfiniteSamplerWrapper(train_dataset),
                        num_workers=args.num_workers,
                        pin_memory=True))
    if args.calc_FID:
        eval_dataset = datasets.ImageFolder(
            os.path.join(args.data_root, 'val'),
            transforms.Compose([
                transforms.ToTensor(),
                _rescale,
            ]))
        eval_loader = iter(
            data.DataLoader(eval_dataset,
                            args.batch_size,
                            sampler=InfiniteSamplerWrapper(eval_dataset),
                            num_workers=args.num_workers,
                            pin_memory=True))
    else:
        eval_loader = None
    num_classes = len(train_dataset.classes)
    print(' prepared datasets...')
    print(' Number of training images: {}'.format(len(train_dataset)))
    # Prepare directories.
    args.num_classes = num_classes
    args, writer = prepare_results_dir(args)
    # initialize models.
    _n_cls = num_classes if args.cGAN else 0
    gen = ResNetGenerator(args.gen_num_features,
                          args.gen_dim_z,
                          args.gen_bottom_width,
                          activation=F.relu,
                          num_classes=_n_cls,
                          distribution=args.gen_distribution).to(device)
    if args.dis_arch_concat:
        dis = SNResNetConcatDiscriminator(args.dis_num_features, _n_cls,
                                          F.relu, args.dis_emb).to(device)
    else:
        dis = SNResNetProjectionDiscriminator(args.dis_num_features, _n_cls,
                                              F.relu).to(device)
    inception_model = inception.InceptionV3().to(
        device) if args.calc_FID else None

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    # gen_criterion = getattr(L, 'gen_{}'.format(args.loss_type))
    # dis_criterion = getattr(L, 'dis_{}'.format(args.loss_type))
    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)
    print(' Initialized models...\n')

    if args.args_path is not None:
        print(' Load weights...\n')
        prev_args, gen, opt_gen, dis, opt_dis = utils.resume_from_args(
            args.args_path, args.gen_ckpt_path, args.dis_ckpt_path)

    # Training loop
    for n_iter in tqdm.tqdm(range(1, args.max_iteration + 1)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, device, num_classes,
                                                    gen)
                dis_fake = dis(fake, pseudo_y)
                if args.relativistic_loss:
                    real, y = sample_from_data(args, device, train_loader)
                    dis_real = dis(real, y)
                else:
                    dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)
                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()
                if n_iter % 10 == 0 and writer is not None:
                    writer.add_scalar('gen', _l_g, n_iter)

            fake, pseudo_y, _ = sample_from_gen(args, device, num_classes, gen)
            real, y = sample_from_data(args, device, train_loader)

            dis_fake, dis_real = dis(fake, pseudo_y), dis(real, y)
            loss_dis = dis_criterion(dis_fake, dis_real)

            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

            cumulative_loss_dis += loss_dis.item()
            if n_iter % 10 == 0 and i == args.n_dis - 1 and writer is not None:
                cumulative_loss_dis /= args.n_dis
                writer.add_scalar('dis', cumulative_loss_dis / args.n_dis,
                                  n_iter)
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'.
                format(n_iter, args.max_iteration, _l_g, cumulative_loss_dis))
            if not args.no_image:
                writer.add_image(
                    'fake',
                    torchvision.utils.make_grid(fake,
                                                nrow=4,
                                                normalize=True,
                                                scale_each=True))
                writer.add_image(
                    'real',
                    torchvision.utils.make_grid(real,
                                                nrow=4,
                                                normalize=True,
                                                scale_each=True))
            # Save previews
            utils.save_images(n_iter, n_iter // args.checkpoint_interval,
                              args.results_root, args.train_image_root, fake,
                              real)
        if n_iter % args.checkpoint_interval == 0:
            # Save checkpoints!
            utils.save_checkpoints(args, n_iter,
                                   n_iter // args.checkpoint_interval, gen,
                                   opt_gen, dis, opt_dis)
        if n_iter % args.eval_interval == 0:
            # TODO (crcrpar): implement Ineption score, FID, and Geometry score
            # Once these criterion are prepared, val_loader will be used.
            fid_score = evaluation.evaluate(args, n_iter, gen, device,
                                            inception_model, eval_loader)
            tqdm.tqdm.write(
                '[Eval] iteration: {:07d}/{:07d}, FID: {:07f}'.format(
                    n_iter, args.max_iteration, fid_score))
            if writer is not None:
                writer.add_scalar("FID", fid_score, n_iter)
                # Project embedding weights if exists.
                embedding_layer = getattr(dis, 'l_y', None)
                if embedding_layer is not None:
                    writer.add_embedding(embedding_layer.weight.data,
                                         list(range(args.num_classes)),
                                         global_step=n_iter)
    if args.test:
        shutil.rmtree(args.results_root)
Beispiel #3
0
def main():
    args = get_args()
    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    def _rescale(img):
        return img * 2.0 - 1.0

    def _noise_adder(img):
        return torch.empty_like(img, dtype=img.dtype).uniform_(0.0,
                                                               1 / 128.0) + img

    eval_dataset = cifar10.CIFAR10(root=args.data_root,
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(64),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))
                                   ]),
                                   minority_classes=None,
                                   keep_ratio=None)
    eval_loader = iter(
        torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=args.batch_size,
            sampler=InfiniteSamplerWrapper(eval_dataset),
            num_workers=args.num_workers,
            pin_memory=True))

    print(' prepared datasets...')

    # Prepare directories.
    num_classes = len(eval_dataset.classes)
    args.num_classes = num_classes

    # initialize models.
    _n_cls = num_classes if args.cGAN else 0
    gen = ResNetGenerator(args.gen_num_features,
                          args.gen_dim_z,
                          args.gen_bottom_width,
                          activation=F.relu,
                          num_classes=_n_cls,
                          distribution=args.gen_distribution).to(device)
    if args.dis_arch_concat:
        dis = SNResNetConcatDiscriminator(args.dis_num_features, _n_cls,
                                          F.relu, args.dis_emb).to(device)
    else:
        dis = SNResNetProjectionDiscriminator(args.dis_num_features, _n_cls,
                                              F.relu,
                                              args.transform_space).to(device)
    inception_model = inception.InceptionV3().to(
        device) if args.calc_FID else None

    gen = torch.nn.DataParallel(gen)
    # dis = torch.nn.DataParallel(dis)

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    # gen_criterion = getattr(L, 'gen_{}'.format(args.loss_type))
    # dis_criterion = getattr(L, 'dis_{}'.format(args.loss_type))
    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)

    print(' Initialized models...\n')

    if args.args_path is None:
        print("Please specify weights to load")
        exit()
    else:
        print(' Load weights...\n')

        prev_args, gen, opt_gen, dis, opt_dis = utils.resume_from_args(
            args.args_path, args.gen_ckpt_path, args.dis_ckpt_path)
    args.n_fid_batches = args.n_eval_batches
    fid_score = evaluation.evaluate(args,
                                    0,
                                    gen,
                                    device,
                                    inception_model,
                                    eval_loader,
                                    to_save=False)
    print(fid_score)
Beispiel #4
0
def main():
    args = get_args()
    weight_path, img_path = directory_path(args)

    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    # dataloading
    train_loader, s_dlen, _n_cls = data_loader2(args)

    fixed_z = torch.randn(200, 10, 128)
    fixed_img_list, fixed_label_list = pick_fixed_img(args, train_loader, 200)

    # initialize model
    gen, dis = select_model(args, _n_cls)

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)

    criterion = nn.CrossEntropyLoss()

    # Training loop
    for n_iter in tqdm.tqdm(range(0, args.max_iteration)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
                dis_fake, dis_mi, dis_c = dis(fake, pseudo_y)
                dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)

                ##################################################
                loss_mi = criterion(dis_mi, pseudo_y)
                loss_c = criterion(dis_c, pseudo_y)

                loss_gen = loss_gen + args.lambda_c * (loss_c - loss_mi)
                ##################################################

                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()

            fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
            real, y = sample_from_data(args, dev, train_loader)

            dis_fake, dis_fake_mi, dis_fake_c = dis(fake, pseudo_y)
            dis_real, dis_real_mi, dis_real_c = dis(real, y)

            ######################################################
            loss_dis_mi = criterion(dis_fake_mi, pseudo_y)
            loss_dis_c = criterion(dis_real_c, y)
            ######################################################

            loss_dis = dis_criterion(dis_fake, dis_real)
            loss_dis = loss_dis + args.lambda_c * (loss_dis_mi + loss_dis_c)

            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

            cumulative_loss_dis += loss_dis.item()
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'
                ' loss mi {:05f}, loss c {:05f}'.format(
                    n_iter, args.max_iteration, _l_g, cumulative_loss_dis,
                    args.lambda_c * loss_dis_mi, args.lambda_c * loss_dis_c))

        if n_iter % args.checkpoint_interval == 0:
            #Save checkpoints!
            utils.save_checkpoints(args, n_iter, gen, opt_gen, dis, opt_dis,
                                   weight_path)
            if args.dataset == "omniglot":
                utils.save_img(fixed_img_list,
                               fixed_label_list,
                               fixed_z,
                               gen,
                               32,
                               28,
                               img_path,
                               n_iter,
                               device=dev)
            elif args.dataset == "vgg" or args.dataset == "animal":
                utils.save_img(fixed_img_list,
                               fixed_label_list,
                               fixed_z,
                               gen,
                               84,
                               64,
                               img_path,
                               n_iter,
                               device=dev)
            elif args.dataset == "cub":
                utils.save_img(fixed_img_list,
                               fixed_label_list,
                               fixed_z,
                               gen,
                               72,
                               64,
                               img_path,
                               n_iter,
                               device=dev)
            else:
                raise Exception("Enter model omniglot or vgg or animal or cub")

    if args.test:
        shutil.rmtree(args.results_root)
Beispiel #5
0
def main():
    args = get_args()
    weight_path, img_path = directory_path(args)

    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    # dataloading
    train_loader, s_dlen, _n_cls = data_loader2(args)

    fixed_z = torch.randn(200, 10, 128)
    fixed_img_list, fixed_label_list = pick_fixed_img(args, train_loader, 200)

    # initialize model
    gen, dis, start = select_model(args, _n_cls)
    
    gen, dis = gen.cuda(), dis.cuda()
    
    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)

    # Training loop
    for n_iter in tqdm.tqdm(range(start, args.max_iteration)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
                fake = DiffAugment(fake, policy=policy)
                dis_fake = dis(fake, pseudo_y)
                dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)
                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()

            fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
            real, y = sample_from_data(args, dev, train_loader)

            fake = DiffAugment(fake, policy=policy)
            real = DiffAugment(real, policy=policy)

            dis_fake, dis_real = dis(fake, pseudo_y), dis(real, y)
            loss_dis = dis_criterion(dis_fake, dis_real)

            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

            cumulative_loss_dis += loss_dis.item()
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'.format(
                    n_iter, args.max_iteration, _l_g, cumulative_loss_dis))

        if n_iter % args.checkpoint_interval == 0:
            #Save checkpoints!
            utils.save_checkpoints(args, n_iter, gen, opt_gen, dis, opt_dis, weight_path)
            utils.save_img(fixed_img_list, fixed_label_list, fixed_z, gen,
                           32, 28, img_path, n_iter, device=dev)
    if args.test:
        shutil.rmtree(args.results_root)