コード例 #1
0
ファイル: solver.py プロジェクト: 592595/TERA
    def train(self):
        
        os.makedirs(self.model_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)

        self.model.train()
        pbar = tqdm(total=self.config.total_steps)
        
        model_kept = []
        global_step = 1

        while global_step <= self.config.total_steps:

            for batch_x in tqdm(self.dataloader, desc="Iteration"):
                if global_step > self.config.total_steps: break

                batch_x, batch_l = self.process_data(batch_x)
                _, indices = torch.sort(batch_l, descending=True)

                batch_x = Variable(batch_x[indices]).cuda()
                batch_l = Variable(batch_l[indices]).cuda()

                outputs, _ = self.model(batch_x[:, :-self.config.time_shift, :], \
                                        batch_l - self.config.time_shift)

                loss = self.criterion(outputs, batch_x[:, self.config.time_shift:, :])
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_thresh)
                
                # Step
                if math.isnan(grad_norm):
                    print('Error : grad norm is NaN @ step ' + str(global_step))
                else:
                    self.optimizer.step()
                self.optimizer.zero_grad()

                if global_step % self.config.log_step == 0:
                    self.log.add_scalar('training loss (step-wise)', float(loss.item()), global_step)
                    self.log.add_scalar('gradient norm', grad_norm, global_step)

                # log and save
                if global_step % self.config.save_step == 0:
                    pred_spec = plot_spectrogram_to_numpy(outputs[0].data.cpu().numpy())
                    true_spec = plot_spectrogram_to_numpy(batch_x[0].data.cpu().numpy())
                    self.log.add_image('pred_spec', pred_spec, global_step)
                    self.log.add_image('true_spec', true_spec, global_step)
                    new_model_path = os.path.join(self.model_dir, 'apc-%d' % global_step + '.ckpt')
                    torch.save(self.model.state_dict(), new_model_path)
                    model_kept.append(new_model_path)

                    if len(model_kept) > self.config.max_keep:
                        os.remove(model_kept[0])
                        model_kept.pop(0)

                pbar.update(1)
                global_step += 1
コード例 #2
0
    def forward(self,
                data,
                records={},
                global_step=0,
                log_step=1000,
                **kwargs):
        """
        Args:
            data:
                [spec_masked, pos_enc, mask_label, attn_mask, spec_target]
            
            records:
                defaultdict(list), by appending contents into records,
                these contents can be averaged and logged on Tensorboard
                later by self.log_records every log_step

        Return:
            loss        
        """

        spec_masked, pos_enc, mask_label, attn_mask, spec_target = data[
            0], data[1], data[2], data[3], data[4]
        spec_masked = spec_masked.to(self.device)

        if pos_enc.dim() == 3:
            # pos_enc: (batch_size, seq_len, hidden_size)
            # GPU memory need (batch_size * seq_len * hidden_size)
            pos_enc = pos_enc.float().to(self.device)
        elif pos_enc.dim() == 2:
            # pos_enc: (seq_len, hidden_size)
            # GPU memory only need (seq_len * hidden_size) even after expanded
            pos_enc = pos_enc.float().to(self.device).expand(
                spec_masked.size(0), *pos_enc.size())

        mask_label = mask_label.bool().to(self.device)
        attn_mask = attn_mask.float().to(self.device)
        spec_target = spec_target.to(self.device)

        loss, pred_spec = self.model(spec_masked, pos_enc, mask_label,
                                     attn_mask, spec_target)

        if global_step % log_step == 0:
            spec_list = [spec_masked, pred_spec, spec_target]
            name_list = ['mask_spec', 'pred_spec', 'true_spec']

            for i in range(len(spec_list)):
                spec = plot_spectrogram_to_numpy(
                    spec_list[i][0].data.cpu().numpy())
                records[name_list[i]] = spec

        return loss, records
コード例 #3
0
    def exec(self):
        ''' Training Unsupervised End-to-end Mockingjay Model'''
        self.verbose('Training set total ' + str(len(self.dataloader)) +
                     ' batches.')

        pbar = tqdm(total=self.total_steps)
        while self.global_step <= self.total_steps:

            progress = tqdm(self.dataloader, desc="Iteration")

            step = 0
            for batch_is_valid, *batch in progress:
                try:
                    if self.global_step > self.total_steps: break
                    if not batch_is_valid: continue
                    step += 1

                    spec_masked, pos_enc, mask_label, attn_mask, spec_stacked = self.process_data(
                        batch)
                    loss, pred_spec = self.model(spec_masked, pos_enc,
                                                 mask_label, attn_mask,
                                                 spec_stacked)

                    # Accumulate Loss
                    if self.gradient_accumulation_steps > 1:
                        loss = loss / self.gradient_accumulation_steps
                    if self.apex and self.paras.multi_gpu:
                        raise NotImplementedError
                    elif self.apex:
                        self.optimizer.backward(loss)
                    elif self.paras.multi_gpu:
                        loss = loss.sum()
                        loss.backward()
                    else:
                        loss.backward()

                    # Update
                    if (step + 1) % self.gradient_accumulation_steps == 0:
                        if self.apex:
                            # modify learning rate with special warm up BERT uses
                            # if conifg.apex is False, BertAdam is used and handles this automatically
                            lr_this_step = self.learning_rate * self.warmup_linear.get_lr(
                                self.global_step, self.warmup_proportion)
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = lr_this_step

                        # Step
                        grad_norm = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.gradient_clipping)
                        if math.isnan(grad_norm):
                            self.verbose('Error : grad norm is NaN @ step ' +
                                         str(self.global_step))
                        else:
                            self.optimizer.step()
                        self.optimizer.zero_grad()

                        if self.global_step % self.log_step == 0:
                            # Log
                            self.log.add_scalar('lr',
                                                self.optimizer.get_lr()[0],
                                                self.global_step)
                            self.log.add_scalar(
                                'loss', (loss.item() *
                                         self.gradient_accumulation_steps),
                                self.global_step)
                            self.log.add_scalar('gradient norm', grad_norm,
                                                self.global_step)
                            progress.set_description(
                                "Loss %.4f" %
                                (loss.item() *
                                 self.gradient_accumulation_steps))

                        if self.global_step % self.save_step == 0:
                            self.save_model('mockingjay')
                            mask_spec = self.up_sample_frames(
                                spec_masked[0], return_first=True)
                            pred_spec = self.up_sample_frames(
                                pred_spec[0], return_first=True)
                            true_spec = self.up_sample_frames(
                                spec_stacked[0], return_first=True)
                            mask_spec = plot_spectrogram_to_numpy(
                                mask_spec.data.cpu().numpy())
                            pred_spec = plot_spectrogram_to_numpy(
                                pred_spec.data.cpu().numpy())
                            true_spec = plot_spectrogram_to_numpy(
                                true_spec.data.cpu().numpy())
                            self.log.add_image('mask_spec', mask_spec,
                                               self.global_step)
                            self.log.add_image('pred_spec', pred_spec,
                                               self.global_step)
                            self.log.add_image('true_spec', true_spec,
                                               self.global_step)

                        pbar.update(1)
                        self.global_step += 1

                except RuntimeError as e:
                    if 'CUDA out of memory' in str(e):
                        print('CUDA out of memory at step: ', self.global_step)
                        torch.cuda.empty_cache()
                        self.optimizer.zero_grad()
                    else:
                        raise

        pbar.close()
        self.log.close()
        self.reset_train()
    def train(self):
        ''' Self-Supervised Pre-Training of Transformer Model'''

        pbar = tqdm(total=self.total_steps)
        pbar.n = self.global_step - 1

        while self.global_step <= self.total_steps:

            progress = tqdm(self.dataloader, desc="Iteration")

            step = 0
            loss_val = 0
            for batch in progress:
                batch_is_valid, *batch = batch
                try:
                    if self.global_step > self.total_steps: break
                    if not batch_is_valid: continue
                    step += 1

                    spec_masked, pos_enc, mask_label, attn_mask, spec_stacked = self.process_data(
                        batch)
                    loss, pred_spec = self.model(spec_masked, pos_enc,
                                                 mask_label, attn_mask,
                                                 spec_stacked)

                    # Accumulate Loss
                    if self.gradient_accumulation_steps > 1:
                        loss = loss / self.gradient_accumulation_steps
                    if self.apex and self.args.multi_gpu:
                        raise NotImplementedError
                    elif self.apex:
                        self.optimizer.backward(loss)
                    elif self.args.multi_gpu:
                        loss = loss.sum()
                        loss.backward()
                    else:
                        loss.backward()
                    loss_val += loss.item()

                    # Update
                    if (step + 1) % self.gradient_accumulation_steps == 0:
                        if self.apex:
                            # modify learning rate with special warm up BERT uses
                            # if conifg.apex is False, BertAdam is used and handles this automatically
                            lr_this_step = self.learning_rate * self.warmup_linear.get_lr(
                                self.global_step, self.warmup_proportion)
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = lr_this_step

                        # Step
                        grad_norm = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.gradient_clipping)
                        if math.isnan(grad_norm):
                            print(
                                '[Runner] - Error : grad norm is NaN @ step ' +
                                str(self.global_step))
                        else:
                            self.optimizer.step()
                        self.optimizer.zero_grad()

                        if self.global_step % self.log_step == 0:
                            # Log
                            self.log.add_scalar('lr',
                                                self.optimizer.get_lr()[0],
                                                self.global_step)
                            self.log.add_scalar('loss', (loss_val),
                                                self.global_step)
                            self.log.add_scalar('gradient norm', grad_norm,
                                                self.global_step)
                            progress.set_description("Loss %.4f" % (loss_val))

                        if self.global_step % self.save_step == 0:
                            self.save_model('states')

                            # tensorboard log
                            spec_list = [spec_masked, pred_spec, spec_stacked]
                            name_list = ['mask_spec', 'pred_spec', 'true_spec']

                            for i in range(len(spec_list)):
                                spec = self.up_sample_frames(spec_list[i][0],
                                                             return_first=True)
                                spec = plot_spectrogram_to_numpy(
                                    spec.data.cpu().numpy())
                                self.log.add_image(name_list[i], spec,
                                                   self.global_step)

                        loss_val = 0
                        pbar.update(1)
                        self.global_step += 1

                except RuntimeError as e:
                    if 'CUDA out of memory' in str(e):
                        print('CUDA out of memory at step: ', self.global_step)
                        torch.cuda.empty_cache()
                        self.optimizer.zero_grad()
                    else:
                        raise

        pbar.close()
        self.log.close()