def get_network(LEARNING_RATE: float, device: str):
    G = Generator().to(device)
    D = Discriminator().to(device)

    criterion = nn.BCELoss()
    d_optimizer = optim.Adam(D.parameters(), lr=LEARNING_RATE)
    g_optimizer = optim.Adam(G.parameters(), lr=LEARNING_RATE)

    return G, D, criterion, d_optimizer, g_optimizer
Exemplo n.º 2
0
def train(opt, dataloader_m=None):
    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = TinyNet()
    discriminator = Discriminator()

    cuda = True if torch.cuda.is_available() else False
    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    # Configure data loader
    os.makedirs('./data/mnist', exist_ok=True)
    dataloader = torch.utils.data.DataLoader(datasets.MNIST(
        './data/mnist',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])),
                                             batch_size=opt.batch_size,
                                             shuffle=True)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(0.9, 0.99))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(0.9, 0.99))

    # ----------
    #  Training
    # ----------
    for epoch in range(opt.n_epochs):
        for i, (imgs, imgs_ns, labels) in enumerate(dataloader):

            # Configure input
            real_imgs = Variable(labels.type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0),
                             requires_grad=False)
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0),
                            requires_grad=False)

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()
            # Sample noise as generator input
            #z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # [64, 256]
            # Generate a batch of images
            gen_imgs = generator(imgs.cuda()) + imgs_ns.cuda()  #[64, 64, 64]
            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()
            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),
                                         fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                  (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(),
                   g_loss.item()))

            batches_done = epoch * len(dataloader) + i
            if batches_done % opt.sample_interval == 0:
                save_image(gen_imgs.data[:25],
                           'images/%d.png' % batches_done,
                           nrow=5,
                           normalize=True)
                weights_path = checkpoint_path.format(epoch=batches_done,
                                                      loss=g_loss)
                print(
                    'saving generator weights file to {}'.format(weights_path))
                torch.save(generator.state_dict(), weights_path)
Exemplo n.º 3
0
optimizer_encorder = optim.RMSprop(params=NetE.parameters(),
                                   lr=lr,
                                   alpha=0.9,
                                   eps=1e-8,
                                   weight_decay=0,
                                   momentum=0,
                                   centered=False)
optimizer_decoder = optim.RMSprop(params=NetG.parameters(),
                                  lr=lr,
                                  alpha=0.9,
                                  eps=1e-8,
                                  weight_decay=0,
                                  momentum=0,
                                  centered=False)
optimizer_discriminator = optim.RMSprop(params=NetD.parameters(),
                                        lr=lr,
                                        alpha=0.9,
                                        eps=1e-8,
                                        weight_decay=0,
                                        momentum=0,
                                        centered=False)

data, _ = next(iter(dataloader))
fixed_batch = Variable(data).to(device)
vutils.save_image(fixed_batch,
                  '%s/real_samples.png' % opt.outf,
                  normalize=True)

margin = 0.6
equilibrium = 0.68
Exemplo n.º 4
0
Dc = Discriminator(3+1).to(device)

if args.dummy_input: # debug purpose
	BATCH_SIZE = 16
	s1 = torch.randn(BATCH_SIZE, 3, args.resolution, args.resolution).to(device)
	contour = torch.randn(BATCH_SIZE, 1, args.resolution, args.resolution).to(device)
	fake = G(s1, contour)
	print('fake.shape', fake.shape)
	ds_out = Ds(torch.cat([fake, s1], dim=1))
	dc_out = Dc(torch.cat([fake, contour], dim=1))
	print('Ds_out.shape', ds_out.shape)
	print('Dc_out.shape', dc_out.shape)
	sys.exit(0)

optimG = optim.Adam(G.parameters(), lr=args.lrG, betas=(0, 0.999))
optimDc = optim.Adam(Dc.parameters(), lr=args.lrD, betas=(0, 0.999))
optimDs = optim.Adam(Ds.parameters(), lr=args.lrD, betas=(0, 0.999))

# prepare dataset
dataset = IconDataset(root=args.dataset, resolution=args.resolution, pad_ratio=args.pad_ratio)
dataloader = torch.utils.data.DataLoader(dataset,
	batch_size=args.batch_size,
	shuffle=True,
	drop_last=True,
	num_workers=4,
	pin_memory=True,
)

# sample fixed inputs
for s1, s2, s3, contour in dataloader:
	break
Exemplo n.º 5
0
                        batch_size=args.batchsize,
                        num_workers=4 * args.ngpu,
                        worker_init_fn=worker_init_fn)
print("Training Datas:", len(dataset))

if not os.path.exists(args.weight_folder):
    os.mkdir(args.weight_folder)

if not os.path.exists(args.result_folder):
    os.mkdir(args.result_folder)

G = Generator(config)
D = Discriminator(config)

G_optim = optim.Adam(G.parameters(), lr=args.lr, betas=(0.5, 0.9))
D_optim = optim.Adam(D.parameters(), lr=args.lr, betas=(0.5, 0.9))

writer = SummaryWriter()

print("Running On:", device)
train_handler = Train_Handler(args.dataset_name,
                              args.epochs,
                              G_optim,
                              D_optim,
                              dataloader,
                              device=device,
                              writer=writer,
                              save_per_epoch=args.save_per_epoch,
                              print_per_iteration=args.print_per_iteration,
                              plot_per_iteration=args.plot_per_iteration,
                              weight_folder=args.weight_folder,
Exemplo n.º 6
0
            flatten_test_index]
    else:
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
    torch_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_train),
                                                   torch.LongTensor(y_train))
    data_loader = torch.utils.data.DataLoader(torch_dataset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    loss_classifier_list = []
    for epoch in range(1, num_epochs + 1):
        learning_rate = base_lr * math.pow(0.9, epoch / lr_step)
        optimizer = torch.optim.Adam([
            {
                'params': discriminator.parameters()
            },
        ],
                                     lr=learning_rate,
                                     weight_decay=l2_decay)

        discriminator.train()
        iter_data = iter(data_loader)
        num_iter = len(data_loader)
        total_clas_loss = 0
        num_batches = 0
        for it in range(0, num_iter):
            data, label = iter_data.next()
            if it % len(data_loader) == 0:
                iter_data = iter(data_loader)
            data = Variable(torch.FloatTensor(data))
Exemplo n.º 7
0
        shuffle=True,
        num_workers=0,
        pin_memory=True)

    # G = torch.nn.DataParallel( Generator(zdim = config.G['zdim'], use_batchnorm = config.G['use_batchnorm'] , use_residual_block = config.G['use_residual_block'] , num_classes = config.G['num_classes'])).cuda()
    # D = torch.nn.DataParallel( Discriminator(use_batchnorm = config.D['use_batchnorm'])).cuda()
    G = Generator(zdim=config.G['zdim'],
                  use_batchnorm=config.G['use_batchnorm'],
                  use_residual_block=config.G['use_residual_block'],
                  num_classes=config.G['num_classes']).cuda()
    D = Discriminator(use_batchnorm=config.D['use_batchnorm']).cuda()
    optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                          G.parameters()),
                                   lr=1e-4)
    optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                          D.parameters()),
                                   lr=1e-5)
    last_epoch = -1

    if config.train['resume_model'] is not None:
        e1 = resume_model(G, config.train['resume_model'])
        e2 = resume_model(D, config.train['resume_model'])
        assert e1 == e2
        last_epoch = e1

    if config.train['resume_optimizer'] is not None:
        e3 = resume_optimizer(optimizer_G, G, config.train['resume_optimizer'])
        e4 = resume_optimizer(optimizer_D, D, config.train['resume_optimizer'])
        assert e1 == e2 and e2 == e3 and e3 == e4
        last_epoch = e1
Exemplo n.º 8
0
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./mnist', train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

G = Generator().to(device)
D = Discriminator().to(device)

criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(
    G.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizerD = torch.optim.SGD(
    D.parameters(), lr=1e-3, weight_decay=1e-7, momentum=0)

label = torch.FloatTensor(batch_size)
real_label, fake_label = 1, 0

for epoch in range(train_epoch):
    for i, (imgs, _) in  enumerate(dataloader):
        # fix G, train D
        optimizerD.zero_grad()
        # load imgs to cuda
        imgs = imgs.to(device)
        # Let D output to 1
        output = D(imgs)
        label.data.fill_(real_label)
        label = label.to(device)
        output = torch.max(output, dim=1).values
Exemplo n.º 9
0
    G_B2A.load_state_dict(torch.load('pretrained/G_B2A.pth'))
    D_A.load_state_dict(torch.load('pretrained/D_A.pth'))
    D_B.load_state_dict(torch.load('pretrained/D_B.pth'))
else:
    G_A2B.apply(weights_init)
    G_B2A.apply(weights_init)
    D_A.apply(weights_init)
    D_B.apply(weights_init)

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()


G_params = list(G_A2B.parameters()) + list(G_B2A.parameters())
D_params = list(D_A.parameters()) + list(D_B.parameters())
optimizer_G = optim.Adam(G_params, lr=0.0002, betas=[0.5, 0.999])
optimizer_D = optim.Adam(D_params, lr=0.0002, betas=[0.5, 0.999])

fake_A_buffer = utils.ReplayBuffer()
fake_B_buffer = utils.ReplayBuffer()
logger = utils.Logger(opt.n_epochs, len(loader_A), vis)

step = 0

losses = {'loss_G': 0, 'loss_G_identity': 0, 'loss_G_GAN': 0, 'loss_G_cycle': 0, 'loss_D': 0}

for epoch in range(1, opt.n_epochs+1):
    for (A, B) in zip(enumerate(loader_A), enumerate(loader_B)):

        step = step + 1
Exemplo n.º 10
0
discriminator.cuda()

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

criterion = nn.BCELoss()
criterion.cuda()
optimizer_disc = optim.Adam(discriminator.parameters(),lr=learning_rate, betas=(0.5, 0.999))
optimizer_gen = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

print("Training...")

#TRAINING
for epoch in range(n_epoch):
    for i, imgs in enumerate(training_data_loader):
        real = Variable(Tensor(imgs.shape[0], 1).fill_(0.9), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.1), requires_grad=False)
        real_imgs = Variable(imgs.type(Tensor))
        z_vec = torch.randn(imgs.shape[0], 100, 1, 1, device='cuda')
        optimizer_gen.zero_grad()
        generator_images = generator(z_vec)
Exemplo n.º 11
0
class MUNIT(nn.Module):
    def __init__(self, opt):
        super(MUNIT, self).__init__()

        # generators and discriminators
        self.gen_a = Generator(opt.ngf, opt.style_dim, opt.mlp_dim)
        self.gen_b = Generator(opt.ngf, opt.style_dim, opt.mlp_dim)
        self.dis_a = Discriminator(opt.ndf)
        self.dis_b = Discriminator(opt.ndf)
        #random style code
        self.s_a = torch.randn(opt.display_size,
                               opt.style_dim,
                               1,
                               1,
                               requires_grad=True).cuda()
        self.s_b = torch.randn(opt.display_size,
                               opt.style_dim,
                               1,
                               1,
                               requires_grad=True).cuda()

        #optimizers
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(dis_params,
                                        lr=opt.lr,
                                        beta=opt.beta1,
                                        weight_delay=opt.weight_delay)
        self.gen_opt = torch.optim.Adam(gen_params,
                                        lr=opt.lr,
                                        beta=opt.beta1,
                                        weight_delay=opt.weight_delay)

        # nerwork weight initialization
        self.apply(weight_init('kaiming'))
        self.dis_a.apply(weight_init('gaussian'))
        self.dis_b.apply(weight_init('gaussian'))

    def forward(self, x_a, x_b):
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        x_a2b = self.gen_b.decode(c_a, self.s_b)
        x_b2a = self.gen_a.decode(c_b, self.s_a)
        return x_a2b, x_b2a

    def backward_G(self, x_a, x_b):

        #encoding and decoding
        s_a = torch.randn(x_a.size(0), opt.style_dim, 1, 1,
                          requires_grad=True).cuda()
        s_b = torch.randn(x_b.size(0), opt.style_dim, 1, 1,
                          requires_grad=True).cuda()
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        x_a_rec = self.gen_a.decode(c_a, s_a_prime)
        x_b_rec = self.gen_b.decode(c_b, s_b_prime)
        x_a2b = self.gen_b.decode(c_a, s_b)
        x_b2a = self.gen_a.decode(c_b, s_a)
        c_a_rec, s_b_rec = self.gen_b.encode(x_a2b)
        c_b_rec, s_a_rec = self.gen_a.encode(x_b2a)
        x_a_fake = self.dis_a(x_b2a)
        x_b_fake = self.dis_b(x_a2b)

        #loss function
        self.loss_xa = recon_loss(x_a_rec - x_a)
        self.loss_xb = recon_loss(x_b_rec - x_b)
        self.loss_ca = recon_loss(c_a_rec - c_a)
        self.loss_cb = recon_loss(c_b_rec - c_b)
        self.loss_sa = recon_loss(s_a_rec - s_a)
        self.loss_sb = recon_loss(s_b_rec - s_b)
        if opt.x_cyc_w > 0:
            x_a2b2a = self.gen_b.decode(c_a_rec, s_a_prime)
            x_b2a2b = self.gen_a.decode(c_b_rec, s_b_prime)
            self.loss_cyc_a2b = recon_loss(x_a2b2a, x_a)
            self.loss_cyc_b2a = recon_loss(x_b2a2b, x_b)
        else:
            self.loss_cyc_a2b = 0
            self.loss_cyc_b2a = 0
        self.loss_gen_a = gen_loss(x_a_fake, opt.gan_type)
        self.loss_gen_b = gen_loss(x_b_fake, opt.gan_type)

        self.gen_total_loss = opt.gan_w*(self.loss_gen_a+self.loss_gen_b)+\
                              opt.x_w*(self.loss_xa+self.loss_xb)+\
                              opt.c_w*(self.loss_ca+self.loss_cb)+\
                              opt.s_w*(self.loss_sa+self.loss_sb)+\
                              opt.x_cyc_w*(self.loss_cyc_a2b+self.loss_cyc_b2a)
        self.gen_total_loss.backward()
        return self.gen_total_loss

    def barkward_D(self, x_a, x_b):
        s_a = torch.randn(x_a.size(0), opt.style_dim, 1, 1,
                          requires_grad=True).cuda()
        s_b = torch.randn(x_b.size(0), opt.style_dim, 1, 1,
                          requires_grad=True).cuda()
        c_a, _ = self.gen_a.encode(x_a.detach())
        c_b, _ = self.gen_b.encode(x_b.detach())
        x_a2b = self.gen_b.decode(c_a, s_b)
        x_b2a = self.gen_a.decode(c_b, s_a)
        x_a_real = self.dis_a(x_a.detach())
        x_a_fake = self.dis_a(x_b2a)
        x_b_real = self.dis_b(x_b.detach())
        x_b_fake = self.dis_b(x_a2b)

        self.loss_dis_a = dis_loss(x_a.detach(), x_b2a, opt.gan_type)
        self.loss_dis_b = dis_loss(x_b.detach(), x_a2b, opt.gan_type)
        self.dis_total_loss = opt.gan_w * (self.loss_dis_a + self.loss.dis_b)

        self.dis_total_loss.backward()
        return self.dis_total_loss

    def optimize_parameters(self, x_a, x_b):
        self.gen_opt.zero_grad()
        self.dis_opt.zero_grad()
        self.backward_G(x_a, x_b)
        self.barkward_D(x_a, x_b)
        self.gen_opt.step()
        self.dis_opt.step()

    def sample(self, x_a, x_b, num_style):
        x_a2b = []
        x_b2a = []
        for i in range(x_a.size(0)):
            s_a = torch.randn(num_style, opt.style_dim, 1, 1).cuda()
            s_b = torch.randn(num_style, opt.style_dim, 1, 1).cuda()
            c_a_i, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b_i, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_i_a2b = []
            x_i_b2a = []
            for j in range(num_style):
                #[1,opt.style_dim,1,1]
                s_a_j = s_a[j].unsqueeze(0)
                s_b_j = s_b[j].unsqueeze(0)
                x_i_a2b.append(self.gen_b.decode(c_a_i, s_b_j))
                x_i_b2a.append(self.gen_a.decode(c_b_i, s_a_j))
            #[num_style,c,h,w]
            x_i_a2b = torch.cat(x_i_a2b, dim=0)
            x_i_b2a = torch.cat(x_i_b2a, dim=0)
            x_a2b.append(x_i_a2b)
            x_b2a.append(x_i_b2a)
        #[batch_size,num_style,c,h,w]
        x_a2b = torch.stack(x_a2b)
        x_b2a = torch.stack(x_b2a)
        #[batch_size*num_style,c,h,w]
        x_a2b = x_a2b.view(-1, x_a2b.size()[2:])
        x_b2a = x_b2a.view(-1, x_b2a.size()[2:])
        return x_a2b, x_b2a

    def test(self, input, direction, style):
        output = []
        if direction == 'a2b':
            encoder = self.gen_a.encode()
            decoder = self.gen_b.decode()
        else:
            encoder = self.gen_b.encode()
            decoder = self.gen_a.decode()
        content, _ = encode(input.unsqueeze(0))
        for i in range(style.size()):
            output.append(decoder(content, style[i].unsqueeze(0)))
        output = torch.cat(output, dim=0)
        return output

    def save(self, save_dir, iterations):
        save_gen = os.path.join(save_dir, 'gen_%8d.pt' % (iteration + 1))
        save_dis = os.path.join(save_dir, 'dis_%8d.pt' % (iteration + 1))
        save_opt = os.path.join(save_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, save_gen)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, save_dis)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, save_opt)
Exemplo n.º 12
0
def train(args):
    # set the logger
    logger = Logger('./logs')

    # GPU enabling
    if (args.gpu != None):
        use_cuda = True
        dtype = torch.cuda.FloatTensor
        torch.cuda.set_device(args.gpu)
        print("Current device: %s" % torch.cuda.get_device_name(args.gpu))

    # define networks
    g_AtoB = Generator().type(dtype)
    g_BtoA = Generator().type(dtype)
    d_A = Discriminator().type(dtype)
    d_B = Discriminator().type(dtype)

    # optimizers
    optimizer_generators = Adam(
        list(g_AtoB.parameters()) + list(g_BtoA.parameters()), INITIAL_LR)
    optimizer_d_A = Adam(d_A.parameters(), INITIAL_LR)
    optimizer_d_B = Adam(d_B.parameters(), INITIAL_LR)

    # loss criterion
    criterion_mse = torch.nn.MSELoss()
    criterion_l1 = torch.nn.L1Loss()

    # get training data
    dataset_transform = transforms.Compose([
        transforms.Resize(int(IMAGE_SIZE * 1),
                          Image.BICUBIC),  # scale shortest side to image_size
        transforms.RandomCrop(
            (IMAGE_SIZE, IMAGE_SIZE)),  # random center image_size out
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5, 0.5))  # normalize
    ])
    dataloader = DataLoader(ImgPairDataset(args.dataroot, dataset_transform,
                                           'train'),
                            batch_size=BATCH_SIZE,
                            shuffle=True)

    # get some test data to display periodically
    test_data_A = torch.tensor([]).type(dtype)
    test_data_B = torch.tensor([]).type(dtype)
    for i in range(NUM_TEST_SAMPLES):
        imgA = ImgPairDataset(args.dataroot, dataset_transform,
                              'test')[i]['A'].type(dtype).unsqueeze(0)
        imgB = ImgPairDataset(args.dataroot, dataset_transform,
                              'test')[i]['B'].type(dtype).unsqueeze(0)
        test_data_A = torch.cat((test_data_A, imgA), dim=0)
        test_data_B = torch.cat((test_data_B, imgB), dim=0)

        fileStrA = 'visualization/test_%d/%s/' % (i, 'B_inStyleofA')
        fileStrB = 'visualization/test_%d/%s/' % (i, 'A_inStyleofB')
        if not os.path.exists(fileStrA):
            os.makedirs(fileStrA)
        if not os.path.exists(fileStrB):
            os.makedirs(fileStrB)

        fileStrA = 'visualization/test_original_%s_%04d.png' % ('A', i)
        fileStrB = 'visualization/test_original_%s_%04d.png' % ('B', i)
        utils.save_image(
            fileStrA,
            ImgPairDataset(args.dataroot, dataset_transform,
                           'test')[i]['A'].data)
        utils.save_image(
            fileStrB,
            ImgPairDataset(args.dataroot, dataset_transform,
                           'test')[i]['B'].data)

    # replay buffers
    replayBufferA = utils.ReplayBuffer(50)
    replayBufferB = utils.ReplayBuffer(50)

    # training loop
    step = 0
    for e in range(EPOCHS):
        startTime = time.time()
        for idx, batch in enumerate(dataloader):
            real_A = batch['A'].type(dtype)
            real_B = batch['B'].type(dtype)

            # some examples seem to have only 1 color channel instead of 3
            if (real_A.shape[1] != 3):
                continue
            if (real_B.shape[1] != 3):
                continue

            # -----------------
            #  train generators
            # -----------------
            optimizer_generators.zero_grad()
            utils.learning_rate_decay(INITIAL_LR, e, EPOCHS,
                                      optimizer_generators)

            # GAN loss
            fake_A = g_BtoA(real_B)
            disc_fake_A = d_A(fake_A)
            fake_B = g_AtoB(real_A)
            disc_fake_B = d_B(fake_B)

            replayBufferA.push(torch.tensor(fake_A.data))
            replayBufferB.push(torch.tensor(fake_B.data))

            target_real = Variable(torch.ones_like(disc_fake_A)).type(dtype)
            target_fake = Variable(torch.zeros_like(disc_fake_A)).type(dtype)

            loss_gan_AtoB = criterion_mse(disc_fake_B, target_real)
            loss_gan_BtoA = criterion_mse(disc_fake_A, target_real)
            loss_gan = loss_gan_AtoB + loss_gan_BtoA

            # cyclic reconstruction loss
            cyclic_A = g_BtoA(fake_B)
            cyclic_B = g_AtoB(fake_A)
            loss_cyclic_AtoBtoA = criterion_l1(cyclic_A,
                                               real_A) * CYCLIC_WEIGHT
            loss_cyclic_BtoAtoB = criterion_l1(cyclic_B,
                                               real_B) * CYCLIC_WEIGHT
            loss_cyclic = loss_cyclic_AtoBtoA + loss_cyclic_BtoAtoB

            # identity loss
            loss_identity = 0
            loss_identity_A = 0
            loss_identity_B = 0
            if (args.use_identity == True):
                identity_A = g_BtoA(real_A)
                identity_B = g_AtoB(real_B)
                loss_identity_A = criterion_l1(identity_A,
                                               real_A) * 0.5 * CYCLIC_WEIGHT
                loss_identity_B = criterion_l1(identity_B,
                                               real_B) * 0.5 * CYCLIC_WEIGHT
                loss_identity = loss_identity_A + loss_identity_B

            loss_generators = loss_gan + loss_cyclic + loss_identity
            loss_generators.backward()
            optimizer_generators.step()

            # -----------------
            #  train discriminators
            # -----------------
            optimizer_d_A.zero_grad()
            utils.learning_rate_decay(INITIAL_LR, e, EPOCHS, optimizer_d_A)

            fake_A = replayBufferA.sample(1).detach()
            disc_fake_A = d_A(fake_A)
            disc_real_A = d_A(real_A)
            loss_d_A = 0.5 * (criterion_mse(disc_real_A, target_real) +
                              criterion_mse(disc_fake_A, target_fake))

            loss_d_A.backward()
            optimizer_d_A.step()

            optimizer_d_B.zero_grad()
            utils.learning_rate_decay(INITIAL_LR, e, EPOCHS, optimizer_d_B)

            fake_B = replayBufferB.sample(1).detach()
            disc_fake_B = d_B(fake_B)
            disc_real_B = d_B(real_B)
            loss_d_B = 0.5 * (criterion_mse(disc_real_B, target_real) +
                              criterion_mse(disc_fake_B, target_fake))

            loss_d_B.backward()
            optimizer_d_B.step()

            #log info and save sample images
            if ((idx % 250) == 0):
                # eval on some sample images
                g_AtoB.eval()
                g_BtoA.eval()

                test_B_hat = g_AtoB(test_data_A).cpu()
                test_A_hat = g_BtoA(test_data_B).cpu()

                fileBaseStr = 'test_%d_%d' % (e, idx)
                for i in range(NUM_TEST_SAMPLES):
                    fileStrA = 'visualization/test_%d/%s/%03d_%04d.png' % (
                        i, 'B_inStyleofA', e, idx)
                    fileStrB = 'visualization/test_%d/%s/%03d_%04d.png' % (
                        i, 'A_inStyleofB', e, idx)
                    utils.save_image(fileStrA, test_A_hat[i].data)
                    utils.save_image(fileStrB, test_B_hat[i].data)

                g_AtoB.train()
                g_BtoA.train()

                endTime = time.time()
                timeForIntervalIterations = endTime - startTime
                startTime = endTime

                print(
                    'Epoch [{:3d}/{:3d}], Training [{:4d}/{:4d}], Time Spent (s): [{:4.4f}], Losses: [G_GAN: {:4.4f}][G_CYC: {:4.4f}][G_IDT: {:4.4f}][D_A: {:4.4f}][D_B: {:4.4f}]'
                    .format(e, EPOCHS, idx, len(dataloader),
                            timeForIntervalIterations, loss_gan, loss_cyclic,
                            loss_identity, loss_d_A, loss_d_B))

                # tensorboard logging
                info = {
                    'loss_generators':
                    loss_generators.item(),
                    'loss_gan_AtoB':
                    loss_gan_AtoB.item(),
                    'loss_gan_BtoA':
                    loss_gan_BtoA.item(),
                    'loss_cyclic_AtoBtoA':
                    loss_cyclic_AtoBtoA.item(),
                    'loss_cyclic_BtoAtoB':
                    loss_cyclic_BtoAtoB.item(),
                    'loss_cyclic':
                    loss_cyclic.item(),
                    'loss_d_A':
                    loss_d_A.item(),
                    'loss_d_B':
                    loss_d_B.item(),
                    'lr_optimizer_generators':
                    optimizer_generators.param_groups[0]['lr'],
                    'lr_optimizer_d_A':
                    optimizer_d_A.param_groups[0]['lr'],
                    'lr_optimizer_d_B':
                    optimizer_d_B.param_groups[0]['lr'],
                }
                if (args.use_identity):
                    info['loss_identity_A'] = loss_identity_A.item()
                    info['loss_identity_B'] = loss_identity_B.item()
                for tag, value in info.items():
                    logger.scalar_summary(tag, value, step)

                info = {
                    'test_A_hat':
                    test_A_hat.data.numpy().transpose(0, 2, 3, 1),
                    'test_B_hat':
                    test_B_hat.data.numpy().transpose(0, 2, 3, 1),
                }
                for tag, images in info.items():
                    logger.image_summary(tag, images, step)

            step += 1

        # save after every epoch
        g_AtoB.eval()
        g_BtoA.eval()
        d_A.eval()
        d_B.eval()

        if use_cuda:
            g_AtoB.cpu()
            g_BtoA.cpu()
            d_A.cpu()
            d_B.cpu()

        if not os.path.exists("models"):
            os.makedirs("models")
        filename_gAtoB = "models/" + str('g_AtoB') + "_epoch_" + str(
            e) + ".model"
        filename_gBtoA = "models/" + str('g_BtoA') + "_epoch_" + str(
            e) + ".model"
        filename_dA = "models/" + str('d_A') + "_epoch_" + str(e) + ".model"
        filename_dB = "models/" + str('d_B') + "_epoch_" + str(e) + ".model"
        torch.save(g_AtoB.state_dict(), filename_gAtoB)
        torch.save(g_BtoA.state_dict(), filename_gBtoA)
        torch.save(d_A.state_dict(), filename_dA)
        torch.save(d_B.state_dict(), filename_dB)

        if use_cuda:
            g_AtoB.cuda()
            g_BtoA.cuda()
            d_A.cuda()
            d_B.cuda()
Exemplo n.º 13
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    start_epoch = 0
    train_image_dataset = image_preprocessing(opt.dataset, 'train')
    data_loader = DataLoader(train_image_dataset, batch_size=opt.batch_size,
                            shuffle=True, num_workers=opt.num_workers)
    criterion = least_squares
    euclidean_l1 = nn.L1Loss()

    G = Generator(ResidualBlock, layer_count=9)
    F = Generator(ResidualBlock, layer_count=9)
    Dx = Discriminator()
    Dy = Discriminator()

    G_optimizer = optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    F_optimizer = optim.Adam(F.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    Dx_optimizer = optim.Adam(Dx.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    Dy_optimizer = optim.Adam(Dy.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))

    if torch.cuda.is_available():
        G = nn.DataParallel(G)
        F = nn.DataParallel(F)
        Dx = nn.DataParallel(Dx)
        Dy = nn.DataParallel(Dy)

        G = G.cuda()
        F = F.cuda()
        Dx = Dx.cuda()
        Dy = Dy.cuda()
    
    if opt.checkpoint is not None:
        G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer, start_epoch = load_ckp(opt.checkpoint, G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer)

    print('[Start] : Cycle GAN Training')

    logger = Logger(opt.epochs, len(data_loader), image_step=10)

    for epoch in range(opt.epochs):
        epoch = epoch + start_epoch + 1
        print("Epoch[{epoch}] : Start".format(epoch=epoch))
        
        for step, data in enumerate(data_loader):
            real_A = to_variable(data['A'])
            real_B = to_variable(data['B'])

            fake_B = G(real_A)
            fake_A = F(real_B)

            # Train Dx
            Dx_optimizer.zero_grad()

            Dx_real = Dx(real_A)
            Dx_fake = Dx(fake_A)

            Dx_loss = patch_loss(criterion, Dx_real, True) + patch_loss(criterion, Dx_fake, 0)

            Dx_loss.backward(retain_graph=True)
            Dx_optimizer.step()

            # Train Dy
            Dy_optimizer.zero_grad()

            Dy_real = Dy(real_B)
            Dy_fake = Dy(fake_B)

            Dy_loss = patch_loss(criterion, Dy_real, True) + patch_loss(criterion, Dy_fake, 0)

            Dy_loss.backward(retain_graph=True)
            Dy_optimizer.step()

            # Train G
            G_optimizer.zero_grad()

            Dy_fake = Dy(fake_B)

            G_loss = patch_loss(criterion, Dy_fake, True)

            # Train F
            F_optimizer.zero_grad()

            Dx_fake = Dx(fake_A)

            F_loss = patch_loss(criterion, Dx_fake, True)

            # identity loss
            loss_identity = euclidean_l1(real_A, fake_A) + euclidean_l1(real_B, fake_B)

            # cycle consistency
            loss_cycle = euclidean_l1(F(fake_B), real_A) + euclidean_l1(G(fake_A), real_B)

            # Optimize G & F
            loss = G_loss + F_loss + opt.lamda * loss_cycle + opt.lamda * loss_identity * (0.5)

            loss.backward()
            G_optimizer.step()
            F_optimizer.step()

            if (step + 1 ) % opt.save_step == 0:
                print("Epoch[{epoch}]| Step [{now}/{total}]| Dx Loss: {Dx_loss}, Dy_Loss: {Dy_loss}, G_Loss: {G_loss}, F_Loss: {F_loss}".format(
                    epoch=epoch, now=step + 1, total=len(data_loader), Dx_loss=Dx_loss.item(), Dy_loss=Dy_loss,
                    G_loss=G_loss.item(), F_loss=F_loss.item()))
                batch_image = torch.cat((torch.cat((real_A, real_B), 3), torch.cat((fake_A, fake_B), 3)), 2)

                torchvision.utils.save_image(denorm(batch_image[0]), opt.training_result + 'result_{result_name}_ep{epoch}_{step}.jpg'.format(result_name=opt.result_name,epoch=epoch, step=(step + 1) * opt.batch_size))
            
            # http://localhost:8097
            logger.log(
                losses={
                    'loss_G': G_loss,
                    'loss_F': F_loss,
                    'loss_identity': loss_identity,
                    'loss_cycle': loss_cycle,
                    'total_G_loss': loss,
                    'loss_Dx': Dx_loss,
                    'loss_Dy': Dy_loss,
                    'total_D_loss': (Dx_loss + Dy_loss),
                },
                images={
                    'real_A': real_A,
                    'real_B': real_B,
                    'fake_A': fake_A,
                    'fake_ B': fake_B,
                },
            )


        torch.save({
            'epoch': epoch,
            'G_model': G.state_dict(),
            'G_optimizer': G_optimizer.state_dict(),
            'F_model': F.state_dict(),
            'F_optimizer': F_optimizer.state_dict(),
            'Dx_model': Dx.state_dict(),
            'Dx_optimizer': Dx_optimizer.state_dict(),
            'Dy_model': Dy.state_dict(),
            'Dy_optimizer': Dy_optimizer.state_dict(),
        }, opt.save_model + 'model_{result_name}_CycleGAN_ep{epoch}.ckp'.format(result_name=opt.result_name, epoch=epoch))
Exemplo n.º 14
0
def train_network(args):
    device = torch.device('cuda' if args.gpu_no >= 0 else 'cpu')

    generator = Generator(unet_flag=args.unet_flag).to(device).train()
    discriminator = Discriminator().to(device).train()

    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta1, args.beta2))
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta1, args.beta2))

    bce_criterion = nn.BCELoss()
    l1_criterion = nn.L1Loss()

    dataset = FacadeFolder(args.data_A, args.data_B, args.imsize,
                           args.cropsize, args.cencrop)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)

    loss_seq = {'d_real': [], 'd_fake': [], 'g_gan': [], 'g_l1': []}
    for epoch in range(args.epoch):
        for iteration, (item_A, item_B) in enumerate(dataloader, 1):
            item_A, item_B = item_A.to(device), item_B.to(device)
            """
                Train Discriminator
            """
            fake_B = generator(item_A)
            fake_discrimination = discriminator(fake_B.detach(), item_A)
            fake_loss = bce_criterion(
                fake_discrimination,
                torch.zeros_like(fake_discrimination).to(device))

            real_discrimination = discriminator(item_B, item_A)
            real_loss = bce_criterion(
                real_discrimination,
                torch.ones_like(real_discrimination).to(device))

            discriminator_loss = (real_loss +
                                  fake_loss) * args.discriminator_weight
            loss_seq['d_real'].append(real_loss.item())
            loss_seq['d_fake'].append(fake_loss.item())

            optimizer_d.zero_grad()
            discriminator_loss.backward()
            optimizer_d.step()
            """
                Train Generator
            """
            fake_B = generator(item_A)
            fake_discrimination = discriminator(fake_B, item_A)
            fake_loss = bce_criterion(
                fake_discrimination,
                torch.ones_like(fake_discrimination).to(device))
            l1_loss = l1_criterion(fake_B, item_B)

            generator_loss = fake_loss + l1_loss * args.l1_weight
            loss_seq['g_gan'].append(fake_loss.item())
            loss_seq['g_l1'].append(l1_loss.item())

            optimizer_g.zero_grad()
            generator_loss.backward()
            optimizer_g.step()
            """
                Check training loss
            """
            if iteration % args.check_iter == 0:
                check_str = "%s: epoch:[%d/%d]\titeration:[%d/%d]" % (
                    time.ctime(), epoch, args.epoch, iteration,
                    len(dataloader))
                for key, value in loss_seq.items():
                    check_str += "\t%s: %2.2f" % (
                        key, lastest_arverage_value(value))
                print(check_str)
                imsave(torch.cat([item_A, item_B, fake_B], dim=0),
                       args.save_path + 'training_image.jpg')

        # save check point
        torch.save(
            {
                'iteration': iteration,
                'g_state_dict': generator.state_dict(),
                'd_state_dict': discriminator.state_dict(),
                'loss_seq': loss_seq
            }, args.save_path + 'check_point.pth')

    return discriminator, generator
Exemplo n.º 15
0
# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
if args.truncnorm:
    fixed_noise = torch.from_numpy(truncated_z_sample(64, args.nz, args.tc_th, args.manualSeed)).float().to(device)
else:
    fixed_noise = torch.randn(64, args.nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999))


# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

writer = SummaryWriter()
print("Starting Training Loop...")
# For each epoch
for epoch in range(args.num_epochs):
            gamma, batch_size, device, epochs, load_checkpoint, mutil_gpu,
            device_ids))

if __name__ == "__main__":
    img_size = (32, 32, 3)
    latent_dim = 128

    writer = SummaryWriter()

    discriminator = Discriminator(img_size=img_size)
    generator = Generator(latent_dim=latent_dim, output_size=img_size)

    optimizer_G = getattr(optim, optimizer_name)(generator.parameters(),
                                                 lr=lr,
                                                 betas=(0.5, 0.999))
    optimizer_D = getattr(optim, optimizer_name)(discriminator.parameters(),
                                                 lr=lr,
                                                 betas=(0.5, 0.999))
    # lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    # if load_checkpoint:
    #     load_model("model_checkpoint/check_point.pkl", model, optimizer, lr_scheduler)

    if mutil_gpu:
        discriminator = nn.DataParallel(discriminator, device_ids, device)
        generator = nn.DataParallel(generator, device_ids, device)
    discriminator = discriminator.to(device=device)
    generator = generator.to(device=device)

    mnist = torchvision.datasets.CIFAR10(root="cifar10",
                                         train=True,
                                         download=True,
Exemplo n.º 17
0
class GAN:
    def __init__(self):
        self.generator = Generator()
        self.generator.to(Params.Device)

        self.discriminator = Discriminator()
        self.discriminator.to(Params.Device)

        self.loss_fn = nn.BCELoss()

        self.optimizer_g = torch.optim.Adam(self.generator.parameters(), Params.LearningRateG, betas=(Params.Beta, 0.999))
        self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), Params.LearningRateD, betas=(Params.Beta, 0.999))

        self.exemplar_latent_vectors = torch.randn( (64, Params.LatentVectorSize), device=Params.Device )


    def load_image_set(self):
        transforms = torchvision.transforms.Compose( [
            torchvision.transforms.Resize(Params.ImageSize),
            torchvision.transforms.CenterCrop(Params.ImageSize),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize( (0.5,0.5,0.5), (0.5,0.5,0.5) )
        ] )
        self.dataset = torchvision.datasets.ImageFolder(
            Params.ImagePath,
            transforms
        )
        self.loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size = Params.BatchSize,
            shuffle=True,
            num_workers=Params.NumWorkers
        )


    def train(self, start_epoch=0):
        self.generator.train()
        self.discriminator.train()

        criterion = nn.BCELoss()

        self.loss_record = {
            'D': [],
            'G': []
        }

        for cur_epoch in range(start_epoch, Params.NumEpochs):
            tic = time.perf_counter()
            
            self.train_epoch(cur_epoch)

            toc = time.perf_counter()
            print( f"Last epoch took: {toc - tic:.0f} seconds" )

            if cur_epoch % Params.CheckpointEvery == 0:
                self.save_checkpoint(f"epoch{cur_epoch+1:03d}")

        self.save_checkpoint(f"Final ({Params.NumEpochs} epochs)")


    def train_epoch(self, cur_epoch):
        for (i, real_images) in enumerate(self.loader):
            # Discriminator training step
            self.discriminator.zero_grad()
            real_images = real_images[0].to(Params.Device)
            batch_size = real_images.size(0)
            labels = torch.full(
                (batch_size,),
                1,
                dtype=torch.float, device=Params.Device
            )
            D_real = self.discriminator(real_images).view(-1)
            D_loss_real = self.loss_fn(D_real, labels)
            D_loss_real.backward()

            latent_vectors = torch.randn( (batch_size, Params.LatentVectorSize), device=Params.Device )
            fake_images = self.generator(latent_vectors)
            D_fake = self.discriminator(fake_images.detach()).view(-1)
            labels.fill_(0)
            D_loss_fake = self.loss_fn(D_fake, labels)
            D_loss_fake.backward()

            self.optimizer_d.step()

            # Generator training step
            self.generator.zero_grad()
            labels.fill_(1)
            D_fake = self.discriminator(fake_images).view(-1)
            G_loss = self.loss_fn(D_fake, labels)
            G_loss.backward()

            self.optimizer_g.step()

            D_loss = D_loss_real.item() + D_loss_fake.item()
            G_loss = G_loss.item()

            if (i % Params.ReportEvery == 0) and i>0:
                print( f"Epoch[{cur_epoch}/{Params.NumEpochs}] Batch[{i}/{len(self.loader)}]" )
                print( f"\tD Loss: {D_loss}\tG Loss: {G_loss}\n" )

            self.loss_record['D'].append(D_loss)
            self.loss_record['G'].append(G_loss)


    def save_checkpoint(self, checkpoint_name):
        dir_path = os.path.join(Params.CheckpointPath, checkpoint_name)
        os.makedirs(dir_path, exist_ok=True)
        print( f"Saving snapshot to: {dir_path}" )

        save_state = dict()
        save_state['g_state'] = self.generator.state_dict()
        save_state['g_optimizer_state'] = self.optimizer_g.state_dict()
        save_state['d_state'] = self.discriminator.state_dict()
        save_state['d_optimizer_state'] = self.optimizer_d.state_dict()

        torch.save(
            save_state,
            os.path.join(dir_path, 'model_params.pt')
        )

        with torch.no_grad():
            self.generator.eval()
            gen_images = self.generator(self.exemplar_latent_vectors)
            self.generator.train()
            torchvision.utils.save_image(
                gen_images,
                os.path.join(dir_path, 'gen_images.png'),
                nrow = 8,
                normalize=True
            )
Exemplo n.º 18
0
class tag2pix(object):
    def __init__(self, args):
        if args.model == 'tag2pix':
            from network import Generator
        elif args.model == 'senet':
            from model.GD_senet import Generator
        elif args.model == 'resnext':
            from model.GD_resnext import Generator
        elif args.model == 'catconv':
            from model.GD_cat_conv import Generator
        elif args.model == 'catall':
            from model.GD_cat_all import Generator
        elif args.model == 'adain':
            from model.GD_adain import Generator
        elif args.model == 'seadain':
            from model.GD_seadain import Generator
        else:
            raise Exception('invalid model name: {}'.format(args.model))

        self.args = args
        self.epoch = args.epoch
        self.batch_size = args.batch_size

        self.gpu_mode = not args.cpu
        self.input_size = args.input_size
        self.color_revert = ColorSpace2RGB(args.color_space)
        self.layers = args.layers
        [self.cit_weight, self.cvt_weight] = args.cit_cvt_weight

        self.load_dump = (args.load is not "")

        self.load_path = Path(args.load)

        self.l1_lambda = args.l1_lambda
        self.guide_beta = args.guide_beta
        self.adv_lambda = args.adv_lambda
        self.save_freq = args.save_freq

        self.two_step_epoch = args.two_step_epoch
        self.brightness_epoch = args.brightness_epoch
        self.save_all_epoch = args.save_all_epoch

        self.iv_dict, self.cv_dict, self.id_to_name = get_tag_dict(
            args.tag_dump)

        cvt_class_num = len(self.cv_dict.keys())
        cit_class_num = len(self.iv_dict.keys())
        self.class_num = cvt_class_num + cit_class_num

        self.start_epoch = 1

        #### load dataset
        if not args.test:
            self.train_data_loader, self.test_data_loader = get_dataset(args)
            self.result_path = Path(args.result_dir) / time.strftime(
                '%y%m%d-%H%M%S', time.localtime())

            if not self.result_path.exists():
                self.result_path.mkdir()

            self.test_images = self.get_test_data(self.test_data_loader,
                                                  args.test_image_count)
        else:
            self.test_data_loader = get_dataset(args)
            self.result_path = Path(args.result_dir)

        ##### initialize network
        self.net_opt = {
            'guide': not args.no_guide,
            'relu': args.use_relu,
            'bn': not args.no_bn,
            'cit': not args.no_cit
        }

        if self.net_opt['cit']:
            self.Pretrain_ResNeXT = se_resnext_half(
                dump_path=args.pretrain_dump,
                num_classes=cit_class_num,
                input_channels=1)
        else:
            self.Pretrain_ResNeXT = nn.Sequential()

        self.G = Generator(input_size=args.input_size,
                           layers=args.layers,
                           cv_class_num=cvt_class_num,
                           iv_class_num=cit_class_num,
                           net_opt=self.net_opt)
        self.D = Discriminator(input_dim=3,
                               output_dim=1,
                               input_size=self.input_size,
                               cv_class_num=cvt_class_num,
                               iv_class_num=cit_class_num)

        for param in self.Pretrain_ResNeXT.parameters():
            param.requires_grad = False
        if args.test:
            for param in self.G.parameters():
                param.requires_grad = False
            for param in self.D.parameters():
                param.requires_grad = False

        self.Pretrain_ResNeXT = nn.DataParallel(self.Pretrain_ResNeXT)
        self.G = nn.DataParallel(self.G)
        self.D = nn.DataParallel(self.D)

        self.G_optimizer = optim.Adam(self.G.parameters(),
                                      lr=args.lrG,
                                      betas=(args.beta1, args.beta2))
        self.D_optimizer = optim.Adam(self.D.parameters(),
                                      lr=args.lrD,
                                      betas=(args.beta1, args.beta2))

        self.BCE_loss = nn.BCELoss()
        self.CE_loss = nn.CrossEntropyLoss()
        self.L1Loss = nn.L1Loss()

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        print("gpu mode: ", self.gpu_mode)
        print("device: ", self.device)
        print(torch.cuda.device_count(), "GPUS!")

        if self.gpu_mode:
            self.Pretrain_ResNeXT.to(self.device)
            self.G.to(self.device)
            self.D.to(self.device)
            self.BCE_loss.to(self.device)
            self.CE_loss.to(self.device)
            self.L1Loss.to(self.device)

    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        self.y_real_, self.y_fake_ = torch.ones(self.batch_size,
                                                1), torch.zeros(
                                                    self.batch_size, 1)

        if self.gpu_mode:
            self.y_real_, self.y_fake_ = self.y_real_.to(
                self.device), self.y_fake_.to(self.device)

        if self.load_dump:
            self.load(self.load_path)
            print("continue training!!!!")
        else:
            self.end_epoch = self.epoch

        self.print_params()

        self.D.train()
        print('training start!!')
        start_time = time.time()

        for epoch in range(self.start_epoch, self.end_epoch + 1):
            print("EPOCH: {}".format(epoch))

            self.G.train()
            epoch_start_time = time.time()

            if epoch == self.brightness_epoch:
                print('changing brightness ...')
                self.train_data_loader.dataset.enhance_brightness(
                    self.input_size)

            max_iter = self.train_data_loader.dataset.__len__(
            ) // self.batch_size

            for iter, (original_, sketch_, iv_tag_, cv_tag_) in enumerate(
                    tqdm(self.train_data_loader, ncols=80)):
                if iter >= max_iter:
                    break

                if self.gpu_mode:
                    sketch_, original_, iv_tag_, cv_tag_ = sketch_.to(
                        self.device), original_.to(self.device), iv_tag_.to(
                            self.device), cv_tag_.to(self.device)

                # update D network
                self.D_optimizer.zero_grad()

                with torch.no_grad():
                    feature_tensor = self.Pretrain_ResNeXT(sketch_)
                if self.gpu_mode:
                    feature_tensor = feature_tensor.to(self.device)

                D_real, CIT_real, CVT_real = self.D(original_)
                D_real_loss = self.BCE_loss(D_real, self.y_real_)

                G_f, _ = self.G(sketch_, feature_tensor, cv_tag_)
                if self.gpu_mode:
                    G_f = G_f.to(self.device)

                D_f_fake, CIT_f_fake, CVT_f_fake = self.D(G_f)
                D_f_fake_loss = self.BCE_loss(D_f_fake, self.y_fake_)

                if self.two_step_epoch == 0 or epoch >= self.two_step_epoch:
                    CIT_real_loss = self.BCE_loss(
                        CIT_real, iv_tag_) if self.net_opt['cit'] else 0
                    CVT_real_loss = self.BCE_loss(CVT_real, cv_tag_)

                    C_real_loss = self.cvt_weight * CVT_real_loss + self.cit_weight * CIT_real_loss

                    CIT_f_fake_loss = self.BCE_loss(
                        CIT_f_fake, iv_tag_) if self.net_opt['cit'] else 0
                    CVT_f_fake_loss = self.BCE_loss(CVT_f_fake, cv_tag_)

                    C_f_fake_loss = self.cvt_weight * CVT_f_fake_loss + self.cit_weight * CIT_f_fake_loss
                else:
                    C_real_loss = 0
                    C_f_fake_loss = 0

                D_loss = self.adv_lambda * (D_real_loss + D_f_fake_loss) + (
                    C_real_loss + C_f_fake_loss)

                self.train_hist['D_loss'].append(D_loss.item())

                D_loss.backward()
                self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_f, G_g = self.G(sketch_, feature_tensor, cv_tag_)

                if self.gpu_mode:
                    G_f, G_g = G_f.to(self.device), G_g.to(self.device)

                D_f_fake, CIT_f_fake, CVT_f_fake = self.D(G_f)

                D_f_fake_loss = self.BCE_loss(D_f_fake, self.y_real_)

                if self.two_step_epoch == 0 or epoch >= self.two_step_epoch:
                    CIT_f_fake_loss = self.BCE_loss(
                        CIT_f_fake, iv_tag_) if self.net_opt['cit'] else 0
                    CVT_f_fake_loss = self.BCE_loss(CVT_f_fake, cv_tag_)

                    C_f_fake_loss = self.cvt_weight * CVT_f_fake_loss + self.cit_weight * CIT_f_fake_loss
                else:
                    C_f_fake_loss = 0

                L1_D_f_fake_loss = self.L1Loss(G_f, original_)
                L1_D_g_fake_loss = self.L1Loss(
                    G_g, original_) if self.net_opt['guide'] else 0

                G_loss = (D_f_fake_loss + C_f_fake_loss) + \
                         (L1_D_f_fake_loss + L1_D_g_fake_loss * self.guide_beta) * self.l1_lambda

                self.train_hist['G_loss'].append(G_loss.item())

                G_loss.backward()
                self.G_optimizer.step()

                if ((iter + 1) % 100) == 0:
                    print(
                        "Epoch: [{:2d}] [{:4d}/{:4d}] D_loss: {:.8f}, G_loss: {:.8f}"
                        .format(epoch, (iter + 1), max_iter, D_loss.item(),
                                G_loss.item()))

            self.train_hist['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)

            with torch.no_grad():
                self.visualize_results(epoch)
                utils.loss_plot(self.train_hist, self.result_path, epoch)

            if epoch >= self.save_all_epoch > 0:
                self.save(epoch)
            elif self.save_freq > 0 and epoch % self.save_freq == 0:
                self.save(epoch)

        print("Training finish!... save training results")

        if self.save_freq == 0 or epoch % self.save_freq != 0:
            if self.save_all_epoch <= 0 or epoch < self.save_all_epoch:
                self.save(epoch)

        self.train_hist['total_time'].append(time.time() - start_time)
        print(
            "Avg one epoch time: {:.2f}, total {} epochs time: {:.2f}".format(
                np.mean(self.train_hist['per_epoch_time']), self.epoch,
                self.train_hist['total_time'][0]))

    def test(self):
        self.load_test(self.args.load)

        self.D.eval()
        self.G.eval()

        load_path = self.load_path
        result_path = self.result_path / load_path.stem

        if not result_path.exists():
            result_path.mkdir()

        with torch.no_grad():
            for sketch_, index_, _, cv_tag_ in tqdm(self.test_data_loader,
                                                    ncols=80):
                if self.gpu_mode:
                    sketch_, cv_tag_ = sketch_.to(self.device), cv_tag_.to(
                        self.device)

                with torch.no_grad():
                    feature_tensor = self.Pretrain_ResNeXT(sketch_)

                if self.gpu_mode:
                    feature_tensor = feature_tensor.to(self.device)

                # D_real, CIT_real, CVT_real = self.D(original_)
                G_f, _ = self.G(sketch_, feature_tensor, cv_tag_)
                G_f = self.color_revert(G_f.cpu())

                for ind, result in zip(index_.cpu().numpy(), G_f):
                    save_path = result_path / f'{ind}.png'
                    if save_path.exists():
                        for i in range(100):
                            save_path = result_path / f'{ind}_{i}.png'
                            if not save_path.exists():
                                break
                    img = Image.fromarray(result)
                    img.save(save_path)

    def visualize_results(self, epoch, fix=True):
        if not self.result_path.exists():
            self.result_path.mkdir()

        self.G.eval()

        # test_data_loader
        original_, sketch_, iv_tag_, cv_tag_ = self.test_images
        image_frame_dim = int(np.ceil(np.sqrt(len(original_))))

        # iv_tag_ to feature tensor 16 * 16 * 256 by pre-reained Sketch.
        with torch.no_grad():
            feature_tensor = self.Pretrain_ResNeXT(sketch_)

            if self.gpu_mode:
                original_, sketch_, iv_tag_, cv_tag_, feature_tensor = original_.to(
                    self.device), sketch_.to(self.device), iv_tag_.to(
                        self.device), cv_tag_.to(
                            self.device), feature_tensor.to(self.device)

            G_f, G_g = self.G(sketch_, feature_tensor, cv_tag_)

            if self.gpu_mode:
                G_f = G_f.cpu()
                G_g = G_g.cpu()

            G_f = self.color_revert(G_f)
            G_g = self.color_revert(G_g)

        utils.save_images(
            G_f[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_epoch{:03d}_G_f.png'.format(epoch))
        utils.save_images(
            G_g[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_epoch{:03d}_G_g.png'.format(epoch))

    def save(self, save_epoch):
        if not self.result_path.exists():
            self.result_path.mkdir()

        with (self.result_path / 'arguments.txt').open('w') as f:
            f.write(pprint.pformat(self.args.__dict__))

        save_dir = self.result_path

        torch.save(
            {
                'G': self.G.state_dict(),
                'D': self.D.state_dict(),
                'G_optimizer': self.G_optimizer.state_dict(),
                'D_optimizer': self.D_optimizer.state_dict(),
                'finish_epoch': save_epoch,
                'result_path': str(save_dir)
            }, str(save_dir / 'tag2pix_{}_epoch.pkl'.format(save_epoch)))

        with (save_dir /
              'tag2pix_{}_history.pkl'.format(save_epoch)).open('wb') as f:
            pickle.dump(self.train_hist, f)

        print("============= save success =============")
        print("epoch from {} to {}".format(self.start_epoch, save_epoch))
        print("save result path is {}".format(str(self.result_path)))

    def load_test(self, checkpoint_path):
        checkpoint = torch.load(str(checkpoint_path))
        self.G.load_state_dict(checkpoint['G'])

    def load(self, checkpoint_path):
        checkpoint = torch.load(str(checkpoint_path))
        self.G.load_state_dict(checkpoint['G'])
        self.D.load_state_dict(checkpoint['D'])
        self.G_optimizer.load_state_dict(checkpoint['G_optimizer'])
        self.D_optimizer.load_state_dict(checkpoint['D_optimizer'])
        self.start_epoch = checkpoint['finish_epoch'] + 1

        self.finish_epoch = self.args.epoch + self.start_epoch - 1

        print("============= load success =============")
        print("epoch start from {} to {}".format(self.start_epoch,
                                                 self.finish_epoch))
        print("previous result path is {}".format(checkpoint['result_path']))

    def get_test_data(self, test_data_loader, count):
        test_count = 0
        original_, sketch_, iv_tag_, cv_tag_ = [], [], [], []
        for orig, sket, ivt, cvt in test_data_loader:
            original_.append(orig)
            sketch_.append(sket)
            iv_tag_.append(ivt)
            cv_tag_.append(cvt)

            test_count += len(orig)
            if test_count >= count:
                break

        original_ = torch.cat(original_, 0)
        sketch_ = torch.cat(sketch_, 0)
        iv_tag_ = torch.cat(iv_tag_, 0)
        cv_tag_ = torch.cat(cv_tag_, 0)

        self.save_tag_tensor_name(iv_tag_, cv_tag_,
                                  self.result_path / "test_image_tags.txt")

        image_frame_dim = int(np.ceil(np.sqrt(len(original_))))

        if self.gpu_mode:
            original_ = original_.cpu()
        sketch_np = sketch_.data.numpy().transpose(0, 2, 3, 1)
        original_np = self.color_revert(original_)

        utils.save_images(
            original_np[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_original.png')
        utils.save_images(
            sketch_np[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_sketch.png')

        return original_, sketch_, iv_tag_, cv_tag_

    def save_tag_tensor_name(self, iv_tensor, cv_tensor, save_file_path):
        '''iv_tensor, cv_tensor: batched one-hot tag tensors'''
        iv_dict_inverse = {
            tag_index: tag_id
            for (tag_id, tag_index) in self.iv_dict.items()
        }
        cv_dict_inverse = {
            tag_index: tag_id
            for (tag_id, tag_index) in self.cv_dict.items()
        }

        with open(save_file_path, 'w') as f:
            f.write("CIT tags\n")

            for tensor_i, batch_unit in enumerate(iv_tensor):
                tag_list = []
                f.write(f'{tensor_i} : ')

                for i, is_tag in enumerate(batch_unit):
                    if is_tag:
                        tag_name = self.id_to_name[iv_dict_inverse[i]]
                        tag_list.append(tag_name)
                        f.write(f"{tag_name}, ")
                f.write("\n")

            f.write("\nCVT tags\n")

            for tensor_i, batch_unit in enumerate(cv_tensor):
                tag_list = []
                f.write(f'{tensor_i} : ')

                for i, is_tag in enumerate(batch_unit):
                    if is_tag:
                        tag_name = self.id_to_name[cv_dict_inverse[i]]
                        tag_list.append(self.id_to_name[cv_dict_inverse[i]])
                        f.write(f"{tag_name}, ")
                f.write("\n")

    def print_params(self):
        params_cnt = [0, 0, 0]
        for param in self.G.parameters():
            params_cnt[0] += param.numel()
        for param in self.D.parameters():
            params_cnt[1] += param.numel()
        for param in self.Pretrain_ResNeXT.parameters():
            params_cnt[2] += param.numel()
        print(
            f'Parameter #: G - {params_cnt[0]} / D - {params_cnt[1]} / Pretrain - {params_cnt[2]}'
        )
Exemplo n.º 19
0
def train_network(args):
    device = torch.device('cuda' if args.gpu_no >= 0 else 'cpu')

    G_A2B = Generator().to(device).train()
    G_B2A = Generator().to(device).train()
    print("Load Generator", G_A2B)

    D_A = Discriminator().to(device).train()
    D_B = Discriminator().to(device).train()
    print("Load Discriminator", D_A)

    optimizer_G_A2B = torch.optim.Adam(G_A2B.parameters(),
                                       lr=args.lr,
                                       betas=(args.beta1, args.beta2))
    optimizer_G_B2A = torch.optim.Adam(G_B2A.parameters(),
                                       lr=args.lr,
                                       betas=(args.beta1, args.beta2))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))

    mse_criterion = nn.MSELoss()
    l1_criterion = nn.L1Loss()

    dataset = FacadeFolder(args.data_A, args.data_B, args.imsize,
                           args.cropsize, args.cencrop)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)

    loss_seq = {'G_A2B': [], 'G_B2A': [], 'D_A': [], 'D_B': [], 'Cycle': []}

    buffer_fake_A = BufferN(50)
    buffer_fake_B = BufferN(50)
    for epoch in range(args.epoch):
        """
            learning rate schedule
        """
        for iteration, (item_A, item_B) in enumerate(dataloader, 1):
            item_A, item_B = item_A.to(device), item_B.to(device)
            """
                train discriminators
            """
            with torch.no_grad():
                fake_B = G_A2B(item_A)
                fake_A = G_B2A(item_B)

            buffer_fake_B.push(fake_B.detach())
            N_fake_B_discrimination = D_B(
                torch.cat(buffer_fake_B.get_buffer(), dim=0))
            real_B_discrimination = D_B(item_B)

            buffer_fake_A.push(fake_A.detach())
            N_fake_A_discrimination = D_A(
                torch.cat(buffer_fake_A.get_buffer(), dim=0))
            real_A_discrimination = D_A(item_A)

            loss_D_B_fake = mse_criterion(
                N_fake_B_discrimination,
                torch.zeros_like(N_fake_B_discrimination).to(device))
            loss_D_B_real = mse_criterion(
                real_B_discrimination,
                torch.ones_like(real_B_discrimination).to(device))
            loss_D_B = (loss_D_B_fake +
                        loss_D_B_real) * args.discriminator_weight
            loss_seq['D_B'].append(loss_D_B.item())

            loss_D_A_fake = mse_criterion(
                N_fake_A_discrimination,
                torch.zeros_like(N_fake_A_discrimination).to(device))
            loss_D_A_real = mse_criterion(
                real_A_discrimination,
                torch.ones_like(real_A_discrimination).to(device))
            loss_D_A = (loss_D_A_fake +
                        loss_D_A_real) * args.discriminator_weight
            loss_seq['D_A'].append(loss_D_A.item())

            optimizer_D_B.zero_grad()
            loss_D_B.backward()
            optimizer_D_B.step()

            optimizer_D_A.zero_grad()
            loss_D_A.backward()
            optimizer_D_A.step()
            """
                train generators
            """
            fake_B = G_A2B(item_A)
            fake_A = G_B2A(item_B)
            fake_B_discrimination = D_B(fake_B)
            fake_A_discrimination = D_A(fake_A)
            B_from_fake_A = G_A2B(fake_A)
            A_from_fake_B = G_B2A(fake_B)

            loss_G_A2B_gan = mse_criterion(
                fake_B_discrimination,
                torch.ones_like(fake_B_discrimination).to(device))
            loss_G_B2A_gan = mse_criterion(
                fake_A_discrimination,
                torch.ones_like(fake_A_discrimination).to(device))
            loss_seq['G_A2B'].append(loss_G_A2B_gan.item())
            loss_seq['G_B2A'].append(loss_G_B2A_gan.item())

            loss_cyc = l1_criterion(B_from_fake_A, item_B) + l1_criterion(
                A_from_fake_B, item_A)
            loss_seq['Cycle'].append(loss_cyc.item())

            loss_G = loss_G_A2B_gan + loss_G_B2A_gan + args.cyc_weight * loss_cyc

            optimizer_G_A2B.zero_grad()
            optimizer_G_B2A.zero_grad()
            loss_G.backward()
            optimizer_G_A2B.step()
            optimizer_G_B2A.step()
            """
                check training loss
            """
            if iteration % args.check_iter == 0:
                check_str = "%s: epoch:[%d/%d]\titeration:[%d/%d]" % (
                    time.ctime(), epoch, args.epoch, iteration,
                    len(dataloader))
                for key, value in loss_seq.items():
                    check_str += "\t%s: %2.2f" % (
                        key, lastest_arverage_value(value))
                print(check_str)
                imsave(torch.cat([item_A, item_B, fake_B], dim=0),
                       args.save_path + 'training_image_A2B.jpg')
                imsave(torch.cat([item_B, item_A, fake_A], dim=0),
                       args.save_path + 'training_image_B2A.jpg')

        # save networks
        torch.save(
            {
                'iteration': iteration,
                'G_A2B_state_dict': G_A2B.state_dict(),
                'G_B2A_state_dict': G_B2A.state_dict(),
                'D_A_state_dict': D_A.state_dict(),
                'D_B_state_dict': D_B.state_dict(),
                'loss_seq': loss_seq
            }, args.save_path + 'check_point.pth')

    return None
Exemplo n.º 20
0
# shift models to cuda if possible
if torch.cuda.is_available():
    G12.cuda()
    G21.cuda()
    D1.cuda()
    D2.cuda()

# %%
# optimizer and loss
LGAN = MSELoss()
LCYC = L1Loss()
LIdentity = L1Loss()

optimizer_G = Adam(itertools.chain(G12.parameters(), G21.parameters()),
                   lr=0.001)
optimizer_D1 = Adam(D1.parameters(), lr=0.001)
optimizer_D2 = Adam(D2.parameters(), lr=0.001)

# %%
# train models
real_label = torch.full((32, ), 1, device="cuda:0")
false_label = torch.full((32, ), 0, device="cuda:0")
bufD1 = Buffer(50)
bufD2 = Buffer(50)

num_epochs = 100
learning_rate = 0.01
for epoch in range(num_epochs):
    for i, (realA, realB) in enumerate(dataloader):
        if torch.cuda.is_available():
            realA.cuda()
Exemplo n.º 21
0
    pass
    net = net.to(device)
    print(net)

NZ = opt.imageSize // 2**nDep
noise = torch.FloatTensor(opt.batchSize, nz, NZ, NZ)

real_label = 1
fake_label = 0

noise = noise.to(device)
fixnoise = fixnoise.to(device)
targetMosaic = targetMosaic.to(device)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr,
                        betas=(opt.beta1, 0.999))  #netD.parameters()
optimizerU = optim.Adam(
    [param for net in Gnets for param in list(net.parameters())],
    lr=opt.lr,
    betas=(opt.beta1, 0.999))


def ganGeneration(content, noise, templates=None, bVis=False):
    x = netMix(content, noise)
    if bVis:
        return x, 0, 0, 0
    return x


buf = []
Exemplo n.º 22
0
def train(data_path,
          crop_size=128,
          final_size=64,
          batch_size=16,
          alternating_step=10000,
          ncritic=1,
          lambda_gp=0.1,
          debug_step=100):
    # define networks
    generator = Generator(final_size=final_size)
    generator.generate_network()
    g_optimizer = Adam(generator.parameters())

    discriminator = Discriminator(final_size=final_size)
    discriminator.generate_network()
    d_optimizer = Adam(discriminator.parameters())

    num_channels = min(generator.num_channels, generator.max_channels)

    # get debugging vectors
    N = (5, 10)
    debug_vectors = torch.randn(N[0] * N[1], num_channels, 1, 1).to(device)
    global upsample
    upsample = [
        Upsample(scale_factor=2**i)
        for i in reversed(range(generator.num_blocks))
    ]

    # get loader
    loader = get_loader(data_path, crop_size, batch_size)

    # training loop
    start_time = time.time()
    for index in range(generator.num_blocks):
        loader.dataset.set_transform_by_index(index)
        data_iterator = iter(loader)
        for phase in ('fade', 'stabilize'):
            if index == 0 and phase == 'fade': continue
            print("index: {}, size: {}x{}, phase: {}".format(
                index, 2**(index + 2), 2**(index + 2), phase))
            for i in range(alternating_step):
                print(i)
                try:
                    batch = next(data_iterator)
                except:
                    data_iterator = iter(loader)
                    batch = next(data_iterator)

                alpha = i / alternating_step if phase == "fade" else 1.0

                batch = batch.to(device)

                d_loss_real = -torch.mean(discriminator(batch, index, alpha))

                latent = torch.randn(batch_size, num_channels, 1, 1).to(device)
                fake_batch = generator(latent, index, alpha).detach()
                d_loss_fake = torch.mean(
                    discriminator(fake_batch, index, alpha))

                d_loss = d_loss_real + d_loss_fake
                d_optimizer.zero_grad()
                d_loss.backward()  # if retain_graph=True
                # then gp works but I'm not sure it's right
                d_optimizer.step()

                # Compute gradient penalty
                alpha_gp = torch.rand(batch.size(0), 1, 1, 1).to(device)
                # mind that x_hat must be both detached from the previous
                # gradient graph (from fake_barch) and with
                # requires_graph=True so that the gradient can be computed
                x_hat = (alpha_gp * batch +
                         (1 - alpha_gp) * fake_batch).requires_grad_(True)
                # x_hat = torch.cuda.FloatTensor(x_hat).requires_grad_(True)
                out = discriminator(x_hat, index, alpha)
                grad = torch.autograd.grad(
                    outputs=out,
                    inputs=x_hat,
                    grad_outputs=torch.ones_like(out).to(device),
                    retain_graph=True,
                    create_graph=True,
                    only_inputs=True)[0]
                grad = grad.view(grad.size(0), -1)  #is this the same as
                # detach?
                l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((l2norm - 1)**2)

                d_loss_gp *= lambda_gp
                d_optimizer.zero_grad()
                d_loss_gp.backward()
                d_optimizer.step()

                if (i + 1) % ncritic == 0:
                    latent = torch.randn(batch_size, num_channels, 1,
                                         1).to(device)
                    fake_batch = generator(latent, index, alpha)
                    g_loss = -torch.mean(
                        discriminator(fake_batch, index, alpha))
                    g_optimizer.zero_grad()
                    g_loss.backward()
                    g_optimizer.step()

                # print debugging images
                if (i + 1) % debug_step == 0:
                    print_debugging_images(generator, debug_vectors, N, index,
                                           alpha, i)
Exemplo n.º 23
0
        print (e,"weightinit")
    pass
    net=net.to(device)
    print(net)

NZ = opt.imageSize//2**nDep
noise = torch.FloatTensor(opt.batchSize, nz, NZ,NZ)

real_label = 1
fake_label = 0

noise=noise.to(device)
fixnoise=fixnoise.to(device)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))#netD.parameters()
optimizerU = optim.Adam([param for net in Gnets for param in list(net.parameters())], lr=opt.lr, betas=(opt.beta1, 0.999))

def famosGeneration(content, noise, templatePatch, bVis = False):
    if opt.multiScale >0:
        x = netMix(content, noise,templatePatch)
    else:
        x = netMix(content,noise)
    a5=x[:,-5:]
    A =4*nn.functional.tanh(x[:,:-5])##smooths probs somehow
    A = nn.functional.softmax(1*(A - A.detach().max()), dim=1)
    mixed = getTemplateMixImage(A, templatePatch)
    alpha = nn.functional.sigmoid(a5[:,3:4])
    beta = nn.functional.sigmoid(a5[:, 4:5])
    fake = blend(nn.functional.tanh(a5[:,:3]),mixed,alpha,beta)
Exemplo n.º 24
0
class Solver(object):

    def __init__(self, configuration):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # retrieve configuration variables
        self.data_path = configuration.data_path
        self.crop_size = configuration.crop_size
        self.final_size = configuration.final_size
        self.batch_size = configuration.batch_size
        self.alternating_step = configuration.alternating_step
        self.ncritic = configuration.ncritic
        self.lambda_gp = configuration.lambda_gp
        self.debug_step = configuration.debug_step
        self.save_step = configuration.save_step
        self.max_checkpoints = configuration.max_checkpoints
        self.log_step = configuration.log_step
        # self.tflogger = Logger(configuration.log_dir)
        ## directoriess
        self.train_dir = configuration.train_dir
        self.img_dir = configuration.img_dir
        self.models_dir = configuration.models_dir
        ## variables
        self.eps_drift = 0.001

        self.resume_training = configuration.resume_training

        self._initialise_networks()

    def _initialise_networks(self):
        self.generator = Generator(final_size=self.final_size)
        self.generator.generate_network()
        self.g_optimizer = Adam(self.generator.parameters(), lr=0.001, betas=(0, 0.99))

        self.discriminator = Discriminator(final_size=self.final_size)
        self.discriminator.generate_network()
        self.d_optimizer = Adam(self.discriminator.parameters(), lr=0.001, betas=(0, 0.99))

        self.num_channels = min(self.generator.num_channels,
                                self.generator.max_channels)
        self.upsample = [Upsample(scale_factor=2**i)
                for i in reversed(range(self.generator.num_blocks))]

    def print_debugging_images(self, generator, latent_vectors, shape, index,
                               alpha, iteration):
        with torch.no_grad():
            columns = []
            for i in range(shape[0]):
                row = []
                for j in range(shape[1]):
                    img_ij = generator(latent_vectors[i * shape[1] +
                                                      j].unsqueeze_(0),
                                       index, alpha)
                    img_ij = self.upsample[index](img_ij)
                    row.append(img_ij)
                columns.append(torch.cat(row, dim=3))
            debugging_image = torch.cat(columns, dim=2)
        # denorm
        debugging_image = (debugging_image + 1) / 2
        debugging_image.clamp_(0, 1)
        save_image(debugging_image.data,
                   os.path.join(self.img_dir, "debug_{}_{}.png".format(index,
                                                                      iteration)))

    def save_trained_networks(self, block_index, phase, step):
        models_file = os.path.join(self.models_dir, "models.json")
        if os.path.isfile(models_file):
            with open(models_file, 'r') as file:
                models_config = json.load(file)
        else:
            models_config = json.loads('{ "checkpoints": [] }')

        generator_save_name = "generator_{}_{}_{}.pth".format(
                                    block_index, phase, step
                                )
        torch.save(self.generator.state_dict(),
                   os.path.join(self.models_dir, generator_save_name))

        discriminator_save_name = "discriminator_{}_{}_{}.pth".format(
                                    block_index, phase, step
                                )
        torch.save(self.discriminator.state_dict(),
                   os.path.join(self.models_dir, discriminator_save_name))

        models_config["checkpoints"].append(OrderedDict({
            "block_index": block_index,
            "phase": phase,
            "step": step,
            "generator": generator_save_name,
            "discriminator": discriminator_save_name
        }))
        if len(models_config["checkpoints"]) > self.max_checkpoints:
            old_save = models_config["checkpoints"][0]
            os.remove(os.path.join(self.models_dir, old_save["generator"]))
            os.remove(os.path.join(self.models_dir, old_save["discriminator"]))
            models_config["checkpoints"] = models_config["checkpoints"][1:]
        with open(os.path.join(self.models_dir, "models.json"), 'w') as file:
            json.dump(models_config, file, indent=4)

    def load_trained_networks(self):
        models_file = os.path.join(self.models_dir, "models.json")
        if os.path.isfile(models_file):
            with open(models_file, 'r') as file:
                models_config = json.load(file)
        else:
            raise FileNotFoundError("File 'models.json' not found in {"
                                    "}".format(self.models_dir))

        last_checkpoint = models_config["checkpoints"][-1]
        block_index = last_checkpoint["block_index"]
        phase = last_checkpoint["phase"]
        step = last_checkpoint["step"]
        generator_save_name = os.path.join(
            self.models_dir, last_checkpoint["generator"])
        discriminator_save_name = os.path.join(
            self.models_dir, last_checkpoint["discriminator"])

        self.generator.load_state_dict(torch.load(generator_save_name))
        self.discriminator.load_state_dict(torch.load(discriminator_save_name))

        return  block_index, phase, step

    def train(self):
        # get debugging vectors
        N = (5, 10)
        debug_vectors = torch.randn(N[0] * N[1], self.num_channels, 1,
                                    1).to(self.device)

        # get loader
        loader = get_loader(self.data_path, self.crop_size, self.batch_size)

        losses = {
            "d_loss_real": None,
            "d_loss_fake": None,
            "g_loss": None
        }

        # resume training if needed
        if self.resume_training:
            start_index, start_phase, start_step = self.load_trained_networks()
        else:
            start_index, start_phase, start_step = (0, "fade", 0)

        # training loop
        start_time = time.time()
        absolute_step = -1
        for index in range(start_index, self.generator.num_blocks):
            loader.dataset.set_transform_by_index(index)
            data_iterator = iter(loader)
            for phase in ('fade', 'stabilize'):
                if index == 0 and phase == 'fade': continue
                if self.resume_training and \
                        index == start_index and \
                        phase is not start_phase:
                    continue #
                if phase == 'phade': self.alternating_step = 10000 #FIXME del
                print("index: {}, size: {}x{}, phase: {}".format(
                    index, 2 ** (index + 2), 2 ** (index + 2), phase))
                if self.resume_training and \
                        phase == start_phase     and \
                        index == start_index:
                    step_range = range(start_step, self.alternating_step)
                else:
                    step_range = range(self.alternating_step)
                for i in step_range:
                    absolute_step += 1
                    try:
                        batch = next(data_iterator)
                    except:
                        data_iterator = iter(loader)
                        batch = next(data_iterator)

                    alpha = i / self.alternating_step if phase == "fade" else 1.0

                    batch = batch.to(self.device)

                    d_loss_real = - torch.mean(
                        self.discriminator(batch, index, alpha))
                    losses["d_loss_real"] = torch.mean(d_loss_real).data[0]

                    latent = torch.randn(
                        batch.size(0), self.num_channels, 1, 1).to(self.device)
                    fake_batch = self.generator(latent, index, alpha).detach()
                    d_loss_fake = torch.mean(
                        self.discriminator(fake_batch, index, alpha))
                    losses["d_loss_fake"] = torch.mean(d_loss_fake).data[0]

                    # drift factor
                    drift = d_loss_real.pow(2) + d_loss_fake.pow(2)

                    d_loss = d_loss_real + d_loss_fake + self.eps_drift * drift
                    self.d_optimizer.zero_grad()
                    d_loss.backward()  # if retain_graph=True
                    # then gp works but I'm not sure it's right
                    self.d_optimizer.step()

                    # Compute gradient penalty
                    alpha_gp = torch.rand(batch.size(0), 1, 1, 1).to(self.device)
                    # mind that x_hat must be both detached from the previous
                    # gradient graph (from fake_barch) and with
                    # requires_graph=True so that the gradient can be computed
                    x_hat = (alpha_gp * batch + (1 - alpha_gp) *
                             fake_batch).requires_grad_(True)
                    # x_hat = torch.cuda.FloatTensor(x_hat).requires_grad_(True)
                    out = self.discriminator(x_hat, index, alpha)
                    grad = torch.autograd.grad(
                        outputs=out,
                        inputs=x_hat,
                        grad_outputs=torch.ones_like(out).to(self.device),
                        retain_graph=True,
                        create_graph=True,
                        only_inputs=True
                    )[0]
                    grad = grad.view(grad.size(0), -1)  # is this the same as
                    # detach?
                    l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                    d_loss_gp = torch.mean((l2norm - 1) ** 2)

                    d_loss_gp *= self.lambda_gp
                    self.d_optimizer.zero_grad()
                    d_loss_gp.backward()
                    self.d_optimizer.step()

                    # train generator
                    if (i + 1) % self.ncritic == 0:
                        latent = torch.randn(
                            self.batch_size, self.num_channels, 1, 1).to(self.device)
                        fake_batch = self.generator(latent, index, alpha)
                        g_loss = - torch.mean(self.discriminator(
                                                    fake_batch, index, alpha))
                        losses["g_loss"] = torch.mean(g_loss).data[0]
                        self.g_optimizer.zero_grad()
                        g_loss.backward()
                        self.g_optimizer.step()

                    # tensorboard logging
                    if (i + 1) % self.log_step == 0:
                        elapsed = time.time() - start_time
                        elapsed = str(datetime.timedelta(seconds=elapsed))
                        print("{}:{}:{}/{} time {}, d_loss_real {}, "
                              "d_loss_fake {}, "
                              "g_loss {}, alpha {}".format(index, phase, i,
                                                           self.alternating_step,
                                                           elapsed,
                                                           d_loss_real,
                                              d_loss_fake,
                                              g_loss, alpha))
                        for name, value in losses.items():
                            self.tflogger.scalar_summary(name, value, absolute_step)


                    # print debugging images
                    if (i + 1) % self.debug_step == 0:
                        self.print_debugging_images(
                            self.generator, debug_vectors, N, index, alpha, i)

                    # save trained networks
                    if (i + 1) % self.save_step == 0:
                        self.save_trained_networks(index, phase, i)
Exemplo n.º 25
0
def main():
    data_loader = load_data()

    G = Generator(C_dim_disc + C_dim_cont + Z_dim)
    D = Discriminator(C_dim_disc=C_dim_disc, C_dim_cont=C_dim_cont)

    G_optimizer = optim.Adam(G.parameters(), G_lr, [G_beta1, G_beta2])
    D_optimizer = optim.Adam(D.parameters(), D_lr, [G_beta1, D_beta2])

    for epoch in range(epochs):
        for it, (image, _label) in enumerate(data_loader):
            # Sample data
            X_numpy = np.array(image)
            X_numpy = -X_numpy * 2
            X = Variable(torch.from_numpy(X_numpy))

            C_disc, C_cont, Z = (
                sample_multinomial(size=[batch_size], n=C_dim_disc),
                sample_normal([batch_size, C_dim_cont], scale=0.5),
                sample_normal([batch_size, Z_dim], scale=0.5),
            )
            C = torch.cat([C_disc, C_cont], dim=1)

            G.train(True)

            # Dicriminator forward-loss-backward-update
            G_sample = G(torch.cat([C, Z], dim=1))

            D_real, D_result = D(X), D(G_sample, cal_Q=True)
            D_fake, C_given_X = D_result
            C_disc_given_x, C_cont_given_x = C_given_X
            # D_result = D(torch.cat([X, G_sample], dim=0), cal_Q=True)
            # D_real, D_fake = D_result[0][:mb_size], D_result[0][mb_size:]
            # C_disc_given_x, C_cont_given_x = D_result[1][0][mb_size:], D_result[1][1][mb_size:]

            Q_loss = -torch.mean(torch.sum(
                C_disc * C_disc_given_x, 1)) + torch.mean(
                    (((C_cont - C_cont_given_x) / 0.5)**2))

            D_loss = -torch.mean(
                torch.log(D_real + 1e-8) +
                torch.log(1 - D_fake + 1e-8)) + Q_loss

            D_optimizer.zero_grad()
            D_loss.backward(retain_graph=True)
            D_optimizer.step()

            # Generator forward-loss-backward-update
            G_loss = -torch.mean(torch.log(D_fake + 1e-8)) + Q_loss

            G_optimizer.zero_grad()
            G_loss.backward()
            G_optimizer.step()

            if it % 200 == 0:
                G.train(False)
                print(f"epoch: {epoch}, iter: {it}")
                print(
                    f"D_loss: {D_loss.data.numpy():.6f}, G_loss: {G_loss.data.numpy():.6f}, Prob diff: {D_real.mean() - D_fake.mean():.6f}"
                )

                C_sample_cont = torch.zeros([1, C_dim_cont])
                Z_sample = torch.zeros([1, Z_dim])

                C_sample_disc, C_sample_cont = (
                    sample_multinomial(size=[100], n=C_dim_disc),
                    sample_normal([100, C_dim_cont], scale=0.5),
                )
                G_sample = G(torch.cat([C_sample_disc, C_sample_cont], dim=1))

                for k in range(2):
                    gs = gridspec.GridSpec(10, 10)
                    fig = plt.figure(figsize=(10, 10))
                    plt.suptitle(f"epoch: {epoch}, iter: {it}\n{get_time()}")
                    gs.update(wspace=0.05, hspace=0.05)

                    k = 0
                    for i in range(10):
                        j_space = np.linspace(-1, 1,
                                              10) if k == 0 else range(10)
                        for j in j_space:
                            C_sample_disc = torch.zeros([1, 10])
                            if k == 0:
                                C_sample_disc[0, i] = 1
                                C_sample_cont[0, 0] = j
                            else:
                                C_sample_disc[0, i] = 1
                                C_sample_disc[0, j] = 1
                            G_sample = G(
                                torch.cat(
                                    [C_sample_disc, C_sample_cont, Z_sample],
                                    dim=1))[0]

                            ax = plt.subplot(gs[k])
                            plt.axis("off")
                            ax.set_xticklabels([])
                            ax.set_yticklabels([])
                            ax.set_aspect("equal")
                            plt.imshow(G_sample.reshape(28, 28).data.numpy(),
                                       cmap="Greys",
                                       vmin=-1,
                                       vmax=1)

                    fig.subplots_adjust(top=0.92)
                    plt.tight_layout()
                    plt.savefig(f"output/{epoch:02d}_{it:04d}_{k}.png",
                                bbox_inches="tight")
                    plt.close(fig)

    torch.save(G.state_dict(), "model/G.pkl")
    torch.save(D.state_dict(), "model/D.pkl")
Exemplo n.º 26
0
criterion = torch.nn.CrossEntropyLoss()
models=[]
zzz=np.arange(len(AllLabel)*2).reshape((len(AllLabel), 2))
X=FinalData
y=np.array(AllLabel)
skf = StratifiedKFold(n_splits=5,shuffle=True, random_state=random.randint(0,99))
for train_index,test_index in skf.split(zzz,y):
	X_train, X_test = X[train_index], X[test_index]
	y_train, y_test = y[train_index], y[test_index]
	torch_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train))
	data_loader = torch.utils.data.DataLoader(torch_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
	loss_classifier_list = []
	for epoch in range(1, num_epochs + 1):
		learning_rate = base_lr * math.pow(0.9, epoch / lr_step)
		gamma_rate = 2 / (1 + math.exp(-10 * (epoch) / num_epochs)) - 1
		optimizer = torch.optim.Adam([{'params': discriminator.parameters()},], lr=learning_rate, weight_decay=l2_decay)

		discriminator.train()
		iter_data = iter(data_loader)
		num_iter = len(data_loader)
		#print(num_iter)
		total_clas_loss = 0
		num_batches = 0
		for it in range(0, num_iter):
			data, label = iter_data.next()
			if it % len(data_loader) == 0:
				iter_data = iter(data_loader)
			data = Variable(torch.FloatTensor(data))
			label = Variable(torch.LongTensor(label))
			Disc_a = discriminator(data)
Exemplo n.º 27
0
def main(params):
    if params['load_dataset']:
        dataset = load_pkl(params['load_dataset'])
    elif params['dataset_class']:
        dataset = globals()[params['dataset_class']](**params[params['dataset_class']])
        if params['save_dataset']:
            save_pkl(params['save_dataset'], dataset)
    else:
        raise Exception('One of either load_dataset (path to pkl) or dataset_class needs to be specified.')
    result_dir = create_result_subdir(params['result_dir'], params['exp_name'])

    losses = ['G_loss', 'D_loss', 'D_real', 'D_fake']
    stats_to_log = [
        'tick_stat',
        'kimg_stat',
    ]
    if params['progressive_growing']:
        stats_to_log.extend([
            'depth',
            'alpha',
            'lod',
            'minibatch_size'
        ])
    stats_to_log.extend([
        'time',
        'sec.tick',
        'sec.kimg'
    ] + losses)
    logger = TeeLogger(os.path.join(result_dir, 'log.txt'), stats_to_log, [(1, 'epoch')])
    logger.log(params_to_str(params))
    if params['resume_network']:
        G, D = load_models(params['resume_network'], params['result_dir'], logger)
    else:
        G = Generator(dataset.shape, **params['Generator'])
        D = Discriminator(dataset.shape, **params['Discriminator'])
    if params['progressive_growing']:
        assert G.max_depth == D.max_depth
    G.cuda()
    D.cuda()
    latent_size = params['Generator']['latent_size']

    logger.log(str(G))
    logger.log('Total nuber of parameters in Generator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), G.parameters()))
    ))
    logger.log(str(D))
    logger.log('Total nuber of parameters in Discriminator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), D.parameters()))
    ))

    def get_dataloader(minibatch_size):
        return DataLoader(dataset, minibatch_size, sampler=InfiniteRandomSampler(dataset),
                          num_workers=params['num_data_workers'], pin_memory=False, drop_last=True)

    def rl(bs):
        return lambda: random_latents(bs, latent_size)

    # Setting up learning rate and optimizers
    opt_g = Adam(G.parameters(), params['G_lr_max'], **params['Adam'])
    opt_d = Adam(D.parameters(), params['D_lr_max'], **params['Adam'])

    def rampup(cur_nimg):
        if cur_nimg < params['lr_rampup_kimg'] * 1000:
            p = max(0.0, 1 - cur_nimg / (params['lr_rampup_kimg'] * 1000))
            return np.exp(-p * p * 5.0)
        else:
            return 1.0
    lr_scheduler_d = LambdaLR(opt_d, rampup)
    lr_scheduler_g = LambdaLR(opt_g, rampup)

    mb_def = params['minibatch_size']
    D_loss_fun = partial(wgan_gp_D_loss, return_all=True, iwass_lambda=params['iwass_lambda'],
                         iwass_epsilon=params['iwass_epsilon'], iwass_target=params['iwass_target'])
    G_loss_fun = wgan_gp_G_loss
    trainer = Trainer(D, G, D_loss_fun, G_loss_fun,
                      opt_d, opt_g, dataset, iter(get_dataloader(mb_def)), rl(mb_def), **params['Trainer'])
    # plugins
    if params['progressive_growing']:
        max_depth = min(G.max_depth, D.max_depth)
        trainer.register_plugin(DepthManager(get_dataloader, rl, max_depth, **params['DepthManager']))
    for i, loss_name in enumerate(losses):
        trainer.register_plugin(EfficientLossMonitor(i, loss_name))

    checkpoints_dir = params['checkpoints_dir'] if params['checkpoints_dir'] else result_dir
    trainer.register_plugin(SaverPlugin(checkpoints_dir, **params['SaverPlugin']))

    def subsitute_samples_path(d):
        return {k:(os.path.join(result_dir, v) if k == 'samples_path' else v) for k,v in d.items()}
    postprocessors = [ globals()[x](**subsitute_samples_path(params[x])) for x in params['postprocessors'] ]
    trainer.register_plugin(OutputGenerator(lambda x: random_latents(x, latent_size),
                                            postprocessors, **params['OutputGenerator']))
    trainer.register_plugin(AbsoluteTimeMonitor(params['resume_time']))
    trainer.register_plugin(LRScheduler(lr_scheduler_d, lr_scheduler_g))
    trainer.register_plugin(logger)
    init_comet(params, trainer)
    trainer.run(params['total_kimg'])
    dataset.close()
Exemplo n.º 28
0
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set up the GAN.
discriminator_model = Discriminator().to(DEVICE)
generator_model = Generator().to(DEVICE)

# Load pre-trained models if they are provided.
if args.load_discriminator_model_path:
    discriminator_model.load_state_dict(
        torch.load(args.load_discriminator_model_path))

if args.load_generator_model_path:
    generator_model.load_state_dict(torch.load(args.load_generator_model_path))

# Set up Adam optimizers for both models.
discriminator_optimizer = optim.Adam(discriminator_model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0, 0.9))
generator_optimizer = optim.Adam(generator_model.parameters(),
                                 lr=args.learning_rate,
                                 betas=(0, 0.9))

# Create a random batch of latent space vectors that will be used to visualize the progression of the generator.
fixed_latent_space_vectors = torch.randn(64, 512, 1, 1, device=DEVICE)

# Load and preprocess images.
images = load_images(args.data_dir, args.training_set_size)

# Add network architectures for Discriminator and Generator to TensorBoard.
if not args.dry_run:
    WRITER.add_graph(discriminator_model,
Exemplo n.º 29
0
n_classes = 10
init_epoch = 0
max_epoch = 10
batch_size = 128
save_epoch = 1
in_channels = z_dim + n_classes
out_channels = n_classes + 1  # for fake class
lr_init = 0.001

train_loader = get_dataset(root, batch_size)

gen = Generator(in_channels).to(device)
dis = Discriminator(out_channels).to(device)

g_optimizer = torch.optim.RMSprop(gen.parameters(), lr=lr_init, alpha=0.9)
d_optimizer = torch.optim.RMSprop(dis.parameters(), lr=lr_init, alpha=0.9)

gen_loss = []
dis_loss = []

for index in range(init_epoch, max_epoch):
    for i, data in enumerate(train_loader):
        # lr reduced by factor of 10 after 8 epoch
        if index >= 8:
            for g in g_optimizer.param_groups:
                g['lr'] = lr_init / 10
            for g in d_optimizer.param_groups:
                g['lr'] = lr_init / 10
        # Define input variables
        x_n = torch.tensor(np.array(data[0], dtype=np.float32), device=device)
        y = torch.tensor(np.array(data[1], dtype=np.float32), device=device)
Exemplo n.º 30
0
#print(netG_A2B)

#print(netD_A)

# Lossess
gan_loss = torch.nn.MSELoss()
cycle_consistency_loss = torch.nn.L1Loss()

#optimizers
optimizerG_A2B = optim.Adam(netG_A2B.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999))
optimizerG_B2A = optim.Adam(netG_B2A.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999))
optimizerD_A = optim.Adam(netD_A.parameters(),
                          lr=opt.lr,
                          betas=(opt.beta1, 0.999))
optimizerD_B = optim.Adam(netD_B.parameters(),
                          lr=opt.lr,
                          betas=(opt.beta1, 0.999))

lr_scheduler_G_A2B = optim.lr_scheduler.LambdaLR(optimizerG_A2B,
                                                 lr_lambda=updateLr(
                                                     opt.num_epochs,
                                                     opt.decay_epoch).update)
lr_scheduler_G_B2A = optim.lr_scheduler.LambdaLR(optimizerG_B2A,
                                                 lr_lambda=updateLr(
                                                     opt.num_epochs,
                                                     opt.decay_epoch).update)
lr_scheduler_D_A = optim.lr_scheduler.LambdaLR(optimizerD_A,
Exemplo n.º 31
0
def main(params):
    if params['load_dataset']:
        dataset = load_pkl(params['load_dataset'])
    elif params['dataset_class']:
        dataset = globals()[params['dataset_class']](
            **params[params['dataset_class']])
        if params['save_dataset']:
            save_pkl(params['save_dataset'], dataset)
    else:
        raise Exception(
            'One of either load_dataset (path to pkl) or dataset_class needs to be specified.'
        )
    result_dir = create_result_subdir(params['result_dir'], params['exp_name'])

    losses = ['G_loss', 'D_loss', 'D_real', 'D_fake']
    stats_to_log = [
        'tick_stat',
        'kimg_stat',
    ]
    if params['progressive_growing']:
        stats_to_log.extend(['depth', 'alpha', 'lod', 'minibatch_size'])
    stats_to_log.extend(['time', 'sec.tick', 'sec.kimg'] + losses)
    logger = TeeLogger(os.path.join(result_dir, 'log.txt'), stats_to_log,
                       [(1, 'epoch')])
    logger.log(params_to_str(params))
    if params['resume_network']:
        G, D = load_models(params['resume_network'], params['result_dir'],
                           logger)
    else:
        G = Generator(dataset.shape, **params['Generator'])
        D = Discriminator(dataset.shape, **params['Discriminator'])
    if params['progressive_growing']:
        assert G.max_depth == D.max_depth
    G.cuda()
    D.cuda()
    latent_size = params['Generator']['latent_size']

    logger.log(str(G))
    logger.log('Total nuber of parameters in Generator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a * b, x.size()),
                G.parameters()))))
    logger.log(str(D))
    logger.log('Total nuber of parameters in Discriminator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a * b, x.size()),
                D.parameters()))))

    def get_dataloader(minibatch_size):
        return DataLoader(dataset,
                          minibatch_size,
                          sampler=InfiniteRandomSampler(dataset),
                          num_workers=params['num_data_workers'],
                          pin_memory=False,
                          drop_last=True)

    def rl(bs):
        return lambda: random_latents(bs, latent_size)

    # Setting up learning rate and optimizers
    opt_g = Adam(G.parameters(), params['G_lr_max'], **params['Adam'])
    opt_d = Adam(D.parameters(), params['D_lr_max'], **params['Adam'])

    def rampup(cur_nimg):
        if cur_nimg < params['lr_rampup_kimg'] * 1000:
            p = max(0.0, 1 - cur_nimg / (params['lr_rampup_kimg'] * 1000))
            return np.exp(-p * p * 5.0)
        else:
            return 1.0

    lr_scheduler_d = LambdaLR(opt_d, rampup)
    lr_scheduler_g = LambdaLR(opt_g, rampup)

    mb_def = params['minibatch_size']
    D_loss_fun = partial(wgan_gp_D_loss,
                         return_all=True,
                         iwass_lambda=params['iwass_lambda'],
                         iwass_epsilon=params['iwass_epsilon'],
                         iwass_target=params['iwass_target'])
    G_loss_fun = wgan_gp_G_loss
    trainer = Trainer(D, G, D_loss_fun, G_loss_fun, opt_d, opt_g, dataset,
                      iter(get_dataloader(mb_def)), rl(mb_def),
                      **params['Trainer'])
    # plugins
    if params['progressive_growing']:
        max_depth = min(G.max_depth, D.max_depth)
        trainer.register_plugin(
            DepthManager(get_dataloader, rl, max_depth,
                         **params['DepthManager']))
    for i, loss_name in enumerate(losses):
        trainer.register_plugin(EfficientLossMonitor(i, loss_name))

    checkpoints_dir = params['checkpoints_dir'] if params[
        'checkpoints_dir'] else result_dir
    trainer.register_plugin(
        SaverPlugin(checkpoints_dir, **params['SaverPlugin']))

    def subsitute_samples_path(d):
        return {
            k: (os.path.join(result_dir, v) if k == 'samples_path' else v)
            for k, v in d.items()
        }

    postprocessors = [
        globals()[x](**subsitute_samples_path(params[x]))
        for x in params['postprocessors']
    ]
    trainer.register_plugin(
        OutputGenerator(lambda x: random_latents(x, latent_size),
                        postprocessors, **params['OutputGenerator']))
    trainer.register_plugin(AbsoluteTimeMonitor(params['resume_time']))
    trainer.register_plugin(LRScheduler(lr_scheduler_d, lr_scheduler_g))
    trainer.register_plugin(logger)
    init_comet(params, trainer)
    trainer.run(params['total_kimg'])
    dataset.close()