def main(args): result_dir = setup_run(args.run_name, create_dirs=['checkpoints', 'samples']) setup_logging(result_dir / 'log.txt') logging.info(args) device = get_default_device(args.device) sample_dir = result_dir / 'samples' checkpoint_dir = result_dir / 'checkpoints' seq_length = 32 dataset = ByteLevelTextDataset(args.dataset, seq_length) depth = math.log2(seq_length) assert int(depth) == depth depth = int(depth) vocab_size = dataset.vocab_size batches = DataLoader(dataset, args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) inter_dim = 8 embedding = nn.Embedding(vocab_size, inter_dim, max_norm=1.0).to(device) embedding.weight.requires_grad = False G = Generator(args.latent_size, [256, 128, 64, 32], out_dim=inter_dim).to(device) D = UnetDiscriminator(32, depth=4, in_dim=inter_dim).to(device) # G.apply(apply_spectral_norm) # D.apply(apply_spectral_norm) G.apply(init_weights) D.apply(init_weights) G.train() D.train() (result_dir / 'G.txt').write_text(str(G)) (result_dir / 'D.txt').write_text(str(D)) if args.use_ema: G_shadow = copy.deepcopy(G) G_sample = G_shadow update_average(G_shadow, G, beta=0.0) else: G_sample = G G_orig = G if args.data_parallel: G = nn.DataParallel(G) D = nn.DataParallel(D) G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.0, 0.999)) D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.0, 0.999)) z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device) #loss_f = RelativisticAverageHingeLoss(D) loss_f = WGAN_GP(D) def decode(embeds): flatten = embeds.transpose(1, 2) flatten = flatten.reshape(-1, flatten.size(-1)) dist = (flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ embedding.weight.T + embedding.weight.T.pow(2).sum(0, keepdim=True)) _, samples = (-dist).max(1) return samples.view(samples_embeds.size(0), -1) try: global_step = 0 for epoch in range(args.epochs): g_loss_sum = 0 d_loss_sum = 0 p_fake_g_sum = 0 p_fake_l_sum = 0 p_real_g_sum = 0 p_real_l_sum = 0 start_time = time.time() cur_step = 0 for step, reals in enumerate(batches): reals = reals.to(device) reals_embed = embedding(reals).transpose(1, 2) reals_embed += torch.randn_like(reals_embed) * 0.01 batch_size = reals.size(0) z = torch.randn(batch_size, args.latent_size, 1).to(device) # Optimize the discriminator fake_out = G(z) D_opt.zero_grad() d_loss, p_real_g, p_real_l, p_fake_g, p_fake_l = loss_d( D, reals_embed, fake_out.detach()) d_loss.backward() D_opt.step() # Optimize generator fake_out = G(z) G_opt.zero_grad() g_loss = loss_g(D, reals_embed, fake_out) g_loss.backward() G_opt.step() if args.use_ema: update_average(G_shadow, G_orig, beta=0.999) g_loss_sum += float(g_loss) d_loss_sum += float(d_loss) p_fake_g_sum += float(p_fake_g) p_fake_l_sum += float(p_fake_l) p_real_g_sum += float(p_real_g) p_real_l_sum += float(p_real_l) if global_step % args.log_every == 0: cur_step = min(step + 1, args.log_every) batches_per_sec = cur_step / (time.time() - start_time) logging.info( f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] ' + #f'grow_index: {current_grow_index}/{depth - 1}, ' + f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, ' + f'p_fake_g: {p_fake_g_sum / cur_step:.5f}, p_fake_l: {p_fake_l_sum / cur_step:.5f}, ' + f'p_real_g: {p_real_g_sum / cur_step:.5f}, p_real_l: {p_real_l_sum / cur_step:.5f}, ' + f'batches/s: {batches_per_sec:02.2f}') g_loss_sum = d_loss_sum = 0 p_fake_sum = p_real_sum = 0 p_fake_g_sum = 0 p_fake_l_sum = 0 p_real_g_sum = 0 p_real_l_sum = 0 start_time = time.time() if global_step % args.sample_every == 0: samples_embeds = G_sample(z_sample) samples = decode(samples_embeds) reals_decode = decode(reals_embed) (sample_dir / f'fakes_{global_step:06d}.txt').write_text( '\n'.join(dataset.seq_to_text(samples))) (sample_dir / f'reals_{global_step:06d}.txt').write_text( '\n'.join(dataset.seq_to_text(reals_decode))) cur_step += 1 global_step += 1 torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth')) torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth')) except KeyboardInterrupt: pass
def main(args): result_dir = setup_run(args.run_name, create_dirs=['checkpoints', 'samples']) setup_logging(result_dir / 'log.txt') logging.info(args) device = get_default_device(args.device) sample_dir = result_dir / 'samples' checkpoint_dir = result_dir / 'checkpoints' seq_length = 32 from bpemb import BPEmb lines = Path(args.dataset).read_text().split('\n')[:2_500_000] bpe = BPEmb(lang='de', vs=50000, dim=100, add_pad_emb=True) data = torch.full((len(lines), seq_length), bpe.vocab_size, dtype=torch.long) for i, encoded_sample in enumerate(bpe.encode_ids_with_bos_eos(lines)): l = min(seq_length, len(encoded_sample)) data[i, :l] = torch.tensor(encoded_sample, dtype=torch.long)[:l] #dataset = ByteLevelTextDataset(args.dataset, seq_length) depth = math.log2(seq_length) assert int(depth) == depth depth = int(depth) vocab_size = bpe.vocab_size + 1 batches = DataLoader(data, args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) inter_dim = bpe.dim embedding = nn.Embedding(vocab_size, inter_dim, _weight=torch.tensor( bpe.vectors, dtype=torch.float)).to(device) embedding.weight.requires_grad = False # embedding = nn.Embedding(vocab_size, inter_dim, max_norm=1.0).to(device) # spiegel model G = Generator(args.latent_size, [256, 256, 128, 64, 64], out_dim=inter_dim).to(device) D = UnetDiscriminator(64, max_channel=256, depth=5, in_dim=inter_dim).to(device) # G = Generator(args.latent_size, inter_dim, 256).to(device) # D = Discriminator(inter_dim, 256).to(device) G.apply(apply_spectral_norm) D.apply(apply_spectral_norm) G.apply(init_weights) D.apply(init_weights) G.train() D.train() (result_dir / 'G.txt').write_text(str(G)) (result_dir / 'D.txt').write_text(str(D)) if args.use_ema: G_shadow = copy.deepcopy(G) G_sample = G_shadow update_average(G_shadow, G, beta=0.0) else: G_sample = G G_orig = G D_orig = D if args.data_parallel: G = nn.DataParallel(G) D = nn.DataParallel(D) D_params = list(D.parameters()) #D_params += list(embedding.parameters()) G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999)) D_opt = torch.optim.Adam(D_params, lr=args.d_lr, betas=(0.5, 0.999)) z_sample = torch.randn(seq_length, args.batch_size, args.latent_size).to(device) #loss_f = RelativisticAverageHingeLoss(D) #loss_f = GANLoss(D) loss_f = WGAN_GP(D) def decode(embeds): flatten = embeds.transpose(1, 2) flatten = flatten.reshape(-1, flatten.size(-1)) dist = (flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ embedding.weight.T + embedding.weight.T.pow(2).sum(0, keepdim=True)) _, ids = (-dist).max(1) ids = ids.view(embeds.size(0), -1) decoded = [] for seq in ids: seq = list(seq.detach().cpu().numpy()) seq = list(filter(lambda x: x != vocab_size - 1, seq)) dec = bpe.decode_ids(np.array(seq)) decoded.append(dec or '') return decoded try: global_step = 0 for epoch in range(args.epochs): g_loss_sum = 0 d_loss_sum = 0 p_fake_sum = 0 p_real_sum = 0 start_time = time.time() cur_step = 0 for step, reals in enumerate(batches): reals = reals.to(device) reals_embed = embedding(reals).permute(1, 0, 2) #reals_embed += torch.normal(0, 0.05, size=reals_embed.shape, device=device) batch_size = reals.size(0) z = torch.randn(seq_length, batch_size, args.latent_size).to(device) # Optimize the discriminator fake_out = G(z) D_opt.zero_grad() d_loss, p_real, p_fake = loss_f.loss_d(reals_embed, fake_out.detach()) d_loss.backward() D_opt.step() # Optimize generator fake_out = G(z) G_opt.zero_grad() g_loss = loss_f.loss_g(reals_embed, fake_out) g_loss.backward() G_opt.step() if args.use_ema: update_average(G_shadow, G_orig, beta=0.999) g_loss_sum += float(g_loss) d_loss_sum += float(d_loss) p_fake_sum += float(p_fake) p_real_sum += float(p_real) if global_step % args.log_every == 0: cur_step = min(step + 1, args.log_every) batches_per_sec = cur_step / (time.time() - start_time) logging.info( f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] ' + #f'grow_index: {current_grow_index}/{depth - 1}, ' + f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, ' + f'p_fake_g: {p_fake_sum / cur_step:.5f}, p_fake_l: {p_real_sum / cur_step:.5f}, ' + #f'G_attn_gamma: {G_attn_sum / cur_step:.2f}, D_attn_gamma: {D_attn_sum / cur_step:.2f}, ' f'batches/s: {batches_per_sec:02.2f}') g_loss_sum = d_loss_sum = 0 p_fake_sum = 0 p_real_sum = 0 start_time = time.time() if global_step % args.sample_every == 0: samples_embeds = G_sample(z_sample).permute(1, 2, 0) samples = decode(samples_embeds) reals_decode = decode(reals_embed.permute(1, 2, 0)) (sample_dir / f'fakes_{global_step:06d}.txt').write_text( '\n'.join(samples)) (sample_dir / f'reals_{global_step:06d}.txt').write_text( '\n'.join(reals_decode)) # (sample_dir / f'fakes_{global_step:06d}.txt').write_text('\n'.join(dataset.seq_to_text(samples))) # (sample_dir / f'reals_{global_step:06d}.txt').write_text('\n'.join(dataset.seq_to_text(reals_decode))) cur_step += 1 global_step += 1 torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth')) torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth')) except KeyboardInterrupt: pass
def main(args): result_dir = setup_run(args.run_name, create_dirs=['checkpoints', 'samples']) setup_logging(result_dir / 'log.txt') logging.info(args) device = get_default_device(args.device) sample_dir = result_dir / 'samples' checkpoint_dir = result_dir / 'checkpoints' decode = BPEDataset(args.original_dataset).seq_to_text vq_model = torch.load(args.vq_model).to(device) depth = vq_model.depth num_classes = vq_model.quantize[0].n_embed dataset = LMDBDataset(args.vq_dataset) batches = DataLoader(dataset, args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) G = Generator(args.latent_size, [128, 128, 128, 128], num_classes, attn=args.attn).to(device) D = Discriminator([128, 128, 128, 128], num_classes, attn=args.attn).to(device) if args.attn: D_gammas = list( map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), D.modules()))) G_gammas = list( map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), G.modules()))) #G.apply(init_weights) #G.apply(apply_spectral_norm) #D.apply(init_weights) #D.apply(apply_spectral_norm) G.train() D.train() (result_dir / 'G.txt').write_text(str(G)) (result_dir / 'D.txt').write_text(str(D)) if args.use_ema: G_shadow = copy.deepcopy(G) G_sample = G_shadow update_average(G_shadow, G, beta=0.0) else: G_sample = G G_orig = G if args.data_parallel: G = nn.DataParallel(G) D = nn.DataParallel(D) G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.9)) D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.9)) z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device) #loss_f = RelativisticAverageHingeLoss(D) loss_f = WGAN_GP(D) try: global_step = 0 for epoch in range(args.epochs): g_loss_sum = 0 d_loss_sum = 0 D_gammas_sum = OrderedDict() G_gammas_sum = OrderedDict() start_time = time.time() cur_step = 0 for step, codes in enumerate(batches): codes = [code.to(device) for code in codes] codes_one_hot = [ F.one_hot(code, num_classes=num_classes).transpose( 1, 2).type(torch.float) for code in codes ] batch_size = codes[0].size(0) # code_noise.p = 0.3 * (1.0 - min(1.0, interpol * 2)) # code = code_noise(code) z = torch.randn(batch_size, args.latent_size, 1).to(device) # Optimize the discriminator fake_logits = G(z) fake_probs = [ torch.softmax(logits, dim=1).detach() for logits in fake_logits ] D_opt.zero_grad() loss_d = loss_f.d_loss(codes_one_hot, fake_probs[::-1]) loss_d.backward() D_opt.step() # Optimize generator fake_logits = G(z) fake_probs = [ torch.softmax(logits, dim=1) for logits in fake_logits ] G_opt.zero_grad() loss_g = loss_f.g_loss(codes_one_hot, fake_probs[::-1]) loss_g.backward() G_opt.step() if args.use_ema: update_average(G_shadow, G_orig, beta=0.999) g_loss_sum += float(loss_g) d_loss_sum += float(loss_d) # p_fake_sum += float(p_fake) # p_real_sum += float(p_real) if args.attn: for i, (d_gamma, g_gamma) in enumerate(zip(D_gammas, G_gammas)): D_gammas_sum[i] = D_gammas_sum.get(i, 0) + d_gamma G_gammas_sum[i] = G_gammas_sum.get(i, 0) + g_gamma if global_step % args.log_every == 0: cur_step = min(step + 1, args.log_every) batches_per_sec = cur_step / (time.time() - start_time) if args.attn: D_gammas_avg = repr_list([ gamma / cur_step for gamma in D_gammas_sum.values() ]) G_gammas_avg = repr_list([ gamma / cur_step for gamma in G_gammas_sum.values() ]) logging.info( f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] ' + #f'grow_index: {current_grow_index}/{depth - 1}, ' + f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, ' + #f'p_fake: {p_fake_sum / cur_step:.5f}, p_real: {p_real_sum / cur_step:.5f}, ' + (f'd_attn_gammas: [{D_gammas_avg}], g_attn_gammas: [{G_gammas_avg}], ' if args.attn else '') + f'batches/s: {batches_per_sec:02.2f}') g_loss_sum = d_loss_sum = 0 D_gammas_sum = OrderedDict() G_gammas_sum = OrderedDict() start_time = time.time() if global_step % args.sample_every == 0: sample_codes = [ logits.argmax(1) for logits in G_sample(z_sample) ] sample_logits = [ vq_model.decode_code(sample_code, depth - 1 - i) for i, sample_code in enumerate(sample_codes) ] samples_decoded = [ decode(logits.argmax(-1)) for logits in sample_logits ] reals_logits = [ vq_model.decode_code(code[:args.n_sample], i) for i, code in enumerate(codes) ] reals_decoded = [ decode(logits.argmax(-1)) for logits in reals_logits ] (sample_dir / f'fakes_{global_step:06d}.txt').write_text('\n\n'.join( map(lambda g: '\n'.join(g), zip(*samples_decoded)))) (sample_dir / f'reals_{global_step:06d}.txt').write_text( '\n\n'.join( map(lambda g: '\n'.join(g), zip(*reals_decoded)))) cur_step += 1 global_step += 1 torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth')) torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth')) except KeyboardInterrupt: pass
def main(args): result_dir = setup_run(args.run_name, create_dirs=['checkpoints', 'samples']) setup_logging(result_dir / 'log.txt') logging.info(args) device = get_default_device(args.device) sample_dir = result_dir / 'samples' checkpoint_dir = result_dir / 'checkpoints' decode = BPEDataset(args.original_dataset).seq_to_text vq_model = torch.load(args.vq_model).to(device) quantizers = vq_model.quantize for q in quantizers: q.eval() depth = vq_model.depth num_classes = vq_model.quantize[0].n_embed quant_dim = vq_model.quantize[0].dim dataset = LMDBDataset(args.vq_dataset) batches = DataLoader(dataset, args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) if args.attn: G_attn = [False, False, True, False] D_attn = [False, True, False, False] else: G_attn = False D_attn = False G = Generator(args.latent_size, [512, 512, 256, 128], quant_dim, attn=G_attn).to(device) D = Discriminator([128, 256, 512, 512], quant_dim, attn=D_attn).to(device) if args.attn: D_gammas = list(map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), D.modules()))) G_gammas = list(map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), G.modules()))) #G.apply(init_weights) #G.apply(apply_spectral_norm) #D.apply(init_weights) #D.apply(apply_spectral_norm) G.train() D.train() (result_dir / 'G.txt').write_text(str(G)) (result_dir / 'D.txt').write_text(str(D)) if args.use_ema: G_shadow = copy.deepcopy(G) G_sample = G_shadow update_average(G_shadow, G, beta=0.0) else: G_sample = G G_orig = G if args.data_parallel: G = nn.DataParallel(G) D = nn.DataParallel(D) G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999)) D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999)) z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device) loss_f = WGAN_GP(D) #loss_f = RelativisticAverageHingeLoss(D) try: global_step = 0 for epoch in range(args.epochs): g_loss_sum = 0 d_loss_sum = 0 p_fake_sum = 0 p_real_sum = 0 vq_diffs_sum = [0] * depth D_gammas_sum = OrderedDict() G_gammas_sum = OrderedDict() start_time = time.time() cur_step = 0 for step, reals in enumerate(batches): #reals_code = [code.to(device) for code in reals_code] #reals_embed = [q.embed_code(c).transpose(1, 2) for q, c in zip(quantizers, reals_code)] reals = [real.to(device) for real in reals] batch_size = reals[0].size(0) z = torch.randn(batch_size, args.latent_size, 1).to(device) # Optimize the discriminator fake_out = G(z) fake_out = [t.detach() for t in fake_out] # fake_embeds = [q( # o.transpose(1, 2) # )[0].transpose(1, 2).detach() for q, o in zip(quantizers, fake_out)] D_opt.zero_grad() loss_d, p_fake, p_real = loss_f.d_loss(reals, fake_out[::-1]) loss_d.backward() D_opt.step() # Optimize generator fake_out = G(z) _, vq_diffs, fake_codes = list(zip(*[q( o.transpose(1, 2)) for q, o in zip(quantizers, fake_out) ])) #fake_out = [t.transpose(1, 2) for t in fake_out] G_opt.zero_grad() loss_g = loss_f.g_loss(reals, fake_out[::-1]) #loss_g += 0.01 * sum(vq_diffs) loss_g.backward() G_opt.step() if args.use_ema: update_average(G_shadow, G_orig, beta=0.999) g_loss_sum += float(loss_g) d_loss_sum += float(loss_d) p_fake_sum += float(p_fake) p_real_sum += float(p_real) vq_diffs_sum = [v_old + float(v_new) for v_old, v_new in zip(vq_diffs_sum, vq_diffs)] if args.attn: for i, (d_gamma, g_gamma) in enumerate(zip(D_gammas, G_gammas)): D_gammas_sum[i] = D_gammas_sum.get(i, 0) + d_gamma G_gammas_sum[i] = G_gammas_sum.get(i, 0) + g_gamma if global_step % args.log_every == 0: cur_step = min(step + 1, args.log_every) batches_per_sec = cur_step / (time.time() - start_time) if args.attn: D_gammas_avg = repr_list([gamma / cur_step for gamma in D_gammas_sum.values()]) G_gammas_avg = repr_list([gamma / cur_step for gamma in G_gammas_sum.values()]) vq_diffs_avg = repr_list([diff / cur_step for diff in vq_diffs_sum]) logging.info(f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] ' + # f'grow_index: {current_grow_index}/{depth - 1}, ' + f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, ' + f'p_fake: {p_fake_sum / cur_step:.5f}, p_real: {p_real_sum / cur_step:.5f}, ' + ( f'd_attn_gammas: [{D_gammas_avg}], g_attn_gammas: [{G_gammas_avg}], ' if args.attn else '') + f'vq_diffs: [{vq_diffs_avg}], ' + f'batches/s: {batches_per_sec:02.2f}') g_loss_sum = d_loss_sum = 0 p_fake_sum = p_real_sum = 0 vq_diffs_sum = [0] * depth D_gammas_sum = OrderedDict() G_gammas_sum = OrderedDict() start_time = time.time() if global_step % args.sample_every == 0: sample_out = G_sample(z_sample) sample_codes= [q( o.transpose(1, 2) )[2] for q, o in zip(quantizers, sample_out)] sample_logits = [vq_model.decode_code(sample_code, depth - 1 - i) for i, sample_code in enumerate(sample_codes)] samples_decoded = [decode(logits.argmax(-1)) for logits in sample_logits] real_codes = [q( o.transpose(1, 2) )[2] for q, o in zip(quantizers, reals)] reals_logits = [vq_model.decode_code(code[:args.n_sample], i) for i, code in enumerate(real_codes)] reals_decoded = [decode(logits.argmax(-1)) for logits in reals_logits] (sample_dir / f'fakes_{global_step:06d}.txt').write_text( '\n\n'.join(map(lambda g: '\n'.join(g), zip(*samples_decoded)))) (sample_dir / f'reals_{global_step:06d}.txt').write_text( '\n\n'.join(map(lambda g: '\n'.join(g), zip(*reals_decoded)))) cur_step += 1 global_step += 1 torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth')) torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth')) except KeyboardInterrupt: pass
def main(args): result_dir = setup_run(args.run_name, create_dirs=['checkpoints', 'samples']) setup_logging(result_dir / 'log.txt') logging.info(args) device = get_default_device(args.device) sample_dir = result_dir / 'samples' checkpoint_dir = result_dir / 'checkpoints' dataset = BPEDataset(args.original_dataset) depth = math.log2(dataset.seq_length) assert int(depth) == depth depth = int(depth) vocab_size = dataset.vocab_size batches = DataLoader(dataset, args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) inter_dim = 32 extract_dims = [inter_dim] * 4 + [vocab_size] inject_dims = [vocab_size] + [inter_dim] * 4 G = Generator(args.latent_size, [128, 128, 128, 128, 128], extract_dims, attn=args.attn).to(device) D = Discriminator([128, 128, 128, 128, 128], inject_dims, attn=args.attn).to(device) T = TransformNetwork(vocab_size, [64, 64, 64, 64], inter_dim).to(device) if args.attn: D_gammas = list( map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), D.modules()))) G_gammas = list( map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), G.modules()))) #G.apply(init_weights) #G.apply(apply_spectral_norm) #D.apply(init_weights) #D.apply(apply_spectral_norm) G.train() D.train() (result_dir / 'G.txt').write_text(str(G)) (result_dir / 'D.txt').write_text(str(D)) if args.use_ema: G_shadow = copy.deepcopy(G) G_sample = G_shadow update_average(G_shadow, G, beta=0.0) else: G_sample = G G_orig = G if args.data_parallel: G = nn.DataParallel(G) D = nn.DataParallel(D) G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.9)) D_opt = torch.optim.Adam(list(D.parameters()) + list(T.parameters()), lr=args.d_lr, betas=(0.5, 0.9)) z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device) #loss_f = RelativisticAverageHingeLoss(D) loss_f = WGAN_GP(D) try: global_step = 0 for epoch in range(args.epochs): g_loss_sum = 0 d_loss_sum = 0 p_fake_sum = 0 p_real_sum = 0 D_gammas_sum = OrderedDict() G_gammas_sum = OrderedDict() start_time = time.time() cur_step = 0 for step, reals in enumerate(batches): reals = reals.to(device) reals_one_hot = F.one_hot(reals, num_classes=vocab_size).transpose( 1, 2).type(torch.float) batch_size = reals.size(0) reals_t = T(reals_one_hot) reals_input = [reals_one_hot] + reals_t z = torch.randn(batch_size, args.latent_size, 1).to(device) # Optimize the discriminator fake_out = G(z) fake_probs = torch.softmax(fake_out[-1], dim=1) fake_input = (fake_out[:-1] + [fake_probs])[::-1] fake_input = [t.detach() for t in fake_input] D_opt.zero_grad() loss_d, p_fake, p_real = loss_f.d_loss(reals_input, fake_input) loss_d.backward() D_opt.step() # Optimize generator fake_out = G(z) fake_probs = torch.softmax(fake_out[-1], dim=1) fake_input = (fake_out[:-1] + [fake_probs])[::-1] G_opt.zero_grad() reals_input = [t.detach() for t in reals_input] loss_g = loss_f.g_loss(reals_input, fake_input) loss_g.backward() G_opt.step() if args.use_ema: update_average(G_shadow, G_orig, beta=0.999) g_loss_sum += float(loss_g) d_loss_sum += float(loss_d) p_fake_sum += float(p_fake) p_real_sum += float(p_real) if args.attn: for i, (d_gamma, g_gamma) in enumerate(zip(D_gammas, G_gammas)): D_gammas_sum[i] = D_gammas_sum.get(i, 0) + d_gamma G_gammas_sum[i] = G_gammas_sum.get(i, 0) + g_gamma if global_step % args.log_every == 0: cur_step = min(step + 1, args.log_every) batches_per_sec = cur_step / (time.time() - start_time) if args.attn: D_gammas_avg = repr_list([ gamma / cur_step for gamma in D_gammas_sum.values() ]) G_gammas_avg = repr_list([ gamma / cur_step for gamma in G_gammas_sum.values() ]) logging.info( f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] ' + #f'grow_index: {current_grow_index}/{depth - 1}, ' + f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, ' + f'p_fake: {p_fake_sum / cur_step:.5f}, p_real: {p_real_sum / cur_step:.5f}, ' + (f'd_attn_gammas: [{D_gammas_avg}], g_attn_gammas: [{G_gammas_avg}], ' if args.attn else '') + f'batches/s: {batches_per_sec:02.2f}') g_loss_sum = d_loss_sum = 0 p_fake_sum = p_real_sum = 0 D_gammas_sum = OrderedDict() G_gammas_sum = OrderedDict() start_time = time.time() if global_step % args.sample_every == 0: samples_decoded = dataset.seq_to_text( G_sample(z_sample)[-1].argmax(1)) reals_decoded = dataset.seq_to_text(reals[:args.n_sample]) (sample_dir / f'fakes_{global_step:06d}.txt').write_text( '\n'.join(samples_decoded)) (sample_dir / f'reals_{global_step:06d}.txt').write_text( '\n'.join(reals_decoded)) cur_step += 1 global_step += 1 torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth')) torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth')) except KeyboardInterrupt: pass
def main(args): result_dir = setup_run(args.run_name, create_dirs=['checkpoints', 'samples']) setup_logging(result_dir / 'log.txt') logging.info(args) device = get_default_device(args.device) sample_dir = result_dir / 'samples' checkpoint_dir = result_dir / 'checkpoints' seq_length = 32 dataset = ByteLevelTextDataset(args.dataset, seq_len=seq_length) vocab_size = dataset.vocab_size batches = DataLoader(dataset, args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) inter_dim = vocab_size # spiegel model # G = Generator(args.latent_size, [256, 256, 128, 64, 64], out_dim=inter_dim).to(device) # D = UnetDiscriminator(64, max_channel=256, depth=5, in_dim=inter_dim).to(device) G = Generator(args.latent_size, [512, 512, 512, 512], out_dim=vocab_size).to(device) D = UnetDiscriminator(512, max_channel=512, depth=4, in_dim=vocab_size).to(device) G.apply(apply_spectral_norm) D.apply(apply_spectral_norm) G.apply(init_weights) D.apply(init_weights) G.train() D.train() (result_dir / 'G.txt').write_text(str(G)) (result_dir / 'D.txt').write_text(str(D)) if args.use_ema: G_shadow = copy.deepcopy(G) G_sample = G_shadow update_average(G_shadow, G, beta=0.0) else: G_sample = G G_orig = G D_orig = D if args.data_parallel: G = nn.DataParallel(G) D = nn.DataParallel(D) G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999)) D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999)) z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device) #loss_f = RelativisticAverageHingeLoss(D) #loss_f = GANLoss(D) loss_f = WGAN_GP(D) try: global_step = 0 for epoch in range(args.epochs): g_loss_sum = 0 d_loss_sum = 0 p_fake_g_sum = 0 p_fake_l_sum = 0 p_real_g_sum = 0 p_real_l_sum = 0 G_attn_sum = 0 D_attn_sum = 0 start_time = time.time() cur_step = 0 for step, reals in enumerate(batches): reals = reals.to(device) reals_one_hot = F.one_hot(reals, num_classes=vocab_size).transpose( 1, 2).to(torch.float) batch_size = reals.size(0) z = torch.randn(batch_size, args.latent_size, 1).to(device) tau = 0.1 # Optimize the discriminator fake_out = G(z) fake_out = torch.softmax(fake_out / tau, dim=1) D_opt.zero_grad() d_loss, p_real_g, p_real_l, p_fake_g, p_fake_l = loss_f.loss_d( reals_one_hot, fake_out.detach()) d_loss.backward() D_opt.step() # Optimize generator fake_out = G(z) fake_out = torch.softmax(fake_out / tau, dim=1) G_opt.zero_grad() g_loss = loss_f.loss_g(reals_one_hot, fake_out) g_loss.backward() G_opt.step() if args.use_ema: update_average(G_shadow, G_orig, beta=0.999) g_loss_sum += float(g_loss) d_loss_sum += float(d_loss) p_fake_g_sum += float(p_fake_g) p_fake_l_sum += float(p_fake_l) p_real_g_sum += float(p_real_g) p_real_l_sum += float(p_real_l) #G_attn_sum += float(G_orig.attn.gamma) #D_attn_sum += float(D_orig.attn.gamma) if global_step % args.log_every == 0: cur_step = min(step + 1, args.log_every) batches_per_sec = cur_step / (time.time() - start_time) logging.info( f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] ' + #f'grow_index: {current_grow_index}/{depth - 1}, ' + f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, ' + f'p_fake_g: {p_fake_g_sum / cur_step:.5f}, p_fake_l: {p_fake_l_sum / cur_step:.5f}, ' + f'p_real_g: {p_real_g_sum / cur_step:.5f}, p_real_l: {p_real_l_sum / cur_step:.5f}, ' + #f'G_attn_gamma: {G_attn_sum / cur_step:.2f}, D_attn_gamma: {D_attn_sum / cur_step:.2f}, ' f'batches/s: {batches_per_sec:02.2f}') g_loss_sum = d_loss_sum = 0 p_fake_g_sum = 0 p_fake_l_sum = 0 p_real_g_sum = 0 p_real_l_sum = 0 G_attn_sum = 0 D_attn_sum = 0 start_time = time.time() if global_step % args.sample_every == 0: samples = G_sample(z_sample).argmax(1) (sample_dir / f'fakes_{global_step:06d}.txt').write_text( '\n'.join(dataset.seq_to_text(samples))) (sample_dir / f'reals_{global_step:06d}.txt').write_text( '\n'.join(dataset.seq_to_text(reals[:args.n_sample]))) # (sample_dir / f'fakes_{global_step:06d}.txt').write_text('\n'.join(dataset.seq_to_text(samples))) # (sample_dir / f'reals_{global_step:06d}.txt').write_text('\n'.join(dataset.seq_to_text(reals_decode))) cur_step += 1 global_step += 1 torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth')) torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth')) except KeyboardInterrupt: pass
def main(args): result_dir = setup_run(args.run_name, create_dirs=['checkpoints', 'samples']) setup_logging(result_dir / 'log.txt') logging.info(args) device = get_default_device(args.device) sample_dir = result_dir / 'samples' checkpoint_dir = result_dir / 'checkpoints' decode = BPEDataset(args.original_dataset).seq_to_text vq_model = torch.load(args.vq_model).to(device) depth = vq_model.depth num_classes = vq_model.quantize[0].n_embed dataset = LMDBDataset(args.vq_dataset) batches = DataLoader(dataset, args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) G = Generator(args.latent_size, [256, 256, 256, 256], num_classes, attn=args.attn).to(device) D = Discriminator([256, 256, 256, 256], num_classes, use_embeddings=False, attn=args.attn).to(device) if args.attn: D_gammas = list( map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), D.modules()))) G_gammas = list( map(attrgetter('gamma'), filter(lambda m: isinstance(m, SelfAttention), G.modules()))) G.apply(init_weights) G.apply(apply_spectral_norm) D.apply(init_weights) D.apply(apply_spectral_norm) G.train() D.train() (result_dir / 'G.txt').write_text(str(G)) (result_dir / 'D.txt').write_text(str(D)) if args.use_ema: G_shadow = copy.deepcopy(G) G_sample = G_shadow update_average(G_shadow, G, beta=0.0) else: G_sample = G G_orig = G if args.data_parallel: G = nn.DataParallel(G) D = nn.DataParallel(D) G_opt = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999)) D_opt = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999)) z_sample = torch.randn(args.n_sample, args.latent_size, 1).to(device) code_noise = CategoricalNoise(num_classes, 0.0) try: global_step = 0 for epoch in range(args.epochs): g_loss_sum = 0 d_loss_sum = 0 p_fake_sum = 0 p_real_sum = 0 D_gammas_sum = OrderedDict() G_gammas_sum = OrderedDict() start_time = time.time() cur_step = 0 for step, codes in enumerate(batches): #current_grow_index = min(global_step // args.steps_per_stage, depth - 1) current_grow_index = 3 interpol = (global_step % args.steps_per_stage) / args.steps_per_stage code = codes[-(current_grow_index + 1)] code = code.to(device) code_noise.p = 0.3 * (1.0 - min(1.0, interpol * 2)) #code = code_noise(code) z = torch.randn(code.size(0), args.latent_size, 1).to(device) fake_logits = G(z, extract_at_grow_index=current_grow_index) G_opt.zero_grad() D_opt.zero_grad() loss_d, loss_g, p_fake, p_real = multinomial_bgan_loss( D, fake_logits, code, n_samples=args.n_mc_samples, tau=args.mc_sample_tau) torch.autograd.backward([loss_d, loss_g]) G_opt.step() D_opt.step() if args.use_ema: update_average(G_shadow, G_orig, beta=0.999) g_loss_sum += float(loss_g) d_loss_sum += float(loss_d) p_fake_sum += float(p_fake) p_real_sum += float(p_real) if args.attn: for i, (d_gamma, g_gamma) in enumerate(zip(D_gammas, G_gammas)): D_gammas_sum[i] = D_gammas_sum.get(i, 0) + d_gamma G_gammas_sum[i] = G_gammas_sum.get(i, 0) + g_gamma if global_step % args.log_every == 0: cur_step = min(step + 1, args.log_every) batches_per_sec = cur_step / (time.time() - start_time) if args.attn: D_gammas_avg = repr_list([ gamma / cur_step for gamma in D_gammas_sum.values() ]) G_gammas_avg = repr_list([ gamma / cur_step for gamma in G_gammas_sum.values() ]) logging.info( f'[EPOCH {epoch + 1:03d}] [{step:05d} / {len(batches):05d}] ' + f'grow_index: {current_grow_index}/{depth - 1}, ' + f'loss_d: {d_loss_sum / cur_step:.5f}, loss_g: {g_loss_sum / cur_step:.5f}, ' + f'p_fake: {p_fake_sum / cur_step:.5f}, p_real: {p_real_sum / cur_step:.5f}, ' + (f'd_attn_gammas: [{D_gammas_avg}], g_attn_gammas: [{G_gammas_avg}], ' if args.attn else '') + f'batches/s: {batches_per_sec:02.2f}, code_noise_p: {code_noise.p:.2f}' ) g_loss_sum = d_loss_sum = 0 p_fake_sum = p_real_sum = 0 D_gammas_sum = OrderedDict() G_gammas_sum = OrderedDict() start_time = time.time() if global_step % args.sample_every == 0: current_depth = depth - 1 - current_grow_index samples_codes = G_sample( z_sample, extract_at_grow_index=current_grow_index).argmax(1) samples_logits = vq_model.decode_code( samples_codes, current_depth) samples_decoded = decode(samples_logits.argmax(-1)) reals_logits = vq_model.decode_code( code[:args.n_sample], current_depth) reals_decoded = decode(reals_logits.argmax(-1)) (sample_dir / f'fakes_{current_grow_index}_{global_step:06d}.txt' ).write_text('\n'.join(samples_decoded)) (sample_dir / f'reals_{current_grow_index}_{global_step:06d}.txt' ).write_text('\n'.join(reals_decoded)) cur_step += 1 global_step += 1 torch.save(G, str(checkpoint_dir / f'G_{global_step:06d}.pth')) torch.save(D, str(checkpoint_dir / f'D_{global_step:06d}.pth')) except KeyboardInterrupt: pass