def predict_feats(self, feats, seq_lens, repeats): feats, embedding_output, attention_mask, predict_mask = self.embed_feats( feats, seq_lens, mask_lm=True) feat_inner = self.forward(embedding_output, attention_mask, end_layer_idx=self.translate_layer_idx) output = self.forward(feat_inner, attention_mask, start_layer_idx=self.translate_layer_idx) output = self.feat_out_layer(output) to_predict = (1 - predict_mask.squeeze()) * attention_mask # shape: (N, T) loss = cpc_loss(output, feats, to_predict, attention_mask) translation_inner = self.translate(feat_inner) translated_logits = self.forward( translation_inner, attention_mask, start_layer_idx=self.translate_layer_idx) translated_logits = self.target_out(translated_logits) repeats = get_tensor_from_array(repeats) intra_s_loss = intra_segment_loss(translated_logits, repeats, attention_mask, self.sep_size) inter_s_loss = inter_segment_loss(translated_logits, attention_mask) self.update_shadow_variable(self.shadow_feat_inner, feat_inner, attention_mask) return loss, intra_s_loss, inter_s_loss
def embed_target(self, target_ids, seq_lens, mask_lm=True): target_ids = get_tensor_from_array(target_ids).long() embedding_output = self.pretrained_word_embeddings(target_ids) predict_mask = None recon_target = None if mask_lm: input_mask, predict_mask = get_mlm_masks(target_ids, self.mask_prob, self.mask_but_no_prob) input_mask = input_mask.unsqueeze(2) embedding_output = self.use_pretrained_mask_embedding( embedding_output, input_mask) recon_target = embedding_output recon_target = self.pad_front(recon_target, 0) target_ids = self.pad_front(target_ids, 0) embedding_output, seq_lens = self.wrap_with_embeddings( embedding_output, seq_lens) embedding_output = self.add_beside_word_embeddings(embedding_output) embedding_output = self.pretrained_embedding_layer_norm( embedding_output) attention_mask = get_attention_mask(seq_lens, target_ids.shape[1]) if predict_mask is not None: predict_mask = self.pad_front(predict_mask, 1) return recon_target, embedding_output, attention_mask, predict_mask
def predict(self, frame_feat, lens): frame_feat = get_tensor_from_array(frame_feat) mask = get_attention_mask(lens, frame_feat.shape[1]) x = self.feat_embeddings(frame_feat) x = self.positional_encoding(x) outputs = self.forward(x, mask) outputs = self.target_out_layer(outputs) return outputs
def finetune_loss(self, frame_feat, frame_label, lens): outputs = self.predict(frame_feat, lens) outputs = outputs.transpose(1, 2) frame_label = get_tensor_from_array(frame_label).long() loss = nn.CrossEntropyLoss(reduction='none')(outputs, frame_label) mask = get_attention_mask(lens, frame_feat.shape[1]) loss = masked_reduce_mean(loss, mask) loss = loss.mean() return loss
def predict_batch(self, batch_frame_feat, batch_frame_len): self.generator.eval() with torch.no_grad(): batch_frame_feat = get_tensor_from_array(batch_frame_feat) mask = create_attention_mask(batch_frame_len, batch_frame_feat.shape[1]) batch_frame_logits, _ = self.generator(batch_frame_feat, mask) batch_frame_prob = torch.softmax(batch_frame_logits, dim=-1) batch_frame_prob = batch_frame_prob.cpu().data.numpy() self.generator.train() return batch_frame_prob
def embed_feats(self, feats, seq_lens, mask_lm=True): feats = get_tensor_from_array(feats) embedding_output = self.feat_embeddings(feats) predict_mask = None if mask_lm: input_mask, predict_mask = get_mlm_masks(feats, self.mask_prob, self.mask_but_no_prob) embedding_output = self.use_pretrained_mask_embedding( embedding_output, input_mask) feats = self.pad_front(feats, 0) embedding_output, seq_lens = self.wrap_with_embeddings( embedding_output, seq_lens) embedding_output = self.add_beside_word_embeddings(embedding_output) attention_mask = get_attention_mask(seq_lens, feats.shape[1]) if predict_mask is not None: predict_mask = self.pad_front(predict_mask, 1) return feats, embedding_output, attention_mask, predict_mask
def pretrain_loss(self, input_feats, seq_lens): input_feats = get_tensor_from_array(input_feats) attention_mask = get_attention_mask(seq_lens, input_feats.shape[1]) input_mask, predict_mask = get_mlm_masks(input_feats, self.mask_prob, self.mask_but_no_prob) masked_input_feats = input_mask * input_feats + ( 1 - input_mask) * self.feat_mask_vec masked_input_feats *= attention_mask.unsqueeze( 2) # taking care of the paddings x = self.feat_embeddings(masked_input_feats) x = self.positional_encoding(x) output = self.forward(x, attention_mask) output = self.feat_out_layer(output) to_predict = (1 - predict_mask.squeeze()) * attention_mask # shape: (N, T) loss = cpc_loss(output, input_feats, to_predict, attention_mask) return loss
def get_pretrained_embeddings(self, token): return self.pretrained_word_embeddings( get_tensor_from_array(np.array([[token]])).long()) # shape: (1, E)
def train( self, config, data_loader, dev_data_loader=None, aug=False, ): print('TRAINING(unsupervised)...') if aug: get_target_batch = data_loader.get_aug_target_batch else: get_target_batch = data_loader.get_target_batch batch_size = config.batch_size * config.repeat logger = Logger(print_step=config.print_step) max_fer = 100.0 for step in range(1, config.step + 1): self.generator.eval() for _ in range(config.dis_iter): self.c_opt.zero_grad() batch_sample_feat, batch_sample_len, batch_repeat_num, batch_phn_label = data_loader.get_sample_batch( config.batch_size, repeat=config.repeat, ) real_target_idx, batch_target_len = get_target_batch(batch_size) batch_sample_feat = get_tensor_from_array(batch_sample_feat) real_target_idx = get_tensor_from_array(real_target_idx).long() mask = create_attention_mask(batch_sample_len, config.phn_max_length) fake_target_logits, fake_target_idx = self.generator(batch_sample_feat, mask) real_score = self.critic(real_target_idx) fake_score = self.critic(fake_target_idx) if not self.wgan: c_loss = torch.mean(-torch.log(real_score + epsilon)) + \ torch.mean(-torch.log(1 - fake_score + epsilon)) else: c_loss = torch.mean(-real_score) + torch.mean(fake_score) c_loss.backward() self.c_opt.step() logger.update({ 'c_loss': c_loss.item(), 'true_sample': array_to_string(real_target_idx[0].cpu().data.numpy()), }) self.generator.train() self.critic.eval() for _ in range(config.gen_iter): self.g_opt.zero_grad() batch_sample_feat, batch_sample_len, batch_repeat_num, batch_phn_label = data_loader.get_sample_batch( config.batch_size, repeat=config.repeat, ) batch_sample_feat = get_tensor_from_array(batch_sample_feat) mask = create_attention_mask(batch_sample_len, config.phn_max_length) batch_repeat_num = get_tensor_from_array(batch_repeat_num) fake_target_logits, fake_target_idx = self.generator(batch_sample_feat, mask) fake_score = self.critic(fake_target_idx) # shape: (N, 1) reward = self.critic.compute_G_reward(fake_score) kernel = self.critic.get_kernel() g_loss = self.generator.compute_loss(reward, kernel, fake_target_logits, fake_target_idx, mask) segment_loss = intra_segment_loss( fake_target_logits, batch_repeat_num, mask, sep_size=(config.batch_size * config.repeat) // 2, ) total_loss = g_loss + config.seg_loss_ratio * segment_loss total_loss.backward() self.g_opt.step() logger.update({ 'g_loss': g_loss.item(), 'seg_loss': segment_loss.item(), 'fake_sample': array_to_string(fake_target_idx[0].cpu().data.numpy()), 'baseline': self.critic.ema.average.item(), }) self.critic.train() if step % config.eval_step == 0: step_fer = frame_eval(self.predict_batch, dev_data_loader) logger.update({'val_fer': step_fer}, ema=False) print(f'EVAL max: {max_fer:.2f} step: {step_fer:.2f}') if step_fer < max_fer: max_fer = step_fer logger.step() print('=' * 80)