def __init__(self, _device): self.device = _device self.batch_size = 64 self.resolution = 28 self.d_criterion = None self.d_optimizer = None self.g_criterion = None self.g_optimizer = None self.discriminator = Discriminator(num_layers=5, activations=["relu", "relu", "relu", "sigmoid"], device=_device, num_nodes=[1, 64, 128, 64, 1], kernels=[5, 5, 3], strides=[2, 2, 2], dropouts=[.25, .25, 0], batch_size=64) # pass one image through the network so as to initialize the output # layer self.discriminator(torch.rand( size=[self.batch_size, 1, self.resolution, self.resolution])) self.generator = Generator(num_layers=6, activations=["relu", "relu", "relu", "relu", "tanh"], num_nodes=[1, 64, 128, 64, 64, 1], kernels=[3, 3, 3, 3], strides=[1, 1, 1, 1], batch_norms=[1, 1, 1, 0], upsamples=[1, 1, 0, 0], dropouts=[.25, .25, 0])
def __init__(self, verbosity=True, latent_dim=100): img_shape = (128, 128, 3) the_disc = Discriminator() the_gen = Generator() self.discriminator = the_disc.define_discriminator( verb=verbosity, sample_shape=img_shape) self.generator = the_gen.define_generator(verb=verbosity, sample_shape=img_shape, latent_dim=latent_dim) self.discriminator.trainable = False optimizer = Adam(0.0002, 0.5) self.discriminator.compile( loss=['binary_crossentropy', 'categorical_crossentropy'], loss_weights=[0.5, 0.5], optimizer=optimizer, metrics=['accuracy']) noise = Input(shape=(latent_dim, )) img = self.generator(noise) valid, _ = self.discriminator(img) self.combined = Model(noise, valid) self.combined.compile(loss=['binary_crossentropy'], optimizer=optimizer)
def __init__(self, damsm, device=DEVICE): self.gen = Generator(device) self.disc = Discriminator(device) self.damsm = damsm.to(device) self.damsm.txt_enc.eval(), self.damsm.img_enc.eval() freeze_params_(self.damsm.txt_enc), freeze_params_(self.damsm.img_enc) self.device = device self.gen.apply(init_weights), self.disc.apply(init_weights) self.gen_optimizer = torch.optim.Adam(self.gen.parameters(), lr=GENERATOR_LR, betas=(0.5, 0.999)) self.discriminators = [self.disc.d64, self.disc.d128, self.disc.d256] self.disc_optimizers = [ torch.optim.Adam(d.parameters(), lr=DISCRIMINATOR_LR, betas=(0.5, 0.999)) for d in self.discriminators ]
def main(): parser = argparse.ArgumentParser() parser.add_argument('--cuda', default=False, action='store_true', help='Enable CUDA') args = parser.parse_args() use_cuda = True if args.cuda and torch.cuda.is_available() else False random.seed(SEED) np.random.seed(SEED) netG = Generator(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, G_LR, use_cuda) netD = Discriminator(VOCAB_SIZE, D_EMB_SIZE, D_NUM_CLASSES, D_FILTER_SIZES, D_NUM_FILTERS, DROPOUT, D_LR, D_L2_REG, use_cuda) oracle = Oracle(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda) # generating synthetic data # print('Generating data...') # generate_samples(oracle, BATCH_SIZE, GENERATED_NUM, REAL_FILE) # pretrain generator gen_set = GeneratorDataset(REAL_FILE) genloader = DataLoader(dataset=gen_set, batch_size=BATCH_SIZE, shuffle=True) print('\nPretraining generator...\n') for epoch in range(PRE_G_EPOCHS): loss = netG.pretrain(genloader) print('Epoch {} pretrain generator training loss: {}'.format( epoch, loss)) generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} pretrain generator val loss: {}'.format( epoch + 1, loss)) # pretrain discriminator print('\nPretraining discriminator...\n') for epoch in range(D_STEPS): generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for _ in range(K_STEPS): loss = netD.dtrain(disloader) print('Epoch {} pretrain discriminator training loss: {}'.format( epoch + 1, loss)) # adversarial training rollout = Rollout(netG, update_rate=ROLLOUT_UPDATE_RATE, rollout_num=ROLLOUT_NUM) print('\n#####################################################') print('Adversarial training...\n') for epoch in range(TOTAL_EPOCHS): for _ in range(G_STEPS): netG.pgtrain(BATCH_SIZE, SEQUENCE_LEN, rollout, netD) for d_step in range(D_STEPS): # train discriminator generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for k_step in range(K_STEPS): loss = netD.dtrain(disloader) print( 'D_step {}, K-step {} adversarial discriminator training loss: {}' .format(d_step + 1, k_step + 1, loss)) rollout.update_params() generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} adversarial generator val loss: {}'.format( epoch + 1, loss))
class AttnGAN: def __init__(self, damsm, device=DEVICE): self.gen = Generator(device) self.disc = Discriminator(device) self.damsm = damsm.to(device) self.damsm.txt_enc.eval(), self.damsm.img_enc.eval() freeze_params_(self.damsm.txt_enc), freeze_params_(self.damsm.img_enc) self.device = device self.gen.apply(init_weights), self.disc.apply(init_weights) self.gen_optimizer = torch.optim.Adam(self.gen.parameters(), lr=GENERATOR_LR, betas=(0.5, 0.999)) self.discriminators = [self.disc.d64, self.disc.d128, self.disc.d256] self.disc_optimizers = [ torch.optim.Adam(d.parameters(), lr=DISCRIMINATOR_LR, betas=(0.5, 0.999)) for d in self.discriminators ] def train(self, dataset, epoch, batch_size=GAN_BATCH, test_sample_every=5, hist_avg=False, evaluator=None): start_time = time.strftime("%Y-%m-%d-%H-%M", time.gmtime()) os.makedirs(f'{OUT_DIR}/{start_time}') if hist_avg: avg_g_params = deepcopy(list(p.data for p in self.gen.parameters())) loader_config = { 'batch_size': batch_size, 'shuffle': True, 'drop_last': True, 'collate_fn': dataset.collate_fn } train_loader = DataLoader(dataset.train, **loader_config) metrics = { 'IS': [], 'FID': [], 'loss': { 'g': [], 'd': [] }, 'accuracy': { 'real': [], 'fake': [], 'mismatched': [], 'unconditional_real': [], 'unconditional_fake': [] } } if evaluator is not None: evaluator = evaluator(dataset, self.damsm.img_enc.inception_model, batch_size, self.device) noise = torch.FloatTensor(batch_size, D_Z).to(self.device) gen_updates = 0 self.disc.train() for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True): self.gen.train(), self.disc.train() g_loss = 0 w_loss = 0 s_loss = 0 kl_loss = 0 g_stage_loss = np.zeros(3, dtype=float) d_loss = np.zeros(3, dtype=float) real_acc = np.zeros(3, dtype=float) fake_acc = np.zeros(3, dtype=float) mismatched_acc = np.zeros(3, dtype=float) uncond_real_acc = np.zeros(3, dtype=float) uncond_fake_acc = np.zeros(3, dtype=float) disc_skips = np.zeros(3, dtype=int) train_pbar = tqdm(train_loader, desc='Training', leave=False, dynamic_ncols=True) for batch in train_pbar: real_imgs = [batch['img64'], batch['img128'], batch['img256']] with torch.no_grad(): word_embs, sent_embs = self.damsm.txt_enc(batch['caption']) attn_mask = torch.tensor(batch['caption']).to( self.device) == dataset.vocab[END_TOKEN] # Generate images noise.data.normal_(0, 1) generated, att, mu, logvar = self.gen(noise, sent_embs, word_embs, attn_mask) # Discriminator loss (with label smoothing) batch_d_loss, batch_real_acc, batch_fake_acc, batch_mismatched_acc, batch_uncond_real_acc, batch_uncond_fake_acc, batch_disc_skips = self.discriminator_step( real_imgs, generated, sent_embs, 0.1) d_grad_norm = [grad_norm(d) for d in self.discriminators] d_loss += batch_d_loss real_acc += batch_real_acc fake_acc += batch_fake_acc mismatched_acc += batch_mismatched_acc uncond_real_acc += batch_uncond_real_acc uncond_fake_acc += batch_uncond_fake_acc disc_skips += batch_disc_skips # Generator loss batch_g_losses = self.generator_step(generated, word_embs, sent_embs, mu, logvar, batch['label']) g_total, batch_g_stage_loss, batch_w_loss, batch_s_loss, batch_kl_loss = batch_g_losses g_stage_loss += batch_g_stage_loss w_loss += batch_w_loss s_loss += batch_s_loss kl_loss += batch_kl_loss gen_updates += 1 avg_g_loss = g_total.item() / batch_size g_loss += avg_g_loss if hist_avg: for p, avg_p in zip(self.gen.parameters(), avg_g_params): avg_p.mul_(0.999).add_(0.001, p.data) if gen_updates % 1000 == 0: tqdm.write( 'Replacing generator weights with their moving average' ) for p, avg_p in zip(self.gen.parameters(), avg_g_params): p.data.copy_(avg_p) train_pbar.set_description( f'Training (G: {grad_norm(self.gen):.2f} ' f'D64: {d_grad_norm[0]:.2f} ' f'D128: {d_grad_norm[1]:.2f} ' f'D256: {d_grad_norm[2]:.2f})') batches = len(train_loader) g_loss /= batches g_stage_loss /= batches w_loss /= batches s_loss /= batches kl_loss /= batches d_loss /= batches real_acc /= batches fake_acc /= batches mismatched_acc /= batches uncond_real_acc /= batches uncond_fake_acc /= batches metrics['loss']['g'].append(g_loss) metrics['loss']['d'].append(d_loss) metrics['accuracy']['real'].append(real_acc) metrics['accuracy']['fake'].append(fake_acc) metrics['accuracy']['mismatched'].append(mismatched_acc) metrics['accuracy']['unconditional_real'].append(uncond_real_acc) metrics['accuracy']['unconditional_fake'].append(uncond_fake_acc) sep = '_' * 10 tqdm.write(f'{sep}Epoch {e}{sep}') if e % test_sample_every == 0: self.gen.eval() generated_samples = [ resolution.unsqueeze(0) for resolution in self.sample_test_set(dataset) ] self._save_generated(generated_samples, e, f'{OUT_DIR}/{start_time}') if evaluator is not None: scores = evaluator.evaluate(self) for k, v in scores.items(): metrics[k].append(v) tqdm.write(f'{k}: {v:.2f}') tqdm.write( f'Generator avg loss: total({g_loss:.3f}) ' f'stage0({g_stage_loss[0]:.3f}) stage1({g_stage_loss[1]:.3f}) stage2({g_stage_loss[2]:.3f}) ' f'w({w_loss:.3f}) s({s_loss:.3f}) kl({kl_loss:.3f})') for i, _ in enumerate(self.discriminators): tqdm.write(f'Discriminator{i} avg: ' f'loss({d_loss[i]:.3f}) ' f'r-acc({real_acc[i]:.3f}) ' f'f-acc({fake_acc[i]:.3f}) ' f'm-acc({mismatched_acc[i]:.3f}) ' f'ur-acc({uncond_real_acc[i]:.3f}) ' f'uf-acc({uncond_fake_acc[i]:.3f}) ' f'skips({disc_skips[i]})') return metrics def sample_test_set(self, dataset, nb_samples=8, nb_captions=2, noise_variations=2): subset = dataset.test sample_indices = np.random.choice(len(subset), nb_samples, replace=False) cap_indices = np.random.choice(10, nb_captions, replace=False) texts = [ subset.data[f'caption_{cap_idx}'].iloc[sample_idx] for sample_idx in sample_indices for cap_idx in cap_indices ] generated_samples = [ self.generate_from_text(texts, dataset) for _ in range(noise_variations) ] combined_img64 = torch.FloatTensor() combined_img128 = torch.FloatTensor() combined_img256 = torch.FloatTensor() for noise_variant in generated_samples: noise_var_img64 = torch.FloatTensor() noise_var_img128 = torch.FloatTensor() noise_var_img256 = torch.FloatTensor() for i in range(nb_samples): # rows: samples, columns: captions * noise variants row64 = torch.cat([ noise_variant[0][i * nb_captions + j] for j in range(nb_captions) ], dim=-1).cpu() row128 = torch.cat([ noise_variant[1][i * nb_captions + j] for j in range(nb_captions) ], dim=-1).cpu() row256 = torch.cat([ noise_variant[2][i * nb_captions + j] for j in range(nb_captions) ], dim=-1).cpu() noise_var_img64 = torch.cat([noise_var_img64, row64], dim=-2) noise_var_img128 = torch.cat([noise_var_img128, row128], dim=-2) noise_var_img256 = torch.cat([noise_var_img256, row256], dim=-2) combined_img64 = torch.cat([combined_img64, noise_var_img64], dim=-1) combined_img128 = torch.cat([combined_img128, noise_var_img128], dim=-1) combined_img256 = torch.cat([combined_img256, noise_var_img256], dim=-1) return combined_img64, combined_img128, combined_img256 @staticmethod def KL_loss(mu, logvar): loss = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) loss = torch.mean(loss).mul_(-0.5) return loss def generator_step(self, generated_imgs, word_embs, sent_embs, mu, logvar, class_labels): self.gen.zero_grad() avg_stage_g_loss = [0, 0, 0] local_features, global_features = self.damsm.img_enc( generated_imgs[-1]) batch_size = sent_embs.size(0) match_labels = torch.LongTensor(range(batch_size)).to(self.device) w1_loss, w2_loss, _ = self.damsm.words_loss(local_features, word_embs, class_labels, match_labels) w_loss = (w1_loss + w2_loss) * LAMBDA s1_loss, s2_loss = self.damsm.sentence_loss(global_features, sent_embs, class_labels, match_labels) s_loss = (s1_loss + s2_loss) * LAMBDA kl_loss = self.KL_loss(mu, logvar) g_total = w_loss + s_loss + kl_loss for i, d in enumerate(self.discriminators): features = d(generated_imgs[i]) fake_logits = d.logit(features, sent_embs) real_labels = torch.ones_like(fake_logits).to(self.device) disc_error = F.binary_cross_entropy_with_logits( fake_logits, real_labels) uncond_fake_logits = d.logit(features) uncond_disc_error = F.binary_cross_entropy_with_logits( uncond_fake_logits, real_labels) stage_loss = disc_error + uncond_disc_error avg_stage_g_loss[i] = stage_loss.item() / batch_size g_total += stage_loss g_total.backward() self.gen_optimizer.step() return g_total, avg_stage_g_loss, w_loss.item( ) / batch_size, s_loss.item() / batch_size, kl_loss.item() def discriminator_step(self, real_imgs, generated_imgs, sent_embs, label_smoothing, skip_acc_threshold=0.9, p_flip=0.05, halting=False): self.disc.zero_grad() batch_size = sent_embs.size(0) avg_d_loss = [0, 0, 0] real_accuracy = [0, 0, 0] fake_accuracy = [0, 0, 0] mismatched_accuracy = [0, 0, 0] uncond_real_accuracy = [0, 0, 0] uncond_fake_accuracy = [0, 0, 0] skipped = [0, 0, 0] for i, d in enumerate(self.discriminators): real_features = d(real_imgs[i].to(self.device)) fake_features = d(generated_imgs[i].detach()) real_logits = d.logit(real_features, sent_embs) real_labels = torch.full_like(real_logits, 1 - label_smoothing).to(self.device) fake_labels = torch.zeros_like(real_logits, dtype=torch.float).to(self.device) # flip_mask = torch.Tensor(real_labels.size()).bernoulli_(p_flip).type(torch.bool) # real_labels[flip_mask], fake_labels[flip_mask] = fake_labels[flip_mask], real_labels[flip_mask] real_error = F.binary_cross_entropy_with_logits( real_logits, real_labels) # Real images should be classified as real real_accuracy[i] = (real_logits >= 0).sum().item() / real_logits.numel() fake_logits = d.logit(fake_features, sent_embs) fake_error = F.binary_cross_entropy_with_logits( fake_logits, fake_labels) # Generated images should be classified as fake fake_accuracy[i] = (fake_logits < 0).sum().item() / fake_logits.numel() mismatched_logits = d.logit(real_features, rotate_tensor(sent_embs, 1)) mismatched_error = F.binary_cross_entropy_with_logits( mismatched_logits, fake_labels) # Images with mismatched descriptions should be classified as fake mismatched_accuracy[i] = (mismatched_logits < 0).sum().item( ) / mismatched_logits.numel() uncond_real_logits = d.logit(real_features) uncond_real_error = F.binary_cross_entropy_with_logits( uncond_real_logits, real_labels) uncond_real_accuracy[i] = (uncond_real_logits >= 0).sum().item( ) / uncond_real_logits.numel() uncond_fake_logits = d.logit(fake_features) uncond_fake_error = F.binary_cross_entropy_with_logits( uncond_fake_logits, fake_labels) uncond_fake_accuracy[i] = (uncond_fake_logits < 0).sum().item( ) / uncond_fake_logits.numel() error = (real_error + uncond_real_error) / 2 + ( fake_error + uncond_fake_error + mismatched_error) / 3 if not halting or fake_accuracy[i] + real_accuracy[ i] < skip_acc_threshold * 2: error.backward() self.disc_optimizers[i].step() else: skipped[i] = 1 avg_d_loss[i] = error.item() / batch_size return avg_d_loss, real_accuracy, fake_accuracy, mismatched_accuracy, uncond_real_accuracy, uncond_fake_accuracy, skipped def generate_from_text(self, texts, dataset, noise=None): encoded = [dataset.train.encode_text(t) for t in texts] generated = self.generate_from_encoded_text(encoded, dataset, noise) return generated def generate_from_encoded_text(self, encoded, dataset, noise=None): with torch.no_grad(): w_emb, s_emb = self.damsm.txt_enc(encoded) attn_mask = torch.tensor(encoded).to( self.device) == dataset.vocab[END_TOKEN] if noise is None: noise = torch.FloatTensor(len(encoded), D_Z).to(self.device) noise.data.normal_(0, 1) generated, att, mu, logvar = self.gen(noise, s_emb, w_emb, attn_mask) return generated def _save_generated(self, generated, epoch, out_dir=OUT_DIR): nb_samples = generated[0].size(0) save_dir = f'{out_dir}/epoch_{epoch:03}' os.makedirs(save_dir) for i in range(nb_samples): save_image(generated[0][i], f'{save_dir}/{i}_64.jpg', normalize=True, range=(-1, 1)) save_image(generated[1][i], f'{save_dir}/{i}_128.jpg', normalize=True, range=(-1, 1)) save_image(generated[2][i], f'{save_dir}/{i}_256.jpg', normalize=True, range=(-1, 1)) def save(self, name, save_dir=GAN_MODEL_DIR, metrics=None): os.makedirs(save_dir, exist_ok=True) torch.save(self.gen.state_dict(), f'{save_dir}/{name}_generator.pt') torch.save(self.disc.state_dict(), f'{save_dir}/{name}_discriminator.pt') if metrics is not None: with open(f'{save_dir}/{name}_metrics.json', 'w') as f: metrics = pre_json_metrics(metrics) json.dump(metrics, f) def load_(self, name, load_dir=GAN_MODEL_DIR): self.gen.load_state_dict(torch.load(f'{load_dir}/{name}_generator.pt')) self.disc.load_state_dict( torch.load(f'{load_dir}/{name}_discriminator.pt')) self.gen.eval(), self.disc.eval() @staticmethod def load(name, damsm, load_dir=GAN_MODEL_DIR, device=DEVICE): attngan = AttnGAN(damsm, device=device) attngan.load_(name, load_dir) return attngan def validate_test_set(self, dataset, batch_size=GAN_BATCH, save_dir=f'{OUT_DIR}/test_samples'): os.makedirs(save_dir, exist_ok=True) loader = DataLoader(dataset.test, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=dataset.collate_fn) loader = tqdm(loader, dynamic_ncols=True, leave=True, desc='Generating samples for test set') self.gen.eval() with torch.no_grad(): i = 0 for batch in loader: word_embs, sent_embs = self.damsm.txt_enc(batch['caption']) attn_mask = torch.tensor(batch['caption']).to( self.device) == dataset.vocab[END_TOKEN] noise = torch.FloatTensor(len(batch['caption']), D_Z).to(self.device) noise.data.normal_(0, 1) generated, att, mu, logvar = self.gen(noise, sent_embs, word_embs, attn_mask) for img in generated[-1]: save_image(img, f'{save_dir}/{i}.jpg', normalize=True, range=(-1, 1)) i += 1 def get_d_score(self, imgs, sent_embs): d = self.disc.d256 features = d(imgs.to(self.device)) scores = d.logit(features, sent_embs) return scores def accept_prob(self, score1, score2): return min(1, (1 / score1 - 1) / (1 / score2 - 1)) def d_scores_test(self, dataset): with torch.no_grad(): loader = DataLoader(dataset.test, batch_size=20, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) scores = [] d = self.disc.d256 for b in loader: img = b['img256'].to(self.device) f = d(img) l = d.logit(f) scores.append(torch.sigmoid(l)) scores = [x.item() for s in scores for x in s.reshape(-1)] return scores def z_test(self, scores, labels): labels = np.array(labels) scores = np.array(scores) num = np.sum(labels - scores) denom = np.sqrt(np.sum(scores * (1 - scores))) return num / denom def d_scores_gen(self, dataset): with torch.no_grad(): loader = DataLoader(dataset.test, batch_size=20, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) scores = [] d = self.disc.d256 for b in loader: noise = torch.FloatTensor(len(b['caption']), D_Z).to(self.device) noise.data.normal_(0, 1) word_embs, sent_embs = self.damsm.txt_enc(b['caption']) attn_mask = torch.tensor(b['caption']).to( self.device) == dataset.vocab[END_TOKEN] generated, _, _, _ = self.gen(noise, sent_embs, word_embs, attn_mask) f = d(generated[-1]) l = d.logit(f) scores.append(torch.sigmoid(l)) scores = [x.item() for s in scores for x in s.reshape(-1)] return scores def mh_sample(self, dataset, k, save_dir='test_samples', batch=GAN_BATCH): evaluator = IS_FID_Evaluator(dataset, self.damsm.img_enc.inception_model, batch, self.device) # self.disc.d256.train() with torch.no_grad(): l = len(dataset.test) score_real = self.d_scores_test(dataset) score_gen = self.d_scores_gen(dataset) print(np.mean(score_real)) print(np.mean(score_gen)) portion = -l // 5 score_test = score_real[:portion] + score_gen[:portion] label_test = [1] * (len(score_test) // 2) + [0] * (len(score_test) // 2) print('Z test before calibration: ', self.z_test(torch.tensor(score_test), label_test)) score_real_calib = score_real[portion:] score_gen_calib = score_gen[portion:] # score_calib = score_real_calib + score_gen_calib score_calib = score_gen_calib + score_real_calib label_calib = len(score_gen_calib) * [0] + len( score_real_calib) * [1] cal_clf = LogisticRegression() cal_clf.fit(np.array(score_calib).reshape(-1, 1), label_calib) score_pred = cal_clf.predict_proba( np.array(score_test).reshape(-1, 1))[:, 1] print('Score pred avg: ', np.mean(score_pred)) test_pred = cal_clf.predict(np.array(score_test).reshape(-1, 1)) print('Z test after calibration: ', self.z_test(score_pred, label_test)) print('Accuracy: ', sum((test_pred == label_test)) / len(test_pred)) os.makedirs(save_dir, exist_ok=True) loader = DataLoader(dataset.test, batch_size=1, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) loader = tqdm(loader, dynamic_ncols=True, leave=True, desc='Generating samples for test set') imgs = [] true_probs = 0 noaccept = 0 for i, sample in enumerate(loader): if i > l - (l // 10): continue word_embs, sent_embs = self.damsm.txt_enc(sample['caption']) attn_mask = torch.tensor(sample['caption']).to( self.device) == dataset.vocab[END_TOKEN] img_chain = [] while len(img_chain) < k: noise = torch.FloatTensor(batch, D_Z).to(self.device) noise.data.normal_(0, 1) generated, _, _, _ = self.gen( noise, sent_embs.repeat(batch, 1), word_embs.repeat(batch, 1, 1), attn_mask.repeat(batch, 1)) for img in generated[-1]: img_chain.append(img) img_chain = img_chain[:k] img_chain = torch.stack(img_chain).to(self.device) score_chain = [] d_loader = DataLoader(img_chain, batch_size=batch, shuffle=False, drop_last=False) for d_batch in d_loader: scores = self.get_d_score(d_batch, sent_embs.repeat(batch, 1)) scores = scores.reshape(-1, 1).cpu().numpy() scores = cal_clf.predict_proba(scores)[:, 1] for s in scores: score_chain.append(s) chosen = 0 for j, s in enumerate(score_chain[1:], 1): alpha = self.accept_prob(score_chain[chosen], s) if np.random.rand() < alpha: chosen = j if chosen == 0: imgs.append(img_chain[torch.tensor( score_chain[1:]).argmax()].cpu()) noaccept += 1 else: imgs.append(img_chain[chosen].cpu()) true_probs += score_chain[0] print(noaccept) print(true_probs / len(dataset.test)) mu_real, sig_real = evaluator.mu_real, evaluator.sig_real mu_fake, sig_fake = activation_statistics( self.damsm.img_enc.inception_model, imgs) print('FID: ', frechet_dist(mu_real, sig_real, mu_fake, sig_fake)) return imgs
def main(): args = parse_args() device = torch.device("cuda") generator = Generator.from_file(args.generator_path).to(device) generator.eval() discriminator = Discriminator(tokenizer=generator.tokenizer).to(device) train_dataset = DailyDialogueDataset( path_join(args.dataset_path, "train/dialogues_train.txt"), tokenizer=generator.tokenizer, ) valid_dataset = DailyDialogueDataset( path_join(args.dataset_path, "validation/dialogues_validation.txt"), tokenizer=generator.tokenizer, ) print(len(train_dataset), len(valid_dataset)) optimizer = AdamW(discriminator.parameters(), lr=args.lr) for epoch in tqdm(range(args.num_epochs)): train_loss, valid_loss = [], [] rewards_real, rewards_fake, accuracy = [], [], [] discriminator.train() for ind in np.random.permutation(len(train_dataset)): optimizer.zero_grad() context, real_reply = train_dataset.sample_dialouge(ind) context, real_reply = ( context.to(device), real_reply.to(device), ) fake_reply = generator.generate(context, do_sample=True) loss, _, _ = discriminator.get_loss(context, real_reply, fake_reply) loss.backward() optimizer.step() train_loss.append(loss.item()) discriminator.eval() real_replies, fake_replies = [], [] for ind in range(len(valid_dataset)): context, real_reply = valid_dataset[ind] context, real_reply = ( context.to(device), real_reply.to(device), ) fake_reply = generator.generate(context, do_sample=True) with torch.no_grad(): loss, reward_real, reward_fake = discriminator.get_loss( context, real_reply, fake_reply) valid_loss.append(loss.item()) rewards_real.append(reward_real) rewards_fake.append(reward_fake) accuracy.extend([reward_real > 0.5, reward_fake < 0.5]) real_reply, fake_reply = ( generator.tokenizer.decode(real_reply[0]), generator.tokenizer.decode(fake_reply[0]), ) real_replies.append(real_reply) fake_replies.append(fake_reply) train_loss, valid_loss = np.mean(train_loss), np.mean(valid_loss) print( f"Epoch {epoch + 1}, Train Loss: {train_loss:.2f}, Valid Loss: {valid_loss:.2f}, Reward real: {np.mean(rewards_real):.2f}, Reward fake: {np.mean(rewards_fake):.2f}, Accuracy: {np.mean(accuracy):.2f}" ) print(f"Adversarial accuracy, {np.mean(accuracy):.2f}") for order in range(1, 5): print( f"BLEU-{order}: {bleuscore(real_replies, fake_replies, order=order)}" ) print(f"DIST-1: {dist1(fake_replies)}") print(f"DIST-2: {dist2unbiased(fake_replies)}")
def main(): args = parse_args() device = torch.device("cuda") generator = Generator.from_file(args.generator_path).to(device) if args.freeze: for name, param in generator.named_parameters(): if ("shared" not in name) and ("decoder.block.5" not in name): param.requires_grad = False discriminator = Discriminator.from_file( args.discriminator_path, tokenizer=generator.tokenizer ).to(device) if args.freeze: for name, param in discriminator.named_parameters(): if ("shared" not in name) and ("decoder.block.5" not in name): param.requires_grad = False train_dataset = DailyDialogueDataset( path_join(args.dataset_path, "train/dialogues_train.txt"), tokenizer=generator.tokenizer, debug=args.debug, ) valid_dataset = DailyDialogueDataset( path_join(args.dataset_path, "validation/dialogues_validation.txt"), tokenizer=generator.tokenizer, debug=args.debug, ) print(len(train_dataset), len(valid_dataset)) generator_optimizer = AdamW(generator.parameters(), lr=args.lr) discriminator_optimizer = AdamW(discriminator.parameters(), lr=args.lr) rewards = deque([], maxlen=args.log_every * args.generator_steps) rewards_real = deque([], maxlen=args.log_every * args.generator_steps) generator_loss = deque([], maxlen=args.log_every * args.generator_steps) discriminator_loss = deque([], maxlen=args.log_every * args.discriminator_steps) best_reward = 0 generator.train() discriminator.train() for iter in tqdm(range(args.num_iterations)): for _ in range(args.discriminator_steps): discriminator_optimizer.zero_grad() context, real_reply = train_dataset.sample() context, real_reply = ( context.to(device), real_reply.to(device), ) fake_reply = generator.generate(context, do_sample=True) if args.regs: split_real = random.randint(1, real_reply.size(1)) real_reply = real_reply[:, :split_real] split_fake = random.randint(1, fake_reply.size(1)) fake_reply = fake_reply[:, :split_fake] loss, _, _ = discriminator.get_loss(context, real_reply, fake_reply) loss.backward() discriminator_optimizer.step() discriminator_loss.append(loss.item()) for _ in range(args.generator_steps): generator_optimizer.zero_grad() context, real_reply = train_dataset.sample() context, real_reply = ( context.to(device), real_reply.to(device), ) fake_reply = generator.generate(context, do_sample=True) logprob_fake = generator.get_logprob(context, fake_reply) reward_fake = discriminator.get_reward(context, fake_reply) baseline = 0 if len(rewards) == 0 else np.mean(list(rewards)) if args.regs: partial_rewards = torch.tensor( [ discriminator.get_reward(context, fake_reply[:, :t]) for t in range(1, fake_reply.size(1) + 1) ] ).to(device) loss = -torch.mean(partial_rewards * logprob_fake) else: loss = -(reward_fake - baseline) * torch.mean(logprob_fake) if args.teacher_forcing: logprob_real = generator.get_logprob(context, real_reply) reward_real = discriminator.get_reward(context, real_reply) loss -= torch.mean(logprob_real) rewards_real.append(reward_real) loss.backward() generator_optimizer.step() generator_loss.append(loss.item()) rewards.append(reward_fake) if iter % args.log_every == 0: mean_reward = np.mean(list(rewards)) mean_reward_real = np.mean(list(rewards_real)) if args.discriminator_steps > 0: print(f"Discriminator Loss {np.mean(list(discriminator_loss))}") if args.generator_steps > 0: print(f"Generator Loss {np.mean(list(generator_loss))}") if args.teacher_forcing: print(f"Mean real reward: {mean_reward_real}") print(f"Mean fake reward: {mean_reward}\n") context, real_reply = valid_dataset.sample() context, real_reply = ( context.to(device), real_reply.to(device), ) fake_reply = generator.generate(context, do_sample=True) reward_fake = discriminator.get_reward(context, fake_reply) print_dialogue( context=context, real_reply=real_reply, fake_reply=fake_reply, tokenizer=generator.tokenizer, ) print(f"Reward: {reward_fake}\n") if mean_reward > best_reward: best_reward = mean_reward torch.save(discriminator.state_dict(), args.discriminator_output_path) torch.save(generator.state_dict(), args.generator_output_path) torch.save( discriminator.state_dict(), "all_" + args.discriminator_output_path ) torch.save(generator.state_dict(), "all_" + args.generator_output_path)
crop_size = (args.crop_size, args.crop_size) image_dataset = ImageDataset(args.image_root, args.mask_root, load_size, crop_size) data_loader = DataLoader( image_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=False, pin_memory=True ) # ----- # model # ----- generator = LBAM(4, 3) discriminator = Discriminator(3) extractor = VGG16FeatureExtractor() # ---------- # load model # ---------- start_iter = args.start_iter if args.pre_trained != '': ckpt_dict_load = torch.load(args.pre_trained) start_iter = ckpt_dict_load['n_iter'] generator.load_state_dict(ckpt_dict_load['generator']) discriminator.load_state_dict(ckpt_dict_load['discriminator']) print('Starting from iter ', start_iter)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--cuda', default=False, action='store_true', help='Enable CUDA') args = parser.parse_args() use_cuda = True if args.cuda and torch.cuda.is_available() else False netG = Generator(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda) netD = Discriminator(VOCAB_SIZE, D_EMB_SIZE, D_NUM_CLASSES, D_FILTER_SIZES, D_NUM_FILTERS, DROPOUT, use_cuda) oracle = Oracle(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda) if use_cuda: netG, netD, oracle = netG.cuda(), netD.cuda(), oracle.cuda() netG.create_optim(G_LR) netD.create_optim(D_LR, D_L2_REG) # generating synthetic data print('Generating data...') generate_samples(oracle, BATCH_SIZE, GENERATED_NUM, REAL_FILE) # pretrain generator gen_set = GeneratorDataset(REAL_FILE) genloader = DataLoader(dataset=gen_set, batch_size=BATCH_SIZE, shuffle=True) print('\nPretraining generator...\n') for epoch in range(PRE_G_EPOCHS): loss = netG.pretrain(genloader) print('Epoch {} pretrain generator training loss: {}'.format( epoch + 1, loss)) generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} pretrain generator val loss: {}'.format( epoch + 1, loss)) # pretrain discriminator print('\nPretraining discriminator...\n') for epoch in range(PRE_D_EPOCHS): generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for k_step in range(K_STEPS): loss = netD.dtrain(disloader) print( 'Epoch {} K-step {} pretrain discriminator training loss: {}'. format(epoch + 1, k_step + 1, loss)) print('\nStarting adversarial training...') for epoch in range(TOTAL_EPOCHS): nets = [copy.deepcopy(netG) for _ in range(POPULATION_SIZE)] population = [(net, evaluate(net, netD)) for net in nets] for g_step in range(G_STEPS): t_start = time.time() population.sort(key=lambda p: p[1], reverse=True) rewards = [p[1] for p in population[:PARENTS_COUNT]] reward_mean = np.mean(rewards) reward_max = np.max(rewards) reward_std = np.std(rewards) print( "Epoch %d step %d: reward_mean=%.2f, reward_max=%.2f, reward_std=%.2f, time=%.2f s" % (epoch, g_step, reward_mean, reward_max, reward_std, time.time() - t_start)) elite = population[0] # generate next population prev_population = population population = [elite] for _ in range(POPULATION_SIZE - 1): parent_idx = np.random.randint(0, PARENTS_COUNT) parent = prev_population[parent_idx][0] net = mutate_net(parent, use_cuda) fitness = evaluate(parent, netD) population.append((net, fitness)) netG = elite[0] for d_step in range(D_STEPS): # train discriminator generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for k_step in range(K_STEPS): loss = netD.dtrain(disloader) print( 'D_step {}, K-step {} adversarial discriminator training loss: {}' .format(d_step + 1, k_step + 1, loss)) generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} adversarial generator val loss: {}'.format( epoch + 1, loss))
class GAN: def __init__(self, _device): self.device = _device self.batch_size = 64 self.resolution = 28 self.d_criterion = None self.d_optimizer = None self.g_criterion = None self.g_optimizer = None self.discriminator = Discriminator(num_layers=5, activations=["relu", "relu", "relu", "sigmoid"], device=_device, num_nodes=[1, 64, 128, 64, 1], kernels=[5, 5, 3], strides=[2, 2, 2], dropouts=[.25, .25, 0], batch_size=64) # pass one image through the network so as to initialize the output # layer self.discriminator(torch.rand( size=[self.batch_size, 1, self.resolution, self.resolution])) self.generator = Generator(num_layers=6, activations=["relu", "relu", "relu", "relu", "tanh"], num_nodes=[1, 64, 128, 64, 64, 1], kernels=[3, 3, 3, 3], strides=[1, 1, 1, 1], batch_norms=[1, 1, 1, 0], upsamples=[1, 1, 0, 0], dropouts=[.25, .25, 0]) def train(self, epochs: int, dataloader): self.display_output() for epoch in range(epochs): i = 0 initial_loss = self.train_generator() true_loss = 1. false_loss = 1. print(f"#\tInitial Generator Loss: {initial_loss}") for data, target in dataloader: if false_loss > 0.7: true_loss, false_loss = self.train_discriminator(data, True) else: true_loss, false_loss = self.train_discriminator(data, False) generator_loss = self.train_generator() if i % 10 == 0: print( f"@\tIndex: {i}\tTrue Loss: {true_loss}\t" f"False Loss: {false_loss}\t" f"Generator Loss: {generator_loss}") self.display_output() i += 1 self.test() def train_discriminator(self, train_data, train): # add noise to labels true = torch.ones((self.batch_size, 1)) noise = torch.nn.functional.relu(0.01 * torch.randn(self.batch_size, 1)) true.sub_(noise) true = true.to(self.device, dtype=torch.float64) false = torch.zeros((self.batch_size, 1)) noise = torch.nn.functional.relu(0.01 * torch.randn(self.batch_size, 1)) false.sub_(noise) false = false.to(self.device, dtype=torch.float64) index = np.random.randint(0, train_data.shape[0], self.batch_size) true_images = train_data[index] true_loss = self.discriminator.batch_train(true_images, true, self.d_criterion, self.d_optimizer, train) # FIXME: Extract 100 to argument noise = torch.randn(self.batch_size, 1, 1, 100, dtype=torch.float64).to( self.device) generated_images = self.generator(noise) false_loss = self.discriminator.batch_train(generated_images, false, self.d_criterion, self.d_optimizer, train) return true_loss, false_loss def train_generator(self): valid = torch.ones((self.batch_size, 1), dtype=torch.float64).to( self.device) noise = torch.randn(self.batch_size, 1, 1, 100, dtype=torch.float64).to( self.device) return self.generator.batch_train(self.discriminator, noise, valid, self.g_criterion, self.g_optimizer) # make noise, and send through discriminator def test(self): noise = torch.randn(self.batch_size, 1, 1, 100, dtype=torch.float64).to( self.device) image = self.generator(noise).detach().cpu().numpy() for i in range(np.size(image, 0)): picture = image[i, 0, :, :] plt.imshow(picture) plt.show() def display_output(self): noise = torch.randn(1, 1, 1, 100, dtype=torch.float64).to(self.device) image = self.generator(noise).detach().cpu().numpy() picture = image[0, 0, :, :] plt.imshow(picture) plt.show()
def main_train(): # Build argument parser parser = argparse.ArgumentParser(description='Train a table to text model') # Training corpus corpora_group = parser.add_argument_group('training corpora', 'Corpora related arguments; specify either unaligned or' ' aligned training corpora') # "Languages (type,path)" corpora_group.add_argument('--src_corpus_params', type=str, default='table, ./data/processed_data/train/train.box', help='the source unaligned corpus (type,path). Type = text/table') corpora_group.add_argument('--trg_corpus_params', type=str, default='text, ./data/processed_data/train/train.article', help='the target unaligned corpus (type,path). Type = text/table') corpora_group.add_argument('--src_para_corpus_params', type=str, default='', help='the source corpus of parallel data(type,path). Type = text/table') corpora_group.add_argument('--trg_para_corpus_params', type=str, default='', help='the target corpus of parallel data(type,path). Type = text/table') # Maybe add src/target type (i.e. text/table) corpora_group.add_argument('--corpus_mode', type=str, default='mono', help='training mode: "mono" (unsupervised) / "para" (supervised)') corpora_group.add_argument('--max_sentence_length', type=int, default=50, help='the maximum sentence length for training (defaults to 50)') corpora_group.add_argument('--cache', type=int, default=100000, help='the cache size (in sentences) for corpus reading (defaults to 1000000)') # Embeddings/vocabulary embedding_group = parser.add_argument_group('embeddings', 'Embedding related arguments; either give pre-trained embeddings,' ' or a vocabulary and embedding dimensionality to' ' randomly initialize them') embedding_group.add_argument('--metadata_path', type=str, default='', required=True, help='Path for bin file created in pre-processing phase, ' 'containing BPEmb related metadata.') # Architecture architecture_group = parser.add_argument_group('architecture', 'Architecture related arguments') architecture_group.add_argument('--layers', type=int, default=2, help='the number of encoder/decoder layers (defaults to 2)') architecture_group.add_argument('--hidden', type=int, default=600, help='the number of dimensions for the hidden layer (defaults to 600)') architecture_group.add_argument('--dis_hidden', type=int, default=150, help='Number of dimensions for the discriminator hidden layers') architecture_group.add_argument('--n_dis_layers', type=int, default=2, help='Number of discriminator layers') architecture_group.add_argument('--disable_bidirectional', action='store_true', help='use a single direction encoder') architecture_group.add_argument('--disable_backtranslation', action='store_true', help='disable backtranslation') architecture_group.add_argument('--disable_field_loss', action='store_true', help='disable backtranslation') architecture_group.add_argument('--disable_discriminator', action='store_true', help='disable discriminator') architecture_group.add_argument('--shared_enc', action='store_true', help='share enc for both directions') architecture_group.add_argument('--shared_dec', action='store_true', help='share dec for both directions') # Denoising denoising_group = parser.add_argument_group('denoising', 'Denoising related arguments') denoising_group.add_argument('--denoising_mode', type=int, default=1, help='0/1/2 = disabled/old/new') denoising_group.add_argument('--word_shuffle', type=int, default=3, help='shuffle words (only relevant in new mode)') denoising_group.add_argument('--word_dropout', type=float, default=0.1, help='randomly remove words (only relevant in new mode)') denoising_group.add_argument('--word_blank', type=float, default=0.2, help='randomly blank out words (only relevant in new mode)') # Optimization optimization_group = parser.add_argument_group('optimization', 'Optimization related arguments') optimization_group.add_argument('--batch', type=int, default=50, help='the batch size (defaults to 50)') optimization_group.add_argument('--learning_rate', type=float, default=0.0002, help='the global learning rate (defaults to 0.0002)') optimization_group.add_argument('--dropout', metavar='PROB', type=float, default=0.3, help='dropout probability for the encoder/decoder (defaults to 0.3)') optimization_group.add_argument('--param_init', metavar='RANGE', type=float, default=0.1, help='uniform initialization in the specified range (defaults to 0.1, 0 for module specific default initialization)') optimization_group.add_argument('--iterations', type=int, default=300000, help='the number of training iterations (defaults to 300000)') # Model saving saving_group = parser.add_argument_group('model saving', 'Arguments for saving the trained model') saving_group.add_argument('--save', metavar='PREFIX', help='save models with the given prefix') saving_group.add_argument('--save_interval', type=int, default=0, help='save intermediate models at this interval') # Logging/validation logging_group = parser.add_argument_group('logging', 'Logging and validation arguments') logging_group.add_argument('--log_interval', type=int, default=100, help='log at this interval (defaults to 1000)') logging_group.add_argument('--dbg_print_interval', type=int, default=1000, help='log at this interval (defaults to 1000)') logging_group.add_argument('--src_valid_corpus', type=str, default='') logging_group.add_argument('--trg_valid_corpus', type=str, default='') logging_group.add_argument('--print_level', type=str, default='info', help='logging level [debug | info]') # Other misc_group = parser.add_argument_group('misc', 'Misc. arguments') misc_group.add_argument('--encoding', default='utf-8', help='the character encoding for input/output (defaults to utf-8)') misc_group.add_argument('--cuda', type=str, default='cpu', help='device for training. default value: "cpu"') misc_group.add_argument('--bleu_device', type=str, default='', help='device for calculating BLEU scores in case a validation dataset is given') # Parse arguments args = parser.parse_args() logger = logging.getLogger() if args.print_level == 'debug': logging.basicConfig(stream=sys.stderr, level=logging.DEBUG) elif args.print_level == 'info': logging.basicConfig(stream=sys.stderr, level=logging.INFO) elif args.print_level == 'warning': logging.basicConfig(stream=sys.stderr, level=logging.WARNING) else: logging.basicConfig(stream=sys.stderr, level=logging.CRITICAL) # Validate arguments if args.src_corpus_params is None or args.trg_corpus_params is None: print("Must supply corpus") sys.exit(-1) args.src_corpus_params = args.src_corpus_params.split(',') args.trg_corpus_params = args.trg_corpus_params.split(',') assert len(args.src_corpus_params) == 2 assert len(args.trg_corpus_params) == 2 src_type, src_corpus_path = args.src_corpus_params trg_type, trg_corpus_path = args.trg_corpus_params src_type = src_type.strip() src_corpus_path = src_corpus_path.strip() trg_type = trg_type.strip() trg_corpus_path = trg_corpus_path.strip() assert src_type != trg_type assert (src_type in ['table', 'text']) and (trg_type in ['table', 'text']) corpus_size = get_num_lines(src_corpus_path + '.content') # Select device if torch.cuda.is_available(): device = torch.device(args.cuda) else: device = torch.device('cpu') if args.bleu_device == '': args.bleu_device = device current_time = str(datetime.datetime.now().timestamp()) run_dir = 'run_' + current_time + '/' train_log_dir = 'logs/train/' + run_dir + args.save valid_log_dir = 'logs/valid/' + run_dir + args.save train_writer = SummaryWriter(train_log_dir) valid_writer = SummaryWriter(valid_log_dir) # Create optimizer lists src2src_optimizers = [] trg2trg_optimizers = [] src2trg_optimizers = [] trg2src_optimizers = [] # Method to create a module optimizer and add it to the given lists def add_optimizer(module, directions=()): if args.param_init != 0.0: for param in module.parameters(): param.data.uniform_(-args.param_init, args.param_init) optimizer = torch.optim.Adam(module.parameters(), lr=args.learning_rate) for direction in directions: direction.append(optimizer) return optimizer assert os.path.isfile(args.metadata_path) metadata = torch.load(args.metadata_path) bpemb_en = metadata.init_bpe_module() word_dict: BpeWordDict = torch.load(metadata.word_dict_path) field_dict: LabelDict = torch.load(metadata.field_dict_path) args.hidden = bpemb_en.dim + bpemb_en.dim // 2 if not args.disable_bidirectional: args.hidden *= 2 # Load embedding and/or vocab # word_dict = BpeWordDict.get(vocab=bpemb_en.words) w_sos_id = {'text': word_dict.bos_index, 'table': word_dict.sot_index} word_embeddings = nn.Embedding(len(word_dict), bpemb_en.dim, padding_idx=word_dict.pad_index) nn.init.normal_(word_embeddings.weight, 0, 0.1) nn.init.constant_(word_embeddings.weight[word_dict.pad_index], 0) with torch.no_grad(): word_embeddings.weight[:bpemb_en.vs, :] = torch.from_numpy(bpemb_en.vectors) word_embedding_size = word_embeddings.weight.data.size()[1] word_embeddings = word_embeddings.to(device) word_embeddings.weight.requires_grad = False logger.debug('w_embeddings is running on cuda: %d', next(word_embeddings.parameters()).is_cuda) # field_dict: LabelDict = torch.load('./data/processed_data/train/field.dict') field_embeddings = nn.Embedding(len(field_dict), bpemb_en.dim // 2, padding_idx=field_dict.pad_index) nn.init.normal_(field_embeddings.weight, 0, 0.1) nn.init.constant_(field_embeddings.weight[field_dict.pad_index], 0) field_embedding_size = field_embeddings.weight.data.size()[1] field_embeddings = field_embeddings.to(device) field_embeddings.weight.requires_grad = True logger.debug('f_embeddings is running on cuda: %d', next(word_embeddings.parameters()).is_cuda) src_encoder_word_embeddings = word_embeddings trg_encoder_word_embeddings = word_embeddings src_encoder_field_embeddings = field_embeddings trg_encoder_field_embeddings = field_embeddings src_decoder_word_embeddings = word_embeddings trg_decoder_word_embeddings = word_embeddings src_decoder_field_embeddings = field_embeddings trg_decoder_field_embeddings = field_embeddings src_generator = LinearGenerator(args.hidden, len(word_dict), len(field_dict)).to(device) if args.shared_dec: trg_generator = src_generator add_optimizer(src_generator, (src2src_optimizers, trg2src_optimizers, trg2trg_optimizers, src2trg_optimizers)) else: trg_generator = LinearGenerator(args.hidden, len(word_dict), len(field_dict)).to(device) add_optimizer(src_generator, (src2src_optimizers, trg2src_optimizers)) add_optimizer(trg_generator, (trg2trg_optimizers, src2trg_optimizers)) logger.debug('src generator is running on cuda: %d', next(src_generator.parameters()).is_cuda) logger.debug('trg generator is running on cuda: %d', next(src_generator.parameters()).is_cuda) # Build encoder src_enc = RNNEncoder(word_embedding_size=word_embedding_size, field_embedding_size=field_embedding_size, hidden_size=args.hidden, bidirectional=not args.disable_bidirectional, layers=args.layers, dropout=args.dropout).to(device) if args.shared_enc: trg_enc = src_enc add_optimizer(src_enc, (src2src_optimizers, src2trg_optimizers, trg2trg_optimizers, trg2src_optimizers)) else: trg_enc = RNNEncoder(word_embedding_size=word_embedding_size, field_embedding_size=field_embedding_size, hidden_size=args.hidden, bidirectional=not args.disable_bidirectional, layers=args.layers, dropout=args.dropout).to(device) add_optimizer(src_enc, (src2src_optimizers, src2trg_optimizers)) add_optimizer(trg_enc, (trg2trg_optimizers, trg2src_optimizers)) logger.debug('encoder model is running on cuda: %d', next(src_enc.parameters()).is_cuda) # Build decoders src_dec = RNNAttentionDecoder(word_embedding_size=word_embedding_size, field_embedding_size=field_embedding_size, hidden_size=args.hidden, layers=args.layers, dropout=args.dropout, input_feeding=False).to(device) if args.shared_dec: trg_dec = src_dec add_optimizer(src_dec, (src2src_optimizers, trg2src_optimizers, trg2trg_optimizers, src2trg_optimizers)) else: trg_dec = RNNAttentionDecoder(word_embedding_size=word_embedding_size, field_embedding_size=field_embedding_size, hidden_size=args.hidden, layers=args.layers, dropout=args.dropout, input_feeding=False).to(device) add_optimizer(src_dec, (src2src_optimizers, trg2src_optimizers)) add_optimizer(trg_dec, (trg2trg_optimizers, src2trg_optimizers)) logger.debug('decoder model is running on cuda: %d', next(src_dec.parameters()).is_cuda) logger.debug('attention model is running on cuda: %d', next(src_dec.attention.parameters()).is_cuda) discriminator = None if (args.corpus_mode == 'mono') and not args.disable_discriminator: discriminator = Discriminator(args.hidden, args.dis_hidden, args.n_dis_layers, args.dropout) discriminator = discriminator.to(device) # Build translators src2src_translator = Translator("src2src", encoder_word_embeddings=src_encoder_word_embeddings, decoder_word_embeddings=src_decoder_word_embeddings, encoder_field_embeddings=src_encoder_field_embeddings, decoder_field_embeddings=src_decoder_field_embeddings, generator=src_generator, src_word_dict=word_dict, trg_word_dict=word_dict, src_field_dict=field_dict, trg_field_dict=field_dict, src_type=src_type, trg_type=src_type, w_sos_id=w_sos_id[src_type], bpemb_en=bpemb_en, encoder=src_enc, decoder=src_dec, discriminator=discriminator, denoising=args.denoising_mode, device=device, max_word_shuffle_distance=args.word_shuffle, word_dropout_prob=args.word_dropout, word_blanking_prob=args.word_blank) src2trg_translator = Translator("src2trg", encoder_word_embeddings=src_encoder_word_embeddings, decoder_word_embeddings=trg_decoder_word_embeddings, encoder_field_embeddings=src_encoder_field_embeddings, decoder_field_embeddings=trg_decoder_field_embeddings, generator=trg_generator, src_word_dict=word_dict, trg_word_dict=word_dict, src_field_dict=field_dict, trg_field_dict=field_dict, src_type=src_type, trg_type=trg_type, w_sos_id=w_sos_id[trg_type], bpemb_en=bpemb_en, encoder=src_enc, decoder=trg_dec, discriminator=discriminator, denoising=0, device=device, max_word_shuffle_distance=args.word_shuffle, word_dropout_prob=args.word_dropout, word_blanking_prob=args.word_blank) trg2trg_translator = Translator("trg2trg", encoder_word_embeddings=trg_encoder_word_embeddings, decoder_word_embeddings=trg_decoder_word_embeddings, encoder_field_embeddings=trg_encoder_field_embeddings, decoder_field_embeddings=trg_decoder_field_embeddings, generator=trg_generator, src_word_dict=word_dict, trg_word_dict=word_dict, src_field_dict=field_dict, trg_field_dict=field_dict, src_type=trg_type, trg_type=trg_type, w_sos_id=w_sos_id[trg_type], bpemb_en=bpemb_en, encoder=trg_enc, decoder=trg_dec, discriminator=discriminator, denoising=args.denoising_mode, device=device, max_word_shuffle_distance=args.word_shuffle, word_dropout_prob=args.word_dropout, word_blanking_prob=args.word_blank) trg2src_translator = Translator("trg2src", encoder_word_embeddings=trg_encoder_word_embeddings, decoder_word_embeddings=src_decoder_word_embeddings, encoder_field_embeddings=trg_encoder_field_embeddings, decoder_field_embeddings=src_decoder_field_embeddings, generator=src_generator, src_word_dict=word_dict, trg_word_dict=word_dict, src_field_dict=field_dict, trg_field_dict=field_dict, src_type=trg_type, trg_type=src_type, w_sos_id=w_sos_id[src_type], bpemb_en=bpemb_en, encoder=trg_enc, decoder=src_dec, discriminator=discriminator, denoising=0, device=device, max_word_shuffle_distance=args.word_shuffle, word_dropout_prob=args.word_dropout, word_blanking_prob=args.word_blank) # Build trainers trainers = [] iters_per_epoch = int(np.ceil(corpus_size / args.batch)) print("CORPUS_SIZE = %d | BATCH_SIZE = %d | ITERS_PER_EPOCH = %d" % (corpus_size, args.batch, iters_per_epoch)) if args.corpus_mode == 'mono': f_content = open(src_corpus_path + '.content', encoding=args.encoding, errors='surrogateescape') f_labels = open(src_corpus_path + '.labels', encoding=args.encoding, errors='surrogateescape') src_corpus_path = data.CorpusReader(f_content, f_labels, max_sentence_length=args.max_sentence_length, cache_size=args.cache) f_content = open(trg_corpus_path + '.content', encoding=args.encoding, errors='surrogateescape') f_labels = open(trg_corpus_path + '.labels', encoding=args.encoding, errors='surrogateescape') trg_corpus_path = data.CorpusReader(f_content, f_labels, max_sentence_length=args.max_sentence_length, cache_size=args.cache) if not args.disable_discriminator: disc_trainer = DiscTrainer(device, src_corpus_path, trg_corpus_path, src_enc, trg_enc, src_encoder_word_embeddings, src_encoder_field_embeddings, word_dict, field_dict, discriminator, args.learning_rate, batch_size=args.batch) trainers.append(disc_trainer) src2src_trainer = Trainer(translator=src2src_translator, optimizers=src2src_optimizers, corpus=src_corpus_path, batch_size=args.batch, iters_per_epoch=iters_per_epoch) trainers.append(src2src_trainer) if not args.disable_backtranslation: trgback2src_trainer = Trainer(translator=trg2src_translator, optimizers=trg2src_optimizers, corpus=data.BacktranslatorCorpusReader(corpus=src_corpus_path, translator=src2trg_translator), batch_size=args.batch, iters_per_epoch=iters_per_epoch) trainers.append(trgback2src_trainer) trg2trg_trainer = Trainer(translator=trg2trg_translator, optimizers=trg2trg_optimizers, corpus=trg_corpus_path, batch_size=args.batch, iters_per_epoch=iters_per_epoch) trainers.append(trg2trg_trainer) if not args.disable_backtranslation: srcback2trg_trainer = Trainer(translator=src2trg_translator, optimizers=src2trg_optimizers, corpus=data.BacktranslatorCorpusReader(corpus=trg_corpus_path, translator=trg2src_translator), batch_size=args.batch, iters_per_epoch=iters_per_epoch) trainers.append(srcback2trg_trainer) elif args.corpus_mode == 'para': fsrc_content = open(src_corpus_path + '.content', encoding=args.encoding, errors='surrogateescape') fsrc_labels = open(src_corpus_path + '.labels', encoding=args.encoding, errors='surrogateescape') ftrg_content = open(trg_corpus_path + '.content', encoding=args.encoding, errors='surrogateescape') ftrg_labels = open(trg_corpus_path + '.labels', encoding=args.encoding, errors='surrogateescape') corpus = data.CorpusReader(fsrc_content, fsrc_labels, trg_word_file=ftrg_content, trg_field_file=ftrg_labels, max_sentence_length=args.max_sentence_length, cache_size=args.cache) src2trg_trainer = Trainer(translator=src2trg_translator, optimizers=src2trg_optimizers, corpus=corpus, batch_size=args.batch, iters_per_epoch=iters_per_epoch) trainers.append(src2trg_trainer) # Build validators if args.src_valid_corpus != '' and args.trg_valid_corpus != '': with ExitStack() as stack: src_content_vfile = stack.enter_context(open(args.src_valid_corpus + '.content', encoding=args.encoding, errors='surrogateescape')) src_labels_vfile = stack.enter_context(open(args.src_valid_corpus + '.labels', encoding=args.encoding, errors='surrogateescape')) trg_content_vfile = stack.enter_context(open(args.trg_valid_corpus + '.content', encoding=args.encoding, errors='surrogateescape')) trg_labels_vfile = stack.enter_context(open(args.trg_valid_corpus + '.labels', encoding=args.encoding, errors='surrogateescape')) src_content = src_content_vfile.readlines() src_labels = src_labels_vfile.readlines() trg_content = trg_content_vfile.readlines() trg_labels = trg_labels_vfile.readlines() assert len(src_content) == len(trg_content) == len(src_labels) == len(trg_labels), \ "Validation sizes do not match {} {} {} {}".format(len(src_content), len(trg_content), len(src_labels), len(trg_labels)) src_content = [list(map(int, line.strip().split())) for line in src_content] src_labels = [list(map(int, line.strip().split())) for line in src_labels] trg_content = [list(map(int, line.strip().split())) for line in trg_content] trg_labels = [list(map(int, line.strip().split())) for line in trg_labels] cache = [] for src_sent, src_label, trg_sent, trg_label in zip(src_content, src_labels, trg_content, trg_labels): if 0 < len(src_sent) <= args.max_sentence_length and 0 < len(trg_sent) <= args.max_sentence_length: cache.append((src_sent, src_label, trg_sent, trg_label)) src_content, src_labels, trg_content, trg_labels = zip(*cache) src2trg_validator = Validator(src2trg_translator, src_content, trg_content, src_labels, trg_labels) if args.corpus_mode == 'mono': src2src_validator = Validator(src2src_translator, src_content, src_content, src_labels, src_labels) trg2src_validator = Validator(trg2src_translator, trg_content, src_content, trg_labels, src_labels) trg2trg_validator = Validator(trg2trg_translator, trg_content, trg_content, trg_labels, trg_labels) del src_content del src_labels del trg_content del trg_labels else: src2src_validator = None src2trg_validator = None trg2src_validator = None trg2trg_validator = None # Build loggers loggers = [] semi_loggers = [] if args.corpus_mode == 'mono': if not args.disable_backtranslation: loggers.append(Logger('Source to target (backtranslation)', srcback2trg_trainer, src2trg_validator, None, args.encoding, short_name='src2trg_bt', train_writer=train_writer, valid_writer=valid_writer)) loggers.append(Logger('Target to source (backtranslation)', trgback2src_trainer, trg2src_validator, None, args.encoding, short_name='trg2src_bt', train_writer=train_writer, valid_writer=valid_writer)) loggers.append(Logger('Source to source', src2src_trainer, src2src_validator, None, args.encoding, short_name='src2src', train_writer=train_writer, valid_writer=valid_writer)) loggers.append(Logger('Target to target', trg2trg_trainer, trg2trg_validator, None, args.encoding, short_name='trg2trg', train_writer=train_writer, valid_writer=valid_writer)) elif args.corpus_mode == 'para': loggers.append(Logger('Source to target', src2trg_trainer, src2trg_validator, None, args.encoding, short_name='src2trg_para', train_writer=train_writer, valid_writer=valid_writer)) # Method to save models def save_models(name): # torch.save(src2src_translator, '{0}.{1}.src2src.pth'.format(args.save, name)) # torch.save(trg2trg_translator, '{0}.{1}.trg2trg.pth'.format(args.save, name)) torch.save(src2trg_translator, '{0}.{1}.src2trg.pth'.format(args.save, name)) if args.corpus_mode == 'mono': torch.save(trg2src_translator, '{0}.{1}.trg2src.pth'.format(args.save, name)) ref_string_path = args.trg_valid_corpus + '.str.content' if not os.path.isfile(ref_string_path): print("Creating ref file... [%s]" % (ref_string_path)) with ExitStack() as stack: fref_content = stack.enter_context( open(args.trg_valid_corpus + '.content', encoding=args.encoding, errors='surrogateescape')) fref_str_content = stack.enter_context( open(ref_string_path, mode='w', encoding=args.encoding, errors='surrogateescape')) for line in fref_content: ref_ids = [int(idstr) for idstr in line.strip().split()] ref_str = bpemb_en.decode_ids(ref_ids) fref_str_content.write(ref_str + '\n') print("Ref file created!") # Training for curr_iter in range(1, args.iterations + 1): print_dbg = (0 != args.dbg_print_interval) and (curr_iter % args.dbg_print_interval == 0) for trainer in trainers: trainer.step(print_dbg=print_dbg, include_field_loss=not args.disable_field_loss) if args.save is not None and args.save_interval > 0 and curr_iter % args.save_interval == 0: save_models('it{0}'.format(curr_iter)) if curr_iter % args.log_interval == 0: print() print('[{0}] TRAIN-STEP {1} x {2}'.format(args.save, curr_iter, args.batch)) for logger in loggers: logger.log(curr_iter) if curr_iter % iters_per_epoch == 0: save_models('it{0}'.format(curr_iter)) print() print('[{0}] VALID-STEP {1}'.format(args.save, curr_iter)) for logger in loggers: if logger.validator is not None: logger.validate(curr_iter) model = '{0}.{1}.src2trg.pth'.format(args.save, 'it{0}'.format(curr_iter)) bleu_thread = threading.Thread(target=calc_bleu, args=(model, args.save, args.src_valid_corpus, args.trg_valid_corpus + '.str.result', ref_string_path, bpemb_en, curr_iter, args.bleu_device, valid_writer)) bleu_thread.start() if args.cuda == args.bleu_device or args.bleu_device == 'cpu': bleu_thread.join() save_models('final') train_writer.close() valid_writer.close()
def train(args): config = SpeechDataset.default_config() config["wanted_words"] = "yes no marvin left right".split() config["data_folder"] = "data" config["cache_size"] = 32768 config["batch_size"] = 64 train_set, dev_set, test_set = SpeechDataset.splits(config) train_loader = data.DataLoader(train_set, batch_size=config["batch_size"], shuffle=True, drop_last=True, collate_fn=train_set.collate_fn) dev_loader = data.DataLoader(dev_set, batch_size=min(len(dev_set), 16), shuffle=True, collate_fn=dev_set.collate_fn) test_loader = data.DataLoader(test_set, batch_size=min(len(test_set), 16), shuffle=True, collate_fn=test_set.collate_fn) gen = Generator() disc = Discriminator() optim_gen = torch.optim.Adam(lr=1e-3, params=gen.parameters(), weight_decay=1e-3) optim_disc = torch.optim.Adam(lr=1e-3, params=disc.parameters(), weight_decay=1e-3) start_epoch = 0 if args.weights_path is not None: weights_dict = torch.load(args.weights_path) start_epoch = weights_dict['epoch'] + 1 gen.load_state_dict(weights_dict['gen_state_dict']) disc.load_state_dict(weights_dict['disc_state_dict']) optim_gen.load_state_dict(weights_dict['optim_gen_state_dict']) optim_disc.load_state_dict(weights_dict['optim_disc_state_dict']) else: gen_state_dict = gen.state_dict() for key in gen_state_dict.keys(): if gen_state_dict[key].dim() >= 2: torch.nn.init.xavier_normal_(gen_state_dict[key], 1e-2) else: if key[-4:] == 'bias': torch.nn.init.zeros_(gen_state_dict[key]) else: torch.nn.init.ones_(gen_state_dict[key]) gen.load_state_dict(gen_state_dict) model_config = dict(dropout_prob=0.5, height=128, width=40, n_labels=7, n_feature_maps1=64, n_feature_maps2=64, conv1_size=(20, 8), conv2_size=(10, 4), conv1_pool=(2, 2), conv1_stride=(1, 1), conv2_stride=(1, 1), conv2_pool=(1, 1), tf_variant=True) kws_model = SpeechModel(model_config) kws_model.load(args.kws_model_path) dct_filters = torch.from_numpy( np.load('dct_filter.npy').astype(np.float32)) if torch.cuda.is_available(): dct_filters = dct_filters.cuda() num_epochs = args.num_epochs c = args.c alpha = args.alpha beta = args.beta mean = torch.load('spectrogram_mean.pkl') std = torch.load('spectrogram_std.pkl') for epoch in range(start_epoch, num_epochs): gen.train() disc.train() for step, sample in enumerate(train_loader): inp, labels = sample if torch.cuda.is_available(): inp = inp.cuda() labels = labels.cuda() gen_noise = gen(((inp - mean) / std).permute(0, 2, 1)) gen_noise[:, :, 101:] = torch.zeros(gen_noise.shape[0], 128, 27) noise_score = disc((((inp - mean) / std).permute(0, 2, 1) + gen_noise).unsqueeze(1)) inp_score = disc(((inp - mean) / std).permute(0, 2, 1).unsqueeze(1)) kws_inp = inp + gen_noise.permute(0, 2, 1) * std kws_inp = kws_inp[:, :101, :].reshape(-1, 128, 101) kws_inp_clone = kws_inp.clone() kws_inp_clone[kws_inp_clone > 0] = torch.log(kws_inp[kws_inp > 0]) mfcc_feat = torch.matmul(dct_filters, kws_inp_clone).permute(0, 2, 1) mfcc_feat = F.pad(mfcc_feat, (0, 0, 0, 128 - mfcc_feat.shape[1])) kws_out = nn.Softmax(dim=1)(kws_model(mfcc_feat)) # Optimise Generator optim_gen.zero_grad() loss_gen = -noise_score.log().mean() loss_adv = kws_out.gather(1, labels.view(-1, 1)).mean() loss_hinge = nn.ReLU()((gen_noise * std).norm(p=2, dim=(1, 2)) - c).mean() loss_gen_total = loss_gen + alpha * loss_hinge + beta * loss_adv loss_gen_total.backward(retain_graph=True) optim_gen.step() print("Epoch : ", epoch, " , Step : ", step) print("Generator Loss", loss_gen) print("Loss Adv", loss_adv) print("Loss Hinge", loss_hinge) # Optimise Discriminator optim_disc.zero_grad() loss_disc = -(inp_score.log().mean() + (1 - noise_score).log().mean()) loss_disc.backward() optim_disc.step() print("Discriminator Loss", loss_disc) print("======================================") weights_dict = {} weights_dict['epoch'] = epoch weights_dict['gen_state_dict'] = gen.state_dict() weights_dict['disc_state_dict'] = disc.state_dict() weights_dict['optim_gen_state_dict'] = optim_gen.state_dict() weights_dict['optim_disc_state_dict'] = optim_disc.state_dict() weights_dict_path = args.save_folder_path + '/epoch{}.weights'.format( epoch) torch.save(weights_dict, weights_dict_path)
def main(): args = parse_args() device = torch.device("cuda") generator = Generator.from_file(args.generator_path).to(device) if args.freeze: for name, param in generator.named_parameters(): if ("shared" not in name) and ("decoder.block.5" not in name): param.requires_grad = False generator.eval() discriminator = Discriminator(tokenizer=generator.tokenizer).to(device) if args.freeze: for name, param in discriminator.named_parameters(): if ("shared" not in name) and ("decoder.block.5" not in name): param.requires_grad = False train_dataset = DailyDialogueDataset( path_join(args.dataset_path, "train/dialogues_train.txt"), tokenizer=generator.tokenizer, ) valid_dataset = DailyDialogueDataset( path_join(args.dataset_path, "validation/dialogues_validation.txt"), tokenizer=generator.tokenizer, ) print(len(train_dataset), len(valid_dataset)) optimizer = AdamW(discriminator.parameters(), lr=args.lr) best_loss = np.float("inf") for epoch in tqdm(range(args.num_epochs)): train_loss, valid_loss = [], [] rewards_real, rewards_fake, accuracy = [], [], [] discriminator.train() for ind in np.random.permutation(len(train_dataset)): optimizer.zero_grad() context, real_reply = train_dataset.sample_dialouge(ind) context, real_reply = ( context.to(device), real_reply.to(device), ) fake_reply = generator.generate(context, do_sample=True) if args.partial: split_real = random.randint(1, real_reply.size(1)) real_reply = real_reply[:, :split_real] split_fake = random.randint(1, fake_reply.size(1) - 1) fake_reply = fake_reply[:, :split_fake] loss, _, _ = discriminator.get_loss(context, real_reply, fake_reply) loss.backward() optimizer.step() train_loss.append(loss.item()) discriminator.eval() for ind in range(len(valid_dataset)): context, real_reply = valid_dataset[ind] context, real_reply = ( context.to(device), real_reply.to(device), ) fake_reply = generator.generate(context, do_sample=True) if args.partial: split_real = random.randint(1, real_reply.size(1)) real_reply = real_reply[:, :split_real] split_fake = random.randint(1, fake_reply.size(1) - 1) fake_reply = fake_reply[:, :split_fake] with torch.no_grad(): loss, reward_real, reward_fake = discriminator.get_loss( context, real_reply, fake_reply ) valid_loss.append(loss.item()) rewards_real.append(reward_real) rewards_fake.append(reward_fake) accuracy.extend([reward_real > 0.5, reward_fake < 0.5]) train_loss, valid_loss = np.mean(train_loss), np.mean(valid_loss) print( f"Epoch {epoch + 1}, Train Loss: {train_loss:.2f}, Valid Loss: {valid_loss:.2f}, Reward real: {np.mean(rewards_real):.2f}, Reward fake: {np.mean(rewards_fake):.2f}, Accuracy: {np.mean(accuracy):.2f}" ) if valid_loss < best_loss: best_loss = valid_loss torch.save(discriminator.state_dict(), args.output_path)