Beispiel #1
0
 def test_glow(self):
     net = Glow(width=12, depth=3, n_levels=3)
     x = torch.randn(args.batch_size, 3, 32, 32)
     zs, logd = net(x)
     recon_x, inv_logd = net.inverse(zs)
     y, _ = net.inverse(batch_size=args.batch_size)
     d_data, d_data_y, d_logd = (recon_x -
                                 x).norm(), (x -
                                             y).norm(), (logd +
                                                         inv_logd).norm()
     assert d_data < 1e-3, 'Data reconstruction fail - norm of difference = {}.'.format(
         d_data)
     #        assert d_data_y < 1e-3, 'Data reconstruction (inv > base > inv) fail - norm of difference = {}.'.format(d_data_y)
     assert d_logd < 1e-3, 'Log determinant inversion fail. - norm of difference = {}'.format(
         d_logd)
Beispiel #2
0
class BeautyGlow(nn.Module):

    def __init__(self):
        self.w = nn.Linear(128, 128, bias=False)
        self.glow = Glow(3, 32, 4, affine=True, conv_lu=True)

    def forward(self, reference, source, l_x, l_y):
        l_ref = self.glow.reverse(reference)
        l_source = self.glow.reverse(source)
        f_ref = self.w(l_ref)
        f_source = self.w(l_souece)
        m_ref = F.linear(l_ref, torch.eye(128) - self.w.weight)
        m_source =F.linear(l_source, torch.eye(128) - self.w.weight)
        l_source_y = m_ref + f_source
        print(l_source_y)
        result = self.glow(l_source)

        perceptual_loss = F.mse_loss(f_ref, l_source)

        makeup_loss = F.mse_loss(m_ref, l_y - l_x)

        intra_domain_loss = F.mse_loss(f_ref, l_x) + F.mse_loss(l_source, l_y)

        l2_norm_f = F.mse_loss(f_ref, torch.zeros(f_ref.size())) * \
            F.mse_loss(l_y, torch.zeros(l_y.size()))
        sim_f = torch.sum(f_ref * l_y) / l2_norm_f
        l2_norm_l = F.mse_loss(l_source, torch.zeros(l_source.size())) * \
            F.mse_loss(l_x, torch.zeros(l_x.size()))
        sim_l = torch.sum(l_source * l_x) / l2_norm_l
        inter_domain_loss = 1 + sim_f + 1 + sim_l

        cycle_f = F.mse_loss(self.w(l_source_y), f_source)
        cycle_m = F.mse_loss(F.linear(l_source_y, torch.eye(128) - self.w.weight, m_ref))
        cycle_consistency_loss = cycle_f + cycle_m

        perceptual = 0.01
        cycle = 0.001
        makeup = 0.1
        intra = 0.1
        inter = 1000

        loss = perceptual_loss + cycle * cycle_consistency_loss + makeup * makeup_loss\
            + intra * intra_domain_loss + inter * inter_domain_loss

        return result, loss
Beispiel #3
0
 def __init__(self):
     self.w = nn.Linear(128, 128, bias=False)
     self.glow = Glow(3, 32, 4, affine=True, conv_lu=True)
Beispiel #4
0
 def test_glow_3_3(self):
     model = Glow(width=24, depth=3, n_levels=3)
     self._train(model, 3)
Beispiel #5
0
 def test_glow_1_1(self):
     model = Glow(width=12, depth=1, n_levels=1)
     self._train(model, 3)
Beispiel #6
0
 def test_glow_depth_2_levels_2(self):
     # 1. sample data; 2. run model forward and reverse; 3. roconstruct data; 4. measure KL between Gaussian fitted to the data and the base distribution
     self.test_kl(Glow(width=12, depth=2, n_levels=2))
Beispiel #7
0
    def train(self, params):

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        batch_size = params["batch_size"]
        learning_rate = params["learning_rate"]
        max_epoch = params["max_epoch"]
        interval = params["interval"]
        image_shape = params["image_shape"]
        dataset_name = params["dataset_name"]

        max_grad_clip = params["max_grad_clip"]
        max_grad_norm = params["max_grad_norm"]

        train_dataset = util.ImageDataset(params)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=batch_size,
                                      num_workers=4,
                                      shuffle=True,
                                      drop_last=True)

        dt_now = datetime.now()
        dt_seq = dt_now.strftime("%y%m%d_%H%M")
        result_dir = os.path.join("./result", f"{dt_seq}_{dataset_name}")
        weight_dir = os.path.join(result_dir, "weights")
        sample_dir = os.path.join(result_dir, "sample")
        os.makedirs(result_dir, exist_ok=True)
        os.makedirs(weight_dir, exist_ok=True)
        os.makedirs(sample_dir, exist_ok=True)

        glow = Glow(params).to(device)

        optimizer = Adam(glow.parameters(), lr=learning_rate)

        initialized = False
        for epoch in range(max_epoch):
            for i, batch in enumerate(train_dataloader):
                batch = batch.to(device)
                if not initialized:
                    glow.initialize_actnorm(batch)
                    initialized = True
                z, nll = glow.inference(batch)

                loss_generative = torch.mean(nll)

                optimizer.zero_grad()
                loss_generative.backward()
                torch.nn.utils.clip_grad_value_(glow.parameters(),
                                                max_grad_clip)
                torch.nn.utils.clip_grad_norm_(glow.parameters(),
                                               max_grad_norm)
                optimizer.step()

                print(
                    f"epoch {epoch} {i}/{len(train_dataloader)}, loss: {loss_generative.item():.4f}"
                )

            if epoch % interval == 0:
                torch.save(glow.state_dict(), f"{weight_dir}/{epoch}_glow.pth")
                torch.save(optimizer.state_dict(),
                           f"{weight_dir}/{epoch}_opt.pth")
                filename = f"{epoch}_glow.png"
                with torch.no_grad():
                    img = glow.generate(z, eps_std=0.5)
                    util.save_samples(img,
                                      sample_dir,
                                      filename,
                                      image_shape,
                                      num_tiles=4)
Beispiel #8
0
def train(cfg):
    date_today = date.today().strftime("%b-%d-%Y")
    summary_writer = SummaryWriter(cfg.log_dir,
                                   flush_secs=5,
                                   filename_suffix=date_today)
    train_data = mx.gluon.data.vision.MNIST(
        train=True).transform_first(data_xform)
    train_loader = mx.gluon.data.DataLoader(train_data,
                                            shuffle=True,
                                            batch_size=cfg.batch_size)
    image_shape = train_data[0][0].shape

    # No initialization. Custom blocks encapsulate initialization and setting of data.
    net = Glow(image_shape, cfg.K, cfg.L, cfg.affine, cfg.filter_size,
               cfg.temp, cfg.n_bits)
    ctx = get_context(cfg.use_gpu)
    net = set_context(net, ctx)

    trainer = mx.gluon.Trainer(net.collect_params(), 'adam',
                               {'learning_rate': cfg.lr})
    n_samples = len(train_loader)
    update_interval = n_samples // 2  # store the loss with summary writer twice
    loss_buffer = LossBuffer()
    global_step = 1

    for epoch in range(1, cfg.n_epochs + 1):
        for idx, (batch, label) in enumerate(train_loader):
            print(f'Epoch {epoch} - Batch {idx}/{n_samples}', end='\r')

            data = mx.gluon.utils.split_and_load(batch, ctx)
            with mx.autograd.record():
                for X in data:
                    z_list, nll, bpd = net(X)
                    prev_loss = loss_buffer.new_loss(bpd.mean())

            loss_buffer.loss.backward()
            trainer.step(1)

            if prev_loss is not None and global_step % update_interval == 0:
                loss = prev_loss.asscalar()
                summary_writer.add_scalar(tag='bpd',
                                          value=loss,
                                          global_step=global_step)

            global_step += 1

        # Sample from latent space to generate random digit and reverse from latent
        if (epoch % cfg.plot_interval) == 0:
            x_generate = net.reverse()[0]
            x_generate = x_generate.reshape(1, *x_generate.shape)
            x_recon = net.reverse(z_list[-1])[0]
            x_recon = x_recon.reshape(1, *x_recon.shape)
            x_real = data[0][0].reshape(1, *data[0][0].shape)
            minim = -0.5
            maxim = 0.5
            x_generate = x_generate.clip(minim, maxim)
            x_generate += -minim
            x_recon = x_recon.clip(minim, maxim)
            x_recon += -minim
            x_real += -minim

            img = mx.nd.concatenate([x_real, x_generate, x_recon],
                                    axis=0).asnumpy()
            summary_writer.add_image(tag='generations',
                                     image=img,
                                     global_step=global_step)

    summary_writer.close()
Beispiel #9
0
opt = parser.parse_args()
print(opt)

if torch.cuda.is_available() and not opt.cuda:
    print(
        "WARNING: You have a CUDA device, so you should probably run with --cuda"
    )

###### Definition of variables ######
# Networks
if opt.generator == "baseline":
    generator = CycleConsistentGenerator(opt.input_nc, opt.output_nc)
    generator.apply(weights_init_normal)
elif opt.generator == "glow":
    generator = Glow(16, opt.input_nc, 256, squeeze=4)

netD_A = Discriminator(opt.input_nc)
netD_B = Discriminator(opt.output_nc)

if opt.cuda:
    generator.cuda()
    netD_A.cuda()
    netD_B.cuda()

netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()