Beispiel #1
0
 def part2(self):
     # even moves for santa and odd moves for robo
     santa_moves = self._moves[::2]
     robo_moves = self._moves[1::2]
     santa_coords = util.accumulate(santa_moves, self._add_coords)
     robo_coords = util.accumulate(robo_moves, self._add_coords)
     return len(set(itertools.chain(santa_coords, robo_coords)))
Beispiel #2
0
    def print_solution(self, network, solution):
        total_cost = network.calculate_cost()
        adjustment = 0
        avbs = network.find_nodes_satisfying(lambda x: x.key()[0] == 'avb' or x.key()[0] == 'rvbp')
        for n in avbs:
            for a in n.arcs():
                if a.cost() < 0:
                    adjustment -= a.cost() * a.total_flow()

        za = VbMapProblem.find_flow_between_nodes_in_network(network,
                                                             'prev_avb', 'avb',
                                                             lambda n1, n2: n1.key()[2])
        for color, zac in za.items():
            print "active moves for color:", color
            print twod_array_to_string(array=zac, with_indices=True, delimiter='\t', total=True)

        zr = VbMapProblem.find_flow_between_nodes_in_network(network,
                                                             'rvb', 'prev_rvb',
                                                             lambda n1, n2: n2.key()[2])
        for color, zrc in zr.items():
            print "replica moves for color:", color
            print twod_array_to_string(array=zrc, with_indices=True, delimiter='\t', total=True)

        za_total = util.accumulate(za.viewvalues(), util.add_to)
        print "total active vbucket moves:"
        print twod_array_to_string(array=za_total, with_indices=True, delimiter='\t', total=True)
        zr_total = util.accumulate(zr.viewvalues(), util.add_to)
        print "total replica vbucket moves:"
        print twod_array_to_string(array=zr_total, with_indices=True, delimiter='\t', total=True)

        x = [[0 for _ in range(self.node_count)] for _ in range(self.node_count)]
        for p, f in solution['flows']:
            from_node = p[0].to_node().key()[1]
            to_node = p[1].to_node().key()[1]
            x[from_node][to_node] += f
        print "solution result:", solution['result']
        print "total cost:", total_cost + adjustment
        print twod_array_to_string(array=x, with_indices=True, delimiter='\t')
        for p, f in solution['flows']:
            print p.sum_costs(), ",", f, ": ", p
Beispiel #3
0
    def test_accumulate_matrices(self):
        m1 = [[1 for _ in range(4)] for _ in range(3)]
        m2 = [[2 for _ in range(4)] for _ in range(5)]
        m = [m1, m2]

        def add_to(first, second):
            if first is None:
                first = util.make_zero_matrix_of_same_dimension(second)
            util.add_to(first, second)
            return first

        result = util.accumulate(m, util.add_to)
        self.assertEqual(len(result), 3)
        for a in result:
            self.assertEqual(len(a), 4)
            for b in a:
                self.assertEqual(b, 3)
Beispiel #4
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    ncols=140,
                    dynamic_ncols=False,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"iter: {i:05d}; d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 100 == 0 or (i + 1) == args.iter:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    sample = F.interpolate(sample, 256)
                    utils.save_image(
                        sample,
                        f"log/%s/finetune-%06d.jpg" % (args.style, i),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if (i + 1) % args.save_every == 0 or (i + 1) == args.iter:
                torch.save(
                    {
                        #"g": g_module.state_dict(),
                        #"d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        #"g_optim": g_optim.state_dict(),
                        #"d_optim": d_optim.state_dict(),
                        #"args": args,
                        #"ada_aug_p": ada_aug_p,
                    },
                    f"%s/%s/fintune-%06d.pt" %
                    (args.model_path, args.style, i + 1),
                )
Beispiel #5
0
    #elif args.arch == 'swagan':
    #from swagan import Generator, Discriminator

    generator = Generator(
        args.size,
        args.latent,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier).to(device)
    discriminator = Discriminator(
        args.size, channel_multiplier=args.channel_multiplier).to(device)
    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      channel_multiplier=args.channel_multiplier).to(device)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = optim.Adam(
        generator.parameters(),
        lr=args.lr * g_reg_ratio,
        betas=(0**g_reg_ratio, 0.99**g_reg_ratio),
    )
    d_optim = optim.Adam(
        discriminator.parameters(),
        lr=args.lr * d_reg_ratio,
        betas=(0**d_reg_ratio, 0.99**d_reg_ratio),
    )
Beispiel #6
0
    problem = vbmap.VbMapProblem(args.n, args.r, args.s, args.working, prev)
    problem.set_use_existing_solution(args.existing)
    problem.generate_replica_networks()

    if args.solver == 'custom':
        problem.solve_min_cost_flow()
    else:
        problem.generate_vbmap_with_colors()
        problem.print_result()

        print "active moves: ", problem.get_total_active_vbucket_moves()
        print "replica moves: ", problem.get_total_replica_vbucket_moves()

        print "color count:", problem.previous.color_count
        plan = problem.make_plan()
        for p in plan:
            print p

        print "active vbuckets"
        avb = util.accumulate(problem.get_active_vbucket_moves(), util.add_to)
        print vbmap.twod_array_to_string(array=avb, with_indices=True, delimiter='\t')

        print "flows"
        x = problem.get_colored_replication_map()
        x_agg = util.accumulate(x, util.add_to)
        print vbmap.twod_array_to_string(x_agg, True, '', '\t')

        print "replica vbuckets"
        rvb = util.accumulate(problem.get_replica_vbucket_moves(), util.add_to)
        print vbmap.twod_array_to_string(rvb, True, '', '\t')
Beispiel #7
0
 def test_basic_collapse(self):
     l = [1, 2, 3]
     result = util.accumulate(l, lambda x, y: x+y, 0)
     print "sum:", result
     self.assertEqual(6, result)
Beispiel #8
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          instyles, Simgs, exstyles, vggloss, id_loss, device):
    loader = sample_data(loader)
    vgg_weights = [0.0, 0.5, 1.0, 0.0, 0.0]
    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    smoothing=0.01,
                    ncols=180,
                    dynamic_ncols=False)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_instyle = torch.randn(args.n_sample, args.latent, device=device)
    sample_exstyle, _, _ = get_paired_data(instyles,
                                           Simgs,
                                           exstyles,
                                           batch_size=args.n_sample,
                                           random_ind=8)
    sample_exstyle = sample_exstyle.to(device)

    for idx in pbar:
        i = idx + args.start_iter

        which = i % args.subspace_freq  # defines whether we use paired data

        if i > args.iter:
            print("Done!")
            break

        # sample S
        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        if which == 0:
            # sample z^+_e, z for Lsty, Lcon and Ladv
            exstyle, _, _ = get_paired_data(instyles,
                                            Simgs,
                                            exstyles,
                                            batch_size=args.batch,
                                            random_ind=8)
            exstyle = exstyle.to(device)
            instyle = mixing_noise(args.batch, args.latent, args.mixing,
                                   device)
            z_plus_latent = False
        else:
            # sample z^+_e, z^+_i and S for Eq. (4)
            exstyle, instyle, real_img = get_paired_data(instyles,
                                                         Simgs,
                                                         exstyles,
                                                         batch_size=args.batch,
                                                         random_ind=8)
            exstyle = exstyle.to(device)
            instyle = [instyle.to(device)]
            real_img = real_img.to(device)
            z_plus_latent = True

        fake_img, _ = generator(instyle,
                                exstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss  # Ladv
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        if which == 0:
            # sample z^+_e, z for Lsty, Lcon and Ladv
            exstyle, _, real_img = get_paired_data(instyles,
                                                   Simgs,
                                                   exstyles,
                                                   batch_size=args.batch,
                                                   random_ind=8)
            real_img = real_img.to(device)
            exstyle = exstyle.to(device)
            instyle = mixing_noise(args.batch, args.latent, args.mixing,
                                   device)
            z_plus_latent = False
        else:
            # sample z^+_e, z^+_i and S for Eq. (4)
            exstyle, instyle, real_img = get_paired_data(instyles,
                                                         Simgs,
                                                         exstyles,
                                                         batch_size=args.batch,
                                                         random_ind=8)
            exstyle = exstyle.to(device)
            instyle = [instyle.to(device)]
            real_img = real_img.to(device)
            z_plus_latent = True

        fake_img, _ = generator(instyle,
                                exstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        with torch.no_grad():
            real_img_256 = F.adaptive_avg_pool2d(real_img, 256).detach()
            real_feats = vggloss(real_img_256)
            real_styles = [
                F.adaptive_avg_pool2d(real_feat, output_size=1).detach()
                for real_feat in real_feats
            ]
            real_content, _ = generator(instyle,
                                        None,
                                        use_res=False,
                                        z_plus_latent=z_plus_latent)
            real_content_256 = F.adaptive_avg_pool2d(real_content,
                                                     256).detach()

        fake_img_256 = F.adaptive_avg_pool2d(fake_img, 256)
        fake_feats = vggloss(fake_img_256)
        fake_styles = [
            F.adaptive_avg_pool2d(fake_feat, output_size=1)
            for fake_feat in fake_feats
        ]
        sty_loss = (torch.tensor(0.0).to(device) if args.CX_loss == 0 else
                    FCX.contextual_loss(fake_feats[2],
                                        real_feats[2].detach(),
                                        band_width=0.2,
                                        loss_type='cosine') * args.CX_loss)
        if args.style_loss > 0:
            sty_loss += ((F.mse_loss(fake_styles[1], real_styles[1]) +
                          F.mse_loss(fake_styles[2], real_styles[2])) *
                         args.style_loss)

        ID_loss = (torch.tensor(0.0).to(device) if args.id_loss == 0 else
                   id_loss(fake_img_256, real_content_256) * args.id_loss)

        gr_loss = torch.tensor(0.0).to(device)
        if which > 0:
            for ii, weight in enumerate(vgg_weights):
                if weight * args.perc_loss > 0:
                    gr_loss += F.l1_loss(
                        fake_feats[ii],
                        real_feats[ii].detach()) * weight * args.perc_loss

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)
        l2_reg_loss = sum(
            torch.norm(p)
            for p in g_module.res.parameters()) * args.L2_reg_loss

        loss_dict["g"] = g_loss  # Ladv
        loss_dict["gr"] = gr_loss  # Lperc
        loss_dict["l2"] = l2_reg_loss  # Lreg in Lcon
        loss_dict["id"] = ID_loss  # LID in Lcon
        loss_dict["sty"] = sty_loss  # Lsty
        g_loss = g_loss + gr_loss + sty_loss + l2_reg_loss + ID_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)

            instyle = mixing_noise(path_batch_size, args.latent, args.mixing,
                                   device)
            exstyle, _, _ = get_paired_data(instyles,
                                            Simgs,
                                            exstyles,
                                            batch_size=path_batch_size,
                                            random_ind=8)
            exstyle = exstyle.to(device)

            fake_img, latents = generator(instyle,
                                          exstyle,
                                          return_latents=True,
                                          use_res=True,
                                          z_plus_latent=False)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema.res, g_module.res, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        gr_loss_val = loss_reduced["gr"].mean().item()
        sty_loss_val = loss_reduced["sty"].mean().item()
        l2_loss_val = loss_reduced["l2"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        id_loss_val = loss_reduced["id"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"iter: {i:d}; d: {d_loss_val:.3f}; g: {g_loss_val:.3f}; gr: {gr_loss_val:.3f}; sty: {sty_loss_val:.3f}; l2: {l2_loss_val:.3f}; id: {id_loss_val:.3f}; "
                f"r1: {r1_val:.3f}; path: {path_loss_val:.3f}; mean path: {mean_path_length_avg:.3f}; "
                f"augment: {ada_aug_p:.4f};"))

            if i % 100 == 0 or (i + 1) == args.iter:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_instyle],
                                      sample_exstyle,
                                      use_res=True)
                    sample = F.interpolate(sample, 256)
                    utils.save_image(
                        sample,
                        f"log/%s/dualstylegan-%06d.jpg" % (args.style, i),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if ((i + 1) >= args.save_begin and
                (i + 1) % args.save_every == 0) or (i + 1) == args.iter:
                torch.save(
                    {
                        #"g": g_module.state_dict(),
                        #"d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        #"g_optim": g_optim.state_dict(),
                        #"d_optim": d_optim.state_dict(),
                        #"args": args,
                        #"ada_aug_p": ada_aug_p,
                    },
                    f"%s/%s/%s-%06d.pt" %
                    (args.model_path, args.style, args.model_name, i + 1),
                )
def pretrain(args,
             loader,
             generator,
             discriminator,
             g_optim,
             d_optim,
             g_ema,
             encoder,
             vggloss,
             device,
             inject_index=5,
             savemodel=True):
    loader = sample_data(loader)
    vgg_weights = [0.5, 0.5, 0.5, 0.0, 0.0]
    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    ncols=140,
                    dynamic_ncols=False,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_zs = mixing_noise(args.n_sample, args.latent, 1.0, device)
    with torch.no_grad():
        source_img, _ = generator([sample_zs[0]],
                                  None,
                                  input_is_latent=False,
                                  z_plus_latent=False,
                                  use_res=False)
        source_img = source_img.detach()
        target_img, _ = generator(sample_zs,
                                  None,
                                  input_is_latent=False,
                                  z_plus_latent=False,
                                  inject_index=inject_index,
                                  use_res=False)
        target_img = target_img.detach()
        style_img, _ = generator([sample_zs[1]],
                                 None,
                                 input_is_latent=False,
                                 z_plus_latent=False,
                                 use_res=False)
        _, sample_style = encoder(F.adaptive_avg_pool2d(style_img, 256),
                                  randomize_noise=False,
                                  return_latents=True,
                                  z_plus_latent=True,
                                  return_z_plus_latent=False)
        sample_style = sample_style.detach()
        if get_rank() == 0:
            utils.save_image(F.adaptive_avg_pool2d(source_img, 256),
                             f"log/%s-instyle.jpg" % (args.model_name),
                             nrow=int(args.n_sample**0.5),
                             normalize=True,
                             range=(-1, 1))
            utils.save_image(F.adaptive_avg_pool2d(target_img, 256),
                             f"log/%s-target.jpg" % (args.model_name),
                             nrow=int(args.n_sample**0.5),
                             normalize=True,
                             range=(-1, 1))
            utils.save_image(F.adaptive_avg_pool2d(style_img, 256),
                             f"log/%s-exstyle.jpg" % (args.model_name),
                             nrow=int(args.n_sample**0.5),
                             normalize=True,
                             range=(-1, 1))

    for idx in pbar:
        i = idx + args.start_iter

        which = i % args.subspace_freq

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        # real_zs contains z1 and z2
        real_zs = mixing_noise(args.batch, args.latent, 1.0, device)
        with torch.no_grad():
            # g(z^+_l) with l=inject_index
            target_img, _ = generator(real_zs,
                                      None,
                                      input_is_latent=False,
                                      z_plus_latent=False,
                                      inject_index=inject_index,
                                      use_res=False)
            target_img = target_img.detach()
            # g(z2)
            style_img, _ = generator([real_zs[1]],
                                     None,
                                     input_is_latent=False,
                                     z_plus_latent=False,
                                     use_res=False)
            style_img = style_img.detach()
            # E(g(z2))
            _, pspstyle = encoder(F.adaptive_avg_pool2d(style_img, 256),
                                  randomize_noise=False,
                                  return_latents=True,
                                  z_plus_latent=True,
                                  return_z_plus_latent=False)
            pspstyle = pspstyle.detach()

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        if which > 0:
            # set z~_2 = z2
            noise = [real_zs[0]]
            externalstyle = g_module.get_latent(real_zs[1]).detach()
            z_plus_latent = False
        else:
            # set z~_2 = E(g(z2))
            noise = [real_zs[0].unsqueeze(1).repeat(1, g_module.n_latent, 1)]
            externalstyle = pspstyle
            z_plus_latent = True

        fake_img, _ = generator(noise,
                                externalstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred) * 0.1

        loss_dict["d"] = d_loss  # Ladv
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        if which > 0:
            # set z~_2 = z2
            noise = [real_zs[0]]
            externalstyle = g_module.get_latent(real_zs[1]).detach()
            z_plus_latent = False
        else:
            # set z~_2 = E(g(z2))
            noise = [real_zs[0].unsqueeze(1).repeat(1, g_module.n_latent, 1)]
            externalstyle = pspstyle
            z_plus_latent = True

        fake_img, _ = generator(noise,
                                externalstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        real_feats = vggloss(F.adaptive_avg_pool2d(target_img, 256).detach())
        fake_feats = vggloss(F.adaptive_avg_pool2d(fake_img, 256))
        gr_loss = torch.tensor(0.0).to(device)
        for ii, weight in enumerate(vgg_weights):
            if weight > 0:
                gr_loss += F.l1_loss(fake_feats[ii],
                                     real_feats[ii].detach()) * weight

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred) * 0.1

        loss_dict["g"] = g_loss  # Ladv
        loss_dict["gr"] = gr_loss  # L_perc

        g_loss += gr_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)

            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            externalstyle = torch.randn(path_batch_size, 512, device=device)
            externalstyle = g_module.get_latent(externalstyle).detach()
            fake_img, latents = generator(noise,
                                          externalstyle,
                                          return_latents=True,
                                          use_res=True,
                                          z_plus_latent=False)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema.res, g_module.res, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        gr_loss_val = loss_reduced["gr"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"iter: {i:d}; d: {d_loss_val:.3f}; g: {g_loss_val:.3f}; gr: {gr_loss_val:.3f}; r1: {r1_val:.3f}; "
                f"path: {path_loss_val:.3f}; mean path: {mean_path_length_avg:.3f}; "
                f"augment: {ada_aug_p:.1f}"))

            if i % 300 == 0 or (i + 1) == args.iter:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([
                        sample_zs[0].unsqueeze(1).repeat(
                            1, g_module.n_latent, 1)
                    ],
                                      sample_style,
                                      use_res=True,
                                      z_plus_latent=True)
                    sample = F.interpolate(sample, 256)
                    utils.save_image(
                        sample,
                        f"log/%s-%06d.jpg" % (args.model_name, i),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if savemodel and ((i + 1) % args.save_every == 0 or
                              (i + 1) == args.iter):
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    f"%s/%s-%06d.pt" %
                    (args.model_path, args.model_name, i + 1),
                )
Beispiel #10
0
 def part2(self):
     moves = self._get_moves()
     floors = util.accumulate(moves)
     # position when gets into basement
     return next((pos for pos, f in enumerate(floors) if f == -1), None)
Beispiel #11
0
 def part1(self):
     all_coords = util.accumulate(self._moves, self._add_coords)
     return len(set(all_coords))
Beispiel #12
0
 def hist_to_cdf(histogram):
     return list(accumulate(normalize(histogram)))
Beispiel #13
0
 def get_replica_vbuckets(self):
     result = self.vbmap_model.get_variable("rvb")
     if util.dimension_count(result) == 2:
         result = util.accumulate(result, util.add_to)
     return result