Beispiel #1
0
    def _async_update(self, rank, device_id, grad_denom, is_global_overflow):

        # if is_global_overflow:
        # def patch_step(opt):
        #     """this function is copied from apex"""
        #     opt_step = opt.step
        #
        #     def skip_step(closure=None):
        #         if closure is not None:
        #             raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
        #         #logger.info(f"Device[{self.gpu_rank}] Gradient overflow. Skipping step. "
        #         #            "(This is from hack-for-optimizer-sync)")
        #         if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
        #             # Clear the master grads that wouldn't be zeroed by model.zero_grad()
        #             for param in opt._amp_stash.all_fp32_from_fp16_params:
        #                 param.grad = None
        #         if hasattr(opt, "most_recent_scale"):
        #             opt.most_recent_scale = 1.0
        #             opt.scale_set_by_backward = False
        #         opt.step = opt_step
        #         opt._amp_stash.already_patched = False
        #
        #     return skip_step
        #
        # # since there is someone in the GPU pool gets overflow, we need to skip one step and keep going
        # if not self.optim.optimizer._amp_stash.already_patched:
        #     patch_step(self.optim.optimizer)
        #     self.dummy = 'dummy'
        # else:
        # is it possible to run out of memory in this case ? ...

        if self.num_replicas > 1:
            self._all_reduce_and_rescale_grads(grad_denom=grad_denom)

        # for param in amp.master_params(optimizer):
        #     param.grad.div_(iters_to_accumulate)

        normalize_gradients(amp.master_params(self.optim.optimizer),
                            grad_denom * self.args.update_frequency)

        if self.args.max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           self.opt.max_grad_norm)

        self.optim.step()

        return [0]
Beispiel #2
0
    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        self.model.train()
        self.loss_function.train()
        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model.zero_grad()
        self.model.reset_states()

        dataset = train_data

        # data iterator: object that controls the
        # data_iterator = DataIterator(dataset, dataset.collater, dataset.batches, seed=self.opt.seed,
        #                              num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size)
        data_iterator = generate_data_iterator(dataset,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=epoch,
                                               buffer_size=opt.buffer_size)

        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_tokens, total_loss, total_words = 0, 0, 0
        total_non_pads = 0
        report_loss, report_tgt_words = 0, 0
        report_src_words = 0
        report_ctc_loss = 0
        report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
        start = time.time()
        n_samples = len(epoch_iterator)

        counter = 0
        num_accumulated_words = 0
        num_accumulated_sents = 0
        grad_scaler = 1

        nan = False
        nan_counter = 0

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        i = data_iterator.iterations_in_epoch if not isinstance(
            train_data, list) else epoch_iterator.n_yielded

        while not data_iterator.end_of_epoch():

            curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            batch = next(epoch_iterator)
            if isinstance(batch, list) and self.n_gpus == 1:
                batch = batch[0]
            batch = rewrap(batch)

            if self.cuda:
                batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

            # if opt.streaming:
            #     if train_data.is_new_stream():
            #         streaming_state = self.model.init_stream()
            # else:
            #     streaming_state = None

            oom = False
            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                targets = batch.get('target_output')
                tgt_mask = targets.ne(onmt.constants.PAD)
                outputs = self.model(batch,
                                     streaming=opt.streaming,
                                     target_mask=tgt_mask,
                                     zero_encoder=opt.zero_encoder,
                                     mirror=opt.mirror_loss,
                                     streaming_state=streaming_state,
                                     nce=opt.nce)

                batch_size = batch.size

                outputs['tgt_mask'] = tgt_mask

                loss_dict = self.loss_function(outputs,
                                               targets,
                                               model=self.model,
                                               vocab_mask=batch.vocab_mask)
                loss_data = loss_dict['data']
                loss = loss_dict[
                    'loss']  # a little trick to avoid gradient overflow with fp16
                full_loss = loss

                if opt.ctc_loss > 0.0:
                    ctc_loss = self.ctc_loss_function(outputs, targets)
                    ctc_loss_data = ctc_loss.item()
                    full_loss = full_loss + opt.ctc_loss * ctc_loss
                    report_ctc_loss += ctc_loss_data

                if opt.mirror_loss:
                    rev_loss = loss_dict['rev_loss']
                    rev_loss_data = loss_dict['rev_loss_data']
                    mirror_loss = loss_dict['mirror_loss']
                    full_loss = full_loss + rev_loss + mirror_loss
                    mirror_loss_data = loss_dict['mirror_loss'].item()
                else:
                    rev_loss_data = None
                    mirror_loss_data = 0

                # reconstruction loss
                if opt.reconstruct:
                    rec_loss = loss_dict['rec_loss']
                    rec_loss = rec_loss
                    full_loss = full_loss + rec_loss
                    rec_loss_data = loss_dict['rec_loss_data']
                else:
                    rec_loss_data = None
                #
                # if opt.lfv_multilingual:
                #     lid_logits = outputs['lid_logits']
                #     lid_labels = batch.get('target_lang')
                #     lid_loss_function = self.loss_function.get_loss_function('lid_loss')
                #     lid_loss = lid_loss_function(lid_logits, lid_labels)
                #     full_loss = full_loss + lid_loss

                optimizer = self.optim.optimizer

                # When the batch size is large, each gradient step is very easy to explode on fp16
                # Normalizing the loss to grad scaler ensures this will not happen
                full_loss.div_(grad_scaler)

                if self.cuda:
                    with amp.scale_loss(full_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    full_loss.backward()

                del outputs

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory on GPU , skipping batch')
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                else:
                    raise e

            if loss != loss:
                # catching NAN problem
                oom = True
                self.model.zero_grad()
                self.optim.zero_grad()
                num_accumulated_words = 0
                num_accumulated_sents = 0
                nan_counter = nan_counter + 1
                print("Warning!!! Loss is Nan")
                if nan_counter >= 15:
                    raise ValueError(
                        "Training stopped because of multiple NaN occurence. "
                        "For ASR, using the Relative Transformer is more stable and recommended."
                    )
                counter = 0
            else:
                nan_counter = 0

            if not oom:
                src_size = batch.src_size
                tgt_size = batch.tgt_size

                counter = counter + 1
                num_accumulated_words += tgt_size
                num_accumulated_sents += batch_size

                #   We only update the parameters after getting gradients from n mini-batches
                update_flag = False
                if counter >= opt.update_frequency > 0:
                    update_flag = True
                elif 0 < opt.batch_size_update <= num_accumulated_words:
                    update_flag = True
                elif i == n_samples:  # update for the last minibatch
                    update_flag = True

                if update_flag:
                    # accumulated gradient case, in this case the update frequency
                    grad_denom = 1 / grad_scaler
                    if self.opt.normalize_gradient:
                        grad_denom = num_accumulated_words * grad_denom
                    # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                    normalize_gradients(amp.master_params(optimizer),
                                        grad_denom)
                    # Update the parameters.
                    if self.opt.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.opt.max_grad_norm)
                    self.optim.step()
                    self.optim.zero_grad()
                    self.model.zero_grad()
                    counter = 0
                    num_accumulated_words = 0
                    num_accumulated_sents = 0
                    num_updates = self.optim._step
                    if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                        valid_loss = self.eval(self.valid_data)
                        valid_ppl = math.exp(min(valid_loss, 100))
                        print('Validation perplexity: %g' % valid_ppl)

                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)

                        self.save(ep, valid_ppl, itr=data_iterator)

                num_words = tgt_size
                report_loss += loss_data
                report_tgt_words += num_words
                report_src_words += src_size
                total_loss += loss_data
                total_words += num_words
                total_tokens += batch.get('target_output').nelement()
                total_non_pads += batch.get('target_output').ne(
                    onmt.constants.PAD).sum().item()
                optim = self.optim
                batch_efficiency = total_non_pads / total_tokens

                if opt.reconstruct:
                    report_rec_loss += rec_loss_data

                if opt.mirror_loss:
                    report_rev_loss += rev_loss_data
                    report_mirror_loss += mirror_loss_data

                if i == 0 or (i % opt.log_interval == -1 % opt.log_interval):
                    log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " %
                                  (epoch, i + 1, len(data_iterator),
                                   math.exp(report_loss / report_tgt_words)))

                    if opt.reconstruct:
                        rec_ppl = math.exp(report_rec_loss /
                                           report_src_words.item())
                        log_string += (" rec_ppl: %6.2f ; " % rec_ppl)

                    if opt.mirror_loss:
                        rev_ppl = math.exp(report_rev_loss / report_tgt_words)
                        log_string += (" rev_ppl: %6.2f ; " % rev_ppl)
                        # mirror loss per word
                        log_string += (" mir_loss: %6.2f ; " %
                                       (report_mirror_loss / report_tgt_words))

                    log_string += ("lr: %.7f ; updates: %7d; " %
                                   (optim.getLearningRate(), optim._step))

                    log_string += ("%5.0f src tok/s; %5.0f tgt tok/s; " %
                                   (report_src_words /
                                    (time.time() - start), report_tgt_words /
                                    (time.time() - start)))

                    if opt.ctc_loss > 0.0:
                        ctc_loss = report_ctc_loss / report_tgt_words
                        log_string += (" ctcloss: %8.2f ; " % ctc_loss)

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    print(log_string)

                    report_loss = 0
                    report_tgt_words, report_src_words = 0, 0
                    report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
                    report_ctc_loss = 0
                    start = time.time()

                i = i + 1

        return total_loss / total_words
Beispiel #3
0
    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        self.model_ae.train()
        self.loss_function_ae.train()
        self.lat_dis.train()
        self.loss_lat_dis.train()

        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model_ae.zero_grad()
        self.lat_dis.zero_grad()

        dataset = train_data
        data_iterator = generate_data_iterator(dataset,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=epoch,
                                               buffer_size=opt.buffer_size)

        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_loss_ae, total_loss_lat_dis, total_frames, total_loss_adv = 0, 0, 0, 0

        report_loss_ae, report_loss_lat_dis, report_loss_adv, report_tgt_frames, report_sent = 0, 0, 0, 0, 0
        report_dis_frames, report_adv_frames = 0, 0

        start = time.time()
        n_samples = len(epoch_iterator)

        counter = 0
        step = 0

        num_accumulated_sents = 0
        grad_scaler = -1

        nan = False
        nan_counter = 0
        n_step_ae = opt.update_frequency
        n_step_lat_dis = opt.update_frequency
        mode_ae = True

        loss_lat_dis = 0.0

        i = data_iterator.iterations_in_epoch if not isinstance(
            train_data, list) else epoch_iterator.n_yielded

        # while not data_iterator.end_of_epoch():
        while True:

            if data_iterator.end_of_epoch():
                data_iterator = generate_data_iterator(
                    dataset,
                    seed=self.opt.seed,
                    num_workers=opt.num_workers,
                    epoch=epoch,
                    buffer_size=opt.buffer_size)
                epoch_iterator = data_iterator.next_epoch_itr(
                    not streaming, pin_memory=opt.pin_memory)

            curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            batch = next(epoch_iterator)
            if isinstance(batch, list) and self.n_gpus == 1:
                batch = batch[0]
            batch = rewrap(batch)

            batch_size = batch.size
            if grad_scaler == -1:
                grad_scaler = 1  # if self.opt.update_frequency > 1 else batch.tgt_size

            if self.cuda:
                batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

            oom = False
            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                #    targets = batch.get('target_output')
                #  tgt_mask = targets.ne(onmt.constants.PAD)
                if mode_ae:
                    step = self.optim_ae._step
                    loss_ae, loss_adv, encoder_outputs = self.autoencoder_backward(
                        batch, step)
                else:
                    loss_lat_dis, encoder_outputs = self.lat_dis_backward(
                        batch)

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory on GPU , skipping batch')
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                else:
                    raise e

            if loss_ae != loss_ae:
                # catching NAN problem
                oom = True
                self.model_ae.zero_grad()
                self.optim_ae.zero_grad()
                # self.lat_dis.zero_grad()
                # self.optim_lat_dis.zero_grad()

                nan_counter = nan_counter + 1
                print("Warning!!! Loss is Nan")
                if nan_counter >= 15:
                    raise ValueError(
                        "Training stopped because of multiple NaN occurence. "
                        "For ASR, using the Relative Transformer is more stable and recommended."
                    )
            else:
                nan_counter = 0

            if not oom:
                src_size = batch.src_size

                if mode_ae:
                    report_adv_frames += encoder_outputs['src_mask'].sum(
                    ).item()
                    report_loss_adv += loss_adv
                else:
                    report_dis_frames += encoder_outputs['src_mask'].sum(
                    ).item()
                    report_loss_lat_dis += loss_lat_dis

                counter = counter + 1

                #   We only update the parameters after getting gradients from n mini-batches
                update_flag = False
                if counter >= opt.update_frequency > 0:
                    update_flag = True
                elif i == n_samples:  # update for the last minibatch
                    update_flag = True

                if update_flag:
                    # accumulated gradient case, in this case the update frequency
                    if (counter == 1
                            and self.opt.update_frequency != 1) or counter > 1:
                        grad_denom = 1 / grad_scaler
                    else:
                        grad_denom = 1.0
                    # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                    normalize_gradients(
                        amp.master_params(self.optim_ae.optimizer), grad_denom)
                    normalize_gradients(
                        amp.master_params(self.optim_lat_dis.optimizer),
                        grad_denom)
                    # Update the parameters.
                    if self.opt.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim_ae.optimizer),
                            self.opt.max_grad_norm)
                        # torch.nn.utils.clip_grad_norm_(amp.master_params(self.optim_lat_dis.optimizer),
                        #                                self.opt.max_grad_norm)

                    torch.nn.utils.clip_grad_value_(
                        amp.master_params(self.optim_lat_dis.optimizer), 0.01)

                    if mode_ae:
                        self.optim_ae.step()
                        self.optim_ae.zero_grad()
                        self.model_ae.zero_grad()
                        self.optim_lat_dis.zero_grad()
                        self.lat_dis.zero_grad()
                    else:
                        self.optim_lat_dis.step()
                        self.optim_lat_dis.zero_grad()
                        self.lat_dis.zero_grad()
                        self.optim_ae.zero_grad()
                        self.model_ae.zero_grad()

                    num_updates = self.optim_ae._step

                    if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every and mode_ae:
                        # if mode_ae is not here then it will continuously save
                        valid_loss_ae, valid_loss_lat_dis = self.eval(
                            self.valid_data)
                        print('Validation loss ae: %g' % valid_loss_ae)
                        print('Validation loss latent discriminator: %g' %
                              valid_loss_lat_dis)
                        self.save(0, valid_loss_ae, itr=data_iterator)

                    if num_updates == 1000000:
                        break

                    mode_ae = not mode_ae
                    counter = 0
                    # num_accumulated_words = 0

                    grad_scaler = -1
                    num_updates = self.optim_ae._step

                report_loss_ae += loss_ae

                # report_tgt_words += num_words
                num_accumulated_sents += batch_size
                report_sent += batch_size
                total_frames += src_size
                report_tgt_frames += src_size
                total_loss_ae += loss_ae
                total_loss_lat_dis += loss_lat_dis
                total_loss_adv += loss_adv

                optim_ae = self.optim_ae
                optim_lat_dis = self.optim_lat_dis
                # batch_efficiency = total_non_pads / total_tokens

                if i == 0 or (i % opt.log_interval == -1 % opt.log_interval):
                    log_string = (
                        "loss_ae : %6.2f ;  loss_lat_dis : %6.2f, loss_adv : %6.2f "
                        % (report_loss_ae / report_tgt_frames,
                           report_loss_lat_dis /
                           (report_dis_frames + 1e-5), report_loss_adv /
                           (report_adv_frames + 1e-5)))

                    log_string += (
                        "lr_ae: %.7f ; updates: %7d; " %
                        (optim_ae.getLearningRate(), optim_ae._step))

                    log_string += (
                        "lr_lat_dis: %.7f ; updates: %7d; " %
                        (optim_lat_dis.getLearningRate(), optim_lat_dis._step))
                    #
                    log_string += ("%5.0f src tok/s " %
                                   (report_tgt_frames / (time.time() - start)))

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    print(log_string)

                    report_loss_ae = 0
                    report_loss_lat_dis = 0
                    report_loss_adv = 0
                    report_tgt_frames = 0
                    report_dis_frames = 0
                    report_adv_frames = 0
                    report_sent = 0
                    start = time.time()

                i = i + 1

        return total_loss_ae / total_frames * 100, total_loss_lat_dis / n_samples * 100, total_loss_adv / n_samples * 100
Beispiel #4
0
    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        self.model.train()
        self.loss_function.train()
        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model.zero_grad()

        dataset = train_data
        data_iterator = generate_data_iterator(dataset,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=epoch,
                                               buffer_size=opt.buffer_size)

        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_loss, total_frames = 0, 0

        report_loss, report_tgt_frames, report_sent = 0, 0, 0

        start = time.time()
        n_samples = len(epoch_iterator)

        counter = 0

        num_accumulated_sents = 0
        grad_scaler = -1

        nan = False
        nan_counter = 0

        i = data_iterator.iterations_in_epoch if not isinstance(
            train_data, list) else epoch_iterator.n_yielded

        while not data_iterator.end_of_epoch():

            curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            batch = next(epoch_iterator)
            if isinstance(batch, list) and self.n_gpus == 1:
                batch = batch[0]
            batch = rewrap(batch)
            if grad_scaler == -1:
                grad_scaler = 1  # if self.opt.update_frequency > 1 else batch.tgt_size

            if self.cuda:
                batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

            oom = False
            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                #    targets = batch.get('target_output')
                #  tgt_mask = targets.ne(onmt.constants.PAD)
                outputs = self.model(batch)

                gate_padded = batch.get('gate_padded')

                if self.opt.n_frames_per_step > 1:
                    slice = torch.arange(0, gate_padded.size(1),
                                         self.opt.n_frames_per_step)
                    gate_padded = gate_padded[:, slice]

                src_org = batch.get('source_org')
                src_org = src_org.narrow(2, 1, src_org.size(2) - 1)

                target = [src_org.permute(1, 2, 0).contiguous(), gate_padded]
                loss = self.loss_function(outputs, target)

                batch_size = batch.size
                loss_data = loss.data.item()
                # a little trick to avoid gradient overflow with fp16
                full_loss = loss

                optimizer = self.optim.optimizer

                # When the batch size is large, each gradient step is very easy to explode on fp16
                # Normalizing the loss to grad scaler ensures this will not happen
                full_loss.div_(grad_scaler)

                if self.cuda:
                    with amp.scale_loss(full_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    full_loss.backward()

                del outputs

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory on GPU , skipping batch')
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                else:
                    raise e

            if loss != loss:
                # catching NAN problem
                oom = True
                self.model.zero_grad()
                self.optim.zero_grad()
                nan_counter = nan_counter + 1
                print("Warning!!! Loss is Nan")
                if nan_counter >= 15:
                    raise ValueError(
                        "Training stopped because of multiple NaN occurence. "
                        "For ASR, using the Relative Transformer is more stable and recommended."
                    )
            else:
                nan_counter = 0

            if not oom:
                src_size = batch.src_size

                counter = counter + 1

                #   We only update the parameters after getting gradients from n mini-batches
                update_flag = False
                if counter >= opt.update_frequency > 0:
                    update_flag = True
                elif i == n_samples:  # update for the last minibatch
                    update_flag = True

                if update_flag:
                    # accumulated gradient case, in this case the update frequency
                    if (counter == 1
                            and self.opt.update_frequency != 1) or counter > 1:
                        grad_denom = 1 / grad_scaler
                        # if self.opt.normalize_gradient:
                        #     grad_denom = num_accumulated_words * grad_denom
                    else:
                        grad_denom = 1.0
                    # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                    normalize_gradients(amp.master_params(optimizer),
                                        grad_denom)
                    # Update the parameters.
                    if self.opt.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.opt.max_grad_norm)
                    self.optim.step()
                    self.optim.zero_grad()
                    self.model.zero_grad()
                    counter = 0
                    # num_accumulated_words = 0

                    grad_scaler = -1
                    num_updates = self.optim._step
                    if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                        valid_loss = self.eval(self.valid_data)
                        valid_ppl = math.exp(min(valid_loss, 100))
                        print('Validation perplexity: %g' % valid_ppl)

                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)

                        self.save(ep, valid_ppl, itr=data_iterator)

                report_loss += loss_data
                # report_tgt_words += num_words
                num_accumulated_sents += batch_size
                report_sent += batch_size
                total_frames += src_size
                report_tgt_frames += src_size
                total_loss += loss_data

                optim = self.optim
                # batch_efficiency = total_non_pads / total_tokens

                if i == 0 or (i % opt.log_interval == -1 % opt.log_interval):
                    log_string = (
                        "Epoch %2d, %5d/%5d; ; loss : %6.2f ; " %
                        (epoch, i + 1, len(data_iterator), report_loss))

                    log_string += ("lr: %.7f ; updates: %7d; " %
                                   (optim.getLearningRate(), optim._step))
                    #
                    log_string += ("%5.0f src tok/s " %
                                   (report_tgt_frames / (time.time() - start)))

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    print(log_string)

                    report_loss = 0
                    report_tgt_frames = 0
                    report_sent = 0
                    start = time.time()

                i = i + 1

        return total_loss / n_samples * 100
Beispiel #5
0
    def train_epoch(self, epoch, resume=False, batch_order=None, iteration=0):

        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model.zero_grad()
        self.model.reset_states()

        if resume:
            train_data.batch_order = batch_order
            train_data.set_index(iteration)
            print("Resuming from iteration: %d" % iteration)
        else:
            batch_order = train_data.create_order()
            iteration = 0

        total_tokens, total_loss, total_words = 0, 0, 0
        total_non_pads = 0
        report_loss, report_tgt_words = 0, 0
        report_src_words = 0
        start = time.time()
        n_samples = len(train_data)

        counter = 0
        num_accumulated_words = 0
        num_accumulated_sents = 0
        denom = 3584
        nan = False

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        for i in range(iteration, n_samples):

            curriculum = (epoch < opt.curriculum)

            batches = [train_data.next(curriculum=curriculum)[0]]

            if (len(self.additional_data) > 0
                    and i % self.additional_data_ratio[0] == 0):
                for j in range(len(self.additional_data)):
                    for k in range(self.additional_data_ratio[j + 1]):
                        if self.additional_data_iteration[j] == len(
                                self.additional_data[j]):
                            self.additional_data_iteration[j] = 0
                            self.additional_data[j].shuffle()
                            self.additional_batch_order[
                                j] = self.additional_data[j].create_order()

                        batches.append(self.additional_data[j].next()[0])
                        self.additional_data_iteration[j] += 1

            for b in range(len(batches)):
                batch = batches[b]
                if self.cuda:
                    batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

                # if opt.streaming:
                #     if train_data.is_new_stream():
                #         streaming_state = self.model.init_stream()
                # else:
                #     streaming_state = None

                oom = False
                try:
                    # outputs is a dictionary containing keys/values necessary for loss function
                    # can be flexibly controlled within models for easier extensibility
                    targets = batch.get('target_output')
                    tgt_mask = targets.data.ne(onmt.constants.PAD)
                    outputs = self.model(batch,
                                         streaming=opt.streaming,
                                         target_mask=tgt_mask,
                                         zero_encoder=opt.zero_encoder,
                                         mirror=opt.mirror_loss,
                                         streaming_state=streaming_state)

                    batch_size = batch.size

                    outputs['tgt_mask'] = tgt_mask

                    loss_dict = self.loss_function(outputs,
                                                   targets,
                                                   model=self.model)
                    loss_data = loss_dict['data']
                    loss = loss_dict['loss'].div(
                        denom
                    )  # a little trick to avoid gradient overflow with fp16

                    optimizer = self.optim.optimizer

                    if self.cuda:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()

                except RuntimeError as e:
                    if 'out of memory' in str(e):
                        print(
                            '| WARNING: ran out of memory on GPU , skipping batch'
                        )
                        oom = True
                        torch.cuda.empty_cache()
                        loss = 0
                        if opt.streaming:  # reset stream in this case ...
                            streaming_state = self.model.init_stream()
                    else:
                        raise e

                if loss != loss:
                    # catching NAN problem
                    oom = True
                    self.model.zero_grad()
                    self.optim.zero_grad()
                    num_accumulated_words = 0
                    num_accumulated_sents = 0

                if not oom:
                    src_size = batch.src_size
                    tgt_size = batch.tgt_size

                    counter = counter + 1
                    num_accumulated_words += tgt_size
                    num_accumulated_sents += batch_size

                    #   We only update the parameters after getting gradients from n mini-batches
                    update_flag = False
                    if 0 < opt.batch_size_update <= num_accumulated_words:
                        update_flag = True
                    elif counter >= opt.update_frequency and 0 >= opt.batch_size_update:
                        update_flag = True
                    elif i == n_samples - 1:  # update for the last minibatch
                        update_flag = True

                    if update_flag:
                        grad_denom = 1 / denom
                        if self.opt.normalize_gradient:
                            grad_denom = num_accumulated_words / denom
                        normalize_gradients(amp.master_params(optimizer),
                                            grad_denom)
                        # Update the parameters.
                        if self.opt.max_grad_norm > 0:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer),
                                self.opt.max_grad_norm)
                        self.optim.step(grad_denom=grad_denom)
                        self.optim.zero_grad()
                        self.model.zero_grad()
                        counter = 0
                        num_accumulated_words = 0
                        num_accumulated_sents = 0
                        num_updates = self.optim._step
                        if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                            valid_loss = self.eval(self.valid_data)
                            valid_ppl = math.exp(min(valid_loss, 100))
                            print('Validation perplexity: %g' % valid_ppl)

                            ep = float(epoch) - 1. + (
                                (float(i) + 1.) / n_samples)

                            self.save(ep,
                                      valid_ppl,
                                      batch_order=batch_order,
                                      iteration=i)

                    num_words = tgt_size
                    report_loss += loss_data
                    report_tgt_words += num_words
                    report_src_words += src_size
                    total_loss += loss_data
                    total_words += num_words
                    total_tokens += batch.get('target_output').nelement()
                    total_non_pads += batch.get('target_output').ne(
                        onmt.constants.PAD).sum().item()
                    optim = self.optim
                    batch_efficiency = total_non_pads / total_tokens

                    if b == 0 and (i == 0 or (i % opt.log_interval
                                              == -1 % opt.log_interval)):
                        print((
                            "Epoch %2d, %5d/%5d; ; ppl: %6.2f ; lr: %.7f ; num updates: %7d "
                            + "%5.0f src tok/s; %5.0f tgt tok/s; %s elapsed") %
                              (epoch, i + 1, len(train_data),
                               math.exp(report_loss / report_tgt_words),
                               optim.getLearningRate(), optim._step,
                               report_src_words /
                               (time.time() - start), report_tgt_words /
                               (time.time() - start),
                               str(
                                   datetime.timedelta(
                                       seconds=int(time.time() -
                                                   self.start_time)))))

                        report_loss, report_tgt_words = 0, 0
                        report_src_words = 0
                        start = time.time()

        return total_loss / total_words
Beispiel #6
0
    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        # Clear the gradients of the model
        self.model.zero_grad()
        # self.model.module.reset_states()

        dataset = train_data
        data_iterator = generate_data_iterator(dataset,
                                               self.rank,
                                               self.world_size,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=epoch,
                                               buffer_size=opt.buffer_size)

        # TODO: fix resume which is currently buggy
        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_tokens, total_loss, total_words = zero_tensor(), zero_tensor(
        ), zero_tensor()
        total_non_pads = zero_tensor()
        report_loss, report_tgt_words = zero_tensor(), zero_tensor()
        report_ctc_loss = zero_tensor()
        report_src_words = zero_tensor()
        report_rec_loss, report_rev_loss, report_mirror_loss = zero_tensor(
        ), zero_tensor(), zero_tensor()
        start = time.time()
        n_samples = len(data_iterator)

        counter = 0
        num_accumulated_words = zero_tensor()
        num_accumulated_sents = zero_tensor()
        grad_div = 1

        nan = False
        nan_counter = zero_tensor()

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        i = data_iterator.iterations_in_epoch if not isinstance(
            train_data, list) else epoch_iterator.n_yielded
        i = i * self.world_size

        while not data_iterator.end_of_epoch():

            curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            # TODO: move everything to the multiGPU trainer
            samples = next(epoch_iterator)

            batch = prepare_sample(samples, device=self.device)

            if opt.streaming:
                if train_data.is_new_stream():
                    streaming_state = self.model.init_stream()
            else:
                streaming_state = None

            # TODO: dealing with oom during distributed training
            oom = False

            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                counter = counter + 1
                reduction_disabled = False if counter >= opt.update_frequency or i == (
                    n_samples - 1) else True

                def maybe_no_sync():
                    if not reduction_disabled and isinstance(
                            self.model, DDP_model):
                        return self.model.no_sync()
                    else:
                        # when we dont reach the updating step, we do not need to synchronize the gradients
                        # thus disabling the backward grad sync to improve speed
                        return contextlib.ExitStack()  # dummy contextmanager

                with maybe_no_sync():
                    with autocast():
                        targets = batch.get('target_output')
                        tgt_mask = targets.ne(onmt.constants.PAD)
                        outputs = self.model(batch,
                                             streaming=opt.streaming,
                                             target_mask=tgt_mask,
                                             zero_encoder=opt.zero_encoder,
                                             mirror=opt.mirror_loss,
                                             streaming_state=streaming_state,
                                             nce=opt.nce)

                        batch_size = batch.size
                        outputs['tgt_mask'] = tgt_mask

                        loss_dict = self.loss_function(outputs,
                                                       targets,
                                                       model=self.model)
                        loss_data = loss_dict['data']
                        loss = loss_dict[
                            'loss']  # a little trick to avoid gradient overflow with fp16
                        full_loss = loss

                        if opt.ctc_loss > 0.0:
                            ctc_loss = self.ctc_loss_function(outputs, targets)
                            ctc_loss_data = ctc_loss.item()
                            full_loss = full_loss + opt.ctc_loss * ctc_loss

                        if opt.mirror_loss:
                            rev_loss = loss_dict['rev_loss']
                            rev_loss_data = loss_dict['rev_loss_data']
                            mirror_loss = loss_dict['mirror_loss']
                            full_loss = full_loss + rev_loss + mirror_loss
                            mirror_loss_data = loss_dict['mirror_loss'].item()
                        else:
                            rev_loss_data = None
                            mirror_loss_data = 0

                        # reconstruction loss
                        if opt.reconstruct:
                            rec_loss = loss_dict['rec_loss']
                            rec_loss = rec_loss
                            full_loss = full_loss + rec_loss
                            rec_loss_data = loss_dict['rec_loss_data']
                        else:
                            rec_loss_data = None

                        if opt.lfv_multilingual:
                            lid_logits = outputs['lid_logits']
                            lid_labels = batch.get('target_lang')
                            lid_loss_function = self.loss_function.get_loss_function(
                                'lid_loss')
                            lid_loss = lid_loss_function(
                                lid_logits, lid_labels)
                            full_loss = full_loss + lid_loss

                        optimizer = self.optim.optimizer

                        # When the batch size is large, each gradient step is very easy to explode on fp16
                        # Normalizing the loss to grad scaler ensures this will not happen
                        full_loss.div_(grad_div)

                    # grad scaler has to be done outside of the autocast
                    self.grad_scaler.scale(full_loss).backward()

                del outputs

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('[WARNING]: ran out of memory on GPU %d' % self.rank,
                          flush=True)
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                    raise e
                else:
                    raise e

            batch_size = batch.size

            src_size = batch.src_size
            tgt_size = batch.tgt_size
            num_accumulated_words.add_(tgt_size)
            num_accumulated_sents.add_(batch_size)

            # We only update the parameters after getting gradients from n mini-batches
            update_flag = False
            if counter >= opt.update_frequency:
                update_flag = True
            elif i == n_samples - 1:  # update for the last minibatch
                update_flag = True

            if update_flag:
                # accumulated gradient case, in this case the update frequency
                self.all_reduce(num_accumulated_words,
                                op=dist.ReduceOp.SUM,
                                group=self.group)

                grad_denom = 1.0 / grad_div

                if self.opt.normalize_gradient:
                    grad_denom = num_accumulated_words.item() * grad_denom
                else:
                    grad_denom = 1

                # the gradient is scaled by world size, so in order to match the model without multiGPU
                # we rescale the model parameters w.r.t the world size
                grad_denom = grad_denom / self.world_size

                # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                normalize_gradients(self.model.parameters(), grad_denom)

                # Update the parameters.
                if self.opt.max_grad_norm > 0:
                    self.grad_scaler.unscale_(self.optim.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.opt.max_grad_norm)
                self.optim.step(scaler=self.grad_scaler)
                self.grad_scaler.update()
                # self.optim.zero_grad()
                self.model.zero_grad()
                counter = 0
                num_accumulated_words.zero_()
                num_accumulated_sents.zero_()

                num_updates = self.optim._step
                if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                    valid_loss = self.eval(self.valid_data)
                    valid_ppl = math.exp(min(valid_loss, 100))

                    if self.is_main():
                        print('Validation perplexity: %g' % valid_ppl)
                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)
                        self.save(ep, valid_ppl, itr=data_iterator)

            num_words = tgt_size
            report_loss.add_(loss_data)
            report_tgt_words.add_(num_words)
            report_src_words.add_(src_size)
            total_loss.add_(loss_data)
            total_words.add_(num_words)
            # total_tokens += batch.get('target_output').nelement()
            # total_non_pads += batch.get('target_output').ne(onmt.constants.PAD).sum().item()
            # batch_efficiency = total_non_pads / total_tokens

            if opt.reconstruct:
                report_rec_loss.add_(rec_loss_data)

            if opt.mirror_loss:
                report_rev_loss.add_(rev_loss_data)
                report_mirror_loss.add_(mirror_loss_data)

            if opt.ctc_loss > 0.0:
                report_ctc_loss.add_(ctc_loss_data)

            # control the index a little bit to ensure the log is always printed
            if i == 0 or ((i + 1) % opt.log_interval < self.world_size):

                self.all_reduce(report_loss,
                                op=dist.ReduceOp.SUM,
                                group=self.group)
                self.all_reduce(report_tgt_words,
                                op=dist.ReduceOp.SUM,
                                group=self.group)
                self.all_reduce(report_src_words,
                                op=dist.ReduceOp.SUM,
                                group=self.group)

                if self.is_main():
                    log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " %
                                  (epoch, i + 1, len(data_iterator),
                                   math.exp(report_loss.item() /
                                            report_tgt_words.item())))

                    if opt.reconstruct:
                        self.all_reduce(report_rec_loss,
                                        op=dist.ReduceOp.SUM,
                                        group=self.group)
                        rec_ppl = math.exp(report_rec_loss.item() /
                                           report_src_words.item())
                        log_string += (" rec_ppl: %6.2f ; " % rec_ppl)

                    if opt.mirror_loss:
                        self.all_reduce(report_rev_loss,
                                        op=dist.ReduceOp.SUM,
                                        group=self.group)
                        rev_ppl = math.exp(report_rev_loss.item() /
                                           report_tgt_words.item())
                        log_string += (" rev_ppl: %6.2f ; " % rev_ppl)
                        log_string += (" mir_loss: %6.2f ; " %
                                       (report_mirror_loss / report_tgt_words))

                    if opt.ctc_loss > 0.0:
                        # if torch.isinf(report_ctc_loss):
                        #     report_ctc_loss.zero_()
                        # self.all_reduce(report_ctc_loss, op=dist.ReduceOp.SUM, group=self.group)
                        ctc_loss = report_ctc_loss.item(
                        ) / report_tgt_words.item()
                        log_string += (" ctcloss: %8.2f ; " % ctc_loss)

                    log_string += (
                        "lr: %.7f ; updates: %7d; " %
                        (self.optim.getLearningRate(), self.optim._step))

                    log_string += (
                        "%5.0f src tok/s; %5.0f tgt tok/s; " %
                        (report_src_words.item() /
                         (time.time() - start), report_tgt_words.item() /
                         (time.time() - start)))

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    self.print(log_string, flush=True)

                report_loss.zero_()
                report_tgt_words.zero_()
                report_src_words.zero_()
                report_rec_loss.zero_()
                report_rev_loss.zero_()
                report_mirror_loss.zero_()
                report_ctc_loss.zero_()
                start = time.time()

            # increase i by world size
            i = i + self.world_size

        return total_loss / total_words
Beispiel #7
0
    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        self.model.train()
        self.loss_function.train()
        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model.zero_grad()
        self.model.reset_states()

        dataset = train_data
        data_iterator = DataIterator(dataset,
                                     dataset.collater,
                                     dataset.batches,
                                     seed=self.opt.seed,
                                     num_workers=opt.num_workers,
                                     epoch=epoch,
                                     buffer_size=opt.buffer_size)

        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_tokens, total_loss, total_words = 0, 0, 0
        total_non_pads = 0
        report_loss, report_tgt_words = 0, 0
        report_src_words = 0
        report_sents = 0
        report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
        report_log_prior = 0
        report_log_variational_posterior = 0
        start = time.time()
        n_samples = len(epoch_iterator)

        counter = 0
        update_counter = 0
        num_accumulated_words = 0
        num_accumulated_sents = 0

        nan = False
        nan_counter = 0

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        i = data_iterator.iterations_in_epoch
        while not data_iterator.end_of_epoch():

            curriculum = (epoch < opt.curriculum)
            batch = next(epoch_iterator)
            batch = rewrap(batch)
            grad_scaler = self.opt.batch_size_words if self.opt.update_frequency > 1 else batch.tgt_size

            if self.cuda:
                batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

            oom = False
            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                targets = batch.get('target_output')
                tgt_mask = targets.data.ne(onmt.constants.PAD)
                outputs = self.model(batch,
                                     streaming=opt.streaming,
                                     target_mask=tgt_mask,
                                     zero_encoder=opt.zero_encoder,
                                     mirror=opt.mirror_loss,
                                     streaming_state=streaming_state)

                batch_size = batch.size

                outputs['tgt_mask'] = tgt_mask

                loss_dict = self.loss_function(outputs,
                                               targets,
                                               model=self.model)
                loss_data = loss_dict['data']
                loss = loss_dict[
                    'loss']  # a little trick to avoid gradient overflow with fp16
                log_prior = self.model.log_prior()
                log_variational_posterior = self.model.log_variational_posterior(
                )

                # the coeff starts off at 1 for each epoch
                # from BBB paper: The first mini batches in each epoch have large KL coeff
                # # the later minibatches are influenced by the data
                # denom = math.pow(1.5, min(32, update_counter))

                # min_coeff = 1 / (self.opt.model_size ** 2)
                # kl_coeff = max(1 / denom, min_coeff)
                kl_coeff = 1 / (batch.tgt_size * opt.update_frequency)
                # kl_coeff = 1 / (self.opt.model_size ** 2)
                # kl_coeff = 1
                full_loss = loss + kl_coeff * (log_variational_posterior -
                                               log_prior)
                # print(log_variational_posterior, log_prior)

                if opt.mirror_loss:
                    rev_loss = loss_dict['rev_loss']
                    rev_loss_data = loss_dict['rev_loss_data']
                    mirror_loss = loss_dict['mirror_loss']
                    full_loss = full_loss + rev_loss + mirror_loss
                    mirror_loss_data = loss_dict['mirror_loss'].item()
                else:
                    rev_loss = None
                    rev_loss_data = None
                    mirror_loss_data = 0

                # reconstruction loss
                if opt.reconstruct:
                    rec_loss = loss_dict['rec_loss']
                    rec_loss = rec_loss
                    full_loss = full_loss + rec_loss
                    rec_loss_data = loss_dict['rec_loss_data']
                else:
                    rec_loss_data = None

                optimizer = self.optim.optimizer

                # When the batch size is large, each gradient step is very easy to explode on fp16
                # Normalizing the loss to grad scaler ensures this will not happen
                full_loss.div_(grad_scaler)

                if self.cuda:
                    with amp.scale_loss(full_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    full_loss.backward()

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory on GPU , skipping batch')
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                else:
                    raise e

            if loss != loss:
                # catching NAN problem
                oom = True
                self.model.zero_grad()
                self.optim.zero_grad()
                num_accumulated_words = 0
                num_accumulated_sents = 0
                nan_counter = nan_counter + 1
                print("Warning!!! Loss is Nan")
                if nan_counter >= 15:
                    raise ValueError(
                        "Training stopped because of multiple NaN occurence. "
                        "For ASR, using the Relative Transformer is more stable and recommended."
                    )
            else:
                nan_counter = 0

            if not oom:
                src_size = batch.src_size
                tgt_size = batch.tgt_size

                counter = counter + 1
                num_accumulated_words += tgt_size
                num_accumulated_sents += batch_size

                #   We only update the parameters after getting gradients from n mini-batches
                update_flag = False
                if counter >= opt.update_frequency > 0:
                    update_flag = True
                elif 0 < opt.batch_size_update <= num_accumulated_words:
                    update_flag = True
                elif i == n_samples:  # update for the last minibatch
                    update_flag = True

                if update_flag:
                    # accumulated gradient case, in this case the update frequency
                    if (counter == 1
                            and self.opt.update_frequency != 1) or counter > 1:
                        grad_denom = 1 / grad_scaler
                        if self.opt.normalize_gradient:
                            grad_denom = num_accumulated_words * grad_denom
                    else:
                        grad_denom = 1
                    # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                    normalize_gradients(amp.master_params(optimizer),
                                        grad_denom)
                    # Update the parameters.
                    if self.opt.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.opt.max_grad_norm)
                    self.optim.step()
                    self.optim.zero_grad()
                    self.model.zero_grad()
                    counter = 0
                    num_accumulated_words = 0
                    num_accumulated_sents = 0
                    num_updates = self.optim._step
                    update_counter += 1
                    if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                        valid_loss = self.eval(self.valid_data)
                        valid_ppl = math.exp(min(valid_loss, 100))
                        print('Validation perplexity: %g' % valid_ppl)

                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)

                        self.save(ep, valid_ppl, itr=data_iterator)

                num_words = tgt_size
                report_loss += loss_data
                report_log_prior += log_prior.item()
                report_log_variational_posterior += log_variational_posterior.item(
                )
                report_tgt_words += num_words
                report_src_words += src_size
                report_sents += 1
                total_loss += loss_data
                total_words += num_words
                total_tokens += batch.get('target_output').nelement()
                total_non_pads += batch.get('target_output').ne(
                    onmt.constants.PAD).sum().item()
                optim = self.optim
                batch_efficiency = total_non_pads / total_tokens

                if opt.reconstruct:
                    report_rec_loss += rec_loss_data

                if opt.mirror_loss:
                    report_rev_loss += rev_loss_data
                    report_mirror_loss += mirror_loss_data

                if i == 0 or (i % opt.log_interval == -1 % opt.log_interval):
                    log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " %
                                  (epoch, i + 1, len(data_iterator),
                                   math.exp(report_loss / report_tgt_words)))

                    kl_div = report_log_variational_posterior - report_log_prior
                    log_string += ("KL q||p: %6.2f ; " %
                                   (kl_div / report_sents))

                    if opt.reconstruct:
                        rec_ppl = math.exp(report_rec_loss /
                                           report_src_words.item())
                        log_string += (" rec_ppl: %6.2f ; " % rec_ppl)

                    if opt.mirror_loss:
                        rev_ppl = math.exp(report_rev_loss / report_tgt_words)
                        log_string += (" rev_ppl: %6.2f ; " % rev_ppl)
                        # mirror loss per word
                        log_string += (" mir_loss: %6.2f ; " %
                                       (report_mirror_loss / report_tgt_words))

                    log_string += ("lr: %.7f ; updates: %7d; " %
                                   (optim.getLearningRate(), optim._step))

                    log_string += ("%5.0f src/s; %5.0f tgt/s; " %
                                   (report_src_words /
                                    (time.time() - start), report_tgt_words /
                                    (time.time() - start)))

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    print(log_string)

                    report_loss = 0
                    report_tgt_words, report_src_words = 0, 0
                    report_sents = 0
                    report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
                    report_log_prior, report_log_variational_posterior = 0, 0
                    start = time.time()

                i = i + 1

        return total_loss / total_words
Beispiel #8
0
    def train_epoch(self, epoch, resume=False, batch_order=None, iteration=0):

        opt = self.opt
        train_data = self.train_data

        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model.zero_grad()
        self.model.reset_states()

        if resume:
            train_data.batch_order = batch_order
            train_data.set_index(iteration)
            print("Resuming from iteration: %d" % iteration)
        else:
            batch_order = train_data.create_order()
            iteration = 0

        total_loss, total_words = 0, 0
        report_loss, report_tgt_words = 0, 0
        report_src_words = 0
        start = time.time()
        n_samples = len(train_data)

        counter = 0
        num_accumulated_words = 0
        num_accumulated_sents = 0

        for i in range(iteration, n_samples):

            curriculum = (epoch < opt.curriculum)

            batches = [train_data.next(curriculum=curriculum)[0]]

            if (len(self.additional_data) > 0
                    and i % self.additional_data_ratio[0] == 0):
                for j in range(len(self.additional_data)):
                    for k in range(self.additional_data_ratio[j + 1]):
                        if self.additional_data_iteration[j] == len(
                                self.additional_data[j]):
                            self.additional_data_iteration[j] = 0
                            self.additional_data[j].shuffle()
                            self.additional_batch_order[
                                j] = self.additional_data[j].create_order()

                        batches.append(self.additional_data[j].next()[0])
                        self.additional_data_iteration[j] += 1

            for b in range(len(batches)):
                batch = batches[b]
                if self.cuda:
                    batch.cuda(fp16=self.opt.fp16)

                oom = False
                try:
                    # outputs is a dictionary containing keys/values necessary for loss function
                    # can be flexibly controlled within models for easier extensibility
                    targets = batch.get('target_output')
                    tgt_mask = targets.data.ne(onmt.Constants.PAD)
                    outputs = self.model(batch, target_masking=tgt_mask)

                    batch_size = batch.size

                    outputs['tgt_mask'] = tgt_mask

                    loss_dict = self.loss_function(outputs,
                                                   targets,
                                                   model=self.model)
                    loss_data = loss_dict['data']
                    loss = loss_dict['loss']

                    optimizer = self.optim.optimizer
                    with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()

                except RuntimeError as e:
                    if 'out of memory' in str(e):
                        print(
                            '| WARNING: ran out of memory on GPU , skipping batch'
                        )
                        oom = True
                        torch.cuda.empty_cache()
                    else:
                        raise e

                if not oom:
                    src_size = batch.src_size
                    tgt_size = batch.tgt_size

                    counter = counter + 1
                    num_accumulated_words += tgt_size
                    num_accumulated_sents += batch_size

                    #   We only update the parameters after getting gradients from n mini-batches
                    # simulating the multi-gpu situation
                    # if counter == opt.virtual_gpu:
                    # if counter >= opt.batch_size_update:

                    if num_accumulated_words >= opt.batch_size_update * 0.95:
                        grad_denom = 1
                        if self.opt.normalize_gradient:
                            grad_denom = num_accumulated_words
                            normalize_gradients(
                                apex.amp.master_params(optimizer), grad_denom)
                        # Update the parameters.
                        self.optim.step(grad_denom=grad_denom)
                        self.model.zero_grad()
                        counter = 0
                        num_accumulated_words = 0
                        num_accumulated_sents = 0
                        num_updates = self.optim._step
                        if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                            valid_loss = self.eval(self.valid_data)
                            valid_ppl = math.exp(min(valid_loss, 100))
                            print('Validation perplexity: %g' % valid_ppl)

                            ep = float(epoch) - 1. + (
                                (float(i) + 1.) / n_samples)

                            self.save(ep,
                                      valid_ppl,
                                      batch_order=batch_order,
                                      iteration=i)

                    num_words = tgt_size
                    report_loss += loss_data
                    report_tgt_words += num_words
                    report_src_words += src_size
                    total_loss += loss_data
                    total_words += num_words
                    optim = self.optim

                    if b == 0 and (i == 0 or (i % opt.log_interval
                                              == -1 % opt.log_interval)):
                        print((
                            "Epoch %2d, %5d/%5d; ; ppl: %6.2f ; lr: %.7f ; num updates: %7d "
                            + "%5.0f src tok/s; %5.0f tgt tok/s; %s elapsed") %
                              (epoch, i + 1, len(train_data),
                               math.exp(report_loss / report_tgt_words),
                               optim.getLearningRate(), optim._step,
                               report_src_words /
                               (time.time() - start), report_tgt_words /
                               (time.time() - start),
                               str(
                                   datetime.timedelta(
                                       seconds=int(time.time() -
                                                   self.start_time)))))

                        report_loss, report_tgt_words = 0, 0
                        report_src_words = 0
                        start = time.time()

        return total_loss / total_words