def process_x_pad_batch(self, x_pad_batch): if self.online_config is not None: x_pad_batch = torch.cat([x_pad_batch, x_pad_batch], dim=-1) # (batch_size, seq_len, channel=2) x_pad_batch = x_pad_batch.transpose(-1, -2).contiguous() # (batch_size, channel=2, seq_len) feats = self.preprocessor(x_pad_batch) return process_train_MAM_data(feats, config=self.mam_config) else: return process_train_MAM_data(spec=(x_pad_batch,), config=self.mam_config)
def __getitem__(self, index): # Load acoustic feature and pad x_batch = [torch.FloatTensor(np.load(os.path.join(self.root, x_file))) for x_file in self.X[index]] x_pad_batch = pad_sequence(x_batch, batch_first=True) # Return (x_spec, t_spec) t_batch = [torch.FloatTensor(np.load(os.path.join(self.t_root, t_file))) for t_file in self.T[index]] t_pad_batch = pad_sequence(t_batch, batch_first=True) batch = process_train_MAM_data(spec=(x_pad_batch, t_pad_batch), config=self.mam_config) return batch
def __getitem__(self, index): # Load acoustic feature and pad if self.sample_step > 0: x_batch = [torch.FloatTensor(self.sample(x_data)) for x_data in self.X[index]] else: x_batch = [torch.FloatTensor(x_data) for x_data in self.X[index]] x_pad_batch = pad_sequence(x_batch, batch_first=True) x_pad_batch = process_train_MAM_data(spec=(x_pad_batch,), config=self.mam_config) return x_pad_batch
def __getitem__(self, index): # Load acoustic feature and pad if self.sample_step > 0: x_batch = [torch.FloatTensor(self.sample(np.load(os.path.join(self.root, x_file)))) for x_file in self.X[index]] else: x_batch = [torch.FloatTensor(np.load(os.path.join(self.root, x_file))) for x_file in self.X[index]] x_pad_batch = pad_sequence(x_batch, batch_first=True) if self.run_mam: x_pad_batch = process_train_MAM_data(spec=(x_pad_batch,), config=self.mam_config) return x_pad_batch
def train(self): ''' Self-Supervised Pre-Training of Transformer Model''' pbar = tqdm(total=self.total_steps) while self.global_step <= self.total_steps: progress = tqdm(self.dataloader, desc="Iteration") step = 0 loss_val = 0 for batch in progress: if 'online' in self.config: # batch are raw waveforms # batch: (batch_size, channel, max_len) specs = self.preprocessor(batch.to(device=self.device)) batch = process_train_MAM_data( specs, config=self.transformer_config) batch_is_valid, *batch = batch try: if self.global_step > self.total_steps: break if not batch_is_valid: continue step += 1 if self.dual_transformer: time_masked, freq_masked, pos_enc, mask_label, attn_mask, spec_stacked = self.process_dual_data( batch) loss, pred_spec = self.model(time_masked, freq_masked, pos_enc, mask_label, attn_mask, spec_stacked) else: spec_masked, pos_enc, mask_label, attn_mask, spec_stacked = self.process_data( batch) loss, pred_spec = self.model(spec_masked, pos_enc, mask_label, attn_mask, spec_stacked) # Accumulate Loss if self.gradient_accumulation_steps > 1: loss = loss / self.gradient_accumulation_steps if self.apex and self.args.multi_gpu: raise NotImplementedError elif self.apex: self.optimizer.backward(loss) elif self.args.multi_gpu: loss = loss.sum() loss.backward() else: loss.backward() loss_val += loss.item() # Update if (step + 1) % self.gradient_accumulation_steps == 0: if self.apex: # modify learning rate with special warm up BERT uses # if conifg.apex is False, BertAdam is used and handles this automatically lr_this_step = self.learning_rate * self.warmup_linear.get_lr( self.global_step, self.warmup_proportion) for param_group in self.optimizer.param_groups: param_group['lr'] = lr_this_step # Step grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.gradient_clipping) if math.isnan(grad_norm): print( '[Runner] - Error : grad norm is NaN @ step ' + str(self.global_step)) else: self.optimizer.step() self.optimizer.zero_grad() if self.global_step % self.log_step == 0: # Log self.log.add_scalar('lr', self.optimizer.get_lr()[0], self.global_step) self.log.add_scalar('loss', (loss_val), self.global_step) self.log.add_scalar('gradient norm', grad_norm, self.global_step) progress.set_description("Loss %.4f" % (loss_val)) if self.global_step % self.save_step == 0: self.save_model('states') # tensorboard log if self.dual_transformer: spec_masked = time_masked spec_list = [spec_masked, pred_spec, spec_stacked] name_list = ['mask_spec', 'pred_spec', 'true_spec'] if self.dual_transformer: spec_list.insert(1, freq_masked) name_list.insert(1, 'mask_freq') name_list[0] = 'mask_time' for i in range(len(spec_list)): spec = self.up_sample_frames(spec_list[i][0], return_first=True) spec = plot_spectrogram_to_numpy( spec.data.cpu().numpy()) self.log.add_image(name_list[i], spec, self.global_step) # if self.dual_transformer: # self.model.PhoneticTransformer.PhoneRecognizer.set_num_updates(self.global_step//1000) loss_val = 0 pbar.update(1) self.global_step += 1 except RuntimeError as e: if 'CUDA out of memory' in str(e): print('CUDA out of memory at step: ', self.global_step) torch.cuda.empty_cache() self.optimizer.zero_grad() else: raise pbar.close() self.log.close()