Example #1
0
def train(double, seed, constant, experience_dir, train_path, valid_path,
          batch_size, epochs, drift, adwd, method):
    if not os.path.exists(experience_dir):
        os.mkdir(experience_dir)
    if not os.path.exists(os.path.join(experience_dir, 'model')):
        os.mkdir(os.path.join(experience_dir, 'model'))
    if not os.path.exists(os.path.join(experience_dir, 'vizu')):
        os.mkdir(os.path.join(experience_dir, 'vizu'))

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    generator = nn.DataParallel(
        Generator(constant, in_channels=1, out_channels=3)).cuda()
    discriminator = nn.DataParallel(Discriminator(constant)).cuda()
    features = nn.DataParallel(Features()).cuda()
    i2v = nn.DataParallel(Illustration2Vec()).cuda()
    sketcher = nn.DataParallel(
        Generator(constant, in_channels=3,
                  out_channels=1)).cuda() if double else None

    optimizerG    = optim.Adam(
        list(generator.parameters()) + list(sketcher.parameters()) \
        if double \
        else generator.parameters(),
        lr=1e-4, betas=(0.5, 0.9)
    )
    optimizerD = optim.Adam(discriminator.parameters(),
                            lr=1e-4,
                            betas=(0.5, 0.9))

    lr_schedulerG = StepLRScheduler(optimizerG, [125000], [0.1], 1e-4, 0, 0)
    lr_schedulerD = StepLRScheduler(optimizerD, [125000], [0.1], 1e-4, 0, 0)

    if os.path.isfile(os.path.join(experience_dir, 'model/generator.pth')):
        checkG = torch.load(os.path.join(experience_dir,
                                         'model/generator.pth'))
        generator.load_state_dict(checkG['generator'])
        if double: sketcher.load_state_dict(checkG['sketcher'])
        optimizerG.load_state_dict(checkG['optimizer'])

    if os.path.isfile(os.path.join(experience_dir, 'model/discriminator.pth')):
        checkD = torch.load(
            os.path.join(experience_dir, 'model/discriminator.pth'))
        discriminator.load_state_dict(checkD['discriminator'])
        optimizerD.load_state_dict(checkD['optimizer'])

    for param in features.parameters():
        param.requires_grad = False
    mse = nn.MSELoss().cuda()
    grad_penalty = GradPenalty(10).cuda()

    img_size = (512, 512)
    mask_gen = Hint((img_size[0] // 4, img_size[1] // 4), 120, (1, 4), 5,
                    (10, 10))
    dataloader = CreateTrainLoader(train_path,
                                   batch_size,
                                   mask_gen,
                                   img_size,
                                   method=method)
    iterator = iter(dataloader)

    valloader = CreateValidLoader(valid_path, batch_size, img_size)
    config = {
        'target_npz': './res/model/fid_stats_color.npz',
        'corps': 2,
        'image_size': img_size[0]
    }
    evaluate = Evaluate(generator, i2v, valloader,
                        namedtuple('Config', config.keys())(*config.values()))

    start_epoch = 0
    if os.path.isfile(os.path.join(experience_dir, 'model/generator.pth')):
        checkG = torch.load(os.path.join(experience_dir,
                                         'model/generator.pth'))
        start_epoch = checkG['epoch'] + 1

    for epoch in range(start_epoch, epochs):
        batch_id = -1

        pbar = tqdm(total=len(dataloader),
                    desc=f'Epoch [{(epoch + 1):06d}/{epochs}]')
        id = 0
        dgp = 0.
        gc = 0.
        s = 0. if double else None

        while batch_id < len(dataloader):
            lr_schedulerG.step(epoch)
            lr_schedulerD.step(epoch)

            generator.train()
            discriminator.train()
            if double: sketcher.train()

            # =============
            # DISCRIMINATOR
            # =============
            for p in discriminator.parameters():
                p.requires_grad = True
            for p in generator.parameters():
                p.requires_grad = False
            optimizerD.zero_grad()
            optimizerG.zero_grad()

            batch_id += 1
            try:
                colored, sketch, hint = iterator.next()
            except StopIteration:
                iterator = iter(dataloader)
                colored, sketch, hint = iterator.next()

            real_colored = colored.cuda()
            real_sketch = sketch.cuda()
            hint = hint.cuda()

            with torch.no_grad():
                feat_sketch = i2v(real_sketch).detach()
                fake_colored = generator(real_sketch, hint,
                                         feat_sketch).detach()

            errD_fake = discriminator(fake_colored,
                                      feat_sketch).mean(0).view(1)
            errD_fake.backward(retain_graph=True)

            errD_real = discriminator(real_colored,
                                      feat_sketch).mean(0).view(1)
            errD = errD_real - errD_fake

            errD_realer = -1 * errD_real + errD_real.pow(2) * drift
            errD_realer.backward(retain_graph=True)

            gp = grad_penalty(discriminator, real_colored, fake_colored,
                              feat_sketch)
            gp.backward()

            optimizerD.step()
            pbar.update(1)

            dgp += errD_realer.item() + gp.item()

            # =============
            # GENERATOR
            # =============
            for p in generator.parameters():
                p.requires_grad = True
            for p in discriminator.parameters():
                p.requires_grad = False
            optimizerD.zero_grad()
            optimizerG.zero_grad()

            batch_id += 1
            try:
                colored, sketch, hint = iterator.next()
            except StopIteration:
                iterator = iter(dataloader)
                colored, sketch, hint = iterator.next()

            real_colored = colored.cuda()
            real_sketch = sketch.cuda()
            hint = hint.cuda()

            with torch.no_grad():
                feat_sketch = i2v(real_sketch).detach()

            fake_colored = generator(real_sketch, hint, feat_sketch)

            errD = discriminator(fake_colored, feat_sketch)
            errG = -1 * errD.mean() * adwd
            errG.backward(retain_graph=True)

            feat1 = features(fake_colored)
            with torch.no_grad():
                feat2 = features(real_colored)

            contentLoss = mse(feat1, feat2)
            contentLoss.backward()

            optimizerG.step()
            pbar.update(1)

            gc += errG.item() + contentLoss.item()

            # =============
            # SKETCHER
            # =============
            if double:
                for p in generator.parameters():
                    p.requires_grad = True
                for p in discriminator.parameters():
                    p.requires_grad = False
                optimizerD.zero_grad()
                optimizerG.zero_grad()

                batch_id += 1
                try:
                    colored, sketch, hint = iterator.next()
                except StopIteration:
                    iterator = iter(dataloader)
                    colored, sketch, hint = iterator.next()

                real_colored = colored.cuda()
                real_sketch = sketch.cuda()
                hint = hint.cuda()

                with torch.no_grad():
                    feat_sketch = i2v(real_sketch).detach()

                fake_colored = generator(real_sketch, hint, feat_sketch)
                fake_sketch = sketcher(fake_colored, hint, feat_sketch)
                errS = mse(fake_sketch, real_sketch)
                errS.backward()

                optimizerG.step()
                pbar.update(1)

                s += errS.item()

            pbar.set_postfix(
                **{
                    'dgp': dgp / (id + 1),
                    'gc': gc / (id + 1),
                    's': s / (id + 1) if double else None
                })

            # =============
            # PLOT
            # =============
            generator.eval()
            discriminator.eval()

            if id % 20 == 0:
                tensors2vizu(
                    img_size,
                    os.path.join(experience_dir,
                                 f'vizu/out_{epoch}_{id}_{batch_id}.jpg'), **{
                                     'strokes': hint[:, :3, ...],
                                     'col': real_colored,
                                     'fcol': fake_colored,
                                     'sketch': real_sketch,
                                     'fsketch': fake_sketch if double else None
                                 })

            id += 1

        pbar.close()

        torch.save(
            {
                'generator': generator.state_dict(),
                'sketcher': sketcher.state_dict() if double else None,
                'optimizer': optimizerG.state_dict(),
                'double': double,
                'epoch': epoch
            }, os.path.join(experience_dir, 'model/generator.pth'))

        torch.save(
            {
                'discriminator': discriminator.state_dict(),
                'optimizer': optimizerD.state_dict(),
                'epoch': epoch
            }, os.path.join(experience_dir, 'model/discriminator.pth'))

        if epoch % 20 == 0:
            fid, fid_var = evaluate()
            print(
                f'\n===================\nFID = {fid} +- {fid_var}\n===================\n'
            )

            with open(os.path.join(experience_dir, 'fid.csv'), 'a+') as f:
                f.write(f'{epoch};{fid};{fid_var}\n')
Example #2
0
parser = argparse.ArgumentParser(description='Test Single Image')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--test_mode', default='GPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU')
parser.add_argument('--image_name', type=str, help='test low resolution image name')
parser.add_argument('--model_name', default='new_train_G.pth', type=str, help='generator model epoch name')

opt = parser.parse_args()

UPSCALE_FACTOR = opt.upscale_factor
TEST_MODE = True if opt.test_mode == 'GPU' else False
# IMAGE_NAME = opt.image_name
IMAGE_NAME = './data/test_input/4.png'
MODEL_NAME = opt.model_name

model = Generator(UPSCALE_FACTOR).eval()
if TEST_MODE:
    model.cuda()
    model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
else:
    model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))

image = Image.open(IMAGE_NAME)
image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
if TEST_MODE:
    image = image.cuda()

start = time.clock()
out = model(image)
elapsed = (time.clock() - start)
print('cost' + str(elapsed) + 's')
    NUM_GPU = torch.cuda.device_count()
else:
    NUM_GPU = 1

# DataLoaders
train_set = TrainDatasetFromFolder('data/train',
                                   crop_size=CROP_SIZE,
                                   upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder('data/val', upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE_TRAIN,
                          shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

# Networks and Loss
netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:',
      sum(param.numel() for param in netG.parameters()))
if USE_DISCRIMINATOR:
    netD = Discriminator()
    print('# discriminator parameters:',
          sum(param.numel() for param in netD.parameters()))

generator_criterion = GeneratorLoss(weight_perception=WEIGHT_PERCEPTION,
                                    weight_adversarial=WEIGHT_ADVERSARIAL,
                                    weight_image=WEIGHT_IMAGE,
                                    network=NETWORK)

if USE_CUDA:
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
Example #4
0
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)],
                         [0.5 for _ in range(CHANNELS_IMG)]),
])

#dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
#comment mnist and uncomment below if you want to train on CelebA dataset
dataset = datasets.ImageFolder(root="../data/color/", transform=transforms)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print('loaded dataset')

# initialize gen and disc/critic
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

print('loaded networks')

# initializate optimizer
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

# for tensorboard plotting
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0
        mean = noise.mean()
        std = noise.std()

        noise.data.add_(-mean).div_(std)


def make_image(tensor):
    return (tensor.detach().clamp_(min=-1, max=1).add(1).div_(2).mul(255).type(
        torch.uint8).permute(0, 2, 3, 1).to("cpu").numpy())


device = args.device

# generate images
## load model
g_ema1 = Generator(args.size1, 512, 8)
g_ema1.load_state_dict(torch.load(args.model1, map_location='cuda:0')["g_ema"],
                       strict=False)
g_ema1.eval()
g_ema1 = g_ema1.to(device)

g_ema2 = Generator(args.size2, 512, 8)
g_ema2.load_state_dict(torch.load(args.model2, map_location='cuda:0')["g_ema"],
                       strict=False)
g_ema2.eval()
g_ema2 = g_ema2.to(device)

## noise
noises_single = g_ema2.make_noise()
noises = []
for noise in noises_single:
 def register(self, trainer):
     self.generate = Generator(trainer.model.model, trainer.cuda)
Example #7
0
image_list = os.listdir(image_path)

outdir = "./output_train"
if not os.path.exists(outdir):
    os.mkdir(outdir)

test_box = []
for i in range(testsize):
    rnd = np.random.randint(Ntrain + 1, Ntrain + 100)
    image_name = image_path + image_list[rnd]
    _, sr = prepare_dataset(image_name)
    test_box.append(sr)

x_test = chainer.as_variable(xp.array(test_box).astype(xp.float32))

generator = Generator()
generator.to_gpu()
gen_opt = set_optimizer(generator)
#serializers.load_npz("./generator_pretrain.model",generator)

discriminator = Discriminator()
discriminator.to_gpu()
dis_opt = set_optimizer(discriminator)

vgg = VGG()
vgg.to_gpu()
vgg_opt = set_optimizer(vgg)
vgg.base.disable_update()

for epoch in range(epochs):
    sum_gen_loss = 0
Example #8
0
def weight_init2(m):
    # 参数初始化。 可以改成xavier初始化方法
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal(m.weight)
        if m.bias is not None:
            nn.init.constant(m.bias, 0.01)


def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal(m.weight)
        if m.bias is not None:
            nn.init.constant(m.bias, 0.01)


G = Generator(input_size=g_input_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_size,
                  hidden_size=d_hidden_size,
                  output_size=d_output_size)

criterion = nn.BCELoss(
)  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
criterion2 = nn.CosineSimilarity(dim=1, eps=1e-6)

d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(),
                         lr=g_learning_rate,
                         betas=optim_betas,
                         weight_decay=0)

G.cuda()
Example #9
0
from modules import *
import os, codecs
from tqdm import tqdm
from utils import *
from model import Generator, Discriminator

if __name__ == '__main__':
    # load gen data
    gen_user, gen_card, gen_card_idx, gen_item_cand, gen_item_pos, gen_num_batch \
        = get_gen_batch_data(is_training=True)
    gen_user_test, gen_card_test, _, gen_item_cand_test, gen_item_pos_test, gen_num_batch_test \
        = get_gen_batch_data(is_training=False)

    # Construct graph
    with tf.name_scope('Generator'):
        g = Generator(is_training=True)
    print(len(tf.get_variable_scope().global_variables()))
    with tf.name_scope('Discriminator'):
        d = Discriminator(is_training=True, is_testing=False)
    print(len(tf.get_variable_scope().global_variables()))

    tf.get_variable_scope().reuse_variables()
    with tf.name_scope('DiscriminatorInfer'):
        d_infer = Discriminator(is_training=False, is_testing=False)
    with tf.name_scope('DiscriminatorTest'):
        d_test = Discriminator(is_training=False, is_testing=True)
    with tf.name_scope('GeneratorInfer'):
        g_infer = Generator(is_training=False)

    print("Graph loaded")
            transforms.Resize(resize),
            transforms.CenterCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    imgs = []

    for imgfile in args.files:
        img = transform(Image.open(imgfile).convert("RGB"))
        imgs.append(img)

    imgs = torch.stack(imgs, 0).to(device)

    g_ema = Generator(args.size, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.to(device)
    clip_loss = CLIPLoss()

    with torch.no_grad():
        noise_sample = torch.randn(n_mean_latent, 512, device=device)
        latent_out = g_ema.style(noise_sample)
        print(latent_out.shape)

        latent_mean = latent_out.mean(0)
        latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5

    percept = lpips.PerceptualLoss(
        model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
Example #11
0
                                                 batch_size=params.batch_size,
                                                 shuffle=False)
test_data_B = DatasetFromFolder(data_dir,
                                subfolder='testB',
                                transform=transform)
test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B,
                                                 batch_size=params.batch_size,
                                                 shuffle=False)

# Get specific test images
test_real_A_data = test_data_A.__getitem__(11).unsqueeze(
    0)  # Convert to 4d tensor (BxNxHxW)
test_real_B_data = test_data_B.__getitem__(91).unsqueeze(0)

# Models
G_A = Generator(3, params.ngf, 3, params.num_resnet)
G_B = Generator(3, params.ngf, 3, params.num_resnet)
D_A = Discriminator(3, params.ndf, 1)
D_B = Discriminator(3, params.ndf, 1)
G_A.normal_weight_init(mean=0.0, std=0.02)
G_B.normal_weight_init(mean=0.0, std=0.02)
D_A.normal_weight_init(mean=0.0, std=0.02)
D_B.normal_weight_init(mean=0.0, std=0.02)
G_A.cuda()
G_B.cuda()
D_A.cuda()
D_B.cuda()

# Set the logger
D_A_log_dir = save_dir + 'D_A_logs'
D_B_log_dir = save_dir + 'D_B_logs'
Example #12
0
                                           batch_size=1,
                                           num_workers=0,
                                           shuffle=True,
                                           pin_memory=True)

validation_loader = torch.utils.data.DataLoader(validation_dataset,
                                                batch_size=4,
                                                num_workers=0,
                                                pin_memory=True)

# Se emplea la GPU si esta disponible
device = 'cuda' if torch.cuda.is_available() else 'cpu'

in_channels, out_channels = 3, 3
# Instanciar generador y discriminador
generator = Generator(in_channels, out_channels)
discriminator = Discriminator(in_channels, out_channels)

# If model is specified load model.
ini_epoch = 0
if args.model:
    checkpoint_file = torch.load(args.model, map_location='cpu')
    generator.load_state_dict(checkpoint_file['generator'])
    discriminator.load_state_dict(checkpoint_file['discriminator'])
    ini_epoch = checkpoint_file['epoch']

generator.to(device)
discriminator.to(device)

# Optimizador.
gen_opt = torch.optim.Adam(generator.parameters(),
Example #13
0
import torch
import pickle
import numpy as np
from hparams import hparams
from utils import pad_seq_to_2
from utils import quantize_f0_numpy
from model import Generator_3 as Generator
from model import Generator_6 as F0_Converter
import matplotlib.pyplot as plt
import os
import glob

out_len = 408

device = 'cuda:1'
G = Generator(hparams).eval().to(device)
g_checkpoint = torch.load('run/models/234000-G.ckpt',
                          map_location=lambda storage, loc: storage)
G.load_state_dict(g_checkpoint['model'])

metadata = pickle.load(
    open('/hd0/speechsplit/preprocessed/spmel/train.pkl', "rb"))

sbmt_i = metadata[0]
emb_org = torch.from_numpy(sbmt_i[1]).unsqueeze(0).to(device)

root_dir = "/hd0/speechsplit/preprocessed/spmel"
feat_dir = "/hd0/speechsplit/preprocessed/raptf0"

# mel-spectrogram, f0 contour load
x_org = np.load(os.path.join(root_dir, sbmt_i[2]))
Example #14
0
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    args.latent = 512
    args.n_mlp = 8

    args.start_iter = 0

    if args.n_sample == 0:
        args.n_sample = args.batch

    generator = Generator(
        args.size,
        args.latent,
        args.n_mlp,
        args.p_categories,
        args.c_categories,
        args.p_size,
        channel_multiplier=args.channel_multiplier).to(device)

    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      args.p_categories,
                      args.c_categories,
                      args.p_size,
                      channel_multiplier=args.channel_multiplier).to(device)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    # netD0 = Discriminator(
Example #15
0
    def train(self, src_emb, tgt_emb, evaluator, **kwargs):
        params = self.params
        # Load data
        if not path.exists(params.data_dir):
            raise "Data path doesn't exists: %s" % params.data_dir

        gan_model = kwargs.get('gan_model', 3)
        dir_model = params.model_dir

        # Embeddings
        src = src_emb
        tgt = tgt_emb

        # Define samplers
        vocab_size = params.most_frequent_sampling_size
        try:
            if params.uniform_sampling:
                raise FileNotFoundError
            weights_src, weights_tgt = get_frequencies()
            self.logger.info('Use frequencies for sampling')
            weights_src = downsample_frequent_words(weights_src)
            weights_tgt = downsample_frequent_words(weights_tgt)
            weights_src[vocab_size:] = 0.0
            weights_tgt[vocab_size:] = 0.0
            weights_src /= weights_src.sum()
            weights_tgt /= weights_tgt.sum()
        except FileNotFoundError:
            weights_src = np.ones(vocab_size) / vocab_size
            weights_tgt = np.ones(vocab_size) / vocab_size

        sampler_src = WeightedSampler(weights_src)
        sampler_tgt = WeightedSampler(weights_tgt)
        iter_src = sampler_src.get_iterator(mini_batch_size)
        iter_tgt = sampler_tgt.get_iterator(mini_batch_size)

        # Create models
        g = Generator(input_size=params.g_input_size,
                      output_size=params.g_output_size)
        d_tgt = Discriminator(  # Discriminator (source-side)
            input_size=params.g_output_size,
            hidden_size=params.d_hidden_size)
        d_src = None  # Discriminator (target-side)
        lambda_r = 0.0  # Coefficient of reconstruction loss
        if params.gan_model == 1:  # Model 1: Undirectional Transformation
            self.logger.info('Model 1')
        elif gan_model == 2:  # Model 2: Bidirectional Transformation
            d_src = Discriminator(  # Discriminator (source-side)
                input_size=g_input_size,
                hidden_size=d_hidden_size)
            self.logger.info('Model 2')
        else:  # Model 3: Adversarial Autoencoder
            lambda_r = params.lambda_r
            self.logger.info('Model 3: lambda = {}'.format(lambda_r))

        # Define loss function and optimizers
        g_optimizer = optim.Adam(g.parameters(), lr=g_learning_rate)
        d_tgt_optimizer = optim.Adam(d_tgt.parameters(),
                                     lr=params.d_learning_rate)
        if d_src is not None:
            d_src_optimizer = optim.Adam(d_src.parameters(),
                                         lr=d_learning_rate)

        if torch.cuda.is_available:
            # Move the network and the optimizer to the GPU
            g = g.cuda()
            d_tgt = d_tgt.cuda()
            if d_src is not None:
                d_src = d_src.cuda()

        lowest_loss = 10000000  # lowest loss value (standard for saving checkpoint)

        for itr, (batch_src, batch_tgt) in enumerate(zip(iter_src, iter_tgt),
                                                     start=1):
            if src.weight.is_cuda:
                batch_src = batch_src.cuda()
                batch_tgt = batch_tgt.cuda()
            embs_src = src(batch_src)
            embs_tgt = tgt(batch_tgt)

            # Generator
            embs_tgt_mapped = g(
                embs_src)  # target embs mapped from source embs
            g_loss = -(d_tgt(embs_tgt_mapped, inject_noise=False) +
                       1e-16).log().mean()  # discriminate in the trg side
            if d_src is not None:  # Model 2
                embs_src_mapped = g(
                    embs_tgt, tgt2src=True)  # src embs mapped from trg embs
                g_loss += -(d_src(embs_src_mapped, inject_noise=False) +
                            1e-16).log().mean()  # target-to-source

            if lambda_r > 0:  # Model 3
                embs_src_r = g(embs_tgt_mapped,
                               tgt2src=True)  # reconstructed src embs
                g_loss_r = 1.0 - F.cosine_similarity(embs_src,
                                                     embs_src_r).mean()
                g_loss += lambda_r * g_loss_r

            ## Update
            g_optimizer.zero_grad()  # reset the gradients
            g_loss.backward()
            g_grad_norm = float(g.W._grad.norm(2, dim=1).mean())
            g_optimizer.step()

            # Discriminator (target-side)
            preds = d_tgt(embs_tgt)
            hits = int(sum(preds >= 0.5))
            d_tgt_loss = -(preds + 1e-16).log().sum()  # -log(D(Y))
            preds = d_tgt(embs_tgt_mapped.detach())
            hits += int(sum(preds < 0.5))
            d_tgt_loss += -(1.0 - preds +
                            1e-16).log().sum()  # -log(1 - D(G(X)))
            d_tgt_loss /= embs_tgt.size(0) + embs_tgt_mapped.size(0)
            d_tgt_acc = hits / float(
                embs_tgt.size(0) + embs_tgt_mapped.size(0))

            ## Update
            d_tgt_optimizer.zero_grad()
            d_tgt_loss.backward()
            d_tgt_optimizer.step()

            # Discriminator (source-side)
            if d_src is not None:
                preds = d_src(embs_src)
                hits = int(sum(preds >= 0.5))
                d_src_loss = -(preds + 1e-16).log().sum()  # -log(D(X))
                preds = d_src(embs_src_mapped.detach())
                hits = int(sum(preds < 0.5))
                d_src_loss += -(1.0 - preds +
                                1e-16).log().sum()  # -log(1 - D(G(Y)))
                d_src_loss /= embs_src.size(0) + embs_src_mapped.size(0)
                d_src_acc = hits / float(
                    embs_src.size(0) + embs_src_mapped.size(0))

                ## Update
                d_src_optimizer.zero_grad()
                d_src_loss.backward()
                d_src_optimizer.step()

            if itr % 100 == 0:
                itr_template = '[{:' + str(int(np.log10(max_iters)) +
                                           1) + 'd}]'
                status = [itr_template.format(itr)]
                status.append('{:.2f}'.format(d_tgt_acc * 100))
                status.append('{:.3f}'.format(float(d_tgt_loss.data)))
                if d_src is not None:
                    status.append('{:.3f}'.format(d_src_acc))
                    status.append('{:.3f}'.format(float(d_src_loss.data)))
                status.append('{:.3f}'.format(float(g_loss.data)))
                if lambda_r > 0:
                    status.append('{:.3f}'.format(float(g_loss_r.data)))
                status.append('{:.3f}'.format(g_grad_norm))
                WW = g.W.t().matmul(g.W)
                if WW.is_cuda:
                    WW = WW.cpu()
                orth_score = np.linalg.norm(WW.data.numpy() -
                                            np.identity(WW.size(0)))
                status.append('{:.2f}'.format(orth_score))
                self.logger.info(' '.join(status))

                if itr % 1000 == 0:
                    filename = path.join(
                        dir_model, 'g{}_checkpoint.pth'.format(gan_model))
                    self.logger.info('Save a model to ' + filename)
                    g.save(filename)
                    all_precisions = evaluator.get_all_precisions(
                        g(src_emb.weight).data)
                    print(json.dumps(all_precisions))
            if itr > max_iters:
                break

            if dir_model is None:
                continue

            # Save checkpoint
            if itr > 10000 and lowest_loss > float(g_loss.data):
                filename = path.join(dir_model,
                                     'g{}_best.pth'.format(gan_model))
                self.logger.info('Save a model to ' + filename)
                lowest_loss = float(g_loss.data)
                g.save(filename)
                all_precisions = evaluator.get_all_precisions(
                    g(src_emb.weight).data)
                print(json.dumps(all_precisions))
        filename = path.join(dir_model, 'g{}_final.pth'.format(gan_model))
        self.logger.info('Save a model to ' + filename)
        g.save(filename)
        return g
Example #16
0
    train_loader = data.DataLoader(train_data,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_threads)

    word_embedding = None

    # pretrained text embedding model
    print('Loading a pretrained text embedding model...')
    txt_encoder = VisualSemanticEmbedding(args.embed_ndim)
    txt_encoder.load_state_dict(torch.load(args.text_embedding_model))
    txt_encoder = txt_encoder.txt_encoder
    for param in txt_encoder.parameters():
        param.requires_grad = False

    G = Generator(use_vgg=args.use_vgg)
    D = Discriminator()

    if not args.no_cuda:
        txt_encoder.cuda()
        G.cuda()
        D.cuda()

    g_optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad,
                                          G.parameters()),
                                   lr=args.learning_rate,
                                   betas=(args.momentum, 0.999))
    d_optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad,
                                          D.parameters()),
                                   lr=args.learning_rate,
                                   betas=(args.momentum, 0.999))
Example #17
0
def test_benchmark(upscale_factor, epoch_num):
    model_name = 'netG_epoch_{}_100.pth'.format(upscale_factor)

    results = {
        'Set5': {
            'psnr': [],
            'ssim': []
        },
        'Set14': {
            'psnr': [],
            'ssim': []
        },
        'BSD100': {
            'psnr': [],
            'ssim': []
        },
        'Urban100': {
            'psnr': [],
            'ssim': []
        },
        'SunHays80': {
            'psnr': [],
            'ssim': []
        }
    }

    model = Generator(upscale_factor).eval()
    if torch.cuda.is_available():
        model = model.cuda()
    model.load_state_dict(
        torch.load('epochs/' + model_name, map_location=torch.device('cpu')))

    test_set = TestDatasetFromFolder('data/test',
                                     upscale_factor=upscale_factor)
    test_loader = DataLoader(dataset=test_set,
                             num_workers=4,
                             batch_size=1,
                             shuffle=False)
    test_bar = tqdm(test_loader, desc='[testing benchmark datasets]')

    out_path = 'benchmark_results/SRF_' + str(upscale_factor) + '/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    for image_name, lr_image, hr_restore_img, hr_image in test_bar:
        image_name = image_name[0]

        # volatile is no longer available
        # lr_image = Variable(lr_image, volatile=True)
        # hr_image = Variable(hr_image, volatile=True)
        # image = Variable(ToTensor()(lr_image), volatile=True).unsqueeze(0)
        with torch.no_grad():
            lr_image = Variable(lr_image)
            hr_image = Variable(hr_image)

        if torch.cuda.is_available():
            lr_image = lr_image.cuda()
            hr_image = hr_image.cuda()

        sr_image = model(lr_image)
        mse = ((hr_image - sr_image)**2).data.mean()
        psnr = 10 * log10(1 / mse)
        ssim = pytorch_ssim.ssim(sr_image, hr_image).data.item()

        test_images = torch.stack([
            display_transform()(hr_restore_img.squeeze(0)),
            display_transform()(hr_image.data.cpu().squeeze(0)),
            display_transform()(sr_image.data.cpu().squeeze(0))
        ])
        image = utils.make_grid(test_images, nrow=3, padding=5)
        utils.save_image(image,
                         out_path + image_name.split('.')[0] +
                         '_psnr_%.4f_ssim_%.4f.' % (psnr, ssim) +
                         image_name.split('.')[-1],
                         padding=5)

        # save psnr\ssim
        results[image_name.split('_')[0]]['psnr'].append(psnr)
        results[image_name.split('_')[0]]['ssim'].append(ssim)

    out_path = 'statistics/'
    saved_results = {'psnr': [], 'ssim': []}
    for item in results.values():
        psnr = np.array(item['psnr'])
        ssim = np.array(item['ssim'])
        if (len(psnr) == 0) or (len(ssim) == 0):
            psnr = 'N/A'
            ssim = 'N/A'
        else:
            psnr = psnr.mean()
            ssim = ssim.mean()
        saved_results['psnr'].append(psnr)
        saved_results['ssim'].append(ssim)

    data_frame = pd.DataFrame(saved_results, results.keys())
    data_frame.to_csv(out_path + 'srf_' + str(upscale_factor) +
                      '_test_results.csv',
                      index_label='DataSet')
    return data_frame
Example #18
0
 def register_generate(self, model, cuda):
     self.generate = Generator(model, cuda)
    def init_model(self):
        #if not os.path.exists(self.modelPath) or os.listdir(self.modelPath) == []:
        # Create the whole training graph
        self.realX = tf.placeholder(tf.float32, [None, self.imgDim, self.imgDim, 3], name="realX")
        self.realLabels = tf.placeholder(tf.float32, [None, self.numClass], name="realLabels")
        self.realLabelsOneHot = tf.placeholder(tf.float32, [None, self.imgDim, self.imgDim, self.numClass], name="realLabelsOneHot")
        self.fakeLabels = tf.placeholder(tf.float32, [None, self.numClass], name="fakeLabels")
        self.fakeLabelsOneHot = tf.placeholder(tf.float32, [None, self.imgDim, self.imgDim, self.numClass], name="fakeLabelsOneHot")
        self.alphagp = tf.placeholder(tf.float32, [], name="alphagp")


        # Initialize the generator and discriminator
        self.Gen = Generator()
        self.Dis = Discriminator()



        # -----------------------------------------------------------------------------------------
        # -----------------------------------Create D training pipeline----------------------------
        # -----------------------------------------------------------------------------------------

        # Create fake image
        self.fakeX = self.Gen.recForward(self.realX, self.fakeLabelsOneHot)
        YSrc_real, YCls_real = self.Dis.forward(self.realX)
        YSrc_fake, YCls_fake = self.Dis.forward(self.fakeX)

        YCls_real = tf.squeeze(YCls_real)  # remove void dimensions
        self.d_loss_real = - tf.reduce_mean(YSrc_real)
        self.d_loss_fake = tf.reduce_mean(YSrc_fake)
        self.d_loss_cls = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.realLabels,logits=YCls_real, name="d_loss_cls")) / self.batchSize




        #TOTAL LOSS
        self.d_loss = self.d_loss_real + self.d_loss_fake + self.lambdaCls * self.d_loss_cls #+ self.d_loss_gp
        vars = tf.trainable_variables()
        self.d_params = [v for v in vars if v.name.startswith('Discriminator/')]
        train_D = tf.train.AdamOptimizer(learning_rate=self.learningRateD, beta1=0.5, beta2=0.999)
        self.train_D_loss = train_D.minimize(self.d_loss, var_list=self.d_params)
        # gvs = self.train_D.compute_gradients(self.d_loss, var_list=self.d_params)
        # capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs]
        # self.train_D_loss = self.train_D.apply_gradients(capped_gvs)

        #-------------GRADIENT PENALTY---------------------------
        interpolates = self.alphagp * self.realX + (1 - self.alphagp) * self.fakeX
        out,_ = self.Dis.forward(interpolates)
        gradients = tf.gradients(out, [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1,2,3]))
        _gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0))
        self.d_loss_gp   = self.lambdaGp * _gradient_penalty
        self.train_D_gp = train_D.minimize(self.d_loss_gp, var_list=self.d_params)
        # gvs = self.train_D.compute_gradients(self.d_loss_gp)
        # capped_gvs = [(ClipIfNotNone(grad), var) for grad, var in gvs]
        # self.train_D_gp = self.train_D.apply_gradients(capped_gvs)
        #-------------------------------------------------------------------------------

        #-----------------accuracy--------------------------------------------------------------
        YCls_real_sigmoid = tf.sigmoid(YCls_real)
        predicted = tf.to_int32(YCls_real_sigmoid > 0.5)
        labels = tf.to_int32(self.realLabels)
        correct = tf.to_float(tf.equal(predicted, labels))
        hundred = tf.constant(100.0)
        self.accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), axis=0) * hundred
        #--------------------------------------------------------------------------------------


        #CLIP D WEIGHTS
        #self.clip_D = [p.assign(tf.clip_by_value(p, -self.clipD, self.clipD)) for p in self.d_params]


        # -----------------------------------------------------------------------------------------
        # ----------------------------Create G training pipeline-----------------------------------
        # -----------------------------------------------------------------------------------------
        #original to target and target to original domain
        #self.fakeX = self.Gen.recForward(self.realX, self.fakeLabelsOneHot)
        rec_x = self.Gen.recForward(self.fakeX,self.realLabelsOneHot)

        # compute losses
        #out_src, out_cls = self.Dis.forward(self.fakeX)
        self.g_loss_adv = - tf.reduce_mean(YSrc_fake)
        self.g_loss_cls = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.fakeLabels,logits=tf.squeeze(YCls_fake))) / self.batchSize

        self.g_loss_rec = tf.reduce_mean(tf.abs(self.realX - rec_x))
        # total G loss and optimize
        self.g_loss = self.g_loss_adv + self.lambdaCls * self.g_loss_cls + self.lambdaRec * self.g_loss_rec
        train_G = tf.train.AdamOptimizer(learning_rate=self.learningRateG, beta1=0.5, beta2=0.999)
        self.g_params = [v for v in vars if v.name.startswith('Generator/')]

        self.train_G_loss = train_G.minimize(self.g_loss, var_list=self.g_params)
        # gvs = self.train_G.compute_gradients(self.g_loss)
        # capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs]
        # self.train_G_loss = self.train_G.apply_gradients(capped_gvs)

        #TF session

        self.saver = tf.train.Saver()
        self.init = tf.global_variables_initializer()
        self.sess = tf.Session()
        self.sess.run(self.init)


        #restore model if it exists
        if os.listdir(self.modelPath) == []:
            self.init = tf.global_variables_initializer()
            self.sess = tf.Session()
            self.sess.run(self.init)
            self.epoch_index = 1
            self.picture = 0


        else:
            self.sess = tf.Session()
            #self.saver = tf.train.import_meta_graph(os.path.join(self.modelPath, "model49999_1.meta"))
            checkpoint = tf.train.latest_checkpoint(self.modelPath)
            self.saver.restore(self.sess, checkpoint)
            #-------------------------------------------------------------------------------------------

            model_info = checkpoint.split("model/model",1)[1].split("_",1)
            self.picture = int(model_info[0])
            self.epoch_index = int(model_info[1])
Example #20
0
 def register_generate(self, model, model_cnnseq2sample, cuda):
     generator = Generator(model, cuda)
     self.generate_cnnseq2sample = GeneratorCNNSeq2Sample(
         generator, model_cnnseq2sample, cuda)
    models_dir = os.path.join(
        '/fast-stripe/workspaces/deval/synthetic-data/cgan_gen_learning_new/gen_eval',
        opt.expt, 'models')
    os.makedirs(images_dir, exist_ok=True)
    os.makedirs(ckpt_dir, exist_ok=True)
    os.makedirs(videos_dir, exist_ok=True)
    os.makedirs(models_dir, exist_ok=True)

    #Loss functions
    if opt.arch == 'mlp':
        adversarial_loss = torch.nn.MSELoss()

    # Initialize generator and discriminator
    if opt.arch == 'mlp':
        generator = Generator(latent_dim=opt.latent_dim,
                              img_shape=img_shape,
                              n_classes=opt.n_classes)
        discriminator = Discriminator(img_shape=img_shape,
                                      n_classes=opt.n_classes)
    elif opt.arch == 'cnn':
        generator = Generator_cnn(latent_dim=opt.latent_dim,
                                  DIM=256,
                                  classes=opt.n_classes)

    # HCN pre-trained model
    if opt.temporal_length == 25:
        fc7_dim = 2048
    elif opt.temporal_length == 60:
        fc7_dim = 4096
    if opt.num_per == 1:
        hcn_model = HCN4(in_channel=3,
import numpy as np
import torch
from torchvision.utils import save_image
from tqdm import tqdm

from model import Generator
from sampling import sample_truncated_normal, linear_interpolation, n_classes

device = 'cuda'
truncated = 0.8

model = Generator(n_feat=36, codes_dim=24, n_classes=n_classes).to(device)
model.load_state_dict(torch.load('generator.pth'))

z_points = [sample_truncated_normal(120, truncated, device) for _ in range(16)]
zs = []
for i in range(5 - 1):
    zs.append(linear_interpolation(z_points[i], z_points[i + 1], 24 * 4))
zs = torch.cat(zs, 0).float()

for j in tqdm(range(zs.size()[0])):
    z = zs[j].unsqueeze(0)

    aux_labels = np.full(1, 0)
    aux_labels_ohe = np.eye(n_classes)[aux_labels]
    aux_labels_ohe = torch.from_numpy(aux_labels_ohe[:, :]).float().cuda()

    out = model(z, aux_labels_ohe, get_feature_maps=True)

    frame = []
Example #23
0
    n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = n_gpu > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    args.latent = 512
    args.n_mlp = 8

    args.start_iter = 0

    encoder = Encoder(args.channel).to(device)
    generator = Generator(args.channel).to(device)

    discriminator = Discriminator(
        args.size, channel_multiplier=args.channel_multiplier).to(device)
    cooccur = CooccurDiscriminator(args.channel).to(device)

    e_ema = Encoder(args.channel).to(device)
    g_ema = Generator(args.channel).to(device)
    e_ema.eval()
    g_ema.eval()
    accumulate(e_ema, encoder, 0)
    accumulate(g_ema, generator, 0)

    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = optim.Adam(
Example #24
0
    def train(self,args):
            model_path='checkpoint/'
            data_path='./data/MNIST'
            vis=visdom.Visdom(env=args.env)
            dataset = args.dataset
            trainset = torchvision.datasets.MNIST(root=dataset, train=True,
                                              download=True, transform=self.train_transform)
            dataloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset),
                                             shuffle=True)
            G=Generator(args.dim,args.channel)
            D=Discriminator(args.channel,args.alpha)

            G_optimizer =optim.Adam(G.parameters(),lr=args.g_lr,betas=(args.beta1,args.beta2))
            D_optimizer=optim.Adam(D.parameters(),lr=args.d_lr,betas=(args.beta1,args.beta2))
            criterion=torch.nn.BCELoss()

            true_labels=Variable(torch.ones(args.batch_size))
            fake_labels=Variable(torch.zeros(args.batch_size))

            gen_vector=Variable(torch.randn(args.batch_size,args.dim,1,1))

            train_d_times=0
            train_g_times=0

            for epoch in range(args.epochs):
                print('epoch {0}'.format(epoch+1))

                for batch_ix,(img,_) in enumerate(dataloader):
                    print("a new batch!")
                    images=Variable(img)
                    if batch_ix% args.train_d_every ==0:
                        D_optimizer.zero_grad()

                        output=D(images)
                        real_loss=criterion(output,true_labels)
                        real_loss.backward()

                        noise=Variable(torch.randn(args.batch_size,args.dim,1,1))
                        fake_images=G(noise).detach()
                        output=D(fake_images)
                        fake_loss=criterion(output,fake_labels)
                        fake_loss.backward()

                        D_optimizer.step()

                        loss=real_loss+fake_loss

                        #visualize error
                        vis.line(win='D_error',
                                 X=torch.Tensor([train_d_times]),
                                 Y=loss.data,
                                 update=None if train_d_times==0 else 'append'
                                 )
                        train_d_times+=1

                    if batch_ix % args.train_g_every==0:
                        G_optimizer.zero_grad()

                        noise=Variable(torch.randn(args.batch_size,args.dim,1,1))
                        fake_images=G(noise)
                        output=D(fake_images)
                        fake_loss=criterion(output,true_labels)
                        fake_loss.backward()

                        G_optimizer.step()
                        vis.line(win='G_error',
                                 X=torch.Tensor(train_g_times),
                                 Y=fake_loss.data,
                                 update=None if train_g_times==0 else 'append'
                                 )
                        train_g_times+=1

                    if (epoch+1) %args.save_per_epoch==0:
                        torch.save(D.state_dict(),model_path+'D_epoch_{0}.pth'.format(epoch))
                        torch.save(G.state_dict(),model_path+'G_epoch_{0}.pth'.format(epoch))

                    if (epoch+1) % args.plot_per_epoch ==0:
                        images=G(gen_vector)
                        n_row=int(math.sqrt(args.batch_size))
                        torchvision.utils.save_image(images.data[: n_row * n_row], model_path + '%d_.png'.format(epoch), normalize=True, range=(-1, 1), nrow=n_row)
            print("train finished")           
Example #25
0
def main():
    global model
    parser = argparse.ArgumentParser(description='DeFiAN')
    parser.add_argument("--cuda",
                        default=True,
                        action="store_true",
                        help="Use cuda?")
    parser.add_argument('--n_GPUs',
                        type=int,
                        default=1,
                        help='parallel training with multiple GPUs')
    parser.add_argument('--GPU_ID',
                        type=int,
                        default=0,
                        help='parallel training with multiple GPUs')
    parser.add_argument('--threads',
                        type=int,
                        default=4,
                        help='number of threads for data_scribble loading')
    parser.add_argument('--seed', type=int, default=1, help='random seed')

    parser.add_argument('--scale', type=int, default=2, help='scale factor')
    parser.add_argument('--attention', default=True, help='True for DeFiAN')
    parser.add_argument('--n_modules',
                        type=int,
                        default=10,
                        help='num of DeFiAM: 10 for DeFiAN_L; 5 for DeFiAN_S')
    parser.add_argument('--n_blocks',
                        type=int,
                        default=20,
                        help='num of RCABs: 20 for DeFiAN_L; 10 for DeFiAN_S')
    parser.add_argument(
        '--n_channels',
        type=int,
        default=64,
        help='num of channels: 64 for DeFiAN_L; 32 for DeFiAN_S')
    parser.add_argument('--activation',
                        default=nn.ReLU(True),
                        help='activation function')
    args = parser.parse_args()

    if args.cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        if args.n_GPUs == 1:
            torch.cuda.set_device(args.GPU_ID)
    cudnn.benchmark = True

    model_path = 'checkpoints/'
    if args.n_modules == 5:
        model_path = model_path + 'DeFiAN_S_x' + str(args.scale)
        result_pathes = 'DeFiAN_S/'
    elif args.n_modules == 10:
        model_path = model_path + 'DeFiAN_L_x' + str(args.scale)
        result_pathes = 'DeFiAN_L/'
    else:
        raise InterruptedError

    print("===> Building model")
    model = Generator(args.n_channels,
                      args.n_blocks,
                      args.n_modules,
                      args.activation,
                      attention=args.attention,
                      scale=[args.scale])

    print("===> Calculating NumParams & FLOPs")
    input_size = (3, 480 // args.scale, 360 // args.scale)
    flops, params = get_model_complexity_info(model,
                                              input_size,
                                              as_strings=False,
                                              print_per_layer_stat=False)
    print('\tParam = {:.3f}K\n\tFLOPs = {:.3f}G on {}'.format(
        params * (1e-3), flops * (1e-9), input_size))

    cpk = torch.load(model_path + '.pth',
                     map_location={'cuda:1': 'cuda:0'})["state_dict"]
    model.load_state_dict(cpk, strict=False)
    model = model.cuda()

    data_valid = [
        'Set5_LR_bicubic', 'Urban100_LR_bicubic', 'Manga109_LR_bicubic'
    ]
    print('====>Testing...')
    for i in range(len(data_valid)):
        result_path = result_pathes + data_valid[i] + '_x' + str(args.scale)
        valid_path = '/mnt/Datasets/Test/' + data_valid[i]
        if not os.path.exists(result_path):
            os.makedirs(result_path)
        valid_psnr, valid_ssim = validation(valid_path, result_path, model,
                                            args.scale)
        print('\t {} --- PSNR = {:.4f} SSIM = {:.4f}'.format(
            data_valid[i], valid_psnr, valid_ssim))
Example #26
0
def train(device, args):
    Tensor = torch.cuda.FloatTensor if device.type == 'cuda' else torch.FloatTensor
    img_shape = (args.channels, args.img_size, args.img_size)
    adversarial_loss = torch.nn.BCELoss().to(device)
    generator = Generator(args.latent_dim, img_shape).to(device)
    discriminator = Discriminator(img_shape).to(device)
    train_dataset = datasets.MNIST(args.data,
                                   train=True,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, ), (0.5, ))
                                   ]))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            generator.load_state_dict(checkpoint['generator'])
            discriminator.load_state_dict(checkpoint['discriminator'])
            optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            optimizer_D.load_state_dict(checkpoint['optimizer_D'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    for epoch in range(args.start_epoch, args.epochs):
        start_time = time.time()
        for batch_idx, (imgs, _) in enumerate(train_loader):

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0),
                             requires_grad=False)
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0),
                            requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs).to(device)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = Variable(
                Tensor(np.random.normal(0, 1,
                                        (imgs.shape[0], args.latent_dim))))

            # Generate a batch of images
            gen_imgs = generator(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),
                                         fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()
            if batch_idx % args.print_freq == 0:
                end_time = time.time()
                if batch_idx == 0:
                    thoughput = args.batch_size / (end_time - start_time)
                else:
                    thoughput = args.batch_size * args.print_freq / (
                        end_time - start_time)
                print(
                    'Train Epoch: {}/{} [{:>5d}/{} ({:>3.0f}%)] D Loss: {:.6f} G Loss: {:.6f} {:>8.1f} imgs/sec  '
                    .format(epoch, args.epochs, batch_idx * len(imgs),
                            len(train_loader.dataset),
                            100. * batch_idx / len(train_loader),
                            d_loss.item(), g_loss.item(), thoughput))
                start_time = time.time()
        if args.save_images:
            os.makedirs(args.save_images, exist_ok=True)
            p = os.path.join(args.save_images, "gan_{}.png".format(epoch))
            print("=> saving image '{}'".format(p))
            save_image(gen_imgs.data[:25], p, nrow=5, normalize=True)
        save_checkpoint({
            'epoch': epoch + 1,
            'generator': generator.state_dict(),
            'discriminator': discriminator.state_dict(),
            'optimizer_G': optimizer_G.state_dict(),
            'optimizer_D': optimizer_D.state_dict(),
        })
Example #27
0
                        default="models/classifiers")

    args = parser.parse_args()

    args.latent = 512
    args.n_mlp = 8

    yaml_config = {}
    with open(args.config, 'r') as stream:
        try:
            yaml_config = yaml.load(stream)
        except yaml.YAMLError as exc:
            print(exc)

    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      channel_multiplier=args.channel_multiplier).to(device)
    new_state_dict = g_ema.state_dict()
    checkpoint = torch.load(args.ckpt)

    ext_state_dict = torch.load(args.ckpt)['g_ema']
    g_ema.load_state_dict(checkpoint['g_ema'])
    new_state_dict.update(ext_state_dict)
    g_ema.load_state_dict(new_state_dict)
    g_ema.eval()
    g_ema.to(device)

    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(args.truncation_mean)
    else:
Example #28
0
    def train(self, src_emb, tgt_emb):
        params = self.params
        # Load data
        if not os.path.exists(params.data_dir):
            raise "Data path doesn't exists: %s" % params.data_dir

        en = src_emb
        it = tgt_emb

        params = _get_eval_params(params)
        #eval = Evaluator(params, src_emb.weight.data, tgt_emb.weight.data, use_cuda=True)

        for _ in range(params.num_random_seeds):
            # Create models
            g = Generator(input_size=params.g_input_size,
                          output_size=params.g_output_size)
            d = Discriminator(input_size=params.d_input_size,
                              hidden_size=params.d_hidden_size,
                              output_size=params.d_output_size)

            g.apply(self.weights_init3)
            seed = random.randint(0, 1000)
            # init_xavier(g)
            # init_xavier(d)
            self.initialize_exp(seed)

            # Define loss function and optimizers
            loss_fn = torch.nn.BCELoss()
            loss_fn2 = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
            d_optimizer = optim.SGD(d.parameters(), lr=params.d_learning_rate)
            g_optimizer = optim.SGD(g.parameters(), lr=params.g_learning_rate)

            if torch.cuda.is_available():
                # Move the network and the optimizer to the GPU
                g = g.cuda()
                d = d.cuda()
                loss_fn = loss_fn.cuda()
                loss_fn2 = loss_fn2.cuda()
            # true_dict = get_true_dict(params.data_dir)
            d_acc_epochs = []
            g_loss_epochs = []
            #d_losses = []
            #g_losses = []
            #w_losses = []
            acc_epochs = []
            csls_epochs = []
            best_valid_metric = -1
            lowest_loss = 1e5
            # logs for plotting later
            log_file = open(
                "log_src_tgt.txt",
                "w")  # Being overwritten in every loop, not really required
            log_file.write("epoch, dis_loss, dis_acc, g_loss\n")

            try:
                for epoch in range(params.num_epochs):
                    hit = 0
                    total = 0
                    start_time = timer()
                    d_losses = []
                    g_losses = []
                    w_losses = []

                    for mini_batch in range(
                            0,
                            params.iters_in_epoch // params.mini_batch_size):
                        for d_index in range(params.d_steps):
                            d_optimizer.zero_grad()  # Reset the gradients
                            d.train()
                            src_batch, tgt_batch = self.get_batch_data_fast_new(
                                en, it)
                            fake, _ = g(src_batch)
                            fake = fake.detach()
                            real = tgt_batch

                            input = torch.cat([real, fake], 0)
                            output = to_variable(
                                torch.FloatTensor(
                                    2 * params.mini_batch_size).zero_())

                            output[:params.
                                   mini_batch_size] = 1 - params.smoothing
                            output[params.mini_batch_size:] = params.smoothing

                            pred = d(input)
                            d_loss = loss_fn(pred, output)
                            d_loss.backward(
                            )  # compute/store gradients, but don't change params
                            d_losses.append(d_loss.data.cpu().numpy())
                            discriminator_decision = pred.data.cpu().numpy()
                            hit += np.sum(
                                discriminator_decision[:params.mini_batch_size]
                                >= 0.5)
                            hit += np.sum(
                                discriminator_decision[params.mini_batch_size:]
                                < 0.5)
                            d_optimizer.step(
                            )  # Only optimizes D's parameters; changes based on stored gradients from backward()

                            # Clip weights
                            _clip(d, params.clip_value)

                            sys.stdout.write(
                                "[%d/%d] :: Discriminator Loss: %f \r" %
                                (mini_batch, params.iters_in_epoch //
                                 params.mini_batch_size,
                                 np.asscalar(np.mean(d_losses))))
                            sys.stdout.flush()

                        total += 2 * params.mini_batch_size * params.d_steps

                        for g_index in range(params.g_steps):
                            # 2. Train G on D's response (but DO NOT train D on these labels)
                            g_optimizer.zero_grad()
                            d.eval()
                            # input, output = self.get_batch_data_fast(en, it, g, detach=False)
                            src_batch, tgt_batch = self.get_batch_data_fast_new(
                                en, it)
                            fake, recon = g(src_batch)
                            real = tgt_batch
                            # input = torch.cat([fake, real], 0)
                            # input = torch.cat([real, fake], 0)
                            output = to_variable(
                                torch.FloatTensor(
                                    2 * params.mini_batch_size).zero_())
                            output[:params.
                                   mini_batch_size] = 1 - params.smoothing
                            output[params.mini_batch_size:] = params.smoothing

                            # pred = d(input)
                            pred = d(fake)
                            output2 = to_variable(
                                torch.FloatTensor(
                                    params.mini_batch_size).zero_())
                            output2 = output2 + 1 - params.smoothing
                            # g_loss = loss_fn(pred, 1 - output)
                            # g_loss = loss_fn(pred, 1- output) +  1.0 - torch.mean(loss_fn2(src_batch,recon))
                            g_loss = loss_fn(pred, output2)
                            #g_loss = loss_fn(pred, output2) + params.rwcon_weight*(1.0 -
                            #torch.mean(loss_fn2(src_batch,recon)))

                            g_loss.backward()
                            g_losses.append(g_loss.data.cpu().numpy())
                            g_optimizer.step()  # Only optimizes G's parameters

                            # Orthogonalize
                            self.orthogonalize(g.map1.weight.data)

                            sys.stdout.write(
                                "[%d/%d] ::                                     Generator Loss: %f \r"
                                % (mini_batch, params.iters_in_epoch //
                                   params.mini_batch_size,
                                   np.asscalar(np.mean(g_losses))))
                            sys.stdout.flush()

                        if epoch > params.threshold:
                            if lowest_loss > float(g_loss.data):
                                lowest_loss = float(g_loss.data)
                                W = g.map1.weight.data.cpu().numpy()
                                w_losses.append(
                                    np.linalg.norm(
                                        np.dot(W.T, W) -
                                        np.identity(params.g_input_size)))
                                for method in ['nn']:
                                    results = get_word_translation_accuracy(
                                        'en',
                                        self.src_ids,
                                        g(src_emb.weight)[0].data,
                                        'it',
                                        self.tgt_ids,
                                        tgt_emb.weight.data,
                                        method=method,
                                        path=params.validation_file)
                                    acc = results[0][1]
                                torch.save(
                                    g.state_dict(),
                                    'tune0/thu/g_seed_{}_epoch_{}_batch_{}_p@1_{:.3f}.t7'
                                    .format(seed, epoch, mini_batch, acc))
                    '''for each epoch'''
                    d_acc_epochs.append(hit / total)
                    g_loss_epochs.append(np.asscalar(np.mean(g_losses)))
                    print(
                        "Epoch {} : Discriminator Loss: {:.5f}, Discriminator Accuracy: {:.5f}, Generator Loss: {:.5f}, Time elapsed {:.2f} mins"
                        .format(epoch, np.asscalar(np.mean(d_losses)),
                                hit / total, np.asscalar(np.mean(g_losses)),
                                (timer() - start_time) / 60))

                    if (epoch + 1) % params.print_every == 0:
                        # No need for discriminator weights
                        # torch.save(d.state_dict(), 'discriminator_weights_en_es_{}.t7'.format(epoch))
                        mstart_time = timer()
                        #for method in ['csls_knn_10']:
                        for method in ['nn']:
                            results = get_word_translation_accuracy(
                                'en',
                                self.src_ids,
                                g(src_emb.weight)[0].data,
                                'it',
                                self.tgt_ids,
                                tgt_emb.weight.data,
                                method=method,
                                path=params.validation_file)
                            acc = results[0][1]
                            print('{} takes {:.2f}s'.format(
                                method,
                                timer() - mstart_time))

                        # all_precisions = eval.get_all_precisions(g(src_emb.weight)[0].data)
                        csls = 0
                        # csls = eval.calc_unsupervised_criterion(g(src_emb.weight)[0].data)
                        #csls = eval.dist_mean_cosine(g(src_emb.weight)[0].data, tgt_emb.weight.data)
                        # print(json.dumps(all_precisions))
                        # p_1 = all_precisions['validation']['adv']['without-ref']['nn'][1]

                        log_file.write("{},{:.5f},{:.5f},{:.5f}\n".format(
                            epoch + 1, np.asscalar(np.mean(d_losses)),
                            hit / total, np.asscalar(np.mean(g_losses))))
                        # log_file.write(str(all_precisions) + "\n")
                        print('Method:csls_knn_10 score:{:.4f}'.format(acc))
                        # Saving generator weights
                        torch.save(
                            g.state_dict(),
                            'tune0/generator_weights_src_tgt_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'
                            .format(seed, params.most_frequent_sampling_size,
                                    params.g_learning_rate, acc))
                        if csls > best_valid_metric:
                            best_valid_metric = csls
                            torch.save(
                                g.state_dict(),
                                'tune0/best/generator_weights_src_tgt_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'
                                .format(seed,
                                        params.most_frequent_sampling_size,
                                        params.g_learning_rate, acc))

                        acc_epochs.append(acc)
                        csls_epochs.append(csls)

                # Save the plot for discriminator accuracy and generator loss
                fig = plt.figure()
                plt.plot(range(0, params.num_epochs),
                         d_acc_epochs,
                         color='b',
                         label='discriminator')
                plt.plot(range(0, params.num_epochs),
                         g_loss_epochs,
                         color='r',
                         label='generator')
                plt.ylabel('accuracy/loss')
                plt.xlabel('epochs')
                plt.legend()
                fig.savefig('d_g.png')

            except KeyboardInterrupt:
                print("Interrupted.. saving model !!!")
                torch.save(g.state_dict(), 'g_model_interrupt.t7')
                torch.save(d.state_dict(), 'd_model_interrupt.t7')
                log_file.close()
                exit()

            log_file.close()

        return g
Example #29
0
# Weight Initialization
######################################################################
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


######################################################################
# Models
######################################################################
# creat generator and discriminator
netG = Generator()
netD = Discriminator()
# Initialize weights
netG.apply(weights_init_normal)
netD.apply(weights_init_normal)
print(netG)
print(netD)

######################################################################
# Loss Functions and Optimizers
######################################################################
# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Create batch of latent vectors that we will use to visualize
Example #30
0
import torch
from torch import nn
import torchvision
from torchvision import datasets
from model import Generator, Discriminator
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])
train_mnist = datasets.FashionMNIST(root='.',
                                    train=True,
                                    download=True,
                                    transform=transform)

train_loader = DataLoader(train_mnist, batch_size=32, shuffle=True)

train_x, train_y = next(iter(train_loader))

epochs = 50
batch_size = 32
lr = 3e-4
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

loss_function = nn.BCELoss()