Пример #1
0
    def update_core(self):
        """Main update routine of the CustomUpdater."""
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
        train_iter = self.get_iterator('main')
        optimizer = self.get_optimizer('main')

        # Get the next batch ( a list of json files)
        batch = train_iter.next()
        self.iteration += 1
        x = self.converter(batch, self.device)

        # Compute the loss at this time step and accumulate it
        loss = self.model(*x).mean() / self.accum_grad
        loss.backward()  # Backprop
        # gradient noise injection
        if self.grad_noise:
            from espnet.asr.asr_utils import add_gradient_noise
            add_gradient_noise(self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55)
        loss.detach()  # Truncate the graph

        # update parameters
        self.forward_count += 1
        if self.forward_count != self.accum_grad:
            return
        self.forward_count = 0
        # compute the gradient norm to check if it is normal or not
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(), self.grad_clip_threshold)
        logging.info('grad norm={}'.format(grad_norm))
        if math.isnan(grad_norm):
            logging.warning('grad norm is nan. Do not update model.')
        else:
            optimizer.step()
        optimizer.zero_grad()
Пример #2
0
    def update_core(self):
        """Main update routine of the CustomUpdater."""
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
        train_iter = self.get_iterator("main")
        optimizer = self.get_optimizer("main")
        epoch = train_iter.epoch

        # Get the next batch (a list of json files)
        batch = train_iter.next()
        # self.iteration += 1 # Increase may result in early report,
        # which is done in other place automatically.
        x = _recursive_to(batch, self.device)
        is_new_epoch = train_iter.epoch != epoch
        # When the last minibatch in the current epoch is given,
        # gradient accumulation is turned off in order to evaluate the model
        # on the validation set in every epoch.
        # see details in https://github.com/espnet/espnet/pull/1388

        # Compute the loss at this time step and accumulate it
        if self.ngpu == 0:
            loss = self.model(*x).mean() / self.accum_grad
        else:
            # apex does not support torch.nn.DataParallel
            loss = (data_parallel(self.model, x, range(self.ngpu)).mean() /
                    self.accum_grad)
        if self.use_apex:
            from apex import amp

            # NOTE: for a compatibility with noam optimizer
            opt = optimizer.optimizer if hasattr(optimizer,
                                                 "optimizer") else optimizer
            with amp.scale_loss(loss, opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        # gradient noise injection
        if self.grad_noise:
            from espnet.asr.asr_utils import add_gradient_noise

            add_gradient_noise(self.model,
                               self.iteration,
                               duration=100,
                               eta=1.0,
                               scale_factor=0.55)

        # update parameters
        self.forward_count += 1
        if not is_new_epoch and self.forward_count != self.accum_grad:
            return
        self.forward_count = 0
        # compute the gradient norm to check if it is normal or not
        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.grad_clip_threshold)
        logging.info("grad norm={}".format(grad_norm))
        if math.isnan(grad_norm):
            logging.warning("grad norm is nan. Do not update model.")
        else:
            optimizer.step()
        optimizer.zero_grad()
Пример #3
0
    def update_core(self):
        """Main update routine of the CustomUpdater."""
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
        train_iter = self.get_iterator('main')
        optimizer = self.get_optimizer('main')

        # Get the next batch ( a list of json files)
        batch = train_iter.next()
        # self.iteration += 1 # Increase may result in early report, which is done in other place automatically.
        x = self.converter(batch, self.device)

        # Compute the loss at this time step and accumulate it
        if self.ngpu == 0:
            loss = self.model(*x).mean() / self.accum_grad
        else:
            # apex does not support torch.nn.DataParallel
            if 'espnet.nets.pytorch_backend.e2e_asr_transformer' in self.model.__class__.__module__:
                loss = data_parallel(self.model, x + (self.iteration, ),
                                     range(self.ngpu)).mean() / self.accum_grad
            else:
                loss = data_parallel(self.model, x, range(
                    self.ngpu)).mean() / self.accum_grad
        if self.use_apex:
            from apex import amp
            # NOTE: for a compatibility with noam optimizer
            opt = optimizer.optimizer if hasattr(optimizer,
                                                 "optimizer") else optimizer
            with amp.scale_loss(loss, opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        # gradient noise injection
        if self.grad_noise:
            from espnet.asr.asr_utils import add_gradient_noise
            add_gradient_noise(self.model,
                               self.iteration,
                               duration=100,
                               eta=1.0,
                               scale_factor=0.55)
        loss.detach()  # Truncate the graph

        # update parameters
        self.forward_count += 1
        if self.forward_count != self.accum_grad:
            return
        self.forward_count = 0
        # compute the gradient norm to check if it is normal or not
        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.grad_clip_threshold)
        logging.info('grad norm={}'.format(grad_norm))
        if math.isnan(grad_norm):
            logging.warning('grad norm is nan. Do not update model.')
        else:
            optimizer.step()
        optimizer.zero_grad()
Пример #4
0
    def update_core(self):
        """Main update routine of the CustomUpdater."""
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
        train_iter = self.get_iterator('main')
        optimizer = self.get_optimizer('main')

        # Get the next batch (a list of json files)
        train_unl_iter = self.get_iterator('sub')
        labeled_batch = train_iter.next()
        unlabeled_batch = train_unl_iter.next()

        # Yield process information for calculating current consistency weight
        epoch = train_iter.epoch
        process_info = {
            'epoch': epoch,
            'current_position': train_iter.current_position,
            'batch_len': train_iter.len
        }

        # self.iteration += 1 # Increase may result in early report, which is done in other place automatically.
        labeled_x = _recursive_to(labeled_batch, self.device)
        unlabeled_x = _recursive_to(unlabeled_batch, self.device)
        is_new_epoch = train_iter.epoch != epoch
        # When the last minibatch in the current epoch is given,
        # gradient accumulation is turned off in order to evaluate the model
        # on the validation set in every epoch.
        # see details in https://github.com/espnet/espnet/pull/1388
        # Compute the loss at this time step and accumulate it
        if self.ngpu == 0:
            loss = self.model(*labeled_x, *unlabeled_x, process_info)
        else:
            # apex does not support torch.nn.DataParallel
            loss = data_parallel(self.model,
                                 (*labeled_x, *unlabeled_x, process_info),
                                 range(self.ngpu))
        loss = loss.mean() / self.accum_grad
        loss.backward()

        # learning rate cosine rampdown for SGD optimizer
        # TODO: make it only for sgd
        # if epoch > self.cosine_rampdown_starts:
        #     for p in optimizer.param_groups:
        #         p["lr"] *= cosine_rampdown(epoch - self.cosine_rampdown_starts,
        #                     self.cosine_rampdown_ends - self.cosine_rampdown_starts)
        #     logging.info("learning rate decayed to " + str(p["lr"]))

        # gradient noise injection
        if self.grad_noise:
            from espnet.asr.asr_utils import add_gradient_noise
            add_gradient_noise(self.model,
                               self.iteration,
                               duration=100,
                               eta=1.0,
                               scale_factor=0.55)
        loss.detach()  # Truncate the graph
        # update parameters
        self.forward_count += 1
        if not is_new_epoch and self.forward_count != self.accum_grad:
            return
        self.forward_count = 0
        # compute the gradient norm to check if it is normal or not
        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.enc.parameters(),
                                                   self.grad_clip_threshold)
        logging.info('grad norm={}'.format(grad_norm))
        if math.isnan(grad_norm):
            logging.warning('grad norm is nan. Do not update model.')
        else:
            optimizer.step()
            global_step = (epoch - self.consistency_rampup_starts
                           ) * train_iter.len + train_iter.current_position
            global_step = global_step if global_step > 0 else 0
            if epoch < self.consistency_rampup_starts:
                update_ema_variables(self.model.enc, self.model.ema_enc, 0,
                                     global_step)
            elif epoch < self.consistency_rampup_ends:
                update_ema_variables(self.model.enc, self.model.ema_enc,
                                     self.ema_pre_decay, global_step)
            else:
                update_ema_variables(self.model.enc, self.model.ema_enc,
                                     self.ema_post_decay, global_step)
        optimizer.zero_grad()
Пример #5
0
    def update_core(self):
        """Main update routine of the CustomUpdater."""
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
        train_iter = self.get_iterator('main')
        asr_optimizer = self.get_optimizer('main')
        tts_optimizer = self.get_optimizer('tts_opt')

        # Get the next batch ( a list of json files)
        batch = train_iter.next()
        #x = self.converter(batch, self.device)
        # Compute the loss at this time step and accumulate it
        if self.ngpu == 0:
            asr_loss = self.model(x).mean() / self.accum_grad
        else:
            # apex does not support torch.nn.DataParallel
            #if (batch[0][1][0][0:5] == np.array([1,1,1,1,1])).all():
            if len(batch[0]) == 3:
                xs_pad, ilens, ys_pad, spembs  = self.converter(batch, self.device)
                x = (xs_pad, ilens, ys_pad)
                if 'espnet.nets.pytorch_backend.e2e_asr_transformer' in self.model.__class__.__module__:
                    fake_loss, best_hyps = data_parallel(self.model, x+(self.iteration, True,), range(self.ngpu))
                else:
                    fake_loss, best_hyps = data_parallel(self.model, x+(True,), range(self.ngpu))
                    if self.text_only:
                        ttsasr_loss = data_parallel(self.model, x+(False,True,), range(self.ngpu)).mean() / self.accum_grad
                # calculate no of nbest and repeat based on it
                #set_requires_grad(self.tts_model, False)
                if self.tts:
                    x_tts = self.random_sampler(best_hyps, ilens, xs_pad, spembs)
                    #tts_loss, after_outs, before_outs, logits, att_ws = self.tts_model(*x_tts+(None,True,))
                    tts_loss, after_outs, before_outs, logits, att_ws = self.tts_model(*x_tts+(True,))
                    #tts_loss = self.loss_fn_tts(after_outs, before_outs, logits, x_tts[4], x_tts[2])
                    #comparison with orig hyp
                    #x_tts_orig = self.random_sampler(x[2], x[1], x[0], spembs)
                    #x_tts_orig[0][x_tts_orig[0] == -1] = 0
                    #tts_loss_j, after_outs_j, before_outs_j, logits_j, att_ws_j = self.tts_model(x_tts_orig[0], x_tts_orig[1], x_tts_orig[2], x_tts_orig[3], x_tts_orig[4], x_tts_orig[5], True)
                    #tts_loss_j = self.loss_fn_tts(after_outs_j, before_outs_j, logits_j, x_tts_orig[3], x_tts_orig[2])
                    #logging.info("true loss is: " + str(tts_loss_j.mean()))
                    logging.info("fake loss is: " + str(fake_loss.mean()))
                    policy_loss = self.policy_rewards(fake_loss, tts_loss)
                    logging.info('tts_loss: ' + str(float(tts_loss.mean())))
                    logging.info('policy_loss: ' + str(float(policy_loss.mean())))
                    asr_loss = policy_loss.mean() / self.accum_grad
                    # asr_loss = tts_loss.mean() / self.accum_grad
                    if self.text_only:
                        asr_loss  = asr_loss + ttsasr_loss
                else:
                    asr_loss = fake_loss.mean() / self.accum_grad
                    logging.info('asr_loss: ' + str(float(asr_loss)))
            else:
                xs_pad, ilens, ys_pad  = self.converter(batch, self.device)
                x = (xs_pad, ilens, ys_pad)
                asr_loss = data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad
                logging.info('asr_sup_loss: ' + str(float(asr_loss)))
        if self.use_apex:
            from apex import amp
            # NOTE: for a compatibility with noam optimizer
            opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
            with amp.scale_loss(loss, opt) as scaled_loss:
                scaled_loss.backward()
        else:
            asr_loss.backward()
        # gradient noise injection
        if self.grad_noise:
            from espnet.asr.asr_utils import add_gradient_noise
            add_gradient_noise(self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55)
        asr_loss.detach()  # Truncate the graph

        # update parameters
        self.forward_count += 1
        if self.forward_count != self.accum_grad:
            return
        self.forward_count = 0
        # compute the gradient norm to check if it is normal or not
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(), self.grad_clip_threshold)
        logging.info('ASR grad norm={}'.format(grad_norm))

        #if (batch[0][1][0][0:5] == np.array([1,1,1,1,1])).all():
        if len(batch[0]) == 3:
            tts_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.tts_model.parameters(), self.grad_clip_threshold)
            logging.info('TTS grad norm={}'.format(tts_grad_norm))
            if math.isnan(tts_grad_norm):
                logging.warning('TTS grad norm is nan. Do not update model.')
            else:
                if self.update_tts:
                    tts_optimizer.step()
        if math.isnan(grad_norm):
            logging.warning('grad norm is nan. Do not update model.')
        else:
            asr_optimizer.step()
        asr_optimizer.zero_grad()
        #if (batch[0][1][0][0:5] == np.array([1,1,1,1,1])).all(): # cheap trick by BMK
        if len(batch[0]) == 3: # cheap trick by BMK
            tts_optimizer.zero_grad()