class Trainer: def __init__(self, corpus_data_0, corpus_data_1, *, params, n_samples=10000000): self.fast_text = [FastText(corpus_data_0.model).to(GPU), FastText(corpus_data_1.model).to(GPU)] self.discriminator = Discriminator(params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units, drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input, leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU) self.mapping = nn.Linear(params.emb_dim, params.emb_dim, bias=False) self.mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim))) self.mapping = self.mapping.to(GPU) self.ft_optimizer, self.ft_scheduler = [], [] for id in [0, 1]: optimizer, scheduler = optimizers.get_sgd_adapt(self.fast_text[id].parameters(), lr=params.ft_lr, mode="max", factor=params.ft_lr_decay, patience=params.ft_lr_patience) self.ft_optimizer.append(optimizer) self.ft_scheduler.append(scheduler) self.a_optimizer, self.a_scheduler = [], [] for id in [0, 1]: optimizer, scheduler = optimizers.get_sgd_adapt( [{"params": self.fast_text[id].u.parameters()}, {"params": self.fast_text[id].v.parameters()}], lr=params.a_lr, mode="max", factor=params.a_lr_decay, patience=params.a_lr_patience) self.a_optimizer.append(optimizer) self.a_scheduler.append(scheduler) if params.d_optimizer == "SGD": self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt(self.discriminator.parameters(), lr=params.d_lr, mode="max", wd=params.d_wd) elif params.d_optimizer == "RMSProp": self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear(self.discriminator.parameters(), params.n_steps, lr=params.d_lr, wd=params.d_wd) else: raise Exception(f"Optimizer {params.d_optimizer} not found.") if params.m_optimizer == "SGD": self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt(self.mapping.parameters(), lr=params.m_lr, mode="max", wd=params.m_wd, factor=params.m_lr_decay, patience=params.m_lr_patience) elif params.m_optimizer == "RMSProp": self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear(self.mapping.parameters(), params.n_steps, lr=params.m_lr, wd=params.m_wd) else: raise Exception(f"Optimizer {params.m_optimizer} not found") self.m_beta = params.m_beta self.smooth = params.smooth self.wgan = params.wgan self.d_clip_mode = params.d_clip_mode if params.wgan: self.loss_fn = _wasserstein_distance else: self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean") self.corpus_data_queue = [ _data_queue(corpus_data_0, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences, batch_size=params.ft_bs), _data_queue(corpus_data_1, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences, batch_size=params.ft_bs) ] self.sampler = [ WordSampler(corpus_data_0.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top), WordSampler(corpus_data_1.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top)] self.d_bs = params.d_bs self.dic_0, self.dic_1 = corpus_data_0.dic, corpus_data_1.dic self.d_gp = params.d_gp def fast_text_step(self): losses = [] for id in [0, 1]: self.ft_optimizer[id].zero_grad() u_b, v_b = self.corpus_data_queue[id].__next__() s = self.fast_text[id](u_b, v_b) loss = FastText.loss_fn(s) loss.backward() self.ft_optimizer[id].step() losses.append(loss.item()) return losses[0], losses[1] def get_adv_batch(self, *, reverse, fix_embedding=False, gp=False): batch = [[self.sampler[id].sample() for _ in range(self.d_bs)] for id in [0, 1]] batch = [self.fast_text[id].model.get_bag(batch[id], self.fast_text[id].u.weight.device) for id in [0, 1]] if fix_embedding: with torch.no_grad(): x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]] else: x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]] y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth) if reverse: y[: self.d_bs] = 1 - y[: self.d_bs] else: y[self.d_bs:] = 1 - y[self.d_bs:] x[0] = self.mapping(x[0]) if gp: t = torch.FloatTensor(self.d_bs, 1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0]) z = x[0] * t + x[1] * (1.0 - t) x = torch.cat(x, 0) return x, y, z else: x = torch.cat(x, 0) return x, y def adversarial_step(self, fix_embedding=False): for id in [0, 1]: self.a_optimizer[id].zero_grad() self.m_optimizer.zero_grad() self.discriminator.eval() x, y = self.get_adv_batch(reverse=True, fix_embedding=fix_embedding) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() for id in [0, 1]: self.a_optimizer[id].step() self.m_optimizer.step() _orthogonalize(self.mapping, self.m_beta) return loss.item() def discriminator_step(self): self.d_optimizer.zero_grad() self.discriminator.train() with torch.no_grad(): if self.d_gp > 0: x, y, z = self.get_adv_batch(reverse=False, gp=True) else: x, y = self.get_adv_batch(reverse=False) z = None y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) if self.d_gp > 0: z.requires_grad_() z_out = self.discriminator(z) g = autograd.grad(z_out, z, grad_outputs=torch.ones_like(z_out, device=GPU), retain_graph=True, create_graph=True, only_inputs=True)[0] gp = torch.mean((g.norm(p=2, dim=1) - 1.0) ** 2) loss += self.d_gp * gp loss.backward() self.d_optimizer.step() if self.wgan: self.discriminator.clip_weights(self.d_clip_mode) return loss.item() def scheduler_step(self, metric): for id in [0, 1]: self.ft_scheduler[id].step(metric) self.a_scheduler[id].step(metric) # self.d_scheduler.step(metric) self.m_scheduler.step(metric)
class Trainer: def __init__(self, params, *, n_samples=10000000): self.model = [ fastText.load_model( os.path.join(params.dataDir, params.model_path_0)), fastText.load_model( os.path.join(params.dataDir, params.model_path_1)) ] self.dic = [ list(zip(*self.model[id].get_words(include_freq=True))) for id in [0, 1] ] x = [ np.empty((params.vocab_size, params.emb_dim), dtype=np.float64) for _ in [0, 1] ] for id in [0, 1]: for i in range(params.vocab_size): x[id][i, :] = self.model[id].get_word_vector( self.dic[id][i][0]) x[id] = normalize_embeddings_np(x[id], params.normalize_pre) u0, s0, _ = scipy.linalg.svd(x[0], full_matrices=False) u1, s1, _ = scipy.linalg.svd(x[1], full_matrices=False) if params.spectral_align_pre: s = (s0 + s1) * 0.5 x[0] = u0 @ np.diag(s) x[1] = u1 @ np.diag(s) else: x[0] = u0 @ np.diag(s0) x[1] = u1 @ np.diag(s1) self.embedding = [ nn.Embedding.from_pretrained(torch.from_numpy(x[id]).to( torch.float).to(GPU), freeze=True, sparse=True) for id in [0, 1] ] self.discriminator = Discriminator( params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units, drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input, leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU) self.mapping = Mapping(params.emb_dim).to(GPU) if params.d_optimizer == "SGD": self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt( self.discriminator.parameters(), lr=params.d_lr, mode="max", wd=params.d_wd) elif params.d_optimizer == "RMSProp": self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear( self.discriminator.parameters(), params.n_steps, lr=params.d_lr, wd=params.d_wd) else: raise Exception(f"Optimizer {params.d_optimizer} not found.") if params.m_optimizer == "SGD": self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt( self.mapping.parameters(), lr=params.m_lr, mode="max", wd=params.m_wd, factor=params.m_lr_decay, patience=params.m_lr_patience) elif params.m_optimizer == "RMSProp": self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear( self.mapping.parameters(), params.n_steps, lr=params.m_lr, wd=params.m_wd) else: raise Exception(f"Optimizer {params.m_optimizer} not found") self.m_beta = params.m_beta self.smooth = params.smooth self.wgan = params.wgan self.d_clip_mode = params.d_clip_mode if params.wgan: self.loss_fn = _wasserstein_distance else: self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean") self.sampler = [ WordSampler(self.dic[id], n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top) for id in [0, 1] ] self.d_bs = params.d_bs self.d_gp = params.d_gp def get_adv_batch(self, *, reverse, gp=False): batch = [ torch.LongTensor( [self.sampler[id].sample() for _ in range(self.d_bs)]).view(self.d_bs, 1).to(GPU) for id in [0, 1] ] with torch.no_grad(): x = [ self.embedding[id](batch[id]).view(self.d_bs, -1) for id in [0, 1] ] y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth) if reverse: y[:self.d_bs] = 1 - y[:self.d_bs] else: y[self.d_bs:] = 1 - y[self.d_bs:] x[0] = self.mapping(x[0]) if gp: t = torch.FloatTensor(self.d_bs, 1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0]) z = x[0] * t + x[1] * (1.0 - t) x = torch.cat(x, 0) return x, y, z else: x = torch.cat(x, 0) return x, y def adversarial_step(self): self.m_optimizer.zero_grad() self.discriminator.eval() x, y = self.get_adv_batch(reverse=True) y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) loss.backward() self.m_optimizer.step() self.mapping.clip_weights() return loss.item() def discriminator_step(self): self.d_optimizer.zero_grad() self.discriminator.train() with torch.no_grad(): if self.d_gp > 0: x, y, z = self.get_adv_batch(reverse=False, gp=True) else: x, y = self.get_adv_batch(reverse=False) z = None y_hat = self.discriminator(x) loss = self.loss_fn(y_hat, y) if self.d_gp > 0: z.requires_grad_() z_out = self.discriminator(z) g = autograd.grad(z_out, z, grad_outputs=torch.ones_like(z_out, device=GPU), retain_graph=True, create_graph=True, only_inputs=True)[0] gp = torch.mean((g.norm(p=2, dim=1) - 1.0)**2) loss += self.d_gp * gp loss.backward() self.d_optimizer.step() if self.wgan: self.discriminator.clip_weights(self.d_clip_mode) return loss.item() def scheduler_step(self, metric): self.m_scheduler.step(metric)