Esempio n. 1
0
    def __init__(self, _device):
        self.device = _device
        self.batch_size = 64
        self.resolution = 28
        self.d_criterion = None
        self.d_optimizer = None
        self.g_criterion = None
        self.g_optimizer = None

        self.discriminator = Discriminator(num_layers=5,
                                           activations=["relu", "relu", "relu",
                                                        "sigmoid"],
                                           device=_device,
                                           num_nodes=[1, 64, 128, 64, 1],
                                           kernels=[5, 5, 3],
                                           strides=[2, 2, 2],
                                           dropouts=[.25, .25, 0],
                                           batch_size=64)

        # pass one image through the network so as to initialize the output
        # layer
        self.discriminator(torch.rand(
            size=[self.batch_size, 1, self.resolution, self.resolution]))

        self.generator = Generator(num_layers=6,
                                   activations=["relu", "relu", "relu", "relu",
                                                "tanh"],
                                   num_nodes=[1, 64, 128, 64, 64, 1],
                                   kernels=[3, 3, 3, 3],
                                   strides=[1, 1, 1, 1],
                                   batch_norms=[1, 1, 1, 0],
                                   upsamples=[1, 1, 0, 0],
                                   dropouts=[.25, .25, 0])
Esempio n. 2
0
    def __init__(self, verbosity=True, latent_dim=100):
        img_shape = (128, 128, 3)

        the_disc = Discriminator()
        the_gen = Generator()
        self.discriminator = the_disc.define_discriminator(
            verb=verbosity, sample_shape=img_shape)
        self.generator = the_gen.define_generator(verb=verbosity,
                                                  sample_shape=img_shape,
                                                  latent_dim=latent_dim)
        self.discriminator.trainable = False

        optimizer = Adam(0.0002, 0.5)
        self.discriminator.compile(
            loss=['binary_crossentropy', 'categorical_crossentropy'],
            loss_weights=[0.5, 0.5],
            optimizer=optimizer,
            metrics=['accuracy'])

        noise = Input(shape=(latent_dim, ))
        img = self.generator(noise)

        valid, _ = self.discriminator(img)

        self.combined = Model(noise, valid)
        self.combined.compile(loss=['binary_crossentropy'],
                              optimizer=optimizer)
Esempio n. 3
0
    def __init__(self, damsm, device=DEVICE):
        self.gen = Generator(device)
        self.disc = Discriminator(device)
        self.damsm = damsm.to(device)
        self.damsm.txt_enc.eval(), self.damsm.img_enc.eval()
        freeze_params_(self.damsm.txt_enc), freeze_params_(self.damsm.img_enc)

        self.device = device
        self.gen.apply(init_weights), self.disc.apply(init_weights)

        self.gen_optimizer = torch.optim.Adam(self.gen.parameters(),
                                              lr=GENERATOR_LR,
                                              betas=(0.5, 0.999))

        self.discriminators = [self.disc.d64, self.disc.d128, self.disc.d256]
        self.disc_optimizers = [
            torch.optim.Adam(d.parameters(),
                             lr=DISCRIMINATOR_LR,
                             betas=(0.5, 0.999)) for d in self.discriminators
        ]
Esempio n. 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda',
                        default=False,
                        action='store_true',
                        help='Enable CUDA')
    args = parser.parse_args()
    use_cuda = True if args.cuda and torch.cuda.is_available() else False

    random.seed(SEED)
    np.random.seed(SEED)

    netG = Generator(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, G_LR, use_cuda)
    netD = Discriminator(VOCAB_SIZE, D_EMB_SIZE, D_NUM_CLASSES, D_FILTER_SIZES,
                         D_NUM_FILTERS, DROPOUT, D_LR, D_L2_REG, use_cuda)
    oracle = Oracle(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda)

    # generating synthetic data
    # print('Generating data...')
    # generate_samples(oracle, BATCH_SIZE, GENERATED_NUM, REAL_FILE)

    # pretrain generator
    gen_set = GeneratorDataset(REAL_FILE)
    genloader = DataLoader(dataset=gen_set,
                           batch_size=BATCH_SIZE,
                           shuffle=True)

    print('\nPretraining generator...\n')
    for epoch in range(PRE_G_EPOCHS):
        loss = netG.pretrain(genloader)
        print('Epoch {} pretrain generator training loss: {}'.format(
            epoch, loss))

        generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        val_set = GeneratorDataset(EVAL_FILE)
        valloader = DataLoader(dataset=val_set,
                               batch_size=BATCH_SIZE,
                               shuffle=True)
        loss = oracle.val(valloader)
        print('Epoch {} pretrain generator val loss: {}'.format(
            epoch + 1, loss))

    # pretrain discriminator
    print('\nPretraining discriminator...\n')
    for epoch in range(D_STEPS):
        generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE)
        dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE)
        disloader = DataLoader(dataset=dis_set,
                               batch_size=BATCH_SIZE,
                               shuffle=True)

        for _ in range(K_STEPS):
            loss = netD.dtrain(disloader)
            print('Epoch {} pretrain discriminator training loss: {}'.format(
                epoch + 1, loss))

    # adversarial training
    rollout = Rollout(netG,
                      update_rate=ROLLOUT_UPDATE_RATE,
                      rollout_num=ROLLOUT_NUM)
    print('\n#####################################################')
    print('Adversarial training...\n')

    for epoch in range(TOTAL_EPOCHS):
        for _ in range(G_STEPS):
            netG.pgtrain(BATCH_SIZE, SEQUENCE_LEN, rollout, netD)

        for d_step in range(D_STEPS):
            # train discriminator
            generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE)
            dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE)
            disloader = DataLoader(dataset=dis_set,
                                   batch_size=BATCH_SIZE,
                                   shuffle=True)

            for k_step in range(K_STEPS):
                loss = netD.dtrain(disloader)
                print(
                    'D_step {}, K-step {} adversarial discriminator training loss: {}'
                    .format(d_step + 1, k_step + 1, loss))
        rollout.update_params()

        generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        val_set = GeneratorDataset(EVAL_FILE)
        valloader = DataLoader(dataset=val_set,
                               batch_size=BATCH_SIZE,
                               shuffle=True)
        loss = oracle.val(valloader)
        print('Epoch {} adversarial generator val loss: {}'.format(
            epoch + 1, loss))
Esempio n. 5
0
class AttnGAN:
    def __init__(self, damsm, device=DEVICE):
        self.gen = Generator(device)
        self.disc = Discriminator(device)
        self.damsm = damsm.to(device)
        self.damsm.txt_enc.eval(), self.damsm.img_enc.eval()
        freeze_params_(self.damsm.txt_enc), freeze_params_(self.damsm.img_enc)

        self.device = device
        self.gen.apply(init_weights), self.disc.apply(init_weights)

        self.gen_optimizer = torch.optim.Adam(self.gen.parameters(),
                                              lr=GENERATOR_LR,
                                              betas=(0.5, 0.999))

        self.discriminators = [self.disc.d64, self.disc.d128, self.disc.d256]
        self.disc_optimizers = [
            torch.optim.Adam(d.parameters(),
                             lr=DISCRIMINATOR_LR,
                             betas=(0.5, 0.999)) for d in self.discriminators
        ]

    def train(self,
              dataset,
              epoch,
              batch_size=GAN_BATCH,
              test_sample_every=5,
              hist_avg=False,
              evaluator=None):
        start_time = time.strftime("%Y-%m-%d-%H-%M", time.gmtime())
        os.makedirs(f'{OUT_DIR}/{start_time}')

        if hist_avg:
            avg_g_params = deepcopy(list(p.data
                                         for p in self.gen.parameters()))

        loader_config = {
            'batch_size': batch_size,
            'shuffle': True,
            'drop_last': True,
            'collate_fn': dataset.collate_fn
        }
        train_loader = DataLoader(dataset.train, **loader_config)

        metrics = {
            'IS': [],
            'FID': [],
            'loss': {
                'g': [],
                'd': []
            },
            'accuracy': {
                'real': [],
                'fake': [],
                'mismatched': [],
                'unconditional_real': [],
                'unconditional_fake': []
            }
        }

        if evaluator is not None:
            evaluator = evaluator(dataset, self.damsm.img_enc.inception_model,
                                  batch_size, self.device)

        noise = torch.FloatTensor(batch_size, D_Z).to(self.device)
        gen_updates = 0

        self.disc.train()
        for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True):
            self.gen.train(), self.disc.train()
            g_loss = 0
            w_loss = 0
            s_loss = 0
            kl_loss = 0
            g_stage_loss = np.zeros(3, dtype=float)
            d_loss = np.zeros(3, dtype=float)
            real_acc = np.zeros(3, dtype=float)
            fake_acc = np.zeros(3, dtype=float)
            mismatched_acc = np.zeros(3, dtype=float)
            uncond_real_acc = np.zeros(3, dtype=float)
            uncond_fake_acc = np.zeros(3, dtype=float)
            disc_skips = np.zeros(3, dtype=int)

            train_pbar = tqdm(train_loader,
                              desc='Training',
                              leave=False,
                              dynamic_ncols=True)
            for batch in train_pbar:
                real_imgs = [batch['img64'], batch['img128'], batch['img256']]

                with torch.no_grad():
                    word_embs, sent_embs = self.damsm.txt_enc(batch['caption'])
                attn_mask = torch.tensor(batch['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]

                # Generate images
                noise.data.normal_(0, 1)
                generated, att, mu, logvar = self.gen(noise, sent_embs,
                                                      word_embs, attn_mask)

                # Discriminator loss (with label smoothing)
                batch_d_loss, batch_real_acc, batch_fake_acc, batch_mismatched_acc, batch_uncond_real_acc, batch_uncond_fake_acc, batch_disc_skips = self.discriminator_step(
                    real_imgs, generated, sent_embs, 0.1)

                d_grad_norm = [grad_norm(d) for d in self.discriminators]

                d_loss += batch_d_loss
                real_acc += batch_real_acc
                fake_acc += batch_fake_acc
                mismatched_acc += batch_mismatched_acc
                uncond_real_acc += batch_uncond_real_acc
                uncond_fake_acc += batch_uncond_fake_acc
                disc_skips += batch_disc_skips

                # Generator loss
                batch_g_losses = self.generator_step(generated, word_embs,
                                                     sent_embs, mu, logvar,
                                                     batch['label'])
                g_total, batch_g_stage_loss, batch_w_loss, batch_s_loss, batch_kl_loss = batch_g_losses
                g_stage_loss += batch_g_stage_loss
                w_loss += batch_w_loss
                s_loss += batch_s_loss
                kl_loss += batch_kl_loss
                gen_updates += 1

                avg_g_loss = g_total.item() / batch_size
                g_loss += avg_g_loss

                if hist_avg:
                    for p, avg_p in zip(self.gen.parameters(), avg_g_params):
                        avg_p.mul_(0.999).add_(0.001, p.data)

                    if gen_updates % 1000 == 0:
                        tqdm.write(
                            'Replacing generator weights with their moving average'
                        )
                        for p, avg_p in zip(self.gen.parameters(),
                                            avg_g_params):
                            p.data.copy_(avg_p)

                train_pbar.set_description(
                    f'Training (G: {grad_norm(self.gen):.2f}  '
                    f'D64: {d_grad_norm[0]:.2f}  '
                    f'D128: {d_grad_norm[1]:.2f}  '
                    f'D256: {d_grad_norm[2]:.2f})')

            batches = len(train_loader)

            g_loss /= batches
            g_stage_loss /= batches
            w_loss /= batches
            s_loss /= batches
            kl_loss /= batches
            d_loss /= batches
            real_acc /= batches
            fake_acc /= batches
            mismatched_acc /= batches
            uncond_real_acc /= batches
            uncond_fake_acc /= batches

            metrics['loss']['g'].append(g_loss)
            metrics['loss']['d'].append(d_loss)
            metrics['accuracy']['real'].append(real_acc)
            metrics['accuracy']['fake'].append(fake_acc)
            metrics['accuracy']['mismatched'].append(mismatched_acc)
            metrics['accuracy']['unconditional_real'].append(uncond_real_acc)
            metrics['accuracy']['unconditional_fake'].append(uncond_fake_acc)

            sep = '_' * 10
            tqdm.write(f'{sep}Epoch {e}{sep}')

            if e % test_sample_every == 0:
                self.gen.eval()
                generated_samples = [
                    resolution.unsqueeze(0)
                    for resolution in self.sample_test_set(dataset)
                ]
                self._save_generated(generated_samples, e,
                                     f'{OUT_DIR}/{start_time}')

                if evaluator is not None:
                    scores = evaluator.evaluate(self)
                    for k, v in scores.items():
                        metrics[k].append(v)
                        tqdm.write(f'{k}: {v:.2f}')

            tqdm.write(
                f'Generator avg loss: total({g_loss:.3f})  '
                f'stage0({g_stage_loss[0]:.3f})  stage1({g_stage_loss[1]:.3f})  stage2({g_stage_loss[2]:.3f})  '
                f'w({w_loss:.3f})  s({s_loss:.3f})  kl({kl_loss:.3f})')

            for i, _ in enumerate(self.discriminators):
                tqdm.write(f'Discriminator{i} avg: '
                           f'loss({d_loss[i]:.3f})  '
                           f'r-acc({real_acc[i]:.3f})  '
                           f'f-acc({fake_acc[i]:.3f})  '
                           f'm-acc({mismatched_acc[i]:.3f})  '
                           f'ur-acc({uncond_real_acc[i]:.3f})  '
                           f'uf-acc({uncond_fake_acc[i]:.3f})  '
                           f'skips({disc_skips[i]})')

        return metrics

    def sample_test_set(self,
                        dataset,
                        nb_samples=8,
                        nb_captions=2,
                        noise_variations=2):
        subset = dataset.test
        sample_indices = np.random.choice(len(subset),
                                          nb_samples,
                                          replace=False)
        cap_indices = np.random.choice(10, nb_captions, replace=False)
        texts = [
            subset.data[f'caption_{cap_idx}'].iloc[sample_idx]
            for sample_idx in sample_indices for cap_idx in cap_indices
        ]

        generated_samples = [
            self.generate_from_text(texts, dataset)
            for _ in range(noise_variations)
        ]
        combined_img64 = torch.FloatTensor()
        combined_img128 = torch.FloatTensor()
        combined_img256 = torch.FloatTensor()

        for noise_variant in generated_samples:
            noise_var_img64 = torch.FloatTensor()
            noise_var_img128 = torch.FloatTensor()
            noise_var_img256 = torch.FloatTensor()
            for i in range(nb_samples):
                # rows: samples, columns: captions * noise variants
                row64 = torch.cat([
                    noise_variant[0][i * nb_captions + j]
                    for j in range(nb_captions)
                ],
                                  dim=-1).cpu()
                row128 = torch.cat([
                    noise_variant[1][i * nb_captions + j]
                    for j in range(nb_captions)
                ],
                                   dim=-1).cpu()
                row256 = torch.cat([
                    noise_variant[2][i * nb_captions + j]
                    for j in range(nb_captions)
                ],
                                   dim=-1).cpu()
                noise_var_img64 = torch.cat([noise_var_img64, row64], dim=-2)
                noise_var_img128 = torch.cat([noise_var_img128, row128],
                                             dim=-2)
                noise_var_img256 = torch.cat([noise_var_img256, row256],
                                             dim=-2)
            combined_img64 = torch.cat([combined_img64, noise_var_img64],
                                       dim=-1)
            combined_img128 = torch.cat([combined_img128, noise_var_img128],
                                        dim=-1)
            combined_img256 = torch.cat([combined_img256, noise_var_img256],
                                        dim=-1)

        return combined_img64, combined_img128, combined_img256

    @staticmethod
    def KL_loss(mu, logvar):
        loss = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
        loss = torch.mean(loss).mul_(-0.5)
        return loss

    def generator_step(self, generated_imgs, word_embs, sent_embs, mu, logvar,
                       class_labels):
        self.gen.zero_grad()
        avg_stage_g_loss = [0, 0, 0]

        local_features, global_features = self.damsm.img_enc(
            generated_imgs[-1])
        batch_size = sent_embs.size(0)
        match_labels = torch.LongTensor(range(batch_size)).to(self.device)

        w1_loss, w2_loss, _ = self.damsm.words_loss(local_features, word_embs,
                                                    class_labels, match_labels)
        w_loss = (w1_loss + w2_loss) * LAMBDA

        s1_loss, s2_loss = self.damsm.sentence_loss(global_features, sent_embs,
                                                    class_labels, match_labels)
        s_loss = (s1_loss + s2_loss) * LAMBDA

        kl_loss = self.KL_loss(mu, logvar)

        g_total = w_loss + s_loss + kl_loss

        for i, d in enumerate(self.discriminators):
            features = d(generated_imgs[i])
            fake_logits = d.logit(features, sent_embs)

            real_labels = torch.ones_like(fake_logits).to(self.device)

            disc_error = F.binary_cross_entropy_with_logits(
                fake_logits, real_labels)

            uncond_fake_logits = d.logit(features)
            uncond_disc_error = F.binary_cross_entropy_with_logits(
                uncond_fake_logits, real_labels)

            stage_loss = disc_error + uncond_disc_error
            avg_stage_g_loss[i] = stage_loss.item() / batch_size
            g_total += stage_loss

        g_total.backward()
        self.gen_optimizer.step()

        return g_total, avg_stage_g_loss, w_loss.item(
        ) / batch_size, s_loss.item() / batch_size, kl_loss.item()

    def discriminator_step(self,
                           real_imgs,
                           generated_imgs,
                           sent_embs,
                           label_smoothing,
                           skip_acc_threshold=0.9,
                           p_flip=0.05,
                           halting=False):
        self.disc.zero_grad()
        batch_size = sent_embs.size(0)

        avg_d_loss = [0, 0, 0]
        real_accuracy = [0, 0, 0]
        fake_accuracy = [0, 0, 0]
        mismatched_accuracy = [0, 0, 0]
        uncond_real_accuracy = [0, 0, 0]
        uncond_fake_accuracy = [0, 0, 0]
        skipped = [0, 0, 0]

        for i, d in enumerate(self.discriminators):
            real_features = d(real_imgs[i].to(self.device))
            fake_features = d(generated_imgs[i].detach())

            real_logits = d.logit(real_features, sent_embs)

            real_labels = torch.full_like(real_logits,
                                          1 - label_smoothing).to(self.device)
            fake_labels = torch.zeros_like(real_logits,
                                           dtype=torch.float).to(self.device)

            # flip_mask = torch.Tensor(real_labels.size()).bernoulli_(p_flip).type(torch.bool)
            # real_labels[flip_mask], fake_labels[flip_mask] = fake_labels[flip_mask], real_labels[flip_mask]

            real_error = F.binary_cross_entropy_with_logits(
                real_logits, real_labels)
            # Real images should be classified as real
            real_accuracy[i] = (real_logits >=
                                0).sum().item() / real_logits.numel()

            fake_logits = d.logit(fake_features, sent_embs)
            fake_error = F.binary_cross_entropy_with_logits(
                fake_logits, fake_labels)
            # Generated images should be classified as fake
            fake_accuracy[i] = (fake_logits <
                                0).sum().item() / fake_logits.numel()

            mismatched_logits = d.logit(real_features,
                                        rotate_tensor(sent_embs, 1))
            mismatched_error = F.binary_cross_entropy_with_logits(
                mismatched_logits, fake_labels)
            # Images with mismatched descriptions should be classified as fake
            mismatched_accuracy[i] = (mismatched_logits < 0).sum().item(
            ) / mismatched_logits.numel()

            uncond_real_logits = d.logit(real_features)
            uncond_real_error = F.binary_cross_entropy_with_logits(
                uncond_real_logits, real_labels)
            uncond_real_accuracy[i] = (uncond_real_logits >= 0).sum().item(
            ) / uncond_real_logits.numel()

            uncond_fake_logits = d.logit(fake_features)
            uncond_fake_error = F.binary_cross_entropy_with_logits(
                uncond_fake_logits, fake_labels)
            uncond_fake_accuracy[i] = (uncond_fake_logits < 0).sum().item(
            ) / uncond_fake_logits.numel()

            error = (real_error + uncond_real_error) / 2 + (
                fake_error + uncond_fake_error + mismatched_error) / 3

            if not halting or fake_accuracy[i] + real_accuracy[
                    i] < skip_acc_threshold * 2:
                error.backward()
                self.disc_optimizers[i].step()
            else:
                skipped[i] = 1

            avg_d_loss[i] = error.item() / batch_size

        return avg_d_loss, real_accuracy, fake_accuracy, mismatched_accuracy, uncond_real_accuracy, uncond_fake_accuracy, skipped

    def generate_from_text(self, texts, dataset, noise=None):
        encoded = [dataset.train.encode_text(t) for t in texts]
        generated = self.generate_from_encoded_text(encoded, dataset, noise)
        return generated

    def generate_from_encoded_text(self, encoded, dataset, noise=None):
        with torch.no_grad():
            w_emb, s_emb = self.damsm.txt_enc(encoded)
            attn_mask = torch.tensor(encoded).to(
                self.device) == dataset.vocab[END_TOKEN]
            if noise is None:
                noise = torch.FloatTensor(len(encoded), D_Z).to(self.device)
                noise.data.normal_(0, 1)
            generated, att, mu, logvar = self.gen(noise, s_emb, w_emb,
                                                  attn_mask)
        return generated

    def _save_generated(self, generated, epoch, out_dir=OUT_DIR):
        nb_samples = generated[0].size(0)
        save_dir = f'{out_dir}/epoch_{epoch:03}'
        os.makedirs(save_dir)

        for i in range(nb_samples):
            save_image(generated[0][i],
                       f'{save_dir}/{i}_64.jpg',
                       normalize=True,
                       range=(-1, 1))
            save_image(generated[1][i],
                       f'{save_dir}/{i}_128.jpg',
                       normalize=True,
                       range=(-1, 1))
            save_image(generated[2][i],
                       f'{save_dir}/{i}_256.jpg',
                       normalize=True,
                       range=(-1, 1))

    def save(self, name, save_dir=GAN_MODEL_DIR, metrics=None):
        os.makedirs(save_dir, exist_ok=True)
        torch.save(self.gen.state_dict(), f'{save_dir}/{name}_generator.pt')
        torch.save(self.disc.state_dict(),
                   f'{save_dir}/{name}_discriminator.pt')
        if metrics is not None:
            with open(f'{save_dir}/{name}_metrics.json', 'w') as f:
                metrics = pre_json_metrics(metrics)
                json.dump(metrics, f)

    def load_(self, name, load_dir=GAN_MODEL_DIR):
        self.gen.load_state_dict(torch.load(f'{load_dir}/{name}_generator.pt'))
        self.disc.load_state_dict(
            torch.load(f'{load_dir}/{name}_discriminator.pt'))
        self.gen.eval(), self.disc.eval()

    @staticmethod
    def load(name, damsm, load_dir=GAN_MODEL_DIR, device=DEVICE):
        attngan = AttnGAN(damsm, device=device)
        attngan.load_(name, load_dir)
        return attngan

    def validate_test_set(self,
                          dataset,
                          batch_size=GAN_BATCH,
                          save_dir=f'{OUT_DIR}/test_samples'):
        os.makedirs(save_dir, exist_ok=True)

        loader = DataLoader(dataset.test,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=False,
                            collate_fn=dataset.collate_fn)
        loader = tqdm(loader,
                      dynamic_ncols=True,
                      leave=True,
                      desc='Generating samples for test set')

        self.gen.eval()
        with torch.no_grad():
            i = 0
            for batch in loader:
                word_embs, sent_embs = self.damsm.txt_enc(batch['caption'])
                attn_mask = torch.tensor(batch['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]
                noise = torch.FloatTensor(len(batch['caption']),
                                          D_Z).to(self.device)
                noise.data.normal_(0, 1)
                generated, att, mu, logvar = self.gen(noise, sent_embs,
                                                      word_embs, attn_mask)

                for img in generated[-1]:
                    save_image(img,
                               f'{save_dir}/{i}.jpg',
                               normalize=True,
                               range=(-1, 1))
                    i += 1

    def get_d_score(self, imgs, sent_embs):
        d = self.disc.d256
        features = d(imgs.to(self.device))
        scores = d.logit(features, sent_embs)
        return scores

    def accept_prob(self, score1, score2):
        return min(1, (1 / score1 - 1) / (1 / score2 - 1))

    def d_scores_test(self, dataset):
        with torch.no_grad():
            loader = DataLoader(dataset.test,
                                batch_size=20,
                                shuffle=False,
                                drop_last=False,
                                collate_fn=dataset.collate_fn)
            scores = []
            d = self.disc.d256
            for b in loader:
                img = b['img256'].to(self.device)
                f = d(img)
                l = d.logit(f)
                scores.append(torch.sigmoid(l))
            scores = [x.item() for s in scores for x in s.reshape(-1)]
        return scores

    def z_test(self, scores, labels):
        labels = np.array(labels)
        scores = np.array(scores)
        num = np.sum(labels - scores)
        denom = np.sqrt(np.sum(scores * (1 - scores)))
        return num / denom

    def d_scores_gen(self, dataset):
        with torch.no_grad():
            loader = DataLoader(dataset.test,
                                batch_size=20,
                                shuffle=False,
                                drop_last=False,
                                collate_fn=dataset.collate_fn)
            scores = []
            d = self.disc.d256
            for b in loader:
                noise = torch.FloatTensor(len(b['caption']),
                                          D_Z).to(self.device)
                noise.data.normal_(0, 1)
                word_embs, sent_embs = self.damsm.txt_enc(b['caption'])
                attn_mask = torch.tensor(b['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]
                generated, _, _, _ = self.gen(noise, sent_embs, word_embs,
                                              attn_mask)

                f = d(generated[-1])
                l = d.logit(f)
                scores.append(torch.sigmoid(l))
            scores = [x.item() for s in scores for x in s.reshape(-1)]
        return scores

    def mh_sample(self, dataset, k, save_dir='test_samples', batch=GAN_BATCH):
        evaluator = IS_FID_Evaluator(dataset,
                                     self.damsm.img_enc.inception_model, batch,
                                     self.device)
        # self.disc.d256.train()
        with torch.no_grad():
            l = len(dataset.test)
            score_real = self.d_scores_test(dataset)
            score_gen = self.d_scores_gen(dataset)
            print(np.mean(score_real))
            print(np.mean(score_gen))
            portion = -l // 5
            score_test = score_real[:portion] + score_gen[:portion]
            label_test = [1] * (len(score_test) //
                                2) + [0] * (len(score_test) // 2)

            print('Z test before calibration: ',
                  self.z_test(torch.tensor(score_test), label_test))

            score_real_calib = score_real[portion:]
            score_gen_calib = score_gen[portion:]
            # score_calib = score_real_calib + score_gen_calib
            score_calib = score_gen_calib + score_real_calib
            label_calib = len(score_gen_calib) * [0] + len(
                score_real_calib) * [1]

            cal_clf = LogisticRegression()
            cal_clf.fit(np.array(score_calib).reshape(-1, 1), label_calib)

            score_pred = cal_clf.predict_proba(
                np.array(score_test).reshape(-1, 1))[:, 1]
            print('Score pred avg: ', np.mean(score_pred))
            test_pred = cal_clf.predict(np.array(score_test).reshape(-1, 1))

            print('Z test after calibration: ',
                  self.z_test(score_pred, label_test))
            print('Accuracy: ',
                  sum((test_pred == label_test)) / len(test_pred))

            os.makedirs(save_dir, exist_ok=True)
            loader = DataLoader(dataset.test,
                                batch_size=1,
                                shuffle=False,
                                drop_last=False,
                                collate_fn=dataset.collate_fn)
            loader = tqdm(loader,
                          dynamic_ncols=True,
                          leave=True,
                          desc='Generating samples for test set')

            imgs = []
            true_probs = 0
            noaccept = 0
            for i, sample in enumerate(loader):
                if i > l - (l // 10):
                    continue
                word_embs, sent_embs = self.damsm.txt_enc(sample['caption'])
                attn_mask = torch.tensor(sample['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]

                img_chain = []
                while len(img_chain) < k:
                    noise = torch.FloatTensor(batch, D_Z).to(self.device)
                    noise.data.normal_(0, 1)
                    generated, _, _, _ = self.gen(
                        noise, sent_embs.repeat(batch, 1),
                        word_embs.repeat(batch, 1, 1),
                        attn_mask.repeat(batch, 1))

                    for img in generated[-1]:
                        img_chain.append(img)

                img_chain = img_chain[:k]
                img_chain = torch.stack(img_chain).to(self.device)

                score_chain = []
                d_loader = DataLoader(img_chain,
                                      batch_size=batch,
                                      shuffle=False,
                                      drop_last=False)
                for d_batch in d_loader:
                    scores = self.get_d_score(d_batch,
                                              sent_embs.repeat(batch, 1))
                    scores = scores.reshape(-1, 1).cpu().numpy()
                    scores = cal_clf.predict_proba(scores)[:, 1]
                    for s in scores:
                        score_chain.append(s)
                chosen = 0
                for j, s in enumerate(score_chain[1:], 1):
                    alpha = self.accept_prob(score_chain[chosen], s)
                if np.random.rand() < alpha:
                    chosen = j

                if chosen == 0:
                    imgs.append(img_chain[torch.tensor(
                        score_chain[1:]).argmax()].cpu())
                    noaccept += 1
                else:
                    imgs.append(img_chain[chosen].cpu())
                true_probs += score_chain[0]

            print(noaccept)
            print(true_probs / len(dataset.test))
            mu_real, sig_real = evaluator.mu_real, evaluator.sig_real
            mu_fake, sig_fake = activation_statistics(
                self.damsm.img_enc.inception_model, imgs)
            print('FID: ', frechet_dist(mu_real, sig_real, mu_fake, sig_fake))
            return imgs
Esempio n. 6
0
def main():
    args = parse_args()
    device = torch.device("cuda")

    generator = Generator.from_file(args.generator_path).to(device)
    generator.eval()
    discriminator = Discriminator(tokenizer=generator.tokenizer).to(device)

    train_dataset = DailyDialogueDataset(
        path_join(args.dataset_path, "train/dialogues_train.txt"),
        tokenizer=generator.tokenizer,
    )
    valid_dataset = DailyDialogueDataset(
        path_join(args.dataset_path, "validation/dialogues_validation.txt"),
        tokenizer=generator.tokenizer,
    )

    print(len(train_dataset), len(valid_dataset))

    optimizer = AdamW(discriminator.parameters(), lr=args.lr)

    for epoch in tqdm(range(args.num_epochs)):
        train_loss, valid_loss = [], []
        rewards_real, rewards_fake, accuracy = [], [], []
        discriminator.train()
        for ind in np.random.permutation(len(train_dataset)):
            optimizer.zero_grad()
            context, real_reply = train_dataset.sample_dialouge(ind)
            context, real_reply = (
                context.to(device),
                real_reply.to(device),
            )
            fake_reply = generator.generate(context, do_sample=True)

            loss, _, _ = discriminator.get_loss(context, real_reply,
                                                fake_reply)
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())

        discriminator.eval()
        real_replies, fake_replies = [], []
        for ind in range(len(valid_dataset)):
            context, real_reply = valid_dataset[ind]
            context, real_reply = (
                context.to(device),
                real_reply.to(device),
            )
            fake_reply = generator.generate(context, do_sample=True)

            with torch.no_grad():
                loss, reward_real, reward_fake = discriminator.get_loss(
                    context, real_reply, fake_reply)
            valid_loss.append(loss.item())
            rewards_real.append(reward_real)
            rewards_fake.append(reward_fake)
            accuracy.extend([reward_real > 0.5, reward_fake < 0.5])

            real_reply, fake_reply = (
                generator.tokenizer.decode(real_reply[0]),
                generator.tokenizer.decode(fake_reply[0]),
            )
            real_replies.append(real_reply)
            fake_replies.append(fake_reply)

        train_loss, valid_loss = np.mean(train_loss), np.mean(valid_loss)
        print(
            f"Epoch {epoch + 1}, Train Loss: {train_loss:.2f}, Valid Loss: {valid_loss:.2f}, Reward real: {np.mean(rewards_real):.2f}, Reward fake: {np.mean(rewards_fake):.2f}, Accuracy: {np.mean(accuracy):.2f}"
        )
        print(f"Adversarial accuracy, {np.mean(accuracy):.2f}")
        for order in range(1, 5):
            print(
                f"BLEU-{order}: {bleuscore(real_replies, fake_replies, order=order)}"
            )
        print(f"DIST-1: {dist1(fake_replies)}")
        print(f"DIST-2: {dist2unbiased(fake_replies)}")
def main():
    args = parse_args()
    device = torch.device("cuda")

    generator = Generator.from_file(args.generator_path).to(device)
    if args.freeze:
        for name, param in generator.named_parameters():
            if ("shared" not in name) and ("decoder.block.5" not in name):
                param.requires_grad = False
    discriminator = Discriminator.from_file(
        args.discriminator_path, tokenizer=generator.tokenizer
    ).to(device)
    if args.freeze:
        for name, param in discriminator.named_parameters():
            if ("shared" not in name) and ("decoder.block.5" not in name):
                param.requires_grad = False
    train_dataset = DailyDialogueDataset(
        path_join(args.dataset_path, "train/dialogues_train.txt"),
        tokenizer=generator.tokenizer,
        debug=args.debug,
    )
    valid_dataset = DailyDialogueDataset(
        path_join(args.dataset_path, "validation/dialogues_validation.txt"),
        tokenizer=generator.tokenizer,
        debug=args.debug,
    )

    print(len(train_dataset), len(valid_dataset))

    generator_optimizer = AdamW(generator.parameters(), lr=args.lr)
    discriminator_optimizer = AdamW(discriminator.parameters(), lr=args.lr)

    rewards = deque([], maxlen=args.log_every * args.generator_steps)
    rewards_real = deque([], maxlen=args.log_every * args.generator_steps)
    generator_loss = deque([], maxlen=args.log_every * args.generator_steps)
    discriminator_loss = deque([], maxlen=args.log_every * args.discriminator_steps)
    best_reward = 0

    generator.train()
    discriminator.train()

    for iter in tqdm(range(args.num_iterations)):
        for _ in range(args.discriminator_steps):
            discriminator_optimizer.zero_grad()
            context, real_reply = train_dataset.sample()
            context, real_reply = (
                context.to(device),
                real_reply.to(device),
            )
            fake_reply = generator.generate(context, do_sample=True)

            if args.regs:
                split_real = random.randint(1, real_reply.size(1))
                real_reply = real_reply[:, :split_real]
                split_fake = random.randint(1, fake_reply.size(1))
                fake_reply = fake_reply[:, :split_fake]

            loss, _, _ = discriminator.get_loss(context, real_reply, fake_reply)
            loss.backward()
            discriminator_optimizer.step()

            discriminator_loss.append(loss.item())

        for _ in range(args.generator_steps):
            generator_optimizer.zero_grad()
            context, real_reply = train_dataset.sample()
            context, real_reply = (
                context.to(device),
                real_reply.to(device),
            )
            fake_reply = generator.generate(context, do_sample=True)

            logprob_fake = generator.get_logprob(context, fake_reply)
            reward_fake = discriminator.get_reward(context, fake_reply)

            baseline = 0 if len(rewards) == 0 else np.mean(list(rewards))

            if args.regs:
                partial_rewards = torch.tensor(
                    [
                        discriminator.get_reward(context, fake_reply[:, :t])
                        for t in range(1, fake_reply.size(1) + 1)
                    ]
                ).to(device)
                loss = -torch.mean(partial_rewards * logprob_fake)

            else:
                loss = -(reward_fake - baseline) * torch.mean(logprob_fake)

            if args.teacher_forcing:
                logprob_real = generator.get_logprob(context, real_reply)
                reward_real = discriminator.get_reward(context, real_reply)
                loss -= torch.mean(logprob_real)
                rewards_real.append(reward_real)

            loss.backward()
            generator_optimizer.step()

            generator_loss.append(loss.item())
            rewards.append(reward_fake)

        if iter % args.log_every == 0:
            mean_reward = np.mean(list(rewards))
            mean_reward_real = np.mean(list(rewards_real))

            if args.discriminator_steps > 0:
                print(f"Discriminator Loss {np.mean(list(discriminator_loss))}")
            if args.generator_steps > 0:
                print(f"Generator Loss {np.mean(list(generator_loss))}")
                if args.teacher_forcing:
                    print(f"Mean real reward: {mean_reward_real}")
                print(f"Mean fake reward: {mean_reward}\n")

            context, real_reply = valid_dataset.sample()
            context, real_reply = (
                context.to(device),
                real_reply.to(device),
            )
            fake_reply = generator.generate(context, do_sample=True)
            reward_fake = discriminator.get_reward(context, fake_reply)

            print_dialogue(
                context=context,
                real_reply=real_reply,
                fake_reply=fake_reply,
                tokenizer=generator.tokenizer,
            )
            print(f"Reward: {reward_fake}\n")

            if mean_reward > best_reward:
                best_reward = mean_reward
                torch.save(discriminator.state_dict(), args.discriminator_output_path)
                torch.save(generator.state_dict(), args.generator_output_path)
            torch.save(
                discriminator.state_dict(), "all_" + args.discriminator_output_path
            )
            torch.save(generator.state_dict(), "all_" + args.generator_output_path)
Esempio n. 8
0
crop_size = (args.crop_size, args.crop_size)
image_dataset = ImageDataset(args.image_root, args.mask_root, load_size, crop_size)
data_loader = DataLoader(
    image_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.num_workers,
    drop_last=False,
    pin_memory=True
)

# -----
# model
# -----
generator = LBAM(4, 3)
discriminator = Discriminator(3)
extractor = VGG16FeatureExtractor()

# ----------
# load model
# ----------
start_iter = args.start_iter
if args.pre_trained != '':

    ckpt_dict_load = torch.load(args.pre_trained)
    start_iter = ckpt_dict_load['n_iter']
    generator.load_state_dict(ckpt_dict_load['generator'])
    discriminator.load_state_dict(ckpt_dict_load['discriminator'])

    print('Starting from iter ', start_iter)
Esempio n. 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda',
                        default=False,
                        action='store_true',
                        help='Enable CUDA')
    args = parser.parse_args()
    use_cuda = True if args.cuda and torch.cuda.is_available() else False

    netG = Generator(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda)
    netD = Discriminator(VOCAB_SIZE, D_EMB_SIZE, D_NUM_CLASSES, D_FILTER_SIZES,
                         D_NUM_FILTERS, DROPOUT, use_cuda)
    oracle = Oracle(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda)

    if use_cuda:
        netG, netD, oracle = netG.cuda(), netD.cuda(), oracle.cuda()

    netG.create_optim(G_LR)
    netD.create_optim(D_LR, D_L2_REG)

    # generating synthetic data
    print('Generating data...')
    generate_samples(oracle, BATCH_SIZE, GENERATED_NUM, REAL_FILE)

    # pretrain generator
    gen_set = GeneratorDataset(REAL_FILE)
    genloader = DataLoader(dataset=gen_set,
                           batch_size=BATCH_SIZE,
                           shuffle=True)

    print('\nPretraining generator...\n')
    for epoch in range(PRE_G_EPOCHS):
        loss = netG.pretrain(genloader)
        print('Epoch {} pretrain generator training loss: {}'.format(
            epoch + 1, loss))

        generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        val_set = GeneratorDataset(EVAL_FILE)
        valloader = DataLoader(dataset=val_set,
                               batch_size=BATCH_SIZE,
                               shuffle=True)
        loss = oracle.val(valloader)
        print('Epoch {} pretrain generator val loss: {}'.format(
            epoch + 1, loss))

    # pretrain discriminator
    print('\nPretraining discriminator...\n')
    for epoch in range(PRE_D_EPOCHS):
        generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE)
        dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE)
        disloader = DataLoader(dataset=dis_set,
                               batch_size=BATCH_SIZE,
                               shuffle=True)

        for k_step in range(K_STEPS):
            loss = netD.dtrain(disloader)
            print(
                'Epoch {} K-step {} pretrain discriminator training loss: {}'.
                format(epoch + 1, k_step + 1, loss))

    print('\nStarting adversarial training...')
    for epoch in range(TOTAL_EPOCHS):

        nets = [copy.deepcopy(netG) for _ in range(POPULATION_SIZE)]
        population = [(net, evaluate(net, netD)) for net in nets]
        for g_step in range(G_STEPS):
            t_start = time.time()
            population.sort(key=lambda p: p[1], reverse=True)
            rewards = [p[1] for p in population[:PARENTS_COUNT]]
            reward_mean = np.mean(rewards)
            reward_max = np.max(rewards)
            reward_std = np.std(rewards)
            print(
                "Epoch %d step %d: reward_mean=%.2f, reward_max=%.2f, reward_std=%.2f, time=%.2f s"
                % (epoch, g_step, reward_mean, reward_max, reward_std,
                   time.time() - t_start))

            elite = population[0]
            # generate next population
            prev_population = population
            population = [elite]
            for _ in range(POPULATION_SIZE - 1):
                parent_idx = np.random.randint(0, PARENTS_COUNT)
                parent = prev_population[parent_idx][0]
                net = mutate_net(parent, use_cuda)
                fitness = evaluate(parent, netD)
                population.append((net, fitness))

        netG = elite[0]

        for d_step in range(D_STEPS):
            # train discriminator
            generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE)
            dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE)
            disloader = DataLoader(dataset=dis_set,
                                   batch_size=BATCH_SIZE,
                                   shuffle=True)

            for k_step in range(K_STEPS):
                loss = netD.dtrain(disloader)
                print(
                    'D_step {}, K-step {} adversarial discriminator training loss: {}'
                    .format(d_step + 1, k_step + 1, loss))

        generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        val_set = GeneratorDataset(EVAL_FILE)
        valloader = DataLoader(dataset=val_set,
                               batch_size=BATCH_SIZE,
                               shuffle=True)
        loss = oracle.val(valloader)
        print('Epoch {} adversarial generator val loss: {}'.format(
            epoch + 1, loss))
Esempio n. 10
0
class GAN:
    def __init__(self, _device):
        self.device = _device
        self.batch_size = 64
        self.resolution = 28
        self.d_criterion = None
        self.d_optimizer = None
        self.g_criterion = None
        self.g_optimizer = None

        self.discriminator = Discriminator(num_layers=5,
                                           activations=["relu", "relu", "relu",
                                                        "sigmoid"],
                                           device=_device,
                                           num_nodes=[1, 64, 128, 64, 1],
                                           kernels=[5, 5, 3],
                                           strides=[2, 2, 2],
                                           dropouts=[.25, .25, 0],
                                           batch_size=64)

        # pass one image through the network so as to initialize the output
        # layer
        self.discriminator(torch.rand(
            size=[self.batch_size, 1, self.resolution, self.resolution]))

        self.generator = Generator(num_layers=6,
                                   activations=["relu", "relu", "relu", "relu",
                                                "tanh"],
                                   num_nodes=[1, 64, 128, 64, 64, 1],
                                   kernels=[3, 3, 3, 3],
                                   strides=[1, 1, 1, 1],
                                   batch_norms=[1, 1, 1, 0],
                                   upsamples=[1, 1, 0, 0],
                                   dropouts=[.25, .25, 0])

    def train(self, epochs: int, dataloader):
        self.display_output()
        for epoch in range(epochs):
            i = 0
            initial_loss = self.train_generator()
            true_loss = 1.
            false_loss = 1.
            print(f"#\tInitial Generator Loss: {initial_loss}")
            for data, target in dataloader:
                if false_loss > 0.7:
                    true_loss, false_loss = self.train_discriminator(data, True)
                else:
                    true_loss, false_loss = self.train_discriminator(data,
                                                                     False)
                generator_loss = self.train_generator()
                if i % 10 == 0:
                    print(
                        f"@\tIndex: {i}\tTrue Loss: {true_loss}\t"
                        f"False Loss: {false_loss}\t"
                        f"Generator Loss: {generator_loss}")
                    self.display_output()
                i += 1
        self.test()

    def train_discriminator(self, train_data, train):

        # add noise to labels
        true = torch.ones((self.batch_size, 1))
        noise = torch.nn.functional.relu(0.01 * torch.randn(self.batch_size, 1))
        true.sub_(noise)
        true = true.to(self.device, dtype=torch.float64)

        false = torch.zeros((self.batch_size, 1))
        noise = torch.nn.functional.relu(0.01 * torch.randn(self.batch_size, 1))
        false.sub_(noise)
        false = false.to(self.device, dtype=torch.float64)

        index = np.random.randint(0, train_data.shape[0], self.batch_size)
        true_images = train_data[index]
        true_loss = self.discriminator.batch_train(true_images, true,
                                                   self.d_criterion,
                                                   self.d_optimizer, train)

        # FIXME: Extract 100 to argument
        noise = torch.randn(self.batch_size, 1, 1, 100, dtype=torch.float64).to(
            self.device)
        generated_images = self.generator(noise)
        false_loss = self.discriminator.batch_train(generated_images, false,
                                                    self.d_criterion,
                                                    self.d_optimizer, train)

        return true_loss, false_loss

    def train_generator(self):
        valid = torch.ones((self.batch_size, 1), dtype=torch.float64).to(
            self.device)
        noise = torch.randn(self.batch_size, 1, 1, 100, dtype=torch.float64).to(
            self.device)
        return self.generator.batch_train(self.discriminator, noise, valid,
                                          self.g_criterion, self.g_optimizer)

    # make noise, and send through discriminator
    def test(self):
        noise = torch.randn(self.batch_size, 1, 1, 100, dtype=torch.float64).to(
            self.device)
        image = self.generator(noise).detach().cpu().numpy()
        for i in range(np.size(image, 0)):
            picture = image[i, 0, :, :]
            plt.imshow(picture)
            plt.show()

    def display_output(self):
        noise = torch.randn(1, 1, 1, 100, dtype=torch.float64).to(self.device)
        image = self.generator(noise).detach().cpu().numpy()
        picture = image[0, 0, :, :]
        plt.imshow(picture)
        plt.show()
Esempio n. 11
0
def main_train():
    # Build argument parser
    parser = argparse.ArgumentParser(description='Train a table to text model')

    # Training corpus
    corpora_group = parser.add_argument_group('training corpora',
                                              'Corpora related arguments; specify either unaligned or'
                                              ' aligned training corpora')
    # "Languages (type,path)"
    corpora_group.add_argument('--src_corpus_params', type=str,
                               default='table, ./data/processed_data/train/train.box',
                               help='the source unaligned corpus (type,path). Type = text/table')
    corpora_group.add_argument('--trg_corpus_params', type=str,
                               default='text, ./data/processed_data/train/train.article',
                               help='the target unaligned corpus (type,path). Type = text/table')
    corpora_group.add_argument('--src_para_corpus_params', type=str, default='',
                               help='the source corpus of parallel data(type,path). Type = text/table')
    corpora_group.add_argument('--trg_para_corpus_params', type=str, default='',
                               help='the target corpus of parallel data(type,path). Type = text/table')
    # Maybe add src/target type (i.e. text/table)
    corpora_group.add_argument('--corpus_mode', type=str, default='mono',
                               help='training mode: "mono" (unsupervised) / "para" (supervised)')

    corpora_group.add_argument('--max_sentence_length', type=int, default=50,
                               help='the maximum sentence length for training (defaults to 50)')
    corpora_group.add_argument('--cache', type=int, default=100000,
                               help='the cache size (in sentences) for corpus reading (defaults to 1000000)')

    # Embeddings/vocabulary
    embedding_group = parser.add_argument_group('embeddings',
                                                'Embedding related arguments; either give pre-trained embeddings,'
                                                ' or a vocabulary and embedding dimensionality to'
                                                ' randomly initialize them')
    embedding_group.add_argument('--metadata_path', type=str, default='', required=True,
                                 help='Path for bin file created in pre-processing phase, '
                                      'containing BPEmb related metadata.')

    # Architecture
    architecture_group = parser.add_argument_group('architecture', 'Architecture related arguments')
    architecture_group.add_argument('--layers', type=int, default=2,
                                    help='the number of encoder/decoder layers (defaults to 2)')
    architecture_group.add_argument('--hidden', type=int, default=600,
                                    help='the number of dimensions for the hidden layer (defaults to 600)')
    architecture_group.add_argument('--dis_hidden', type=int, default=150,
                                    help='Number of dimensions for the discriminator hidden layers')
    architecture_group.add_argument('--n_dis_layers', type=int, default=2,
                                    help='Number of discriminator layers')
    architecture_group.add_argument('--disable_bidirectional', action='store_true',
                                    help='use a single direction encoder')
    architecture_group.add_argument('--disable_backtranslation', action='store_true', help='disable backtranslation')
    architecture_group.add_argument('--disable_field_loss', action='store_true', help='disable backtranslation')
    architecture_group.add_argument('--disable_discriminator', action='store_true', help='disable discriminator')
    architecture_group.add_argument('--shared_enc', action='store_true', help='share enc for both directions')
    architecture_group.add_argument('--shared_dec', action='store_true', help='share dec for both directions')

    # Denoising
    denoising_group = parser.add_argument_group('denoising', 'Denoising related arguments')
    denoising_group.add_argument('--denoising_mode', type=int, default=1, help='0/1/2 = disabled/old/new')
    denoising_group.add_argument('--word_shuffle', type=int, default=3,
                                 help='shuffle words (only relevant in new mode)')
    denoising_group.add_argument('--word_dropout', type=float, default=0.1,
                                 help='randomly remove words (only relevant in new mode)')
    denoising_group.add_argument('--word_blank', type=float, default=0.2,
                                 help='randomly blank out words (only relevant in new mode)')

    # Optimization
    optimization_group = parser.add_argument_group('optimization', 'Optimization related arguments')
    optimization_group.add_argument('--batch', type=int, default=50, help='the batch size (defaults to 50)')
    optimization_group.add_argument('--learning_rate', type=float, default=0.0002,
                                    help='the global learning rate (defaults to 0.0002)')
    optimization_group.add_argument('--dropout', metavar='PROB', type=float, default=0.3,
                                    help='dropout probability for the encoder/decoder (defaults to 0.3)')
    optimization_group.add_argument('--param_init', metavar='RANGE', type=float, default=0.1,
                                    help='uniform initialization in the specified range (defaults to 0.1,  0 for module specific default initialization)')
    optimization_group.add_argument('--iterations', type=int, default=300000,
                                    help='the number of training iterations (defaults to 300000)')

    # Model saving
    saving_group = parser.add_argument_group('model saving', 'Arguments for saving the trained model')
    saving_group.add_argument('--save', metavar='PREFIX', help='save models with the given prefix')
    saving_group.add_argument('--save_interval', type=int, default=0, help='save intermediate models at this interval')

    # Logging/validation
    logging_group = parser.add_argument_group('logging', 'Logging and validation arguments')
    logging_group.add_argument('--log_interval', type=int, default=100, help='log at this interval (defaults to 1000)')
    logging_group.add_argument('--dbg_print_interval', type=int, default=1000,
                               help='log at this interval (defaults to 1000)')
    logging_group.add_argument('--src_valid_corpus', type=str, default='')
    logging_group.add_argument('--trg_valid_corpus', type=str, default='')
    logging_group.add_argument('--print_level', type=str, default='info', help='logging level [debug | info]')

    # Other
    misc_group = parser.add_argument_group('misc', 'Misc. arguments')
    misc_group.add_argument('--encoding', default='utf-8',
                            help='the character encoding for input/output (defaults to utf-8)')
    misc_group.add_argument('--cuda', type=str, default='cpu', help='device for training. default value: "cpu"')
    misc_group.add_argument('--bleu_device', type=str, default='',
                            help='device for calculating BLEU scores in case a validation dataset is given')

    # Parse arguments
    args = parser.parse_args()

    logger = logging.getLogger()
    if args.print_level == 'debug':
        logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
    elif args.print_level == 'info':
        logging.basicConfig(stream=sys.stderr, level=logging.INFO)
    elif args.print_level == 'warning':
        logging.basicConfig(stream=sys.stderr, level=logging.WARNING)
    else:
        logging.basicConfig(stream=sys.stderr, level=logging.CRITICAL)

    # Validate arguments
    if args.src_corpus_params is None or args.trg_corpus_params is None:
        print("Must supply corpus")
        sys.exit(-1)

    args.src_corpus_params = args.src_corpus_params.split(',')
    args.trg_corpus_params = args.trg_corpus_params.split(',')
    assert len(args.src_corpus_params) == 2
    assert len(args.trg_corpus_params) == 2

    src_type, src_corpus_path = args.src_corpus_params
    trg_type, trg_corpus_path = args.trg_corpus_params

    src_type = src_type.strip()
    src_corpus_path = src_corpus_path.strip()
    trg_type = trg_type.strip()
    trg_corpus_path = trg_corpus_path.strip()

    assert src_type != trg_type
    assert (src_type in ['table', 'text']) and (trg_type in ['table', 'text'])

    corpus_size = get_num_lines(src_corpus_path + '.content')

    # Select device
    if torch.cuda.is_available():
        device = torch.device(args.cuda)
    else:
        device = torch.device('cpu')

    if args.bleu_device == '':
        args.bleu_device = device

    current_time = str(datetime.datetime.now().timestamp())
    run_dir = 'run_' + current_time + '/'
    train_log_dir = 'logs/train/' + run_dir + args.save
    valid_log_dir = 'logs/valid/' + run_dir + args.save

    train_writer = SummaryWriter(train_log_dir)
    valid_writer = SummaryWriter(valid_log_dir)

    # Create optimizer lists
    src2src_optimizers = []
    trg2trg_optimizers = []
    src2trg_optimizers = []
    trg2src_optimizers = []

    # Method to create a module optimizer and add it to the given lists
    def add_optimizer(module, directions=()):
        if args.param_init != 0.0:
            for param in module.parameters():
                param.data.uniform_(-args.param_init, args.param_init)
        optimizer = torch.optim.Adam(module.parameters(), lr=args.learning_rate)
        for direction in directions:
            direction.append(optimizer)
        return optimizer

    assert os.path.isfile(args.metadata_path)

    metadata = torch.load(args.metadata_path)
    bpemb_en = metadata.init_bpe_module()
    word_dict: BpeWordDict = torch.load(metadata.word_dict_path)
    field_dict: LabelDict = torch.load(metadata.field_dict_path)

    args.hidden = bpemb_en.dim + bpemb_en.dim // 2
    if not args.disable_bidirectional:
        args.hidden *= 2

    # Load embedding and/or vocab
    # word_dict = BpeWordDict.get(vocab=bpemb_en.words)
    w_sos_id = {'text': word_dict.bos_index, 'table': word_dict.sot_index}

    word_embeddings = nn.Embedding(len(word_dict), bpemb_en.dim, padding_idx=word_dict.pad_index)
    nn.init.normal_(word_embeddings.weight, 0, 0.1)
    nn.init.constant_(word_embeddings.weight[word_dict.pad_index], 0)
    with torch.no_grad():
        word_embeddings.weight[:bpemb_en.vs, :] = torch.from_numpy(bpemb_en.vectors)
    word_embedding_size = word_embeddings.weight.data.size()[1]
    word_embeddings = word_embeddings.to(device)
    word_embeddings.weight.requires_grad = False
    logger.debug('w_embeddings is running on cuda: %d', next(word_embeddings.parameters()).is_cuda)

    # field_dict: LabelDict = torch.load('./data/processed_data/train/field.dict')
    field_embeddings = nn.Embedding(len(field_dict), bpemb_en.dim // 2, padding_idx=field_dict.pad_index)
    nn.init.normal_(field_embeddings.weight, 0, 0.1)
    nn.init.constant_(field_embeddings.weight[field_dict.pad_index], 0)
    field_embedding_size = field_embeddings.weight.data.size()[1]
    field_embeddings = field_embeddings.to(device)
    field_embeddings.weight.requires_grad = True
    logger.debug('f_embeddings is running on cuda: %d', next(word_embeddings.parameters()).is_cuda)

    src_encoder_word_embeddings = word_embeddings
    trg_encoder_word_embeddings = word_embeddings
    src_encoder_field_embeddings = field_embeddings
    trg_encoder_field_embeddings = field_embeddings

    src_decoder_word_embeddings = word_embeddings
    trg_decoder_word_embeddings = word_embeddings
    src_decoder_field_embeddings = field_embeddings
    trg_decoder_field_embeddings = field_embeddings

    src_generator = LinearGenerator(args.hidden, len(word_dict), len(field_dict)).to(device)

    if args.shared_dec:
        trg_generator = src_generator
        add_optimizer(src_generator, (src2src_optimizers, trg2src_optimizers, trg2trg_optimizers, src2trg_optimizers))
    else:
        trg_generator = LinearGenerator(args.hidden, len(word_dict), len(field_dict)).to(device)
        add_optimizer(src_generator, (src2src_optimizers, trg2src_optimizers))
        add_optimizer(trg_generator, (trg2trg_optimizers, src2trg_optimizers))

    logger.debug('src generator is running on cuda: %d', next(src_generator.parameters()).is_cuda)
    logger.debug('trg generator is running on cuda: %d', next(src_generator.parameters()).is_cuda)

    # Build encoder
    src_enc = RNNEncoder(word_embedding_size=word_embedding_size, field_embedding_size=field_embedding_size,
                         hidden_size=args.hidden, bidirectional=not args.disable_bidirectional,
                         layers=args.layers, dropout=args.dropout).to(device)

    if args.shared_enc:
        trg_enc = src_enc
        add_optimizer(src_enc, (src2src_optimizers, src2trg_optimizers, trg2trg_optimizers, trg2src_optimizers))
    else:
        trg_enc = RNNEncoder(word_embedding_size=word_embedding_size, field_embedding_size=field_embedding_size,
                             hidden_size=args.hidden, bidirectional=not args.disable_bidirectional,
                             layers=args.layers, dropout=args.dropout).to(device)
        add_optimizer(src_enc, (src2src_optimizers, src2trg_optimizers))
        add_optimizer(trg_enc, (trg2trg_optimizers, trg2src_optimizers))

    logger.debug('encoder model is running on cuda: %d', next(src_enc.parameters()).is_cuda)

    # Build decoders
    src_dec = RNNAttentionDecoder(word_embedding_size=word_embedding_size,
                                  field_embedding_size=field_embedding_size, hidden_size=args.hidden,
                                  layers=args.layers, dropout=args.dropout, input_feeding=False).to(device)

    if args.shared_dec:
        trg_dec = src_dec
        add_optimizer(src_dec, (src2src_optimizers, trg2src_optimizers, trg2trg_optimizers, src2trg_optimizers))
    else:
        trg_dec = RNNAttentionDecoder(word_embedding_size=word_embedding_size,
                                      field_embedding_size=field_embedding_size, hidden_size=args.hidden,
                                      layers=args.layers, dropout=args.dropout, input_feeding=False).to(device)
        add_optimizer(src_dec, (src2src_optimizers, trg2src_optimizers))
        add_optimizer(trg_dec, (trg2trg_optimizers, src2trg_optimizers))

    logger.debug('decoder model is running on cuda: %d', next(src_dec.parameters()).is_cuda)
    logger.debug('attention model is running on cuda: %d', next(src_dec.attention.parameters()).is_cuda)

    discriminator = None

    if (args.corpus_mode == 'mono') and not args.disable_discriminator:
        discriminator = Discriminator(args.hidden, args.dis_hidden, args.n_dis_layers, args.dropout)
        discriminator = discriminator.to(device)

    # Build translators
    src2src_translator = Translator("src2src",
                                    encoder_word_embeddings=src_encoder_word_embeddings,
                                    decoder_word_embeddings=src_decoder_word_embeddings,
                                    encoder_field_embeddings=src_encoder_field_embeddings,
                                    decoder_field_embeddings=src_decoder_field_embeddings,
                                    generator=src_generator,
                                    src_word_dict=word_dict, trg_word_dict=word_dict,
                                    src_field_dict=field_dict, trg_field_dict=field_dict,
                                    src_type=src_type, trg_type=src_type, w_sos_id=w_sos_id[src_type],
                                    bpemb_en=bpemb_en, encoder=src_enc, decoder=src_dec, discriminator=discriminator,
                                    denoising=args.denoising_mode, device=device,
                                    max_word_shuffle_distance=args.word_shuffle,
                                    word_dropout_prob=args.word_dropout,
                                    word_blanking_prob=args.word_blank)
    src2trg_translator = Translator("src2trg",
                                    encoder_word_embeddings=src_encoder_word_embeddings,
                                    decoder_word_embeddings=trg_decoder_word_embeddings,
                                    encoder_field_embeddings=src_encoder_field_embeddings,
                                    decoder_field_embeddings=trg_decoder_field_embeddings,
                                    generator=trg_generator,
                                    src_word_dict=word_dict, trg_word_dict=word_dict,
                                    src_field_dict=field_dict, trg_field_dict=field_dict,
                                    src_type=src_type, trg_type=trg_type, w_sos_id=w_sos_id[trg_type],
                                    bpemb_en=bpemb_en, encoder=src_enc, decoder=trg_dec, discriminator=discriminator,
                                    denoising=0, device=device,
                                    max_word_shuffle_distance=args.word_shuffle,
                                    word_dropout_prob=args.word_dropout,
                                    word_blanking_prob=args.word_blank)
    trg2trg_translator = Translator("trg2trg",
                                    encoder_word_embeddings=trg_encoder_word_embeddings,
                                    decoder_word_embeddings=trg_decoder_word_embeddings,
                                    encoder_field_embeddings=trg_encoder_field_embeddings,
                                    decoder_field_embeddings=trg_decoder_field_embeddings,
                                    generator=trg_generator,
                                    src_word_dict=word_dict, trg_word_dict=word_dict,
                                    src_field_dict=field_dict, trg_field_dict=field_dict,
                                    src_type=trg_type, trg_type=trg_type, w_sos_id=w_sos_id[trg_type],
                                    bpemb_en=bpemb_en, encoder=trg_enc, decoder=trg_dec, discriminator=discriminator,
                                    denoising=args.denoising_mode, device=device,
                                    max_word_shuffle_distance=args.word_shuffle,
                                    word_dropout_prob=args.word_dropout,
                                    word_blanking_prob=args.word_blank)
    trg2src_translator = Translator("trg2src",
                                    encoder_word_embeddings=trg_encoder_word_embeddings,
                                    decoder_word_embeddings=src_decoder_word_embeddings,
                                    encoder_field_embeddings=trg_encoder_field_embeddings,
                                    decoder_field_embeddings=src_decoder_field_embeddings,
                                    generator=src_generator,
                                    src_word_dict=word_dict, trg_word_dict=word_dict,
                                    src_field_dict=field_dict, trg_field_dict=field_dict,
                                    src_type=trg_type, trg_type=src_type, w_sos_id=w_sos_id[src_type],
                                    bpemb_en=bpemb_en, encoder=trg_enc, decoder=src_dec, discriminator=discriminator,
                                    denoising=0, device=device,
                                    max_word_shuffle_distance=args.word_shuffle,
                                    word_dropout_prob=args.word_dropout,
                                    word_blanking_prob=args.word_blank)

    # Build trainers
    trainers = []
    iters_per_epoch = int(np.ceil(corpus_size / args.batch))
    print("CORPUS_SIZE = %d | BATCH_SIZE = %d | ITERS_PER_EPOCH = %d" % (corpus_size, args.batch, iters_per_epoch))

    if args.corpus_mode == 'mono':
        f_content = open(src_corpus_path + '.content', encoding=args.encoding, errors='surrogateescape')
        f_labels = open(src_corpus_path + '.labels', encoding=args.encoding, errors='surrogateescape')
        src_corpus_path = data.CorpusReader(f_content, f_labels, max_sentence_length=args.max_sentence_length,
                                       cache_size=args.cache)
        f_content = open(trg_corpus_path + '.content', encoding=args.encoding, errors='surrogateescape')
        f_labels = open(trg_corpus_path + '.labels', encoding=args.encoding, errors='surrogateescape')
        trg_corpus_path = data.CorpusReader(f_content, f_labels, max_sentence_length=args.max_sentence_length,
                                       cache_size=args.cache)

        if not args.disable_discriminator:
            disc_trainer = DiscTrainer(device, src_corpus_path, trg_corpus_path, src_enc, trg_enc, src_encoder_word_embeddings,
                                       src_encoder_field_embeddings, word_dict, field_dict, discriminator,
                                       args.learning_rate, batch_size=args.batch)
            trainers.append(disc_trainer)

        src2src_trainer = Trainer(translator=src2src_translator, optimizers=src2src_optimizers, corpus=src_corpus_path,
                                  batch_size=args.batch, iters_per_epoch=iters_per_epoch)
        trainers.append(src2src_trainer)
        if not args.disable_backtranslation:
            trgback2src_trainer = Trainer(translator=trg2src_translator, optimizers=trg2src_optimizers,
                                          corpus=data.BacktranslatorCorpusReader(corpus=src_corpus_path,
                                                                                 translator=src2trg_translator),
                                          batch_size=args.batch, iters_per_epoch=iters_per_epoch)
            trainers.append(trgback2src_trainer)

        trg2trg_trainer = Trainer(translator=trg2trg_translator, optimizers=trg2trg_optimizers, corpus=trg_corpus_path,
                                  batch_size=args.batch, iters_per_epoch=iters_per_epoch)
        trainers.append(trg2trg_trainer)
        if not args.disable_backtranslation:
            srcback2trg_trainer = Trainer(translator=src2trg_translator, optimizers=src2trg_optimizers,
                                          corpus=data.BacktranslatorCorpusReader(corpus=trg_corpus_path,
                                                                                 translator=trg2src_translator),
                                          batch_size=args.batch, iters_per_epoch=iters_per_epoch)
            trainers.append(srcback2trg_trainer)
    elif args.corpus_mode == 'para':
        fsrc_content = open(src_corpus_path + '.content', encoding=args.encoding, errors='surrogateescape')
        fsrc_labels = open(src_corpus_path + '.labels', encoding=args.encoding, errors='surrogateescape')
        ftrg_content = open(trg_corpus_path + '.content', encoding=args.encoding, errors='surrogateescape')
        ftrg_labels = open(trg_corpus_path + '.labels', encoding=args.encoding, errors='surrogateescape')
        corpus = data.CorpusReader(fsrc_content, fsrc_labels, trg_word_file=ftrg_content, trg_field_file=ftrg_labels,
                                   max_sentence_length=args.max_sentence_length,
                                   cache_size=args.cache)
        src2trg_trainer = Trainer(translator=src2trg_translator, optimizers=src2trg_optimizers, corpus=corpus,
                                  batch_size=args.batch, iters_per_epoch=iters_per_epoch)
        trainers.append(src2trg_trainer)

    # Build validators
    if args.src_valid_corpus != '' and args.trg_valid_corpus != '':
        with ExitStack() as stack:
            src_content_vfile = stack.enter_context(open(args.src_valid_corpus + '.content', encoding=args.encoding,
                                                         errors='surrogateescape'))
            src_labels_vfile = stack.enter_context(open(args.src_valid_corpus + '.labels', encoding=args.encoding,
                                                        errors='surrogateescape'))
            trg_content_vfile = stack.enter_context(open(args.trg_valid_corpus + '.content', encoding=args.encoding,
                                                         errors='surrogateescape'))
            trg_labels_vfile = stack.enter_context(open(args.trg_valid_corpus + '.labels', encoding=args.encoding,
                                                        errors='surrogateescape'))

            src_content = src_content_vfile.readlines()
            src_labels = src_labels_vfile.readlines()
            trg_content = trg_content_vfile.readlines()
            trg_labels = trg_labels_vfile.readlines()
            assert len(src_content) == len(trg_content) == len(src_labels) == len(trg_labels), \
                "Validation sizes do not match {} {} {} {}".format(len(src_content), len(trg_content), len(src_labels),
                len(trg_labels))

            src_content = [list(map(int, line.strip().split())) for line in src_content]
            src_labels = [list(map(int, line.strip().split())) for line in src_labels]
            trg_content = [list(map(int, line.strip().split())) for line in trg_content]
            trg_labels = [list(map(int, line.strip().split())) for line in trg_labels]

            cache = []
            for src_sent, src_label, trg_sent, trg_label in zip(src_content, src_labels, trg_content, trg_labels):
                if 0 < len(src_sent) <= args.max_sentence_length and 0 < len(trg_sent) <= args.max_sentence_length:
                    cache.append((src_sent, src_label, trg_sent, trg_label))

            src_content, src_labels, trg_content, trg_labels = zip(*cache)

            src2trg_validator = Validator(src2trg_translator, src_content, trg_content, src_labels, trg_labels)

            if args.corpus_mode == 'mono':
                src2src_validator = Validator(src2src_translator, src_content, src_content, src_labels, src_labels)

                trg2src_validator = Validator(trg2src_translator, trg_content, src_content, trg_labels, src_labels)

                trg2trg_validator = Validator(trg2trg_translator, trg_content, trg_content, trg_labels, trg_labels)

            del src_content
            del src_labels
            del trg_content
            del trg_labels
    else:
        src2src_validator = None
        src2trg_validator = None
        trg2src_validator = None
        trg2trg_validator = None

    # Build loggers
    loggers = []
    semi_loggers = []

    if args.corpus_mode == 'mono':
        if not args.disable_backtranslation:
            loggers.append(Logger('Source to target (backtranslation)', srcback2trg_trainer, src2trg_validator,
                                  None, args.encoding, short_name='src2trg_bt', train_writer=train_writer,
                                  valid_writer=valid_writer))
            loggers.append(Logger('Target to source (backtranslation)', trgback2src_trainer, trg2src_validator,
                                  None, args.encoding, short_name='trg2src_bt', train_writer=train_writer,
                                  valid_writer=valid_writer))

        loggers.append(Logger('Source to source', src2src_trainer, src2src_validator, None, args.encoding,
                              short_name='src2src', train_writer=train_writer, valid_writer=valid_writer))
        loggers.append(Logger('Target to target', trg2trg_trainer, trg2trg_validator, None, args.encoding,
                              short_name='trg2trg', train_writer=train_writer, valid_writer=valid_writer))
    elif args.corpus_mode == 'para':
        loggers.append(Logger('Source to target', src2trg_trainer, src2trg_validator, None, args.encoding,
                              short_name='src2trg_para', train_writer=train_writer, valid_writer=valid_writer))

    # Method to save models
    def save_models(name):
        # torch.save(src2src_translator, '{0}.{1}.src2src.pth'.format(args.save, name))
        # torch.save(trg2trg_translator, '{0}.{1}.trg2trg.pth'.format(args.save, name))
        torch.save(src2trg_translator, '{0}.{1}.src2trg.pth'.format(args.save, name))
        if args.corpus_mode == 'mono':
            torch.save(trg2src_translator, '{0}.{1}.trg2src.pth'.format(args.save, name))

    ref_string_path = args.trg_valid_corpus + '.str.content'

    if not os.path.isfile(ref_string_path):
        print("Creating ref file... [%s]" % (ref_string_path))

        with ExitStack() as stack:

            fref_content = stack.enter_context(
                open(args.trg_valid_corpus + '.content', encoding=args.encoding, errors='surrogateescape'))
            fref_str_content = stack.enter_context(
                open(ref_string_path, mode='w', encoding=args.encoding, errors='surrogateescape'))

            for line in fref_content:
                ref_ids = [int(idstr) for idstr in line.strip().split()]
                ref_str = bpemb_en.decode_ids(ref_ids)
                fref_str_content.write(ref_str + '\n')

        print("Ref file created!")

    # Training
    for curr_iter in range(1, args.iterations + 1):
        print_dbg = (0 != args.dbg_print_interval) and (curr_iter % args.dbg_print_interval == 0)

        for trainer in trainers:
            trainer.step(print_dbg=print_dbg, include_field_loss=not args.disable_field_loss)

        if args.save is not None and args.save_interval > 0 and curr_iter % args.save_interval == 0:
            save_models('it{0}'.format(curr_iter))

        if curr_iter % args.log_interval == 0:
            print()
            print('[{0}] TRAIN-STEP {1} x {2}'.format(args.save, curr_iter, args.batch))
            for logger in loggers:
                logger.log(curr_iter)

        if curr_iter % iters_per_epoch == 0:
            save_models('it{0}'.format(curr_iter))
            print()
            print('[{0}] VALID-STEP {1}'.format(args.save, curr_iter))
            for logger in loggers:
                if logger.validator is not None:
                    logger.validate(curr_iter)

            model = '{0}.{1}.src2trg.pth'.format(args.save, 'it{0}'.format(curr_iter))

            bleu_thread = threading.Thread(target=calc_bleu,
                                           args=(model, args.save, args.src_valid_corpus, args.trg_valid_corpus + '.str.result',
                                                 ref_string_path, bpemb_en, curr_iter, args.bleu_device, valid_writer))
            bleu_thread.start()
            if args.cuda == args.bleu_device or args.bleu_device == 'cpu':
                bleu_thread.join()

    save_models('final')
    train_writer.close()
    valid_writer.close()
Esempio n. 12
0
def train(args):
    config = SpeechDataset.default_config()
    config["wanted_words"] = "yes no marvin left right".split()
    config["data_folder"] = "data"
    config["cache_size"] = 32768
    config["batch_size"] = 64
    train_set, dev_set, test_set = SpeechDataset.splits(config)

    train_loader = data.DataLoader(train_set,
                                   batch_size=config["batch_size"],
                                   shuffle=True,
                                   drop_last=True,
                                   collate_fn=train_set.collate_fn)
    dev_loader = data.DataLoader(dev_set,
                                 batch_size=min(len(dev_set), 16),
                                 shuffle=True,
                                 collate_fn=dev_set.collate_fn)
    test_loader = data.DataLoader(test_set,
                                  batch_size=min(len(test_set), 16),
                                  shuffle=True,
                                  collate_fn=test_set.collate_fn)

    gen = Generator()
    disc = Discriminator()
    optim_gen = torch.optim.Adam(lr=1e-3,
                                 params=gen.parameters(),
                                 weight_decay=1e-3)
    optim_disc = torch.optim.Adam(lr=1e-3,
                                  params=disc.parameters(),
                                  weight_decay=1e-3)

    start_epoch = 0

    if args.weights_path is not None:
        weights_dict = torch.load(args.weights_path)
        start_epoch = weights_dict['epoch'] + 1
        gen.load_state_dict(weights_dict['gen_state_dict'])
        disc.load_state_dict(weights_dict['disc_state_dict'])
        optim_gen.load_state_dict(weights_dict['optim_gen_state_dict'])
        optim_disc.load_state_dict(weights_dict['optim_disc_state_dict'])
    else:
        gen_state_dict = gen.state_dict()
        for key in gen_state_dict.keys():
            if gen_state_dict[key].dim() >= 2:
                torch.nn.init.xavier_normal_(gen_state_dict[key], 1e-2)
            else:
                if key[-4:] == 'bias':
                    torch.nn.init.zeros_(gen_state_dict[key])
                else:
                    torch.nn.init.ones_(gen_state_dict[key])
        gen.load_state_dict(gen_state_dict)

    model_config = dict(dropout_prob=0.5,
                        height=128,
                        width=40,
                        n_labels=7,
                        n_feature_maps1=64,
                        n_feature_maps2=64,
                        conv1_size=(20, 8),
                        conv2_size=(10, 4),
                        conv1_pool=(2, 2),
                        conv1_stride=(1, 1),
                        conv2_stride=(1, 1),
                        conv2_pool=(1, 1),
                        tf_variant=True)
    kws_model = SpeechModel(model_config)
    kws_model.load(args.kws_model_path)

    dct_filters = torch.from_numpy(
        np.load('dct_filter.npy').astype(np.float32))
    if torch.cuda.is_available():
        dct_filters = dct_filters.cuda()

    num_epochs = args.num_epochs
    c = args.c
    alpha = args.alpha
    beta = args.beta
    mean = torch.load('spectrogram_mean.pkl')
    std = torch.load('spectrogram_std.pkl')

    for epoch in range(start_epoch, num_epochs):
        gen.train()
        disc.train()

        for step, sample in enumerate(train_loader):
            inp, labels = sample
            if torch.cuda.is_available():
                inp = inp.cuda()
                labels = labels.cuda()

            gen_noise = gen(((inp - mean) / std).permute(0, 2, 1))
            gen_noise[:, :, 101:] = torch.zeros(gen_noise.shape[0], 128, 27)
            noise_score = disc((((inp - mean) / std).permute(0, 2, 1) +
                                gen_noise).unsqueeze(1))
            inp_score = disc(((inp - mean) / std).permute(0, 2,
                                                          1).unsqueeze(1))

            kws_inp = inp + gen_noise.permute(0, 2, 1) * std
            kws_inp = kws_inp[:, :101, :].reshape(-1, 128, 101)
            kws_inp_clone = kws_inp.clone()
            kws_inp_clone[kws_inp_clone > 0] = torch.log(kws_inp[kws_inp > 0])
            mfcc_feat = torch.matmul(dct_filters,
                                     kws_inp_clone).permute(0, 2, 1)
            mfcc_feat = F.pad(mfcc_feat, (0, 0, 0, 128 - mfcc_feat.shape[1]))

            kws_out = nn.Softmax(dim=1)(kws_model(mfcc_feat))

            # Optimise Generator
            optim_gen.zero_grad()
            loss_gen = -noise_score.log().mean()
            loss_adv = kws_out.gather(1, labels.view(-1, 1)).mean()
            loss_hinge = nn.ReLU()((gen_noise * std).norm(p=2, dim=(1, 2)) -
                                   c).mean()
            loss_gen_total = loss_gen + alpha * loss_hinge + beta * loss_adv
            loss_gen_total.backward(retain_graph=True)
            optim_gen.step()
            print("Epoch : ", epoch, " , Step : ", step)
            print("Generator Loss", loss_gen)
            print("Loss Adv", loss_adv)
            print("Loss Hinge", loss_hinge)

            # Optimise Discriminator
            optim_disc.zero_grad()
            loss_disc = -(inp_score.log().mean() +
                          (1 - noise_score).log().mean())
            loss_disc.backward()
            optim_disc.step()
            print("Discriminator Loss", loss_disc)

            print("======================================")

        weights_dict = {}
        weights_dict['epoch'] = epoch
        weights_dict['gen_state_dict'] = gen.state_dict()
        weights_dict['disc_state_dict'] = disc.state_dict()
        weights_dict['optim_gen_state_dict'] = optim_gen.state_dict()
        weights_dict['optim_disc_state_dict'] = optim_disc.state_dict()
        weights_dict_path = args.save_folder_path + '/epoch{}.weights'.format(
            epoch)
        torch.save(weights_dict, weights_dict_path)
def main():
    args = parse_args()
    device = torch.device("cuda")

    generator = Generator.from_file(args.generator_path).to(device)
    if args.freeze:
        for name, param in generator.named_parameters():
            if ("shared" not in name) and ("decoder.block.5" not in name):
                param.requires_grad = False
    generator.eval()

    discriminator = Discriminator(tokenizer=generator.tokenizer).to(device)
    if args.freeze:
        for name, param in discriminator.named_parameters():
            if ("shared" not in name) and ("decoder.block.5" not in name):
                param.requires_grad = False

    train_dataset = DailyDialogueDataset(
        path_join(args.dataset_path, "train/dialogues_train.txt"),
        tokenizer=generator.tokenizer,
    )
    valid_dataset = DailyDialogueDataset(
        path_join(args.dataset_path, "validation/dialogues_validation.txt"),
        tokenizer=generator.tokenizer,
    )

    print(len(train_dataset), len(valid_dataset))

    optimizer = AdamW(discriminator.parameters(), lr=args.lr)

    best_loss = np.float("inf")

    for epoch in tqdm(range(args.num_epochs)):
        train_loss, valid_loss = [], []
        rewards_real, rewards_fake, accuracy = [], [], []
        discriminator.train()
        for ind in np.random.permutation(len(train_dataset)):
            optimizer.zero_grad()
            context, real_reply = train_dataset.sample_dialouge(ind)
            context, real_reply = (
                context.to(device),
                real_reply.to(device),
            )
            fake_reply = generator.generate(context, do_sample=True)

            if args.partial:
                split_real = random.randint(1, real_reply.size(1))
                real_reply = real_reply[:, :split_real]
                split_fake = random.randint(1, fake_reply.size(1) - 1)
                fake_reply = fake_reply[:, :split_fake]

            loss, _, _ = discriminator.get_loss(context, real_reply, fake_reply)
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())

        discriminator.eval()
        for ind in range(len(valid_dataset)):
            context, real_reply = valid_dataset[ind]
            context, real_reply = (
                context.to(device),
                real_reply.to(device),
            )
            fake_reply = generator.generate(context, do_sample=True)

            if args.partial:
                split_real = random.randint(1, real_reply.size(1))
                real_reply = real_reply[:, :split_real]
                split_fake = random.randint(1, fake_reply.size(1) - 1)
                fake_reply = fake_reply[:, :split_fake]

            with torch.no_grad():
                loss, reward_real, reward_fake = discriminator.get_loss(
                    context, real_reply, fake_reply
                )
            valid_loss.append(loss.item())
            rewards_real.append(reward_real)
            rewards_fake.append(reward_fake)
            accuracy.extend([reward_real > 0.5, reward_fake < 0.5])

        train_loss, valid_loss = np.mean(train_loss), np.mean(valid_loss)
        print(
            f"Epoch {epoch + 1}, Train Loss: {train_loss:.2f}, Valid Loss: {valid_loss:.2f}, Reward real: {np.mean(rewards_real):.2f}, Reward fake: {np.mean(rewards_fake):.2f}, Accuracy: {np.mean(accuracy):.2f}"
        )
        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(discriminator.state_dict(), args.output_path)