def main(args):
    result_dir = setup_run(args.run_name,
                           create_dirs=['checkpoints', 'samples'])
    setup_logging(result_dir / 'log.txt')

    logging.info(args)

    device = get_default_device(args.device)

    sample_dir = result_dir / 'samples'
    checkpoint_dir = result_dir / 'checkpoints'

    seq_length = 32

    dataset = ByteLevelTextDataset(args.dataset, seq_length)

    depth = math.log2(seq_length)

    assert int(depth) == depth

    depth = int(depth)

    vocab_size = dataset.vocab_size

    batches = DataLoader(dataset,
                         args.batch_size,
                         shuffle=True,
                         pin_memory=True,
                         num_workers=args.num_workers)

    inter_dim = 8

    embedding = nn.Embedding(vocab_size, inter_dim, max_norm=1.0).to(device)
    embedding.weight.requires_grad = False

    G = Generator(args.latent_size, [256, 128, 64, 32],
                  out_dim=inter_dim).to(device)
    D = UnetDiscriminator(32, depth=4, in_dim=inter_dim).to(device)

    # G.apply(apply_spectral_norm)
    # D.apply(apply_spectral_norm)

    G.apply(init_weights)
    D.apply(init_weights)

    G.train()
    D.train()

    (result_dir / 'G.txt').write_text(str(G))
    (result_dir / 'D.txt').write_text(str(D))

    if args.use_ema:
        G_shadow = copy.deepcopy(G)
        G_sample = G_shadow
        update_average(G_shadow, G, beta=0.0)
    else:
        G_sample = G

    G_orig = G

    if args.data_parallel:
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)

    G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.0, 0.999))
    D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.0, 0.999))

    z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device)

    #loss_f = RelativisticAverageHingeLoss(D)
    loss_f = WGAN_GP(D)

    def decode(embeds):
        flatten = embeds.transpose(1, 2)
        flatten = flatten.reshape(-1, flatten.size(-1))

        dist = (flatten.pow(2).sum(1, keepdim=True) -
                2 * flatten @ embedding.weight.T +
                embedding.weight.T.pow(2).sum(0, keepdim=True))

        _, samples = (-dist).max(1)
        return samples.view(samples_embeds.size(0), -1)

    try:
        global_step = 0
        for epoch in range(args.epochs):
            g_loss_sum = 0
            d_loss_sum = 0

            p_fake_g_sum = 0
            p_fake_l_sum = 0

            p_real_g_sum = 0
            p_real_l_sum = 0

            start_time = time.time()

            cur_step = 0

            for step, reals in enumerate(batches):
                reals = reals.to(device)
                reals_embed = embedding(reals).transpose(1, 2)
                reals_embed += torch.randn_like(reals_embed) * 0.01

                batch_size = reals.size(0)

                z = torch.randn(batch_size, args.latent_size, 1).to(device)

                # Optimize the discriminator
                fake_out = G(z)

                D_opt.zero_grad()

                d_loss, p_real_g, p_real_l, p_fake_g, p_fake_l = loss_d(
                    D, reals_embed, fake_out.detach())
                d_loss.backward()

                D_opt.step()

                # Optimize generator
                fake_out = G(z)

                G_opt.zero_grad()

                g_loss = loss_g(D, reals_embed, fake_out)
                g_loss.backward()

                G_opt.step()

                if args.use_ema:
                    update_average(G_shadow, G_orig, beta=0.999)

                g_loss_sum += float(g_loss)
                d_loss_sum += float(d_loss)

                p_fake_g_sum += float(p_fake_g)
                p_fake_l_sum += float(p_fake_l)

                p_real_g_sum += float(p_real_g)
                p_real_l_sum += float(p_real_l)

                if global_step % args.log_every == 0:
                    cur_step = min(step + 1, args.log_every)
                    batches_per_sec = cur_step / (time.time() - start_time)

                    logging.info(
                        f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] '
                        +
                        #f'grow_index: {current_grow_index}/{depth - 1}, ' +
                        f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, '
                        +
                        f'p_fake_g: {p_fake_g_sum / cur_step:.5f}, p_fake_l: {p_fake_l_sum / cur_step:.5f}, '
                        +
                        f'p_real_g: {p_real_g_sum / cur_step:.5f}, p_real_l: {p_real_l_sum / cur_step:.5f}, '
                        + f'batches/s: {batches_per_sec:02.2f}')

                    g_loss_sum = d_loss_sum = 0
                    p_fake_sum = p_real_sum = 0

                    p_fake_g_sum = 0
                    p_fake_l_sum = 0

                    p_real_g_sum = 0
                    p_real_l_sum = 0

                    start_time = time.time()

                if global_step % args.sample_every == 0:
                    samples_embeds = G_sample(z_sample)
                    samples = decode(samples_embeds)

                    reals_decode = decode(reals_embed)

                    (sample_dir / f'fakes_{global_step:06d}.txt').write_text(
                        '\n'.join(dataset.seq_to_text(samples)))
                    (sample_dir / f'reals_{global_step:06d}.txt').write_text(
                        '\n'.join(dataset.seq_to_text(reals_decode)))

                cur_step += 1
                global_step += 1

            torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth'))
            torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth'))
    except KeyboardInterrupt:
        pass
def main(args):
    result_dir = setup_run(args.run_name,
                           create_dirs=['checkpoints', 'samples'])
    setup_logging(result_dir / 'log.txt')

    logging.info(args)

    device = get_default_device(args.device)

    sample_dir = result_dir / 'samples'
    checkpoint_dir = result_dir / 'checkpoints'

    seq_length = 32

    from bpemb import BPEmb

    lines = Path(args.dataset).read_text().split('\n')[:2_500_000]

    bpe = BPEmb(lang='de', vs=50000, dim=100, add_pad_emb=True)

    data = torch.full((len(lines), seq_length),
                      bpe.vocab_size,
                      dtype=torch.long)

    for i, encoded_sample in enumerate(bpe.encode_ids_with_bos_eos(lines)):
        l = min(seq_length, len(encoded_sample))
        data[i, :l] = torch.tensor(encoded_sample, dtype=torch.long)[:l]

    #dataset = ByteLevelTextDataset(args.dataset, seq_length)

    depth = math.log2(seq_length)

    assert int(depth) == depth

    depth = int(depth)

    vocab_size = bpe.vocab_size + 1

    batches = DataLoader(data,
                         args.batch_size,
                         shuffle=True,
                         pin_memory=True,
                         num_workers=args.num_workers)

    inter_dim = bpe.dim

    embedding = nn.Embedding(vocab_size,
                             inter_dim,
                             _weight=torch.tensor(
                                 bpe.vectors, dtype=torch.float)).to(device)
    embedding.weight.requires_grad = False
    # embedding = nn.Embedding(vocab_size, inter_dim, max_norm=1.0).to(device)

    # spiegel model
    G = Generator(args.latent_size, [256, 256, 128, 64, 64],
                  out_dim=inter_dim).to(device)
    D = UnetDiscriminator(64, max_channel=256, depth=5,
                          in_dim=inter_dim).to(device)

    # G = Generator(args.latent_size, inter_dim, 256).to(device)
    # D = Discriminator(inter_dim, 256).to(device)

    G.apply(apply_spectral_norm)
    D.apply(apply_spectral_norm)

    G.apply(init_weights)
    D.apply(init_weights)

    G.train()
    D.train()

    (result_dir / 'G.txt').write_text(str(G))
    (result_dir / 'D.txt').write_text(str(D))

    if args.use_ema:
        G_shadow = copy.deepcopy(G)
        G_sample = G_shadow
        update_average(G_shadow, G, beta=0.0)
    else:
        G_sample = G

    G_orig = G
    D_orig = D

    if args.data_parallel:
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)

    D_params = list(D.parameters())
    #D_params += list(embedding.parameters())

    G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999))
    D_opt = torch.optim.Adam(D_params, lr=args.d_lr, betas=(0.5, 0.999))

    z_sample = torch.randn(seq_length, args.batch_size,
                           args.latent_size).to(device)

    #loss_f = RelativisticAverageHingeLoss(D)
    #loss_f = GANLoss(D)
    loss_f = WGAN_GP(D)

    def decode(embeds):
        flatten = embeds.transpose(1, 2)
        flatten = flatten.reshape(-1, flatten.size(-1))

        dist = (flatten.pow(2).sum(1, keepdim=True) -
                2 * flatten @ embedding.weight.T +
                embedding.weight.T.pow(2).sum(0, keepdim=True))

        _, ids = (-dist).max(1)
        ids = ids.view(embeds.size(0), -1)

        decoded = []
        for seq in ids:
            seq = list(seq.detach().cpu().numpy())
            seq = list(filter(lambda x: x != vocab_size - 1, seq))
            dec = bpe.decode_ids(np.array(seq))
            decoded.append(dec or '')

        return decoded

    try:
        global_step = 0
        for epoch in range(args.epochs):
            g_loss_sum = 0
            d_loss_sum = 0

            p_fake_sum = 0
            p_real_sum = 0

            start_time = time.time()

            cur_step = 0

            for step, reals in enumerate(batches):
                reals = reals.to(device)
                reals_embed = embedding(reals).permute(1, 0, 2)
                #reals_embed += torch.normal(0, 0.05, size=reals_embed.shape, device=device)

                batch_size = reals.size(0)

                z = torch.randn(seq_length, batch_size,
                                args.latent_size).to(device)

                # Optimize the discriminator
                fake_out = G(z)

                D_opt.zero_grad()

                d_loss, p_real, p_fake = loss_f.loss_d(reals_embed,
                                                       fake_out.detach())
                d_loss.backward()

                D_opt.step()

                # Optimize generator
                fake_out = G(z)

                G_opt.zero_grad()

                g_loss = loss_f.loss_g(reals_embed, fake_out)
                g_loss.backward()

                G_opt.step()

                if args.use_ema:
                    update_average(G_shadow, G_orig, beta=0.999)

                g_loss_sum += float(g_loss)
                d_loss_sum += float(d_loss)

                p_fake_sum += float(p_fake)
                p_real_sum += float(p_real)

                if global_step % args.log_every == 0:
                    cur_step = min(step + 1, args.log_every)
                    batches_per_sec = cur_step / (time.time() - start_time)

                    logging.info(
                        f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] '
                        +
                        #f'grow_index: {current_grow_index}/{depth - 1}, ' +
                        f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, '
                        +
                        f'p_fake_g: {p_fake_sum / cur_step:.5f}, p_fake_l: {p_real_sum / cur_step:.5f}, '
                        +
                        #f'G_attn_gamma: {G_attn_sum / cur_step:.2f}, D_attn_gamma: {D_attn_sum / cur_step:.2f}, '
                        f'batches/s: {batches_per_sec:02.2f}')

                    g_loss_sum = d_loss_sum = 0

                    p_fake_sum = 0
                    p_real_sum = 0

                    start_time = time.time()

                if global_step % args.sample_every == 0:
                    samples_embeds = G_sample(z_sample).permute(1, 2, 0)
                    samples = decode(samples_embeds)

                    reals_decode = decode(reals_embed.permute(1, 2, 0))

                    (sample_dir / f'fakes_{global_step:06d}.txt').write_text(
                        '\n'.join(samples))
                    (sample_dir / f'reals_{global_step:06d}.txt').write_text(
                        '\n'.join(reals_decode))

                    # (sample_dir / f'fakes_{global_step:06d}.txt').write_text('\n'.join(dataset.seq_to_text(samples)))
                    # (sample_dir / f'reals_{global_step:06d}.txt').write_text('\n'.join(dataset.seq_to_text(reals_decode)))

                cur_step += 1
                global_step += 1

            torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth'))
            torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth'))
    except KeyboardInterrupt:
        pass
def main(args):
    result_dir = setup_run(args.run_name,
                           create_dirs=['checkpoints', 'samples'])
    setup_logging(result_dir / 'log.txt')

    logging.info(args)

    device = get_default_device(args.device)

    sample_dir = result_dir / 'samples'
    checkpoint_dir = result_dir / 'checkpoints'

    decode = BPEDataset(args.original_dataset).seq_to_text

    vq_model = torch.load(args.vq_model).to(device)

    depth = vq_model.depth
    num_classes = vq_model.quantize[0].n_embed

    dataset = LMDBDataset(args.vq_dataset)
    batches = DataLoader(dataset,
                         args.batch_size,
                         shuffle=True,
                         pin_memory=True,
                         num_workers=args.num_workers)

    G = Generator(args.latent_size, [128, 128, 128, 128],
                  num_classes,
                  attn=args.attn).to(device)
    D = Discriminator([128, 128, 128, 128], num_classes,
                      attn=args.attn).to(device)

    if args.attn:
        D_gammas = list(
            map(attrgetter('gamma'),
                filter(lambda m: isinstance(m, SelfAttention), D.modules())))
        G_gammas = list(
            map(attrgetter('gamma'),
                filter(lambda m: isinstance(m, SelfAttention), G.modules())))

    #G.apply(init_weights)
    #G.apply(apply_spectral_norm)

    #D.apply(init_weights)
    #D.apply(apply_spectral_norm)

    G.train()
    D.train()

    (result_dir / 'G.txt').write_text(str(G))
    (result_dir / 'D.txt').write_text(str(D))

    if args.use_ema:
        G_shadow = copy.deepcopy(G)
        G_sample = G_shadow
        update_average(G_shadow, G, beta=0.0)
    else:
        G_sample = G

    G_orig = G

    if args.data_parallel:
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)

    G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.9))
    D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.9))

    z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device)

    #loss_f = RelativisticAverageHingeLoss(D)
    loss_f = WGAN_GP(D)

    try:
        global_step = 0
        for epoch in range(args.epochs):
            g_loss_sum = 0
            d_loss_sum = 0

            D_gammas_sum = OrderedDict()
            G_gammas_sum = OrderedDict()

            start_time = time.time()

            cur_step = 0

            for step, codes in enumerate(batches):
                codes = [code.to(device) for code in codes]
                codes_one_hot = [
                    F.one_hot(code, num_classes=num_classes).transpose(
                        1, 2).type(torch.float) for code in codes
                ]

                batch_size = codes[0].size(0)

                # code_noise.p = 0.3 * (1.0 - min(1.0, interpol * 2))
                # code = code_noise(code)

                z = torch.randn(batch_size, args.latent_size, 1).to(device)

                # Optimize the discriminator
                fake_logits = G(z)
                fake_probs = [
                    torch.softmax(logits, dim=1).detach()
                    for logits in fake_logits
                ]

                D_opt.zero_grad()

                loss_d = loss_f.d_loss(codes_one_hot, fake_probs[::-1])
                loss_d.backward()

                D_opt.step()

                # Optimize generator
                fake_logits = G(z)
                fake_probs = [
                    torch.softmax(logits, dim=1) for logits in fake_logits
                ]

                G_opt.zero_grad()

                loss_g = loss_f.g_loss(codes_one_hot, fake_probs[::-1])
                loss_g.backward()

                G_opt.step()

                if args.use_ema:
                    update_average(G_shadow, G_orig, beta=0.999)

                g_loss_sum += float(loss_g)
                d_loss_sum += float(loss_d)

                # p_fake_sum += float(p_fake)
                # p_real_sum += float(p_real)

                if args.attn:
                    for i, (d_gamma,
                            g_gamma) in enumerate(zip(D_gammas, G_gammas)):
                        D_gammas_sum[i] = D_gammas_sum.get(i, 0) + d_gamma
                        G_gammas_sum[i] = G_gammas_sum.get(i, 0) + g_gamma

                if global_step % args.log_every == 0:
                    cur_step = min(step + 1, args.log_every)
                    batches_per_sec = cur_step / (time.time() - start_time)

                    if args.attn:
                        D_gammas_avg = repr_list([
                            gamma / cur_step
                            for gamma in D_gammas_sum.values()
                        ])
                        G_gammas_avg = repr_list([
                            gamma / cur_step
                            for gamma in G_gammas_sum.values()
                        ])

                    logging.info(
                        f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] '
                        +
                        #f'grow_index: {current_grow_index}/{depth - 1}, ' +
                        f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, '
                        +
                        #f'p_fake: {p_fake_sum / cur_step:.5f}, p_real: {p_real_sum / cur_step:.5f}, ' +
                        (f'd_attn_gammas: [{D_gammas_avg}], g_attn_gammas: [{G_gammas_avg}], '
                         if args.attn else '') +
                        f'batches/s: {batches_per_sec:02.2f}')

                    g_loss_sum = d_loss_sum = 0

                    D_gammas_sum = OrderedDict()
                    G_gammas_sum = OrderedDict()

                    start_time = time.time()

                if global_step % args.sample_every == 0:
                    sample_codes = [
                        logits.argmax(1) for logits in G_sample(z_sample)
                    ]
                    sample_logits = [
                        vq_model.decode_code(sample_code, depth - 1 - i)
                        for i, sample_code in enumerate(sample_codes)
                    ]
                    samples_decoded = [
                        decode(logits.argmax(-1)) for logits in sample_logits
                    ]

                    reals_logits = [
                        vq_model.decode_code(code[:args.n_sample], i)
                        for i, code in enumerate(codes)
                    ]
                    reals_decoded = [
                        decode(logits.argmax(-1)) for logits in reals_logits
                    ]

                    (sample_dir /
                     f'fakes_{global_step:06d}.txt').write_text('\n\n'.join(
                         map(lambda g: '\n'.join(g), zip(*samples_decoded))))
                    (sample_dir / f'reals_{global_step:06d}.txt').write_text(
                        '\n\n'.join(
                            map(lambda g: '\n'.join(g), zip(*reals_decoded))))

                cur_step += 1
                global_step += 1

            torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth'))
            torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth'))
    except KeyboardInterrupt:
        pass
def main(args):
    result_dir = setup_run(args.run_name, create_dirs=['checkpoints', 'samples'])
    setup_logging(result_dir / 'log.txt')

    logging.info(args)

    device = get_default_device(args.device)

    sample_dir = result_dir / 'samples'
    checkpoint_dir = result_dir / 'checkpoints'

    decode = BPEDataset(args.original_dataset).seq_to_text

    vq_model = torch.load(args.vq_model).to(device)

    quantizers = vq_model.quantize
    for q in quantizers:
        q.eval()

    depth = vq_model.depth
    num_classes = vq_model.quantize[0].n_embed
    quant_dim = vq_model.quantize[0].dim

    dataset = LMDBDataset(args.vq_dataset)
    batches = DataLoader(dataset, args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers)

    if args.attn:
        G_attn = [False, False, True, False]
        D_attn = [False, True, False, False]
    else:
        G_attn = False
        D_attn = False


    G = Generator(args.latent_size, [512, 512, 256, 128], quant_dim, attn=G_attn).to(device)
    D = Discriminator([128, 256, 512, 512], quant_dim, attn=D_attn).to(device)

    if args.attn:
        D_gammas = list(map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), D.modules())))
        G_gammas = list(map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), G.modules())))

    #G.apply(init_weights)
    #G.apply(apply_spectral_norm)

    #D.apply(init_weights)
    #D.apply(apply_spectral_norm)

    G.train()
    D.train()

    (result_dir / 'G.txt').write_text(str(G))
    (result_dir / 'D.txt').write_text(str(D))

    if args.use_ema:
        G_shadow = copy.deepcopy(G)
        G_sample = G_shadow
        update_average(G_shadow, G, beta=0.0)
    else:
        G_sample = G

    G_orig = G

    if args.data_parallel:
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)

    G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999))
    D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999))

    z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device)

    loss_f = WGAN_GP(D)
    #loss_f = RelativisticAverageHingeLoss(D)

    try:
        global_step = 0
        for epoch in range(args.epochs):
            g_loss_sum = 0
            d_loss_sum = 0

            p_fake_sum = 0
            p_real_sum = 0

            vq_diffs_sum = [0] * depth

            D_gammas_sum = OrderedDict()
            G_gammas_sum = OrderedDict()

            start_time = time.time()

            cur_step = 0

            for step, reals in enumerate(batches):
                #reals_code = [code.to(device) for code in reals_code]
                #reals_embed = [q.embed_code(c).transpose(1, 2) for q, c in zip(quantizers, reals_code)]

                reals = [real.to(device) for real in reals]

                batch_size = reals[0].size(0)

                z = torch.randn(batch_size, args.latent_size, 1).to(device)

                # Optimize the discriminator
                fake_out = G(z)
                fake_out = [t.detach() for t in fake_out]
                # fake_embeds = [q(
                #     o.transpose(1, 2)
                # )[0].transpose(1, 2).detach() for q, o in zip(quantizers, fake_out)]

                D_opt.zero_grad()

                loss_d, p_fake, p_real = loss_f.d_loss(reals, fake_out[::-1])
                loss_d.backward()

                D_opt.step()

                # Optimize generator
                fake_out = G(z)
                _, vq_diffs, fake_codes = list(zip(*[q(
                    o.transpose(1, 2))
                    for q, o in zip(quantizers, fake_out)
                ]))
                #fake_out = [t.transpose(1, 2) for t in fake_out]

                G_opt.zero_grad()

                loss_g = loss_f.g_loss(reals, fake_out[::-1])
                #loss_g += 0.01 * sum(vq_diffs)
                loss_g.backward()

                G_opt.step()

                if args.use_ema:
                    update_average(G_shadow, G_orig, beta=0.999)

                g_loss_sum += float(loss_g)
                d_loss_sum += float(loss_d)

                p_fake_sum += float(p_fake)
                p_real_sum += float(p_real)

                vq_diffs_sum = [v_old + float(v_new) for v_old, v_new in zip(vq_diffs_sum, vq_diffs)]

                if args.attn:
                    for i, (d_gamma, g_gamma) in enumerate(zip(D_gammas, G_gammas)):
                        D_gammas_sum[i] = D_gammas_sum.get(i, 0) + d_gamma
                        G_gammas_sum[i] = G_gammas_sum.get(i, 0) + g_gamma

                if global_step % args.log_every == 0:
                    cur_step = min(step + 1, args.log_every)
                    batches_per_sec = cur_step / (time.time() - start_time)

                    if args.attn:
                        D_gammas_avg = repr_list([gamma / cur_step for gamma in D_gammas_sum.values()])
                        G_gammas_avg = repr_list([gamma / cur_step for gamma in G_gammas_sum.values()])

                    vq_diffs_avg = repr_list([diff / cur_step for diff in vq_diffs_sum])

                    logging.info(f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] ' +
                                 # f'grow_index: {current_grow_index}/{depth - 1}, ' +
                                 f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, ' +
                                 f'p_fake: {p_fake_sum / cur_step:.5f}, p_real: {p_real_sum / cur_step:.5f}, ' +
                                 (
                                     f'd_attn_gammas: [{D_gammas_avg}], g_attn_gammas: [{G_gammas_avg}], ' if args.attn else '') +
                                 f'vq_diffs: [{vq_diffs_avg}], ' +
                                 f'batches/s: {batches_per_sec:02.2f}')

                    g_loss_sum = d_loss_sum = 0
                    p_fake_sum = p_real_sum = 0

                    vq_diffs_sum = [0] * depth

                    D_gammas_sum = OrderedDict()
                    G_gammas_sum = OrderedDict()

                    start_time = time.time()

                if global_step % args.sample_every == 0:
                    sample_out = G_sample(z_sample)
                    sample_codes= [q(
                        o.transpose(1, 2)
                    )[2] for q, o in zip(quantizers, sample_out)]
                    sample_logits = [vq_model.decode_code(sample_code, depth - 1 - i) for i, sample_code in
                                     enumerate(sample_codes)]
                    samples_decoded = [decode(logits.argmax(-1)) for logits in sample_logits]

                    real_codes = [q(
                        o.transpose(1, 2)
                    )[2] for q, o in zip(quantizers, reals)]
                    reals_logits = [vq_model.decode_code(code[:args.n_sample], i) for i, code in enumerate(real_codes)]
                    reals_decoded = [decode(logits.argmax(-1)) for logits in reals_logits]

                    (sample_dir / f'fakes_{global_step:06d}.txt').write_text(
                        '\n\n'.join(map(lambda g: '\n'.join(g), zip(*samples_decoded))))
                    (sample_dir / f'reals_{global_step:06d}.txt').write_text(
                        '\n\n'.join(map(lambda g: '\n'.join(g), zip(*reals_decoded))))

                cur_step += 1
                global_step += 1

            torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth'))
            torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth'))
    except KeyboardInterrupt:
        pass
Example #5
0
def main(args):
    result_dir = setup_run(args.run_name,
                           create_dirs=['checkpoints', 'samples'])
    setup_logging(result_dir / 'log.txt')

    logging.info(args)

    device = get_default_device(args.device)

    sample_dir = result_dir / 'samples'
    checkpoint_dir = result_dir / 'checkpoints'

    dataset = BPEDataset(args.original_dataset)

    depth = math.log2(dataset.seq_length)

    assert int(depth) == depth

    depth = int(depth)

    vocab_size = dataset.vocab_size

    batches = DataLoader(dataset,
                         args.batch_size,
                         shuffle=True,
                         pin_memory=True,
                         num_workers=args.num_workers)

    inter_dim = 32

    extract_dims = [inter_dim] * 4 + [vocab_size]
    inject_dims = [vocab_size] + [inter_dim] * 4

    G = Generator(args.latent_size, [128, 128, 128, 128, 128],
                  extract_dims,
                  attn=args.attn).to(device)
    D = Discriminator([128, 128, 128, 128, 128], inject_dims,
                      attn=args.attn).to(device)

    T = TransformNetwork(vocab_size, [64, 64, 64, 64], inter_dim).to(device)

    if args.attn:
        D_gammas = list(
            map(attrgetter('gamma'),
                filter(lambda m: isinstance(m, SelfAttention), D.modules())))
        G_gammas = list(
            map(attrgetter('gamma'),
                filter(lambda m: isinstance(m, SelfAttention), G.modules())))

    #G.apply(init_weights)
    #G.apply(apply_spectral_norm)

    #D.apply(init_weights)
    #D.apply(apply_spectral_norm)

    G.train()
    D.train()

    (result_dir / 'G.txt').write_text(str(G))
    (result_dir / 'D.txt').write_text(str(D))

    if args.use_ema:
        G_shadow = copy.deepcopy(G)
        G_sample = G_shadow
        update_average(G_shadow, G, beta=0.0)
    else:
        G_sample = G

    G_orig = G

    if args.data_parallel:
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)

    G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.9))
    D_opt = torch.optim.Adam(list(D.parameters()) + list(T.parameters()),
                             lr=args.d_lr,
                             betas=(0.5, 0.9))

    z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device)

    #loss_f = RelativisticAverageHingeLoss(D)
    loss_f = WGAN_GP(D)

    try:
        global_step = 0
        for epoch in range(args.epochs):
            g_loss_sum = 0
            d_loss_sum = 0

            p_fake_sum = 0
            p_real_sum = 0

            D_gammas_sum = OrderedDict()
            G_gammas_sum = OrderedDict()

            start_time = time.time()

            cur_step = 0

            for step, reals in enumerate(batches):
                reals = reals.to(device)
                reals_one_hot = F.one_hot(reals,
                                          num_classes=vocab_size).transpose(
                                              1, 2).type(torch.float)
                batch_size = reals.size(0)

                reals_t = T(reals_one_hot)
                reals_input = [reals_one_hot] + reals_t

                z = torch.randn(batch_size, args.latent_size, 1).to(device)

                # Optimize the discriminator
                fake_out = G(z)
                fake_probs = torch.softmax(fake_out[-1], dim=1)
                fake_input = (fake_out[:-1] + [fake_probs])[::-1]
                fake_input = [t.detach() for t in fake_input]

                D_opt.zero_grad()

                loss_d, p_fake, p_real = loss_f.d_loss(reals_input, fake_input)
                loss_d.backward()

                D_opt.step()

                # Optimize generator
                fake_out = G(z)
                fake_probs = torch.softmax(fake_out[-1], dim=1)
                fake_input = (fake_out[:-1] + [fake_probs])[::-1]

                G_opt.zero_grad()

                reals_input = [t.detach() for t in reals_input]

                loss_g = loss_f.g_loss(reals_input, fake_input)
                loss_g.backward()

                G_opt.step()

                if args.use_ema:
                    update_average(G_shadow, G_orig, beta=0.999)

                g_loss_sum += float(loss_g)
                d_loss_sum += float(loss_d)

                p_fake_sum += float(p_fake)
                p_real_sum += float(p_real)

                if args.attn:
                    for i, (d_gamma,
                            g_gamma) in enumerate(zip(D_gammas, G_gammas)):
                        D_gammas_sum[i] = D_gammas_sum.get(i, 0) + d_gamma
                        G_gammas_sum[i] = G_gammas_sum.get(i, 0) + g_gamma

                if global_step % args.log_every == 0:
                    cur_step = min(step + 1, args.log_every)
                    batches_per_sec = cur_step / (time.time() - start_time)

                    if args.attn:
                        D_gammas_avg = repr_list([
                            gamma / cur_step
                            for gamma in D_gammas_sum.values()
                        ])
                        G_gammas_avg = repr_list([
                            gamma / cur_step
                            for gamma in G_gammas_sum.values()
                        ])

                    logging.info(
                        f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] '
                        +
                        #f'grow_index: {current_grow_index}/{depth - 1}, ' +
                        f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, '
                        +
                        f'p_fake: {p_fake_sum / cur_step:.5f}, p_real: {p_real_sum / cur_step:.5f}, '
                        +
                        (f'd_attn_gammas: [{D_gammas_avg}], g_attn_gammas: [{G_gammas_avg}], '
                         if args.attn else '') +
                        f'batches/s: {batches_per_sec:02.2f}')

                    g_loss_sum = d_loss_sum = 0
                    p_fake_sum = p_real_sum = 0

                    D_gammas_sum = OrderedDict()
                    G_gammas_sum = OrderedDict()

                    start_time = time.time()

                if global_step % args.sample_every == 0:

                    samples_decoded = dataset.seq_to_text(
                        G_sample(z_sample)[-1].argmax(1))
                    reals_decoded = dataset.seq_to_text(reals[:args.n_sample])

                    (sample_dir / f'fakes_{global_step:06d}.txt').write_text(
                        '\n'.join(samples_decoded))
                    (sample_dir / f'reals_{global_step:06d}.txt').write_text(
                        '\n'.join(reals_decoded))

                cur_step += 1
                global_step += 1

            torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth'))
            torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth'))
    except KeyboardInterrupt:
        pass
def main(args):
    result_dir = setup_run(args.run_name,
                           create_dirs=['checkpoints', 'samples'])
    setup_logging(result_dir / 'log.txt')

    logging.info(args)

    device = get_default_device(args.device)

    sample_dir = result_dir / 'samples'
    checkpoint_dir = result_dir / 'checkpoints'

    seq_length = 32

    dataset = ByteLevelTextDataset(args.dataset, seq_len=seq_length)

    vocab_size = dataset.vocab_size

    batches = DataLoader(dataset,
                         args.batch_size,
                         shuffle=True,
                         pin_memory=True,
                         num_workers=args.num_workers)

    inter_dim = vocab_size

    # spiegel model
    # G = Generator(args.latent_size, [256, 256, 128, 64, 64], out_dim=inter_dim).to(device)
    # D = UnetDiscriminator(64, max_channel=256, depth=5, in_dim=inter_dim).to(device)

    G = Generator(args.latent_size, [512, 512, 512, 512],
                  out_dim=vocab_size).to(device)
    D = UnetDiscriminator(512, max_channel=512, depth=4,
                          in_dim=vocab_size).to(device)

    G.apply(apply_spectral_norm)
    D.apply(apply_spectral_norm)

    G.apply(init_weights)
    D.apply(init_weights)

    G.train()
    D.train()

    (result_dir / 'G.txt').write_text(str(G))
    (result_dir / 'D.txt').write_text(str(D))

    if args.use_ema:
        G_shadow = copy.deepcopy(G)
        G_sample = G_shadow
        update_average(G_shadow, G, beta=0.0)
    else:
        G_sample = G

    G_orig = G
    D_orig = D

    if args.data_parallel:
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)

    G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999))
    D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999))

    z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device)

    #loss_f = RelativisticAverageHingeLoss(D)
    #loss_f = GANLoss(D)
    loss_f = WGAN_GP(D)

    try:
        global_step = 0
        for epoch in range(args.epochs):
            g_loss_sum = 0
            d_loss_sum = 0

            p_fake_g_sum = 0
            p_fake_l_sum = 0

            p_real_g_sum = 0
            p_real_l_sum = 0

            G_attn_sum = 0
            D_attn_sum = 0

            start_time = time.time()

            cur_step = 0

            for step, reals in enumerate(batches):
                reals = reals.to(device)
                reals_one_hot = F.one_hot(reals,
                                          num_classes=vocab_size).transpose(
                                              1, 2).to(torch.float)

                batch_size = reals.size(0)

                z = torch.randn(batch_size, args.latent_size, 1).to(device)

                tau = 0.1

                # Optimize the discriminator
                fake_out = G(z)
                fake_out = torch.softmax(fake_out / tau, dim=1)

                D_opt.zero_grad()

                d_loss, p_real_g, p_real_l, p_fake_g, p_fake_l = loss_f.loss_d(
                    reals_one_hot, fake_out.detach())
                d_loss.backward()

                D_opt.step()

                # Optimize generator
                fake_out = G(z)
                fake_out = torch.softmax(fake_out / tau, dim=1)

                G_opt.zero_grad()

                g_loss = loss_f.loss_g(reals_one_hot, fake_out)
                g_loss.backward()

                G_opt.step()

                if args.use_ema:
                    update_average(G_shadow, G_orig, beta=0.999)

                g_loss_sum += float(g_loss)
                d_loss_sum += float(d_loss)

                p_fake_g_sum += float(p_fake_g)
                p_fake_l_sum += float(p_fake_l)

                p_real_g_sum += float(p_real_g)
                p_real_l_sum += float(p_real_l)

                #G_attn_sum += float(G_orig.attn.gamma)
                #D_attn_sum += float(D_orig.attn.gamma)

                if global_step % args.log_every == 0:
                    cur_step = min(step + 1, args.log_every)
                    batches_per_sec = cur_step / (time.time() - start_time)

                    logging.info(
                        f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] '
                        +
                        #f'grow_index: {current_grow_index}/{depth - 1}, ' +
                        f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, '
                        +
                        f'p_fake_g: {p_fake_g_sum / cur_step:.5f}, p_fake_l: {p_fake_l_sum / cur_step:.5f}, '
                        +
                        f'p_real_g: {p_real_g_sum / cur_step:.5f}, p_real_l: {p_real_l_sum / cur_step:.5f}, '
                        +
                        #f'G_attn_gamma: {G_attn_sum / cur_step:.2f}, D_attn_gamma: {D_attn_sum / cur_step:.2f}, '
                        f'batches/s: {batches_per_sec:02.2f}')

                    g_loss_sum = d_loss_sum = 0

                    p_fake_g_sum = 0
                    p_fake_l_sum = 0

                    p_real_g_sum = 0
                    p_real_l_sum = 0

                    G_attn_sum = 0
                    D_attn_sum = 0

                    start_time = time.time()

                if global_step % args.sample_every == 0:
                    samples = G_sample(z_sample).argmax(1)

                    (sample_dir / f'fakes_{global_step:06d}.txt').write_text(
                        '\n'.join(dataset.seq_to_text(samples)))
                    (sample_dir / f'reals_{global_step:06d}.txt').write_text(
                        '\n'.join(dataset.seq_to_text(reals[:args.n_sample])))

                    # (sample_dir / f'fakes_{global_step:06d}.txt').write_text('\n'.join(dataset.seq_to_text(samples)))
                    # (sample_dir / f'reals_{global_step:06d}.txt').write_text('\n'.join(dataset.seq_to_text(reals_decode)))

                cur_step += 1
                global_step += 1

            torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth'))
            torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth'))
    except KeyboardInterrupt:
        pass
def main(args):
    result_dir = setup_run(args.run_name,
                           create_dirs=['checkpoints', 'samples'])
    setup_logging(result_dir / 'log.txt')

    logging.info(args)

    device = get_default_device(args.device)

    sample_dir = result_dir / 'samples'
    checkpoint_dir = result_dir / 'checkpoints'

    decode = BPEDataset(args.original_dataset).seq_to_text

    vq_model = torch.load(args.vq_model).to(device)

    depth = vq_model.depth
    num_classes = vq_model.quantize[0].n_embed

    dataset = LMDBDataset(args.vq_dataset)
    batches = DataLoader(dataset,
                         args.batch_size,
                         shuffle=True,
                         pin_memory=True,
                         num_workers=args.num_workers)

    G = Generator(args.latent_size, [256, 256, 256, 256],
                  num_classes,
                  attn=args.attn).to(device)
    D = Discriminator([256, 256, 256, 256],
                      num_classes,
                      use_embeddings=False,
                      attn=args.attn).to(device)

    if args.attn:
        D_gammas = list(
            map(attrgetter('gamma'),
                filter(lambda m: isinstance(m, SelfAttention), D.modules())))
        G_gammas = list(
            map(attrgetter('gamma'),
                filter(lambda m: isinstance(m, SelfAttention), G.modules())))

    G.apply(init_weights)
    G.apply(apply_spectral_norm)

    D.apply(init_weights)
    D.apply(apply_spectral_norm)

    G.train()
    D.train()

    (result_dir / 'G.txt').write_text(str(G))
    (result_dir / 'D.txt').write_text(str(D))

    if args.use_ema:
        G_shadow = copy.deepcopy(G)
        G_sample = G_shadow
        update_average(G_shadow, G, beta=0.0)
    else:
        G_sample = G

    G_orig = G

    if args.data_parallel:
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)

    G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999))
    D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999))

    z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device)

    code_noise = CategoricalNoise(num_classes, 0.0)

    try:
        global_step = 0
        for epoch in range(args.epochs):
            g_loss_sum = 0
            d_loss_sum = 0

            p_fake_sum = 0
            p_real_sum = 0

            D_gammas_sum = OrderedDict()
            G_gammas_sum = OrderedDict()

            start_time = time.time()

            cur_step = 0

            for step, codes in enumerate(batches):
                #current_grow_index = min(global_step // args.steps_per_stage, depth - 1)
                current_grow_index = 3
                interpol = (global_step %
                            args.steps_per_stage) / args.steps_per_stage

                code = codes[-(current_grow_index + 1)]
                code = code.to(device)

                code_noise.p = 0.3 * (1.0 - min(1.0, interpol * 2))
                #code = code_noise(code)

                z = torch.randn(code.size(0), args.latent_size, 1).to(device)

                fake_logits = G(z, extract_at_grow_index=current_grow_index)

                G_opt.zero_grad()
                D_opt.zero_grad()

                loss_d, loss_g, p_fake, p_real = multinomial_bgan_loss(
                    D,
                    fake_logits,
                    code,
                    n_samples=args.n_mc_samples,
                    tau=args.mc_sample_tau)

                torch.autograd.backward([loss_d, loss_g])

                G_opt.step()
                D_opt.step()

                if args.use_ema:
                    update_average(G_shadow, G_orig, beta=0.999)

                g_loss_sum += float(loss_g)
                d_loss_sum += float(loss_d)

                p_fake_sum += float(p_fake)
                p_real_sum += float(p_real)

                if args.attn:
                    for i, (d_gamma,
                            g_gamma) in enumerate(zip(D_gammas, G_gammas)):
                        D_gammas_sum[i] = D_gammas_sum.get(i, 0) + d_gamma
                        G_gammas_sum[i] = G_gammas_sum.get(i, 0) + g_gamma

                if global_step % args.log_every == 0:
                    cur_step = min(step + 1, args.log_every)
                    batches_per_sec = cur_step / (time.time() - start_time)

                    if args.attn:
                        D_gammas_avg = repr_list([
                            gamma / cur_step
                            for gamma in D_gammas_sum.values()
                        ])
                        G_gammas_avg = repr_list([
                            gamma / cur_step
                            for gamma in G_gammas_sum.values()
                        ])

                    logging.info(
                        f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] '
                        + f'grow_index: {current_grow_index}/{depth - 1}, ' +
                        f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, '
                        +
                        f'p_fake: {p_fake_sum / cur_step:.5f}, p_real: {p_real_sum / cur_step:.5f}, '
                        +
                        (f'd_attn_gammas: [{D_gammas_avg}], g_attn_gammas: [{G_gammas_avg}], '
                         if args.attn else '') +
                        f'batches/s: {batches_per_sec:02.2f}, code_noise_p: {code_noise.p:.2f}'
                    )

                    g_loss_sum = d_loss_sum = 0
                    p_fake_sum = p_real_sum = 0

                    D_gammas_sum = OrderedDict()
                    G_gammas_sum = OrderedDict()

                    start_time = time.time()

                if global_step % args.sample_every == 0:
                    current_depth = depth - 1 - current_grow_index

                    samples_codes = G_sample(
                        z_sample,
                        extract_at_grow_index=current_grow_index).argmax(1)
                    samples_logits = vq_model.decode_code(
                        samples_codes, current_depth)
                    samples_decoded = decode(samples_logits.argmax(-1))

                    reals_logits = vq_model.decode_code(
                        code[:args.n_sample], current_depth)
                    reals_decoded = decode(reals_logits.argmax(-1))

                    (sample_dir /
                     f'fakes_{current_grow_index}_{global_step:06d}.txt'
                     ).write_text('\n'.join(samples_decoded))
                    (sample_dir /
                     f'reals_{current_grow_index}_{global_step:06d}.txt'
                     ).write_text('\n'.join(reals_decoded))

                cur_step += 1
                global_step += 1

            torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth'))
            torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth'))
    except KeyboardInterrupt:
        pass