コード例 #1
0
def prepare_sample(batch, device=None, fp16=False):

    # TODO: sample is a Batch object. This function probably
    batch = rewrap(batch)
    batch.cuda(fp16=fp16, device=device)
    # pass
    # for i, t in enumerate(sample):
    #     sample[i] = Variable(t.cuda(device=device))
    #
    return batch
コード例 #2
0
ファイル: mp_trainer.py プロジェクト: nlp-dke/NMTGMinor
def prepare_sample(batch, device=None):
    """
    Put minibatch on the corresponding GPU
    :param batch:
    :param device:
    :return:
    """
    if isinstance(batch, list):
        batch = batch[0]
    batch = rewrap(batch)
    batch.cuda(fp16=False, device=device)

    return batch
コード例 #3
0
    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        self.model.train()
        self.loss_function.train()
        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model.zero_grad()
        self.model.reset_states()

        dataset = train_data

        # data iterator: object that controls the
        # data_iterator = DataIterator(dataset, dataset.collater, dataset.batches, seed=self.opt.seed,
        #                              num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size)
        data_iterator = generate_data_iterator(dataset,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=epoch,
                                               buffer_size=opt.buffer_size)

        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_tokens, total_loss, total_words = 0, 0, 0
        total_non_pads = 0
        report_loss, report_tgt_words = 0, 0
        report_src_words = 0
        report_ctc_loss = 0
        report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
        start = time.time()
        n_samples = len(epoch_iterator)

        counter = 0
        num_accumulated_words = 0
        num_accumulated_sents = 0
        grad_scaler = 1

        nan = False
        nan_counter = 0

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        i = data_iterator.iterations_in_epoch if not isinstance(
            train_data, list) else epoch_iterator.n_yielded

        while not data_iterator.end_of_epoch():

            curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            batch = next(epoch_iterator)
            if isinstance(batch, list) and self.n_gpus == 1:
                batch = batch[0]
            batch = rewrap(batch)

            if self.cuda:
                batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

            # if opt.streaming:
            #     if train_data.is_new_stream():
            #         streaming_state = self.model.init_stream()
            # else:
            #     streaming_state = None

            oom = False
            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                targets = batch.get('target_output')
                tgt_mask = targets.ne(onmt.constants.PAD)
                outputs = self.model(batch,
                                     streaming=opt.streaming,
                                     target_mask=tgt_mask,
                                     zero_encoder=opt.zero_encoder,
                                     mirror=opt.mirror_loss,
                                     streaming_state=streaming_state,
                                     nce=opt.nce)

                batch_size = batch.size

                outputs['tgt_mask'] = tgt_mask

                loss_dict = self.loss_function(outputs,
                                               targets,
                                               model=self.model,
                                               vocab_mask=batch.vocab_mask)
                loss_data = loss_dict['data']
                loss = loss_dict[
                    'loss']  # a little trick to avoid gradient overflow with fp16
                full_loss = loss

                if opt.ctc_loss > 0.0:
                    ctc_loss = self.ctc_loss_function(outputs, targets)
                    ctc_loss_data = ctc_loss.item()
                    full_loss = full_loss + opt.ctc_loss * ctc_loss
                    report_ctc_loss += ctc_loss_data

                if opt.mirror_loss:
                    rev_loss = loss_dict['rev_loss']
                    rev_loss_data = loss_dict['rev_loss_data']
                    mirror_loss = loss_dict['mirror_loss']
                    full_loss = full_loss + rev_loss + mirror_loss
                    mirror_loss_data = loss_dict['mirror_loss'].item()
                else:
                    rev_loss_data = None
                    mirror_loss_data = 0

                # reconstruction loss
                if opt.reconstruct:
                    rec_loss = loss_dict['rec_loss']
                    rec_loss = rec_loss
                    full_loss = full_loss + rec_loss
                    rec_loss_data = loss_dict['rec_loss_data']
                else:
                    rec_loss_data = None
                #
                # if opt.lfv_multilingual:
                #     lid_logits = outputs['lid_logits']
                #     lid_labels = batch.get('target_lang')
                #     lid_loss_function = self.loss_function.get_loss_function('lid_loss')
                #     lid_loss = lid_loss_function(lid_logits, lid_labels)
                #     full_loss = full_loss + lid_loss

                optimizer = self.optim.optimizer

                # When the batch size is large, each gradient step is very easy to explode on fp16
                # Normalizing the loss to grad scaler ensures this will not happen
                full_loss.div_(grad_scaler)

                if self.cuda:
                    with amp.scale_loss(full_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    full_loss.backward()

                del outputs

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory on GPU , skipping batch')
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                else:
                    raise e

            if loss != loss:
                # catching NAN problem
                oom = True
                self.model.zero_grad()
                self.optim.zero_grad()
                num_accumulated_words = 0
                num_accumulated_sents = 0
                nan_counter = nan_counter + 1
                print("Warning!!! Loss is Nan")
                if nan_counter >= 15:
                    raise ValueError(
                        "Training stopped because of multiple NaN occurence. "
                        "For ASR, using the Relative Transformer is more stable and recommended."
                    )
                counter = 0
            else:
                nan_counter = 0

            if not oom:
                src_size = batch.src_size
                tgt_size = batch.tgt_size

                counter = counter + 1
                num_accumulated_words += tgt_size
                num_accumulated_sents += batch_size

                #   We only update the parameters after getting gradients from n mini-batches
                update_flag = False
                if counter >= opt.update_frequency > 0:
                    update_flag = True
                elif 0 < opt.batch_size_update <= num_accumulated_words:
                    update_flag = True
                elif i == n_samples:  # update for the last minibatch
                    update_flag = True

                if update_flag:
                    # accumulated gradient case, in this case the update frequency
                    grad_denom = 1 / grad_scaler
                    if self.opt.normalize_gradient:
                        grad_denom = num_accumulated_words * grad_denom
                    # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                    normalize_gradients(amp.master_params(optimizer),
                                        grad_denom)
                    # Update the parameters.
                    if self.opt.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.opt.max_grad_norm)
                    self.optim.step()
                    self.optim.zero_grad()
                    self.model.zero_grad()
                    counter = 0
                    num_accumulated_words = 0
                    num_accumulated_sents = 0
                    num_updates = self.optim._step
                    if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                        valid_loss = self.eval(self.valid_data)
                        valid_ppl = math.exp(min(valid_loss, 100))
                        print('Validation perplexity: %g' % valid_ppl)

                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)

                        self.save(ep, valid_ppl, itr=data_iterator)

                num_words = tgt_size
                report_loss += loss_data
                report_tgt_words += num_words
                report_src_words += src_size
                total_loss += loss_data
                total_words += num_words
                total_tokens += batch.get('target_output').nelement()
                total_non_pads += batch.get('target_output').ne(
                    onmt.constants.PAD).sum().item()
                optim = self.optim
                batch_efficiency = total_non_pads / total_tokens

                if opt.reconstruct:
                    report_rec_loss += rec_loss_data

                if opt.mirror_loss:
                    report_rev_loss += rev_loss_data
                    report_mirror_loss += mirror_loss_data

                if i == 0 or (i % opt.log_interval == -1 % opt.log_interval):
                    log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " %
                                  (epoch, i + 1, len(data_iterator),
                                   math.exp(report_loss / report_tgt_words)))

                    if opt.reconstruct:
                        rec_ppl = math.exp(report_rec_loss /
                                           report_src_words.item())
                        log_string += (" rec_ppl: %6.2f ; " % rec_ppl)

                    if opt.mirror_loss:
                        rev_ppl = math.exp(report_rev_loss / report_tgt_words)
                        log_string += (" rev_ppl: %6.2f ; " % rev_ppl)
                        # mirror loss per word
                        log_string += (" mir_loss: %6.2f ; " %
                                       (report_mirror_loss / report_tgt_words))

                    log_string += ("lr: %.7f ; updates: %7d; " %
                                   (optim.getLearningRate(), optim._step))

                    log_string += ("%5.0f src tok/s; %5.0f tgt tok/s; " %
                                   (report_src_words /
                                    (time.time() - start), report_tgt_words /
                                    (time.time() - start)))

                    if opt.ctc_loss > 0.0:
                        ctc_loss = report_ctc_loss / report_tgt_words
                        log_string += (" ctcloss: %8.2f ; " % ctc_loss)

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    print(log_string)

                    report_loss = 0
                    report_tgt_words, report_src_words = 0, 0
                    report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
                    report_ctc_loss = 0
                    start = time.time()

                i = i + 1

        return total_loss / total_words
コード例 #4
0
    def eval(self, data):
        total_loss = 0
        total_words = 0
        opt = self.opt

        self.model.eval()
        self.loss_function.eval()
        self.model.reset_states()

        # the data iterator creates an epoch iterator
        data_iterator = generate_data_iterator(data,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=1,
                                               buffer_size=opt.buffer_size)
        epoch_iterator = data_iterator.next_epoch_itr(False, pin_memory=False)

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None
        """ PyTorch semantics: save space by not creating gradients """

        data_size = len(epoch_iterator)
        i = 0

        with torch.no_grad():
            # for i in range(len()):
            while not data_iterator.end_of_epoch():
                # batch = data.next()[0]
                batch = next(epoch_iterator)
                if isinstance(batch, list):
                    batch = batch[0]
                batch = rewrap(batch)

                if self.cuda:
                    batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)
                """ outputs can be either 
                        hidden states from decoder or
                        prob distribution from decoder generator
                """
                targets = batch.get('target_output')
                tgt_mask = targets.ne(onmt.constants.PAD)
                outputs = self.model(batch,
                                     streaming=opt.streaming,
                                     target_mask=tgt_mask,
                                     mirror=opt.mirror_loss,
                                     streaming_state=streaming_state,
                                     nce=opt.nce)

                if opt.streaming:
                    streaming_state = outputs['streaming_state']

                outputs['tgt_mask'] = tgt_mask

                loss_dict = self.loss_function(outputs,
                                               targets,
                                               model=self.model,
                                               eval=True,
                                               vocab_mask=batch.vocab_mask)

                loss_data = loss_dict['data']

                total_loss += loss_data
                total_words += batch.tgt_size
                i = i + 1

        self.model.train()
        self.loss_function.train()
        return total_loss / total_words
コード例 #5
0
ファイル: trainer.py プロジェクト: quanpn90/SpeechGAN
    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        self.model_ae.train()
        self.loss_function_ae.train()
        self.lat_dis.train()
        self.loss_lat_dis.train()

        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model_ae.zero_grad()
        self.lat_dis.zero_grad()

        dataset = train_data
        data_iterator = generate_data_iterator(dataset,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=epoch,
                                               buffer_size=opt.buffer_size)

        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_loss_ae, total_loss_lat_dis, total_frames, total_loss_adv = 0, 0, 0, 0

        report_loss_ae, report_loss_lat_dis, report_loss_adv, report_tgt_frames, report_sent = 0, 0, 0, 0, 0
        report_dis_frames, report_adv_frames = 0, 0

        start = time.time()
        n_samples = len(epoch_iterator)

        counter = 0
        step = 0

        num_accumulated_sents = 0
        grad_scaler = -1

        nan = False
        nan_counter = 0
        n_step_ae = opt.update_frequency
        n_step_lat_dis = opt.update_frequency
        mode_ae = True

        loss_lat_dis = 0.0

        i = data_iterator.iterations_in_epoch if not isinstance(
            train_data, list) else epoch_iterator.n_yielded

        # while not data_iterator.end_of_epoch():
        while True:

            if data_iterator.end_of_epoch():
                data_iterator = generate_data_iterator(
                    dataset,
                    seed=self.opt.seed,
                    num_workers=opt.num_workers,
                    epoch=epoch,
                    buffer_size=opt.buffer_size)
                epoch_iterator = data_iterator.next_epoch_itr(
                    not streaming, pin_memory=opt.pin_memory)

            curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            batch = next(epoch_iterator)
            if isinstance(batch, list) and self.n_gpus == 1:
                batch = batch[0]
            batch = rewrap(batch)

            batch_size = batch.size
            if grad_scaler == -1:
                grad_scaler = 1  # if self.opt.update_frequency > 1 else batch.tgt_size

            if self.cuda:
                batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

            oom = False
            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                #    targets = batch.get('target_output')
                #  tgt_mask = targets.ne(onmt.constants.PAD)
                if mode_ae:
                    step = self.optim_ae._step
                    loss_ae, loss_adv, encoder_outputs = self.autoencoder_backward(
                        batch, step)
                else:
                    loss_lat_dis, encoder_outputs = self.lat_dis_backward(
                        batch)

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory on GPU , skipping batch')
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                else:
                    raise e

            if loss_ae != loss_ae:
                # catching NAN problem
                oom = True
                self.model_ae.zero_grad()
                self.optim_ae.zero_grad()
                # self.lat_dis.zero_grad()
                # self.optim_lat_dis.zero_grad()

                nan_counter = nan_counter + 1
                print("Warning!!! Loss is Nan")
                if nan_counter >= 15:
                    raise ValueError(
                        "Training stopped because of multiple NaN occurence. "
                        "For ASR, using the Relative Transformer is more stable and recommended."
                    )
            else:
                nan_counter = 0

            if not oom:
                src_size = batch.src_size

                if mode_ae:
                    report_adv_frames += encoder_outputs['src_mask'].sum(
                    ).item()
                    report_loss_adv += loss_adv
                else:
                    report_dis_frames += encoder_outputs['src_mask'].sum(
                    ).item()
                    report_loss_lat_dis += loss_lat_dis

                counter = counter + 1

                #   We only update the parameters after getting gradients from n mini-batches
                update_flag = False
                if counter >= opt.update_frequency > 0:
                    update_flag = True
                elif i == n_samples:  # update for the last minibatch
                    update_flag = True

                if update_flag:
                    # accumulated gradient case, in this case the update frequency
                    if (counter == 1
                            and self.opt.update_frequency != 1) or counter > 1:
                        grad_denom = 1 / grad_scaler
                    else:
                        grad_denom = 1.0
                    # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                    normalize_gradients(
                        amp.master_params(self.optim_ae.optimizer), grad_denom)
                    normalize_gradients(
                        amp.master_params(self.optim_lat_dis.optimizer),
                        grad_denom)
                    # Update the parameters.
                    if self.opt.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim_ae.optimizer),
                            self.opt.max_grad_norm)
                        # torch.nn.utils.clip_grad_norm_(amp.master_params(self.optim_lat_dis.optimizer),
                        #                                self.opt.max_grad_norm)

                    torch.nn.utils.clip_grad_value_(
                        amp.master_params(self.optim_lat_dis.optimizer), 0.01)

                    if mode_ae:
                        self.optim_ae.step()
                        self.optim_ae.zero_grad()
                        self.model_ae.zero_grad()
                        self.optim_lat_dis.zero_grad()
                        self.lat_dis.zero_grad()
                    else:
                        self.optim_lat_dis.step()
                        self.optim_lat_dis.zero_grad()
                        self.lat_dis.zero_grad()
                        self.optim_ae.zero_grad()
                        self.model_ae.zero_grad()

                    num_updates = self.optim_ae._step

                    if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every and mode_ae:
                        # if mode_ae is not here then it will continuously save
                        valid_loss_ae, valid_loss_lat_dis = self.eval(
                            self.valid_data)
                        print('Validation loss ae: %g' % valid_loss_ae)
                        print('Validation loss latent discriminator: %g' %
                              valid_loss_lat_dis)
                        self.save(0, valid_loss_ae, itr=data_iterator)

                    if num_updates == 1000000:
                        break

                    mode_ae = not mode_ae
                    counter = 0
                    # num_accumulated_words = 0

                    grad_scaler = -1
                    num_updates = self.optim_ae._step

                report_loss_ae += loss_ae

                # report_tgt_words += num_words
                num_accumulated_sents += batch_size
                report_sent += batch_size
                total_frames += src_size
                report_tgt_frames += src_size
                total_loss_ae += loss_ae
                total_loss_lat_dis += loss_lat_dis
                total_loss_adv += loss_adv

                optim_ae = self.optim_ae
                optim_lat_dis = self.optim_lat_dis
                # batch_efficiency = total_non_pads / total_tokens

                if i == 0 or (i % opt.log_interval == -1 % opt.log_interval):
                    log_string = (
                        "loss_ae : %6.2f ;  loss_lat_dis : %6.2f, loss_adv : %6.2f "
                        % (report_loss_ae / report_tgt_frames,
                           report_loss_lat_dis /
                           (report_dis_frames + 1e-5), report_loss_adv /
                           (report_adv_frames + 1e-5)))

                    log_string += (
                        "lr_ae: %.7f ; updates: %7d; " %
                        (optim_ae.getLearningRate(), optim_ae._step))

                    log_string += (
                        "lr_lat_dis: %.7f ; updates: %7d; " %
                        (optim_lat_dis.getLearningRate(), optim_lat_dis._step))
                    #
                    log_string += ("%5.0f src tok/s " %
                                   (report_tgt_frames / (time.time() - start)))

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    print(log_string)

                    report_loss_ae = 0
                    report_loss_lat_dis = 0
                    report_loss_adv = 0
                    report_tgt_frames = 0
                    report_dis_frames = 0
                    report_adv_frames = 0
                    report_sent = 0
                    start = time.time()

                i = i + 1

        return total_loss_ae / total_frames * 100, total_loss_lat_dis / n_samples * 100, total_loss_adv / n_samples * 100
コード例 #6
0
ファイル: trainer.py プロジェクト: quanpn90/SpeechGAN
    def eval(self, data):
        total_loss_ae = 0
        total_loss_lat_dis = 0
        total_tgt_frames = 0
        total_src = 0
        total_sent = 0
        opt = self.opt

        self.model_ae.eval()
        self.loss_function_ae.eval()
        self.lat_dis.eval()
        self.loss_lat_dis.eval()
        # self.model.reset_states()

        # the data iterator creates an epoch iterator
        data_iterator = generate_data_iterator(data,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=1,
                                               buffer_size=opt.buffer_size)
        epoch_iterator = data_iterator.next_epoch_itr(False, pin_memory=False)
        """ PyTorch semantics: save space by not creating gradients """

        data_size = len(epoch_iterator)
        # print(data_size)
        i = 0

        with torch.no_grad():
            # for i in range(len()):
            while not data_iterator.end_of_epoch():
                # batch = data.next()[0]
                batch = next(epoch_iterator)
                if isinstance(batch, list):
                    batch = batch[0]
                batch = rewrap(batch)

                if self.cuda:
                    batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)
                """ outputs can be either 
                        hidden states from decoder or
                        prob distribution from decoder generator
                """
                encoder_outputs, decoder_outputs = self.model_ae(batch)

                gate_padded = batch.get('gate_padded')

                if self.opt.n_frames_per_step > 1:
                    slice = torch.arange(self.opt.n_frames_per_step - 1,
                                         gate_padded.size(1),
                                         self.opt.n_frames_per_step)
                    gate_padded = gate_padded[:, slice]

                src_org = batch.get('source_org')
                src_org = src_org.narrow(2, 1, src_org.size(2) - 1)
                target = [src_org.permute(1, 2, 0).contiguous(), gate_padded]
                loss_ae = self.loss_function_ae(decoder_outputs, target)
                loss_ae_data = loss_ae.data.item()

                preds = self.lat_dis(encoder_outputs['context'])

                loss_lat_dis = self.loss_lat_dis(
                    preds,
                    batch.get('source_lang'),
                    mask=encoder_outputs['src_mask'],
                    adversarial=False)
                total_src += encoder_outputs['src_mask'].float().sum().item()
                loss_lat_dis_data = loss_lat_dis.data.item()

                total_loss_ae += loss_ae_data
                total_loss_lat_dis += loss_lat_dis_data
                total_tgt_frames += batch.src_size
                total_sent += batch.size
                i = i + 1

        return total_loss_ae / total_tgt_frames, total_loss_lat_dis / total_src * 100
コード例 #7
0
    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        self.model.train()
        self.loss_function.train()
        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model.zero_grad()

        dataset = train_data
        data_iterator = generate_data_iterator(dataset,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=epoch,
                                               buffer_size=opt.buffer_size)

        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_loss, total_frames = 0, 0

        report_loss, report_tgt_frames, report_sent = 0, 0, 0

        start = time.time()
        n_samples = len(epoch_iterator)

        counter = 0

        num_accumulated_sents = 0
        grad_scaler = -1

        nan = False
        nan_counter = 0

        i = data_iterator.iterations_in_epoch if not isinstance(
            train_data, list) else epoch_iterator.n_yielded

        while not data_iterator.end_of_epoch():

            curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            batch = next(epoch_iterator)
            if isinstance(batch, list) and self.n_gpus == 1:
                batch = batch[0]
            batch = rewrap(batch)
            if grad_scaler == -1:
                grad_scaler = 1  # if self.opt.update_frequency > 1 else batch.tgt_size

            if self.cuda:
                batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

            oom = False
            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                #    targets = batch.get('target_output')
                #  tgt_mask = targets.ne(onmt.constants.PAD)
                outputs = self.model(batch)

                gate_padded = batch.get('gate_padded')

                if self.opt.n_frames_per_step > 1:
                    slice = torch.arange(0, gate_padded.size(1),
                                         self.opt.n_frames_per_step)
                    gate_padded = gate_padded[:, slice]

                src_org = batch.get('source_org')
                src_org = src_org.narrow(2, 1, src_org.size(2) - 1)

                target = [src_org.permute(1, 2, 0).contiguous(), gate_padded]
                loss = self.loss_function(outputs, target)

                batch_size = batch.size
                loss_data = loss.data.item()
                # a little trick to avoid gradient overflow with fp16
                full_loss = loss

                optimizer = self.optim.optimizer

                # When the batch size is large, each gradient step is very easy to explode on fp16
                # Normalizing the loss to grad scaler ensures this will not happen
                full_loss.div_(grad_scaler)

                if self.cuda:
                    with amp.scale_loss(full_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    full_loss.backward()

                del outputs

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory on GPU , skipping batch')
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                else:
                    raise e

            if loss != loss:
                # catching NAN problem
                oom = True
                self.model.zero_grad()
                self.optim.zero_grad()
                nan_counter = nan_counter + 1
                print("Warning!!! Loss is Nan")
                if nan_counter >= 15:
                    raise ValueError(
                        "Training stopped because of multiple NaN occurence. "
                        "For ASR, using the Relative Transformer is more stable and recommended."
                    )
            else:
                nan_counter = 0

            if not oom:
                src_size = batch.src_size

                counter = counter + 1

                #   We only update the parameters after getting gradients from n mini-batches
                update_flag = False
                if counter >= opt.update_frequency > 0:
                    update_flag = True
                elif i == n_samples:  # update for the last minibatch
                    update_flag = True

                if update_flag:
                    # accumulated gradient case, in this case the update frequency
                    if (counter == 1
                            and self.opt.update_frequency != 1) or counter > 1:
                        grad_denom = 1 / grad_scaler
                        # if self.opt.normalize_gradient:
                        #     grad_denom = num_accumulated_words * grad_denom
                    else:
                        grad_denom = 1.0
                    # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                    normalize_gradients(amp.master_params(optimizer),
                                        grad_denom)
                    # Update the parameters.
                    if self.opt.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.opt.max_grad_norm)
                    self.optim.step()
                    self.optim.zero_grad()
                    self.model.zero_grad()
                    counter = 0
                    # num_accumulated_words = 0

                    grad_scaler = -1
                    num_updates = self.optim._step
                    if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                        valid_loss = self.eval(self.valid_data)
                        valid_ppl = math.exp(min(valid_loss, 100))
                        print('Validation perplexity: %g' % valid_ppl)

                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)

                        self.save(ep, valid_ppl, itr=data_iterator)

                report_loss += loss_data
                # report_tgt_words += num_words
                num_accumulated_sents += batch_size
                report_sent += batch_size
                total_frames += src_size
                report_tgt_frames += src_size
                total_loss += loss_data

                optim = self.optim
                # batch_efficiency = total_non_pads / total_tokens

                if i == 0 or (i % opt.log_interval == -1 % opt.log_interval):
                    log_string = (
                        "Epoch %2d, %5d/%5d; ; loss : %6.2f ; " %
                        (epoch, i + 1, len(data_iterator), report_loss))

                    log_string += ("lr: %.7f ; updates: %7d; " %
                                   (optim.getLearningRate(), optim._step))
                    #
                    log_string += ("%5.0f src tok/s " %
                                   (report_tgt_frames / (time.time() - start)))

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    print(log_string)

                    report_loss = 0
                    report_tgt_frames = 0
                    report_sent = 0
                    start = time.time()

                i = i + 1

        return total_loss / n_samples * 100
コード例 #8
0
    def eval(self, data):
        total_loss = 0
        total_tgt_frames = 0
        total_sent = 0
        opt = self.opt

        self.model.eval()
        self.loss_function.eval()
        # self.model.reset_states()

        # the data iterator creates an epoch iterator
        data_iterator = generate_data_iterator(data,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=1,
                                               buffer_size=opt.buffer_size)
        epoch_iterator = data_iterator.next_epoch_itr(False, pin_memory=False)

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None
        """ PyTorch semantics: save space by not creating gradients """

        data_size = len(epoch_iterator)
        i = 0

        with torch.no_grad():
            # for i in range(len()):
            while not data_iterator.end_of_epoch():
                # batch = data.next()[0]
                batch = next(epoch_iterator)
                if isinstance(batch, list):
                    batch = batch[0]
                batch = rewrap(batch)

                if self.cuda:
                    batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)
                """ outputs can be either 
                        hidden states from decoder or
                        prob distribution from decoder generator
                """

                outputs = self.model(batch)

                gate_padded = batch.get('gate_padded')

                if self.opt.n_frames_per_step > 1:
                    slice = torch.arange(self.opt.n_frames_per_step - 1,
                                         gate_padded.size(1),
                                         self.opt.n_frames_per_step)
                    gate_padded = gate_padded[:, slice]

                src_org = batch.get('source_org')
                src_org = src_org.narrow(2, 1, src_org.size(2) - 1)
                target = [src_org.permute(1, 2, 0).contiguous(), gate_padded]
                loss = self.loss_function(outputs, target)
                loss_data = loss.data.item()

                total_loss += loss_data
                total_tgt_frames += batch.src_size
                total_sent += batch.size
                i = i + 1

        self.model.train()
        self.loss_function.train()
        return total_loss / data_size * 100
コード例 #9
0
    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        self.model.train()
        self.loss_function.train()
        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model.zero_grad()
        self.model.reset_states()

        dataset = train_data
        data_iterator = DataIterator(dataset,
                                     dataset.collater,
                                     dataset.batches,
                                     seed=self.opt.seed,
                                     num_workers=opt.num_workers,
                                     epoch=epoch,
                                     buffer_size=opt.buffer_size)

        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_tokens, total_loss, total_words = 0, 0, 0
        total_non_pads = 0
        report_loss, report_tgt_words = 0, 0
        report_src_words = 0
        report_sents = 0
        report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
        report_log_prior = 0
        report_log_variational_posterior = 0
        start = time.time()
        n_samples = len(epoch_iterator)

        counter = 0
        update_counter = 0
        num_accumulated_words = 0
        num_accumulated_sents = 0

        nan = False
        nan_counter = 0

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        i = data_iterator.iterations_in_epoch
        while not data_iterator.end_of_epoch():

            curriculum = (epoch < opt.curriculum)
            batch = next(epoch_iterator)
            batch = rewrap(batch)
            grad_scaler = self.opt.batch_size_words if self.opt.update_frequency > 1 else batch.tgt_size

            if self.cuda:
                batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

            oom = False
            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                targets = batch.get('target_output')
                tgt_mask = targets.data.ne(onmt.constants.PAD)
                outputs = self.model(batch,
                                     streaming=opt.streaming,
                                     target_mask=tgt_mask,
                                     zero_encoder=opt.zero_encoder,
                                     mirror=opt.mirror_loss,
                                     streaming_state=streaming_state)

                batch_size = batch.size

                outputs['tgt_mask'] = tgt_mask

                loss_dict = self.loss_function(outputs,
                                               targets,
                                               model=self.model)
                loss_data = loss_dict['data']
                loss = loss_dict[
                    'loss']  # a little trick to avoid gradient overflow with fp16
                log_prior = self.model.log_prior()
                log_variational_posterior = self.model.log_variational_posterior(
                )

                # the coeff starts off at 1 for each epoch
                # from BBB paper: The first mini batches in each epoch have large KL coeff
                # # the later minibatches are influenced by the data
                # denom = math.pow(1.5, min(32, update_counter))

                # min_coeff = 1 / (self.opt.model_size ** 2)
                # kl_coeff = max(1 / denom, min_coeff)
                kl_coeff = 1 / (batch.tgt_size * opt.update_frequency)
                # kl_coeff = 1 / (self.opt.model_size ** 2)
                # kl_coeff = 1
                full_loss = loss + kl_coeff * (log_variational_posterior -
                                               log_prior)
                # print(log_variational_posterior, log_prior)

                if opt.mirror_loss:
                    rev_loss = loss_dict['rev_loss']
                    rev_loss_data = loss_dict['rev_loss_data']
                    mirror_loss = loss_dict['mirror_loss']
                    full_loss = full_loss + rev_loss + mirror_loss
                    mirror_loss_data = loss_dict['mirror_loss'].item()
                else:
                    rev_loss = None
                    rev_loss_data = None
                    mirror_loss_data = 0

                # reconstruction loss
                if opt.reconstruct:
                    rec_loss = loss_dict['rec_loss']
                    rec_loss = rec_loss
                    full_loss = full_loss + rec_loss
                    rec_loss_data = loss_dict['rec_loss_data']
                else:
                    rec_loss_data = None

                optimizer = self.optim.optimizer

                # When the batch size is large, each gradient step is very easy to explode on fp16
                # Normalizing the loss to grad scaler ensures this will not happen
                full_loss.div_(grad_scaler)

                if self.cuda:
                    with amp.scale_loss(full_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    full_loss.backward()

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory on GPU , skipping batch')
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                else:
                    raise e

            if loss != loss:
                # catching NAN problem
                oom = True
                self.model.zero_grad()
                self.optim.zero_grad()
                num_accumulated_words = 0
                num_accumulated_sents = 0
                nan_counter = nan_counter + 1
                print("Warning!!! Loss is Nan")
                if nan_counter >= 15:
                    raise ValueError(
                        "Training stopped because of multiple NaN occurence. "
                        "For ASR, using the Relative Transformer is more stable and recommended."
                    )
            else:
                nan_counter = 0

            if not oom:
                src_size = batch.src_size
                tgt_size = batch.tgt_size

                counter = counter + 1
                num_accumulated_words += tgt_size
                num_accumulated_sents += batch_size

                #   We only update the parameters after getting gradients from n mini-batches
                update_flag = False
                if counter >= opt.update_frequency > 0:
                    update_flag = True
                elif 0 < opt.batch_size_update <= num_accumulated_words:
                    update_flag = True
                elif i == n_samples:  # update for the last minibatch
                    update_flag = True

                if update_flag:
                    # accumulated gradient case, in this case the update frequency
                    if (counter == 1
                            and self.opt.update_frequency != 1) or counter > 1:
                        grad_denom = 1 / grad_scaler
                        if self.opt.normalize_gradient:
                            grad_denom = num_accumulated_words * grad_denom
                    else:
                        grad_denom = 1
                    # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                    normalize_gradients(amp.master_params(optimizer),
                                        grad_denom)
                    # Update the parameters.
                    if self.opt.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.opt.max_grad_norm)
                    self.optim.step()
                    self.optim.zero_grad()
                    self.model.zero_grad()
                    counter = 0
                    num_accumulated_words = 0
                    num_accumulated_sents = 0
                    num_updates = self.optim._step
                    update_counter += 1
                    if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                        valid_loss = self.eval(self.valid_data)
                        valid_ppl = math.exp(min(valid_loss, 100))
                        print('Validation perplexity: %g' % valid_ppl)

                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)

                        self.save(ep, valid_ppl, itr=data_iterator)

                num_words = tgt_size
                report_loss += loss_data
                report_log_prior += log_prior.item()
                report_log_variational_posterior += log_variational_posterior.item(
                )
                report_tgt_words += num_words
                report_src_words += src_size
                report_sents += 1
                total_loss += loss_data
                total_words += num_words
                total_tokens += batch.get('target_output').nelement()
                total_non_pads += batch.get('target_output').ne(
                    onmt.constants.PAD).sum().item()
                optim = self.optim
                batch_efficiency = total_non_pads / total_tokens

                if opt.reconstruct:
                    report_rec_loss += rec_loss_data

                if opt.mirror_loss:
                    report_rev_loss += rev_loss_data
                    report_mirror_loss += mirror_loss_data

                if i == 0 or (i % opt.log_interval == -1 % opt.log_interval):
                    log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " %
                                  (epoch, i + 1, len(data_iterator),
                                   math.exp(report_loss / report_tgt_words)))

                    kl_div = report_log_variational_posterior - report_log_prior
                    log_string += ("KL q||p: %6.2f ; " %
                                   (kl_div / report_sents))

                    if opt.reconstruct:
                        rec_ppl = math.exp(report_rec_loss /
                                           report_src_words.item())
                        log_string += (" rec_ppl: %6.2f ; " % rec_ppl)

                    if opt.mirror_loss:
                        rev_ppl = math.exp(report_rev_loss / report_tgt_words)
                        log_string += (" rev_ppl: %6.2f ; " % rev_ppl)
                        # mirror loss per word
                        log_string += (" mir_loss: %6.2f ; " %
                                       (report_mirror_loss / report_tgt_words))

                    log_string += ("lr: %.7f ; updates: %7d; " %
                                   (optim.getLearningRate(), optim._step))

                    log_string += ("%5.0f src/s; %5.0f tgt/s; " %
                                   (report_src_words /
                                    (time.time() - start), report_tgt_words /
                                    (time.time() - start)))

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    print(log_string)

                    report_loss = 0
                    report_tgt_words, report_src_words = 0, 0
                    report_sents = 0
                    report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
                    report_log_prior, report_log_variational_posterior = 0, 0
                    start = time.time()

                i = i + 1

        return total_loss / total_words