def critic_loss_on_batch(self, batch): batch.domain_x = cudable(batch.domain_x) batch.domain_y = cudable(batch.domain_y) x_hid = self.encoder(batch.domain_x) y_hid = self.encoder(batch.domain_y) x2y_hid = self.gen_x2y(x_hid) y2x_hid = self.gen_y2x(y_hid) # Critic loss critic_x_preds_x, critic_x_preds_y2x = self.critic_x( x_hid), self.critic_x(y2x_hid) critic_y_preds_y, critic_y_preds_x2y = self.critic_y( y_hid), self.critic_y(x2y_hid) critic_x_loss = self.critic_criterion(critic_x_preds_x, critic_x_preds_y2x) critic_y_loss = self.critic_criterion(critic_y_preds_y, critic_y_preds_x2y) critic_x_gp = wgan_gp(self.critic_x, x_hid, y2x_hid) critic_y_gp = wgan_gp(self.critic_y, y_hid, x2y_hid) critic_x_total_loss = critic_x_loss + self.config.hp.gp_lambda * critic_x_gp critic_y_total_loss = critic_y_loss + self.config.hp.gp_lambda * critic_y_gp critics_total_loss = (critic_x_total_loss + critic_y_total_loss) / 2 losses_info = { 'critic_loss/domain_x': critic_x_loss.item(), 'critic_loss/domain_y': critic_y_loss.item(), 'critic_loss/gp_x': critic_x_gp.item(), 'critic_loss/gp_y': critic_y_gp.item(), } return critics_total_loss, losses_info
def wgan_gp(critic, real_data, fake_data, *critic_args, **critic_kwargs): "Computes gradient penalty according to WGAN-GP paper" assert real_data.size() == fake_data.size() batch_size, ndim = real_data.size(0), real_data.dim() if ndim == 1: eps = cudable(torch.rand(batch_size)) else: eps = cudable(torch.rand(batch_size, *np.ones(ndim - 1).astype(int))) eps = eps.expand(real_data.size()) interpolations = eps * real_data + (1 - eps) * fake_data preds = critic(interpolations, *critic_args, **critic_kwargs) grads = autograd.grad( outputs=preds, inputs=interpolations, grad_outputs=cudable(torch.ones(preds.size())), retain_graph=True, create_graph=True, only_inputs=True )[0] gp = ((grads.norm(2, dim=1) - 1) ** 2).mean() return gp
def gen_loss_on_batch(self, batch): batch.domain_x = cudable(batch.domain_x) batch.domain_y = cudable(batch.domain_y) x_hid = self.encoder(batch.domain_x) y_hid = self.encoder(batch.domain_y) x2y_hid = self.gen_x2y(x_hid) y2x_hid = self.gen_y2x(y_hid) x2y2x_hid = self.gen_y2x(x2y_hid) y2x2y_hid = self.gen_x2y(y2x_hid) # Generator loss (consists of making critic's life harder and lp loss) critic_x_preds_y2x = self.critic_x(y2x_hid) critic_y_preds_x2y = self.critic_y(x2y_hid) adv_x_loss = self.generator_criterion(critic_y_preds_x2y) adv_y_loss = self.generator_criterion(critic_x_preds_y2x) adv_loss = (adv_x_loss + adv_y_loss) / 2 x_hid_lp_loss = torch.norm(x_hid - x2y2x_hid, p=self.config.hp.p_norm) y_hid_lp_loss = torch.norm(y_hid - y2x2y_hid, p=self.config.hp.p_norm) lp_loss = (x_hid_lp_loss + y_hid_lp_loss) / 2 # Total loss gen_loss = adv_loss + self.config.hp.lp_loss_coef * lp_loss losses_info = { 'gen_adv_loss/domain_x': adv_x_loss.item(), 'gen_adv_loss/domain_y': adv_y_loss.item(), 'gen_lp_loss/domain_x': x_hid_lp_loss.item(), 'gen_lp_loss/domain_y': y_hid_lp_loss.item(), } return gen_loss, losses_info
def validate(self): rec_losses = [] bleus = [] for batch in self.val_dataloader: batch.src, batch.trg = cudable(batch.src), cudable(batch.trg) # CE loss rec_loss = self.loss_on_batch(batch) rec_losses.append(rec_loss.item()) # BLEU encs, enc_mask = self.transformer.encoder(batch.src) preds = InferenceState({ 'model': self.transformer.decoder, 'inputs': encs, 'enc_mask': enc_mask, 'vocab': self.vocab_trg, 'max_len': 50 }).inference() preds = itos_many(preds, self.vocab_trg) gold = itos_many(batch.trg, self.vocab_trg) bleu = compute_bleu_for_sents(preds, gold) bleus.append(bleu) self.writer.add_scalar('val/rec_loss', np.mean(rec_losses), self.num_iters_done) self.writer.add_scalar('val/bleu', np.mean(bleus), self.num_iters_done) self.losses['val_bleu'].append(np.mean(bleus)) texts = ['Translation: {}\n\n Gold: {}'.format(t,g) for t,g in zip(preds, gold)] text = '\n\n ================== \n\n'.join(texts[:10]) self.writer.add_text('Samples', text, self.num_iters_done)
def loss_on_batch(self, batch): batch.src, batch.trg = cudable(batch.src), cudable(batch.trg) recs = self.transformer(batch.src, batch.trg) targets = batch.trg[:, 1:].contiguous().view(-1) rec_loss = self.criterion(recs.view(-1, len(self.vocab_trg)), targets) return rec_loss
def predict(lines): lines = tokenize(lines) # Grouping all lines into batches src, trg = generate_dataset_with_middle_chars(lines) examples = [Example.fromlist([m,o], fields) for m,o in zip(src, trg)] ds = Dataset(examples, fields) dataloader = data.BucketIterator(ds, batch_size, repeat=False, shuffle=False) word_translations = [] for batch in dataloader: # Generating predictions batch.src = cudable(batch.src) batch.trg = cudable(batch.trg) morphs = morph_chars_idx(batch.trg, field.vocab) morphs = cudable(torch.from_numpy(morphs).float()) first_chars_embs = decoder.embed(batch.trg[:, :n_first_chars]) z = encoder(batch.src) z = merge_z(torch.cat([z, morphs], dim=1)) z = decoder.gru(first_chars_embs, z.unsqueeze(0))[1].squeeze(0) out = simple_inference(decoder, z, field.vocab, max_len=30) first_chars = batch.trg[:, :n_first_chars].cpu().numpy().tolist() results = [s + p for s,p in zip(first_chars, out)] results = itos_many(results, field.vocab, sep='') word_translations.extend(results) transfered = group_by_lens(word_translations, [len(s.split()) for s in lines]) transfered = [mix_transfered(o,t) for o,t in zip(lines, transfered)] return transfered
def pad_to_max(self, seq, max_len): if len(seq) == max_len: return seq if self.gumbel: # TODO: we fill with eps to prevent numerical issues in loss computation later # But loss on pads is always zero, why are we doing this? eps = 1e-8 pads = cudable(T.zeros(max_len - len(seq), len(self.vocab))) pads = pads.fill_(eps).index_fill_(1, T.tensor(self.pad_idx), 1.) else: pads = cudable(T.zeros(max_len - len(seq)).fill_(self.pad_idx)) return T.cat((seq, pads), dim=0)
def init_models(self): size = self.config.hp.size dropout_p = self.config.hp.dropout dropword_p = self.config.hp.dropword self.encoder = RNNEncoder(size, size, self.vocab, dropword_p) self.decoder = RNNDecoder(size, size, self.vocab, dropword_p) self.split_nn = SplitNN(size, self.config.hp.style_vec_size) self.motivator = FFN([self.config.hp.style_vec_size, 1], dropout=dropout_p) self.critic = FFN([size, size, 1], dropout=dropout_p) self.merge_nn = MergeNN(size, self.config.hp.style_vec_size) # Let's save all ae params into single list for future use self.ae_params = list( chain( self.encoder.parameters(), self.decoder.parameters(), self.split_nn.parameters(), self.merge_nn.parameters(), )) self.dissonet = cudable( DissoNet(self.encoder, self.decoder, self.split_nn, self.motivator, self.critic, self.merge_nn)) if torch.cuda.device_count() > 1: print('Going to parallelize on {} GPUs'.format( torch.cuda.device_count())) self.dissonet = nn.DataParallel(self.dissonet)
def transfer_style_on_batch(self, batch): batch = cudable(batch) state_x = self.encoder(batch.domain_x) state_y = self.encoder(batch.domain_y) content_x, style_x = self.dissonet.split_nn(state_x) content_y, style_y = self.dissonet.split_nn(state_y) state_x2y = self.merge_nn(content_x, style_y) state_y2x = self.merge_nn(content_y, style_x) state_x2x = self.merge_nn(content_x, style_x) state_y2y = self.merge_nn(content_y, style_y) x2y = simple_inference(self.decoder, state_x2y, self.vocab, eos_token='|') y2x = simple_inference(self.decoder, state_y2x, self.vocab, eos_token='|') x2x = simple_inference(self.decoder, state_x2x, self.vocab, eos_token='|') y2y = simple_inference(self.decoder, state_y2y, self.vocab, eos_token='|') return x2y, y2x, x2x, y2y
def subsequent_mask(size): "Mask out subsequent positions." attn_shape = (1, size, size) mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') mask = cudable(torch.from_numpy(mask) == 0) return mask
def stack_finished(self): "Pads finished sequences with <pad> token and stacks into tensor" max_len = max(len(s) for s in self.finished) for i, seq in enumerate(self.finished): self.finished[i] = self.pad_to_max(seq, max_len) self.finished = cudable(T.stack(self.finished))
def transfer_style_on_batch(self, batch): batch.domain_x = cudable(batch.domain_x) batch.domain_y = cudable(batch.domain_y) x_z = self.encoder(batch.domain_x) y_z = self.encoder(batch.domain_y) x2y_z = self.gen_x2y(x_z) y2x_z = self.gen_y2x(y_z) x2x_z = self.gen_y2x(x2y_z) y2y_z = self.gen_x2y(y2x_z) x2y = simple_inference(self.decoder, x2y_z, self.vocab, eos_token='|') y2x = simple_inference(self.decoder, y2x_z, self.vocab, eos_token='|') x2x = simple_inference(self.decoder, x2x_z, self.vocab, eos_token='|') y2y = simple_inference(self.decoder, y2y_z, self.vocab, eos_token='|') return x2y, y2x, x2x, y2y
def pad_to(self, seq, n): "Pads sequence with zero vectors to the desired length" assert len(seq) <= n if len(seq) == n: return seq pads = torch.zeros(n - len(seq), seq.size(1)) pads = cudable(pads).type(seq.type()) return torch.cat([seq, pads])
def forward(self, encs, encs_mask): tokens = torch.arange(self.config.n_vecs).long().unsqueeze(0).repeat( encs.size(0), 1) x = self.embed(cudable(tokens)) x = x * np.sqrt(self.config.d_model) for _ in range(self.config.n_steps): x = self.layer(encs, x, encs_mask, None) x = self.norm(x) return x
def init_models(self): size = self.config.hp.model_size dropout_p = self.config.hp.dropout dropword_p = self.config.hp.dropword self.encoder = cudable( RNNEncoder(size, size, self.vocab, dropword_p, noise=self.config.hp.noiseness)) self.decoder = cudable(RNNDecoder(size, size, self.vocab, dropword_p)) def create_critic(): return FFN([size, size, 1], dropout_p) # GAN from X to Y self.gen_x2y = cudable(Generator(size, self.config.hp.gen_n_rec_steps)) self.critic_y = cudable(create_critic()) # GAN from Y to X self.gen_y2x = cudable(Generator(size, self.config.hp.gen_n_rec_steps)) self.critic_x = cudable(create_critic()) self.transfer_strength = TransferStrength(self.config) self.content_preservation = ContentPreservation(self.config)
def validate(self): losses = [] for batch in self.val_dataloader: batch.domain_x = cudable(batch.domain_x) batch.domain_y = cudable(batch.domain_y) *_, rec_losses_info = self.ae_loss_on_batch(batch) *_, gen_losses_info = self.gen_loss_on_batch(batch) with torch.enable_grad(): *_, critic_losses_info = self.critic_loss_on_batch(batch) losses_info = dict( list(critic_losses_info.items()) + list(gen_losses_info.items()) + list(rec_losses_info.items())) losses.append(losses_info) for l in losses[0].keys(): value = np.mean([info[l] for info in losses]) self.writer.add_scalar('VAL/' + l, value, self.num_iters_done) # Ok, let's now validate style transfer and auto-encoding self.validate_inference()
def init_lm(config_path, state_path, model_cls_name: str): model_cls = MODEL_CLASSES[model_cls_name] hp = load_config(config_path).get('hp') get_path = create_get_path_fn(state_path) # Loading vocab field = Field(eos_token=EOS_TOKEN, batch_first=True, tokenize=char_tokenize, pad_first=True) field.vocab = pickle.load(open(get_path('vocab', 'pickle'), 'rb')) print('Loading models..') device = None if torch.cuda.is_available() else 'cpu' if model_cls is RNNLM: lm = cudable(RNNLM(hp.model_size, field.vocab, n_layers=hp.n_layers)).eval() lm.load_state_dict(torch.load(get_path('lm'), map_location=device)) elif model_cls is ConditionalLM: lm = cudable(ConditionalLM(hp.model_size, field.vocab)).eval() lm.load_state_dict(torch.load(get_path('lm'), map_location=device)) elif model_cls is CharLMFromEmbs: rnn_lm = cudable( RNNLM(hp.model_size, field.vocab, n_layers=hp.n_layers)) style_embed = cudable(nn.Embedding(2, hp.model_size)) rnn_lm.load_state_dict(torch.load(get_path('lm'), map_location=device)) style_embed.load_state_dict( torch.load(get_path('style_embed'), map_location=device)) lm = cudable(CharLMFromEmbs(rnn_lm, style_embed, n_layers=hp.n_layers)).eval() else: raise NotImplementedError return lm, field
def ae_loss_on_batch(self, batch): batch.domain_x = cudable(batch.domain_x) batch.domain_y = cudable(batch.domain_y) x_hid = self.encoder(batch.domain_x) y_hid = self.encoder(batch.domain_y) # Reconstruction loss recs_x = self.decoder(x_hid, batch.domain_x[:, :-1]) recs_y = self.decoder(y_hid, batch.domain_y[:, :-1]) rec_loss_x = self.rec_criterion( recs_x.view(-1, len(self.vocab)), batch.domain_x[:, 1:].contiguous().view(-1)) rec_loss_y = self.rec_criterion( recs_y.view(-1, len(self.vocab)), batch.domain_y[:, 1:].contiguous().view(-1)) rec_loss = (rec_loss_x + rec_loss_y) / 2 losses_info = { 'rec_loss/domain_x': rec_loss_x.item(), 'rec_loss/domain_y': rec_loss_y.item(), } return rec_loss, losses_info
def onehot_gumbel_softmax(logits, temperature): """ input: [*, n_class] return: [*, n_class] a one-hot vector """ y = gumbel_softmax_sample(logits, temperature) shape = y.size() _, idx = y.max(dim=-1) y_hard = torch.zeros_like(y).view(-1, shape[-1]) y_hard.scatter_(1, idx.view(-1, 1), 1) y_hard = cudable(y_hard.view(*shape)) return (y_hard - y).detach() + y
def validate(self): losses = [] rec_losses = [] for batch in self.val_dataloader: batch = cudable(batch) with torch.enable_grad(): rec_loss, *_, losses_info = self.loss_on_batch(batch) rec_losses.append(rec_loss.item()) losses.append(losses_info) for l in losses[0].keys(): value = np.mean([info[l] for info in losses]) self.writer.add_scalar('VAL/' + l, value, self.num_iters_done) self.losses['val_rec_loss'].append(np.mean(rec_losses)) # Ok, let's now validate style transfer and auto-encoding self.validate_inference()
def init_models(self): self.transformer = cudable(Transformer(self.config.hp.transformer, self.vocab_src, self.vocab_trg))
def sample_gumbel(shape, eps=1e-20): U = cudable(torch.rand(shape)) G = -torch.log(-torch.log(U + eps) + eps) return G
def init_z(self, batch_size, style): styles = cudable(torch.ones(self.n_layers, batch_size).fill_(style)).long() z = self.style_embed(styles) return z
def generate_active_seqs(self): return cudable( T.tensor([[self.bos_idx] for _ in range(self.batch_size)]).long())
def predict(sentences: List[str], n_lines: int, temperature: float = None, max_len: int = None): "For each sentence generates `n_lines` lines sequentially to form a dialog" dialogs = [s for s in sentences ] # Let's not mutate original list and copy it batch_size = len(dialogs) temperature = temperature or DEFAULT_TEMPERATURE max_len = max_len or DEFAULT_MAX_LINE_LEN for _ in range(n_lines): examples = [ Example.fromlist([EOS_TOKEN.join(d)], [('text', field)]) for d in dialogs ] dataset = Dataset(examples, [('text', field)]) dataloader = data.BucketIterator(dataset, batch_size, shuffle=False, repeat=False) batch = next(iter(dataloader)) # We have a single batch text = cudable( batch.text[:, -MAX_CONTEXT_SIZE:] ) # As we made pad_first we are not afraid of losing information if model_cls_name == 'CharLMFromEmbs': z = lm.init_z(text.size(0), 1) z = lm(z, text, return_z=True)[1] elif model_cls_name == 'ConditionalLM': z = cudable(torch.zeros(2, len(text), 2048)) z = lm(z, text, style=1, return_z=True)[1] elif model_cls_name == 'WeightedLMEnsemble': z = cudable(torch.zeros(2, 1, len(text), 4096)) z = lm(z, text, return_z=True)[1] else: embs = lm.embed(text) z = lm.gru(embs)[1] next_lines = InferenceState({ 'model': lm, 'inputs': z, 'vocab': field.vocab, 'max_len': max_len, 'bos_token': EOS_TOKEN, # We start infering a new reply when we see EOS 'eos_token': EOS_TOKEN, 'temperature': temperature, 'sample_type': 'sample', 'inputs_batch_dim': 1 if model_cls_name != 'WeightedLMEnsemble' else 2, 'substitute_inputs': True, 'kwargs': inference_kwargs }).inference() next_lines = itos_many(next_lines, field.vocab, sep='') next_lines = [slice_unfinished_sentence(l) for l in next_lines] dialogs = [d + EOS_TOKEN + l for d, l in zip(dialogs, next_lines)] dialogs = [d.split(EOS_TOKEN) for d in dialogs] dialogs = [[s for s in d if len(s) != 0] for d in dialogs] dialogs = [assign_speakers(d) for d in dialogs] return dialogs
def init_models(self): self.classifier = cudable( RNNClassifier(self.config.hp.model_size, self.vocab))
# models_dir = 'experiments/char-wm/checkpoints' models_dir = 'models/style-model' versions = set([int(f.split('-')[1].split('.')[0]) for f in os.listdir(models_dir) if '-' in f]) latest_iter = max(versions) get_path = lambda m, ext='pth': os.path.join(models_dir, '{}-{}.{}'.format(m, latest_iter, ext)) print('Latest iter (version) found: {}. Loading from it.'.format(latest_iter)) # Loading vocab field = Field(init_token='<bos>', eos_token='<eos>', batch_first=True, tokenize=char_tokenize) field.vocab = pickle.load(open(get_path('vocab', 'pickle'), 'rb')) fields = [('src', field), ('trg', field)] print('Loading models..') encoder = cudable(RNNEncoder(512, 512, field.vocab)).eval() decoder = cudable(RNNDecoder(512, 512, field.vocab)).eval() merge_z = cudable(FFN([512 + MORPHS_SIZE, 512])).eval() location = None if torch.cuda.is_available() else 'cpu' encoder.load_state_dict(torch.load(get_path('encoder'), map_location=location)) decoder.load_state_dict(torch.load(get_path('decoder'), map_location=location)) merge_z.load_state_dict(torch.load(get_path('merge_z'), map_location=location)) def predict(lines): lines = tokenize(lines) # Grouping all lines into batches src, trg = generate_dataset_with_middle_chars(lines) examples = [Example.fromlist([m,o], fields) for m,o in zip(src, trg)] ds = Dataset(examples, fields)