Beispiel #1
0
 def process_x_pad_batch(self, x_pad_batch):
     if self.online_config is not None:
         x_pad_batch = torch.cat([x_pad_batch, x_pad_batch], dim=-1) # (batch_size, seq_len, channel=2)
         x_pad_batch = x_pad_batch.transpose(-1, -2).contiguous() # (batch_size, channel=2, seq_len)
         feats = self.preprocessor(x_pad_batch)
         return process_train_MAM_data(feats, config=self.mam_config)
     else:
         return process_train_MAM_data(spec=(x_pad_batch,), config=self.mam_config)
Beispiel #2
0
 def __getitem__(self, index):
     # Load acoustic feature and pad
     x_batch = [torch.FloatTensor(np.load(os.path.join(self.root, x_file))) for x_file in self.X[index]]
     x_pad_batch = pad_sequence(x_batch, batch_first=True)
     # Return (x_spec, t_spec)
     t_batch = [torch.FloatTensor(np.load(os.path.join(self.t_root, t_file))) for t_file in self.T[index]]
     t_pad_batch = pad_sequence(t_batch, batch_first=True)
     batch = process_train_MAM_data(spec=(x_pad_batch, t_pad_batch), config=self.mam_config)
     return batch
Beispiel #3
0
 def __getitem__(self, index):
     # Load acoustic feature and pad
     if self.sample_step > 0:
         x_batch = [torch.FloatTensor(self.sample(x_data)) for x_data in self.X[index]]
     else:
         x_batch = [torch.FloatTensor(x_data) for x_data in self.X[index]]
     x_pad_batch = pad_sequence(x_batch, batch_first=True)
     x_pad_batch = process_train_MAM_data(spec=(x_pad_batch,), config=self.mam_config)
     return x_pad_batch
 def __getitem__(self, index):
     # Load acoustic feature and pad
     if self.sample_step > 0:
         x_batch = [torch.FloatTensor(self.sample(np.load(os.path.join(self.root, x_file)))) for x_file in self.X[index]]
     else:
         x_batch = [torch.FloatTensor(np.load(os.path.join(self.root, x_file))) for x_file in self.X[index]]
     x_pad_batch = pad_sequence(x_batch, batch_first=True)
     if self.run_mam: x_pad_batch = process_train_MAM_data(spec=(x_pad_batch,), config=self.mam_config)
     return x_pad_batch
    def train(self):
        ''' Self-Supervised Pre-Training of Transformer Model'''

        pbar = tqdm(total=self.total_steps)
        while self.global_step <= self.total_steps:

            progress = tqdm(self.dataloader, desc="Iteration")

            step = 0
            loss_val = 0
            for batch in progress:
                if 'online' in self.config:
                    # batch are raw waveforms
                    # batch: (batch_size, channel, max_len)
                    specs = self.preprocessor(batch.to(device=self.device))
                    batch = process_train_MAM_data(
                        specs, config=self.transformer_config)

                batch_is_valid, *batch = batch
                try:
                    if self.global_step > self.total_steps: break
                    if not batch_is_valid: continue
                    step += 1

                    if self.dual_transformer:
                        time_masked, freq_masked, pos_enc, mask_label, attn_mask, spec_stacked = self.process_dual_data(
                            batch)
                        loss, pred_spec = self.model(time_masked, freq_masked,
                                                     pos_enc, mask_label,
                                                     attn_mask, spec_stacked)
                    else:
                        spec_masked, pos_enc, mask_label, attn_mask, spec_stacked = self.process_data(
                            batch)
                        loss, pred_spec = self.model(spec_masked, pos_enc,
                                                     mask_label, attn_mask,
                                                     spec_stacked)

                    # Accumulate Loss
                    if self.gradient_accumulation_steps > 1:
                        loss = loss / self.gradient_accumulation_steps
                    if self.apex and self.args.multi_gpu:
                        raise NotImplementedError
                    elif self.apex:
                        self.optimizer.backward(loss)
                    elif self.args.multi_gpu:
                        loss = loss.sum()
                        loss.backward()
                    else:
                        loss.backward()
                    loss_val += loss.item()

                    # Update
                    if (step + 1) % self.gradient_accumulation_steps == 0:
                        if self.apex:
                            # modify learning rate with special warm up BERT uses
                            # if conifg.apex is False, BertAdam is used and handles this automatically
                            lr_this_step = self.learning_rate * self.warmup_linear.get_lr(
                                self.global_step, self.warmup_proportion)
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = lr_this_step

                        # Step
                        grad_norm = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.gradient_clipping)
                        if math.isnan(grad_norm):
                            print(
                                '[Runner] - Error : grad norm is NaN @ step ' +
                                str(self.global_step))
                        else:
                            self.optimizer.step()
                        self.optimizer.zero_grad()

                        if self.global_step % self.log_step == 0:
                            # Log
                            self.log.add_scalar('lr',
                                                self.optimizer.get_lr()[0],
                                                self.global_step)
                            self.log.add_scalar('loss', (loss_val),
                                                self.global_step)
                            self.log.add_scalar('gradient norm', grad_norm,
                                                self.global_step)
                            progress.set_description("Loss %.4f" % (loss_val))

                        if self.global_step % self.save_step == 0:
                            self.save_model('states')

                            # tensorboard log
                            if self.dual_transformer: spec_masked = time_masked
                            spec_list = [spec_masked, pred_spec, spec_stacked]
                            name_list = ['mask_spec', 'pred_spec', 'true_spec']
                            if self.dual_transformer:
                                spec_list.insert(1, freq_masked)
                                name_list.insert(1, 'mask_freq')
                                name_list[0] = 'mask_time'

                            for i in range(len(spec_list)):
                                spec = self.up_sample_frames(spec_list[i][0],
                                                             return_first=True)
                                spec = plot_spectrogram_to_numpy(
                                    spec.data.cpu().numpy())
                                self.log.add_image(name_list[i], spec,
                                                   self.global_step)

                            # if self.dual_transformer:
                            #     self.model.PhoneticTransformer.PhoneRecognizer.set_num_updates(self.global_step//1000)

                        loss_val = 0
                        pbar.update(1)
                        self.global_step += 1

                except RuntimeError as e:
                    if 'CUDA out of memory' in str(e):
                        print('CUDA out of memory at step: ', self.global_step)
                        torch.cuda.empty_cache()
                        self.optimizer.zero_grad()
                    else:
                        raise

        pbar.close()
        self.log.close()