Ejemplo n.º 1
0
class Trainer:
    def __init__(self, corpus_data_0, corpus_data_1, *, params, n_samples=10000000):
        self.fast_text = [FastText(corpus_data_0.model).to(GPU), FastText(corpus_data_1.model).to(GPU)]
        self.discriminator = Discriminator(params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units,
                                           drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input,
                                           leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU)
        self.mapping = nn.Linear(params.emb_dim, params.emb_dim, bias=False)
        self.mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim)))
        self.mapping = self.mapping.to(GPU)
        self.ft_optimizer, self.ft_scheduler = [], []
        for id in [0, 1]:
            optimizer, scheduler = optimizers.get_sgd_adapt(self.fast_text[id].parameters(),
                                                            lr=params.ft_lr, mode="max", factor=params.ft_lr_decay,
                                                            patience=params.ft_lr_patience)
            self.ft_optimizer.append(optimizer)
            self.ft_scheduler.append(scheduler)
        self.a_optimizer, self.a_scheduler = [], []
        for id in [0, 1]:
            optimizer, scheduler = optimizers.get_sgd_adapt(
                [{"params": self.fast_text[id].u.parameters()}, {"params": self.fast_text[id].v.parameters()}],
                lr=params.a_lr, mode="max", factor=params.a_lr_decay, patience=params.a_lr_patience)
            self.a_optimizer.append(optimizer)
            self.a_scheduler.append(scheduler)
        if params.d_optimizer == "SGD":
            self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt(self.discriminator.parameters(),
                                                                          lr=params.d_lr, mode="max", wd=params.d_wd)

        elif params.d_optimizer == "RMSProp":
            self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear(self.discriminator.parameters(),
                                                                               params.n_steps,
                                                                               lr=params.d_lr, wd=params.d_wd)
        else:
            raise Exception(f"Optimizer {params.d_optimizer} not found.")
        if params.m_optimizer == "SGD":
            self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt(self.mapping.parameters(),
                                                                          lr=params.m_lr, mode="max", wd=params.m_wd,
                                                                          factor=params.m_lr_decay,
                                                                          patience=params.m_lr_patience)
        elif params.m_optimizer == "RMSProp":
            self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear(self.mapping.parameters(),
                                                                               params.n_steps,
                                                                               lr=params.m_lr, wd=params.m_wd)
        else:
            raise Exception(f"Optimizer {params.m_optimizer} not found")
        self.m_beta = params.m_beta
        self.smooth = params.smooth
        self.wgan = params.wgan
        self.d_clip_mode = params.d_clip_mode
        if params.wgan:
            self.loss_fn = _wasserstein_distance
        else:
            self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean")
        self.corpus_data_queue = [
            _data_queue(corpus_data_0, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences,
                        batch_size=params.ft_bs),
            _data_queue(corpus_data_1, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences,
                        batch_size=params.ft_bs)
        ]
        self.sampler = [
            WordSampler(corpus_data_0.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top),
            WordSampler(corpus_data_1.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top)]
        self.d_bs = params.d_bs
        self.dic_0, self.dic_1 = corpus_data_0.dic, corpus_data_1.dic
        self.d_gp = params.d_gp

    def fast_text_step(self):
        losses = []
        for id in [0, 1]:
            self.ft_optimizer[id].zero_grad()
            u_b, v_b = self.corpus_data_queue[id].__next__()
            s = self.fast_text[id](u_b, v_b)
            loss = FastText.loss_fn(s)
            loss.backward()
            self.ft_optimizer[id].step()
            losses.append(loss.item())
        return losses[0], losses[1]

    def get_adv_batch(self, *, reverse, fix_embedding=False, gp=False):
        batch = [[self.sampler[id].sample() for _ in range(self.d_bs)]
                 for id in [0, 1]]
        batch = [self.fast_text[id].model.get_bag(batch[id], self.fast_text[id].u.weight.device)
                 for id in [0, 1]]
        if fix_embedding:
            with torch.no_grad():
                x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]]
        else:
            x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]]
        y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth)
        if reverse:
            y[: self.d_bs] = 1 - y[: self.d_bs]
        else:
            y[self.d_bs:] = 1 - y[self.d_bs:]
        x[0] = self.mapping(x[0])
        if gp:
            t = torch.FloatTensor(self.d_bs, 1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0])
            z = x[0] * t + x[1] * (1.0 - t)
            x = torch.cat(x, 0)
            return x, y, z
        else:
            x = torch.cat(x, 0)
            return x, y

    def adversarial_step(self, fix_embedding=False):
        for id in [0, 1]:
            self.a_optimizer[id].zero_grad()
        self.m_optimizer.zero_grad()
        self.discriminator.eval()
        x, y = self.get_adv_batch(reverse=True, fix_embedding=fix_embedding)
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        loss.backward()
        for id in [0, 1]:
            self.a_optimizer[id].step()
        self.m_optimizer.step()
        _orthogonalize(self.mapping, self.m_beta)
        return loss.item()

    def discriminator_step(self):
        self.d_optimizer.zero_grad()
        self.discriminator.train()
        with torch.no_grad():
            if self.d_gp > 0:
                x, y, z = self.get_adv_batch(reverse=False, gp=True)
            else:
                x, y = self.get_adv_batch(reverse=False)
                z = None
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        if self.d_gp > 0:
            z.requires_grad_()
            z_out = self.discriminator(z)
            g = autograd.grad(z_out, z, grad_outputs=torch.ones_like(z_out, device=GPU),
                              retain_graph=True, create_graph=True, only_inputs=True)[0]
            gp = torch.mean((g.norm(p=2, dim=1) - 1.0) ** 2)
            loss += self.d_gp * gp
        loss.backward()
        self.d_optimizer.step()
        if self.wgan:
            self.discriminator.clip_weights(self.d_clip_mode)
        return loss.item()

    def scheduler_step(self, metric):
        for id in [0, 1]:
            self.ft_scheduler[id].step(metric)
            self.a_scheduler[id].step(metric)
        # self.d_scheduler.step(metric)
        self.m_scheduler.step(metric)
Ejemplo n.º 2
0
class Trainer:
    def __init__(self, params, *, n_samples=10000000):
        self.model = [
            fastText.load_model(
                os.path.join(params.dataDir, params.model_path_0)),
            fastText.load_model(
                os.path.join(params.dataDir, params.model_path_1))
        ]
        self.dic = [
            list(zip(*self.model[id].get_words(include_freq=True)))
            for id in [0, 1]
        ]
        x = [
            np.empty((params.vocab_size, params.emb_dim), dtype=np.float64)
            for _ in [0, 1]
        ]
        for id in [0, 1]:
            for i in range(params.vocab_size):
                x[id][i, :] = self.model[id].get_word_vector(
                    self.dic[id][i][0])
            x[id] = normalize_embeddings_np(x[id], params.normalize_pre)
        u0, s0, _ = scipy.linalg.svd(x[0], full_matrices=False)
        u1, s1, _ = scipy.linalg.svd(x[1], full_matrices=False)
        if params.spectral_align_pre:
            s = (s0 + s1) * 0.5
            x[0] = u0 @ np.diag(s)
            x[1] = u1 @ np.diag(s)
        else:
            x[0] = u0 @ np.diag(s0)
            x[1] = u1 @ np.diag(s1)
        self.embedding = [
            nn.Embedding.from_pretrained(torch.from_numpy(x[id]).to(
                torch.float).to(GPU),
                                         freeze=True,
                                         sparse=True) for id in [0, 1]
        ]
        self.discriminator = Discriminator(
            params.emb_dim,
            n_layers=params.d_n_layers,
            n_units=params.d_n_units,
            drop_prob=params.d_drop_prob,
            drop_prob_input=params.d_drop_prob_input,
            leaky=params.d_leaky,
            batch_norm=params.d_bn).to(GPU)
        self.mapping = Mapping(params.emb_dim).to(GPU)
        if params.d_optimizer == "SGD":
            self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt(
                self.discriminator.parameters(),
                lr=params.d_lr,
                mode="max",
                wd=params.d_wd)

        elif params.d_optimizer == "RMSProp":
            self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear(
                self.discriminator.parameters(),
                params.n_steps,
                lr=params.d_lr,
                wd=params.d_wd)
        else:
            raise Exception(f"Optimizer {params.d_optimizer} not found.")
        if params.m_optimizer == "SGD":
            self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt(
                self.mapping.parameters(),
                lr=params.m_lr,
                mode="max",
                wd=params.m_wd,
                factor=params.m_lr_decay,
                patience=params.m_lr_patience)
        elif params.m_optimizer == "RMSProp":
            self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear(
                self.mapping.parameters(),
                params.n_steps,
                lr=params.m_lr,
                wd=params.m_wd)
        else:
            raise Exception(f"Optimizer {params.m_optimizer} not found")
        self.m_beta = params.m_beta
        self.smooth = params.smooth
        self.wgan = params.wgan
        self.d_clip_mode = params.d_clip_mode
        if params.wgan:
            self.loss_fn = _wasserstein_distance
        else:
            self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean")
        self.sampler = [
            WordSampler(self.dic[id],
                        n_urns=n_samples,
                        alpha=params.a_sample_factor,
                        top=params.a_sample_top) for id in [0, 1]
        ]
        self.d_bs = params.d_bs
        self.d_gp = params.d_gp

    def get_adv_batch(self, *, reverse, gp=False):
        batch = [
            torch.LongTensor(
                [self.sampler[id].sample()
                 for _ in range(self.d_bs)]).view(self.d_bs, 1).to(GPU)
            for id in [0, 1]
        ]
        with torch.no_grad():
            x = [
                self.embedding[id](batch[id]).view(self.d_bs, -1)
                for id in [0, 1]
            ]
        y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth)
        if reverse:
            y[:self.d_bs] = 1 - y[:self.d_bs]
        else:
            y[self.d_bs:] = 1 - y[self.d_bs:]
        x[0] = self.mapping(x[0])
        if gp:
            t = torch.FloatTensor(self.d_bs,
                                  1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0])
            z = x[0] * t + x[1] * (1.0 - t)
            x = torch.cat(x, 0)
            return x, y, z
        else:
            x = torch.cat(x, 0)
            return x, y

    def adversarial_step(self):
        self.m_optimizer.zero_grad()
        self.discriminator.eval()
        x, y = self.get_adv_batch(reverse=True)
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        loss.backward()
        self.m_optimizer.step()
        self.mapping.clip_weights()
        return loss.item()

    def discriminator_step(self):
        self.d_optimizer.zero_grad()
        self.discriminator.train()
        with torch.no_grad():
            if self.d_gp > 0:
                x, y, z = self.get_adv_batch(reverse=False, gp=True)
            else:
                x, y = self.get_adv_batch(reverse=False)
                z = None
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        if self.d_gp > 0:
            z.requires_grad_()
            z_out = self.discriminator(z)
            g = autograd.grad(z_out,
                              z,
                              grad_outputs=torch.ones_like(z_out, device=GPU),
                              retain_graph=True,
                              create_graph=True,
                              only_inputs=True)[0]
            gp = torch.mean((g.norm(p=2, dim=1) - 1.0)**2)
            loss += self.d_gp * gp
        loss.backward()
        self.d_optimizer.step()
        if self.wgan:
            self.discriminator.clip_weights(self.d_clip_mode)
        return loss.item()

    def scheduler_step(self, metric):
        self.m_scheduler.step(metric)