Esempio n. 1
0
    def __init__(self, model, loss_function, train_data, valid_data, dicts, opt, setup_optimizer=True):
        super().__init__(model, loss_function, train_data, valid_data, dicts, opt)

        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages, opt.label_smoothing, opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        self.n_gpus = len(self.opt.gpus)

        if opt.ctc_loss != 0:
            from onmt.speech.ctc_loss import CTC
            self.ctc_loss_function = CTC(dicts['tgt'].size(), opt.model_size, 0.0, reduce=True)

        if self.cuda:
            torch.cuda.set_device(self.opt.gpus[0])
            if self.opt.seed >= 0:
                torch.manual_seed(self.opt.seed)
            self.loss_function = self.loss_function.cuda()
            self.model = self.model.cuda()
            if opt.ctc_loss > 0.0:
                self.ctc_loss_function = self.ctc_loss_function.cuda()

        if setup_optimizer:

            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if not self.opt.fp16:
                opt_level = "O0"
                keep_batchnorm_fp32 = False
            elif self.opt.fp16_mixed:
                opt_level = "O1"
                keep_batchnorm_fp32 = None
            else:
                opt_level = "O2"
                keep_batchnorm_fp32 = False

            if self.cuda:
                self.model, self.optim.optimizer = amp.initialize(self.model,
                                                                  self.optim.optimizer,
                                                                  opt_level=opt_level,
                                                                  keep_batchnorm_fp32=keep_batchnorm_fp32,
                                                                  loss_scale="dynamic",
                                                                  verbosity=1 if self.opt.verbose else 0)

            print(self.optim.optimizer)

        # An ugly hack to switch between align right and align left
        if hasattr(self.model, 'relative'):
            if self.model.relative:
                self.train_data.src_align_right = True
                self.train_data.tgt_align_right = False
                self.valid_data.src_align_right = True
                self.valid_data.tgt_align_right = False
Esempio n. 2
0
class XETrainer(BaseTrainer):
    def __init__(self,
                 model,
                 loss_function,
                 train_data,
                 valid_data,
                 dicts,
                 opt,
                 setup_optimizer=True):
        super().__init__(model, loss_function, train_data, valid_data, dicts,
                         opt)

        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages,
                                           opt.label_smoothing,
                                           opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        self.n_gpus = len(self.opt.gpus)

        if opt.ctc_loss != 0:
            from onmt.speech.ctc_loss import CTC
            self.ctc_loss_function = CTC(0.0, reduce=True)

        if self.cuda:
            torch.cuda.set_device(self.opt.gpus[0])
            if self.opt.seed >= 0:
                torch.manual_seed(self.opt.seed)
            self.loss_function = self.loss_function.cuda()
            self.model = self.model.cuda()
            if opt.ctc_loss > 0.0:
                self.ctc_loss_function = self.ctc_loss_function.cuda()

        if setup_optimizer:

            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if not self.opt.fp16:
                opt_level = "O0"
                keep_batchnorm_fp32 = False
            elif self.opt.fp16_mixed:
                opt_level = "O1"
                keep_batchnorm_fp32 = None
            else:
                opt_level = "O2"
                keep_batchnorm_fp32 = False

            if self.cuda:
                self.model, self.optim.optimizer = amp.initialize(
                    self.model,
                    self.optim.optimizer,
                    opt_level=opt_level,
                    keep_batchnorm_fp32=keep_batchnorm_fp32,
                    loss_scale="dynamic",
                    verbosity=1 if self.opt.verbose else 0)
        # An ugly hack to switch between align right and align left
        if hasattr(self.model, 'relative'):
            if self.model.relative:
                self.train_data.src_align_right = True
                self.train_data.tgt_align_right = False
                self.valid_data.src_align_right = True
                self.valid_data.tgt_align_right = False

    def save(self, epoch, valid_ppl, itr=None):

        opt = self.opt
        model = self.model
        dicts = self.dicts

        model_state_dict = self.model.state_dict()
        optim_state_dict = self.optim.state_dict()

        if itr:
            itr_state_dict = itr.state_dict()
        else:
            itr_state_dict = None

        #  drop a checkpoint
        checkpoint = {
            'model': model_state_dict,
            'dicts': dicts,
            'opt': opt,
            'epoch': epoch,
            'itr': itr_state_dict,
            'optim': optim_state_dict,
            'amp': amp.state_dict()
        }

        file_name = '%s_ppl_%.6f_e%.2f.pt' % (opt.save_model, valid_ppl, epoch)
        print('Writing to %s' % file_name)
        torch.save(checkpoint, file_name)

        # check the save directory here
        checkpoint_dir = os.path.dirname(opt.save_model)
        existed_save_files = checkpoint_paths(checkpoint_dir)
        for save_file in existed_save_files[opt.keep_save_files:]:
            print(" * Deleting old save file %s ...." % save_file)
            os.remove(save_file)

    def eval(self, data):
        total_loss = 0
        total_words = 0
        opt = self.opt

        self.model.eval()
        self.loss_function.eval()
        self.model.reset_states()

        # the data iterator creates an epoch iterator
        data_iterator = generate_data_iterator(data,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=1,
                                               buffer_size=opt.buffer_size)
        epoch_iterator = data_iterator.next_epoch_itr(False, pin_memory=False)

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None
        """ PyTorch semantics: save space by not creating gradients """

        data_size = len(epoch_iterator)
        i = 0

        with torch.no_grad():
            # for i in range(len()):
            while not data_iterator.end_of_epoch():
                # batch = data.next()[0]
                batch = next(epoch_iterator)
                if isinstance(batch, list):
                    batch = batch[0]
                batch = rewrap(batch)

                if self.cuda:
                    batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)
                """ outputs can be either 
                        hidden states from decoder or
                        prob distribution from decoder generator
                """
                targets = batch.get('target_output')
                tgt_mask = targets.ne(onmt.constants.PAD)
                outputs = self.model(batch,
                                     streaming=opt.streaming,
                                     target_mask=tgt_mask,
                                     mirror=opt.mirror_loss,
                                     streaming_state=streaming_state,
                                     nce=opt.nce)

                if opt.streaming:
                    streaming_state = outputs['streaming_state']

                outputs['tgt_mask'] = tgt_mask

                loss_dict = self.loss_function(outputs,
                                               targets,
                                               model=self.model,
                                               eval=True,
                                               vocab_mask=batch.vocab_mask)

                loss_data = loss_dict['data']

                total_loss += loss_data
                total_words += batch.tgt_size
                i = i + 1

        self.model.train()
        self.loss_function.train()
        return total_loss / total_words

    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        self.model.train()
        self.loss_function.train()
        # Clear the gradients of the model
        # self.runner.zero_grad()
        self.model.zero_grad()
        self.model.reset_states()

        dataset = train_data

        # data iterator: object that controls the
        # data_iterator = DataIterator(dataset, dataset.collater, dataset.batches, seed=self.opt.seed,
        #                              num_workers=opt.num_workers, epoch=epoch, buffer_size=opt.buffer_size)
        data_iterator = generate_data_iterator(dataset,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=epoch,
                                               buffer_size=opt.buffer_size)

        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_tokens, total_loss, total_words = 0, 0, 0
        total_non_pads = 0
        report_loss, report_tgt_words = 0, 0
        report_src_words = 0
        report_ctc_loss = 0
        report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
        start = time.time()
        n_samples = len(epoch_iterator)

        counter = 0
        num_accumulated_words = 0
        num_accumulated_sents = 0
        grad_scaler = 1

        nan = False
        nan_counter = 0

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        i = data_iterator.iterations_in_epoch if not isinstance(
            train_data, list) else epoch_iterator.n_yielded

        while not data_iterator.end_of_epoch():

            curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            batch = next(epoch_iterator)
            if isinstance(batch, list) and self.n_gpus == 1:
                batch = batch[0]
            batch = rewrap(batch)

            if self.cuda:
                batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

            # if opt.streaming:
            #     if train_data.is_new_stream():
            #         streaming_state = self.model.init_stream()
            # else:
            #     streaming_state = None

            oom = False
            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                targets = batch.get('target_output')
                tgt_mask = targets.ne(onmt.constants.PAD)
                outputs = self.model(batch,
                                     streaming=opt.streaming,
                                     target_mask=tgt_mask,
                                     zero_encoder=opt.zero_encoder,
                                     mirror=opt.mirror_loss,
                                     streaming_state=streaming_state,
                                     nce=opt.nce)

                batch_size = batch.size

                outputs['tgt_mask'] = tgt_mask

                loss_dict = self.loss_function(outputs,
                                               targets,
                                               model=self.model,
                                               vocab_mask=batch.vocab_mask)
                loss_data = loss_dict['data']
                loss = loss_dict[
                    'loss']  # a little trick to avoid gradient overflow with fp16
                full_loss = loss

                if opt.ctc_loss > 0.0:
                    ctc_loss = self.ctc_loss_function(outputs, targets)
                    ctc_loss_data = ctc_loss.item()
                    full_loss = full_loss + opt.ctc_loss * ctc_loss
                    report_ctc_loss += ctc_loss_data

                if opt.mirror_loss:
                    rev_loss = loss_dict['rev_loss']
                    rev_loss_data = loss_dict['rev_loss_data']
                    mirror_loss = loss_dict['mirror_loss']
                    full_loss = full_loss + rev_loss + mirror_loss
                    mirror_loss_data = loss_dict['mirror_loss'].item()
                else:
                    rev_loss_data = None
                    mirror_loss_data = 0

                # reconstruction loss
                if opt.reconstruct:
                    rec_loss = loss_dict['rec_loss']
                    rec_loss = rec_loss
                    full_loss = full_loss + rec_loss
                    rec_loss_data = loss_dict['rec_loss_data']
                else:
                    rec_loss_data = None
                #
                # if opt.lfv_multilingual:
                #     lid_logits = outputs['lid_logits']
                #     lid_labels = batch.get('target_lang')
                #     lid_loss_function = self.loss_function.get_loss_function('lid_loss')
                #     lid_loss = lid_loss_function(lid_logits, lid_labels)
                #     full_loss = full_loss + lid_loss

                optimizer = self.optim.optimizer

                # When the batch size is large, each gradient step is very easy to explode on fp16
                # Normalizing the loss to grad scaler ensures this will not happen
                full_loss.div_(grad_scaler)

                if self.cuda:
                    with amp.scale_loss(full_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    full_loss.backward()

                del outputs

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory on GPU , skipping batch')
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                else:
                    raise e

            if loss != loss:
                # catching NAN problem
                oom = True
                self.model.zero_grad()
                self.optim.zero_grad()
                num_accumulated_words = 0
                num_accumulated_sents = 0
                nan_counter = nan_counter + 1
                print("Warning!!! Loss is Nan")
                if nan_counter >= 15:
                    raise ValueError(
                        "Training stopped because of multiple NaN occurence. "
                        "For ASR, using the Relative Transformer is more stable and recommended."
                    )
                counter = 0
            else:
                nan_counter = 0

            if not oom:
                src_size = batch.src_size
                tgt_size = batch.tgt_size

                counter = counter + 1
                num_accumulated_words += tgt_size
                num_accumulated_sents += batch_size

                #   We only update the parameters after getting gradients from n mini-batches
                update_flag = False
                if counter >= opt.update_frequency > 0:
                    update_flag = True
                elif 0 < opt.batch_size_update <= num_accumulated_words:
                    update_flag = True
                elif i == n_samples:  # update for the last minibatch
                    update_flag = True

                if update_flag:
                    # accumulated gradient case, in this case the update frequency
                    grad_denom = 1 / grad_scaler
                    if self.opt.normalize_gradient:
                        grad_denom = num_accumulated_words * grad_denom
                    # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                    normalize_gradients(amp.master_params(optimizer),
                                        grad_denom)
                    # Update the parameters.
                    if self.opt.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.opt.max_grad_norm)
                    self.optim.step()
                    self.optim.zero_grad()
                    self.model.zero_grad()
                    counter = 0
                    num_accumulated_words = 0
                    num_accumulated_sents = 0
                    num_updates = self.optim._step
                    if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                        valid_loss = self.eval(self.valid_data)
                        valid_ppl = math.exp(min(valid_loss, 100))
                        print('Validation perplexity: %g' % valid_ppl)

                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)

                        self.save(ep, valid_ppl, itr=data_iterator)

                num_words = tgt_size
                report_loss += loss_data
                report_tgt_words += num_words
                report_src_words += src_size
                total_loss += loss_data
                total_words += num_words
                total_tokens += batch.get('target_output').nelement()
                total_non_pads += batch.get('target_output').ne(
                    onmt.constants.PAD).sum().item()
                optim = self.optim
                batch_efficiency = total_non_pads / total_tokens

                if opt.reconstruct:
                    report_rec_loss += rec_loss_data

                if opt.mirror_loss:
                    report_rev_loss += rev_loss_data
                    report_mirror_loss += mirror_loss_data

                if i == 0 or (i % opt.log_interval == -1 % opt.log_interval):
                    log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " %
                                  (epoch, i + 1, len(data_iterator),
                                   math.exp(report_loss / report_tgt_words)))

                    if opt.reconstruct:
                        rec_ppl = math.exp(report_rec_loss /
                                           report_src_words.item())
                        log_string += (" rec_ppl: %6.2f ; " % rec_ppl)

                    if opt.mirror_loss:
                        rev_ppl = math.exp(report_rev_loss / report_tgt_words)
                        log_string += (" rev_ppl: %6.2f ; " % rev_ppl)
                        # mirror loss per word
                        log_string += (" mir_loss: %6.2f ; " %
                                       (report_mirror_loss / report_tgt_words))

                    log_string += ("lr: %.7f ; updates: %7d; " %
                                   (optim.getLearningRate(), optim._step))

                    log_string += ("%5.0f src tok/s; %5.0f tgt tok/s; " %
                                   (report_src_words /
                                    (time.time() - start), report_tgt_words /
                                    (time.time() - start)))

                    if opt.ctc_loss > 0.0:
                        ctc_loss = report_ctc_loss / report_tgt_words
                        log_string += (" ctcloss: %8.2f ; " % ctc_loss)

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    print(log_string)

                    report_loss = 0
                    report_tgt_words, report_src_words = 0, 0
                    report_rec_loss, report_rev_loss, report_mirror_loss = 0, 0, 0
                    report_ctc_loss = 0
                    start = time.time()

                i = i + 1

        return total_loss / total_words

    # def run(self, save_file=None):
    def run(self, checkpoint=None):

        opt = self.opt
        model = self.model
        optim = self.optim

        if checkpoint is not None:
            self.model.load_state_dict(checkpoint['model'])
            prec_opt = checkpoint['opt'] if 'opt' in checkpoint else None

            if not opt.reset_optim:
                print("* Loading optimizer states ... ")
                self.optim.load_state_dict(checkpoint['optim'])
                if prec_opt is not None and hasattr(prec_opt, "fp16_mixed"):
                    # Only load amp information if the mode is the same
                    # Maybe its better to change between optimization mode?
                    if opt.fp16_mixed == prec_opt.fp16_mixed and opt.fp16 == prec_opt.fp16:
                        if 'amp' in checkpoint:
                            amp.load_state_dict(checkpoint['amp'])

                # Only load the progress when we use the same optimizer
                if 'itr' in checkpoint:
                    itr_progress = checkpoint['itr']
                else:
                    itr_progress = None

                resume = True
                start_epoch = checkpoint[
                    'epoch'] if 'epoch' in checkpoint else 1
                if start_epoch is None:
                    start_epoch = 1
            else:
                itr_progress = None
                resume = False
                start_epoch = 1

            del checkpoint['model']
            del checkpoint['optim']
            del checkpoint
        else:
            itr_progress = None
            print('Initializing model parameters')
            init_model_parameters(model, opt)
            resume = False
            start_epoch = 1

        if opt.load_encoder_from:
            self.load_encoder_weight(opt.load_encoder_from)

        if opt.load_decoder_from:
            self.load_decoder_weight(opt.load_decoder_from)

        # if we are on a GPU: warm up the memory allocator
        if self.cuda:
            self.warm_up()

            valid_loss = self.eval(self.valid_data)
            valid_ppl = math.exp(min(valid_loss, 100))
            print('Validation perplexity: %g' % valid_ppl)

        self.start_time = time.time()

        if opt.starting_step > 0:
            self.optim.override_starting_step(opt.starting_step)

        if opt.override_ctc_loss >= 0:
            opt.ctc_loss = opt.override_ctc_loss

        for epoch in range(start_epoch, start_epoch + opt.epochs):
            print('')

            #  (1) train for one epoch on the training set
            train_loss = self.train_epoch(epoch,
                                          resume=resume,
                                          itr_progress=itr_progress)
            train_ppl = math.exp(min(train_loss, 100))
            print('Train perplexity: %g' % train_ppl)

            #  (2) evaluate on the validation set
            valid_loss = self.eval(self.valid_data)
            valid_ppl = math.exp(min(valid_loss, 100))
            print('Validation perplexity: %g' % valid_ppl)

            self.save(epoch, valid_ppl)
            itr_progress = None
            resume = False
Esempio n. 3
0
class Trainer(object):

    def __init__(self, device, train_data, valid_data, dicts, opt, setup_optimizer=True):
        """
        :param model:
        :param device: int (GPU id)
        :param loss_function:
        :param train_data:
        :param valid_data:
        :param dicts:
        :param opt:
        """

        self.device = device
        opt.node_rank = 0
        opt.nodes = 1
        self.world_size = len(opt.gpus)

        # in the case of single node distributed, it should equal self.device
        self.rank = self.device

        # make a group to later use with self.all_reduce
        self.group = dist.group.WORLD

        self.print("[INFO] Training Options:", opt)
        if self.world_size > 1:
            dist.init_process_group(backend='nccl', init_method='env://', world_size=self.world_size, rank=self.rank)

        self.model = None

        if self.rank == 0:
            self.train_data = train_data
            self.valid_data = valid_data
        else:
            # Do we really need to deepcopy the data instances (which could cause memory leak easily)
            self.train_data = copy.deepcopy(train_data)
            self.valid_data = copy.deepcopy(valid_data)

        self.dicts = dicts
        self.opt = opt
        self.cuda = (len(opt.gpus) >= 1 and opt.gpus[0] >= 0)

        assert self.cuda, "[ERROR] Training is only available on GPUs."

        self.start_time = 0

        # setting up models and others
        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages, opt.label_smoothing, opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        torch.manual_seed(self.opt.seed)

        # note: we must start creating models after ccreating the processes
        # for some reason passing a pre-created model to a process creates a "pickle" error
        if not opt.fusion:

            if self.is_main():
                print("[INFO] Building models .... ", flush=True)
            model = build_model(opt, dicts)

            """ Building the loss function """
            if opt.ctc_loss > 0.0:
                from onmt.speech.ctc_loss import CTC
                self.ctc_loss_function = CTC(0.0, reduce=True)

            if opt.nce:
                from onmt.modules.nce.nce_loss import NCELoss
                loss_function = NCELoss(opt.model_size, dicts['tgt'].size(), noise_ratio=opt.nce_noise,
                                        logz=9, label_smoothing=opt.label_smoothing)
            else:
                loss_function = NMTLossFunc(opt.model_size, dicts['tgt'].size(),
                                            label_smoothing=opt.label_smoothing,
                                            mirror=opt.mirror_loss,
                                            fast_xentropy=opt.fast_xentropy)

            # This function replaces modules with the more optimized counterparts so that it can run faster
            # Currently exp with LayerNorm
            if not opt.memory_profiling:
                # distributed is required to convert BatchNorm to SyncBatchNorm for DDP
                optimize_model(model, distributed=(self.world_size > 1))

        init_model_parameters(model, opt)
        self.model = model
        self.loss_function = loss_function
        self.grad_scaler = torch.cuda.amp.GradScaler()

        if opt.load_from:
            checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage)
            self.model.load_state_dict(checkpoint['model'])
            if 'scaler' in checkpoint and checkpoint['scaler'] is not None:
                self.grad_scaler.load_state_dict(checkpoint['scaler'])

        if self.cuda:
            torch.cuda.set_device(self.device)

            self.loss_function = self.loss_function.cuda(device=self.device)
            self.model = self.model.cuda(device=self.device)
            if opt.ctc_loss > 0.0:
                self.ctc_loss_function = self.ctc_loss_function.cuda(device=self.device)

            # Ensure that the distributed copies have the same initial parameters
            # Manual seed may not work the same for different GPU models.
            # if self.world_size > 1:
            #     params = [p for p in self.model.parameters()]
            #
            #     with torch.no_grad():
            #         if not self.is_main():
            #             # zero everything except for the main model
            #             for p in params:
            #                 p.zero_()
            #         else:
            #             for p in params:
            #                 p.add_(0)
            #
            # # run all_reduce to ensure that all models have exactly the same parameters
            # if self.world_size > 1:
            #     params = [p for p in self.model.parameters()]
            #     all_reduce_and_rescale_tensors(params, 1)

        if setup_optimizer:

            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if self.is_main():
                print("[INFO] Optimizer: ", self.optim.optimizer)

            if opt.load_from:
                if 'optim' in checkpoint and checkpoint['optim'] is not None and not opt.reset_optim:
                    self.optim.load_state_dict(checkpoint['optim'])

        if self.world_size > 1:
            # find_unused_parameters may be required for dropped layer (parameters that are not connected to
            # any particular graph)
            find_unused_parameters = True

            self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank],
                                                                   output_device=self.rank,
                                                                   find_unused_parameters=find_unused_parameters)

        print("[INFO] Process %d ready." % self.rank, flush=True)

        # if self.world_size > 1:
        #     params = self.model.module.state_dict()
        #     param_keys = list(params.keys())
        #     # key = random.choice(param_keys)
        #     key = param_keys[0]
        #
        #     print(params[key].sum(), flush=True)

    def is_main(self):

        return self.rank == 0

    def all_reduce(self, tensor, **kwargs):

        if self.world_size > 1:
            dist.all_reduce(tensor, **kwargs)

        # otherwise, do nothing
        return

    def print(self, *content, flush=False):
        """
        A helper function to print only on the main process
        :param flush:
        :param content:
        :return:
        """

        if self.is_main():
            print(*content, flush=flush)
        else:
            return

    def load_encoder_weight(self, checkpoint_file):

        print("Loading pretrained models from %s" % checkpoint_file)
        checkpoint = torch.load(checkpoint_file, map_location=lambda storage, loc: storage)

        pretrained_model = build_model(checkpoint['opt'], checkpoint['dicts'])
        pretrained_model.load_state_dict(checkpoint['model'])

        print("Loading pretrained encoder weights ...")
        pretrained_model.encoder.language_embedding = None
        enc_language_embedding = self.model.encoder.language_embedding
        self.model.encoder.language_embedding = None
        encoder_state_dict = pretrained_model.encoder.state_dict()

        self.model.encoder.load_state_dict(encoder_state_dict)
        self.model.encoder.language_embedding = enc_language_embedding
        return

    def load_decoder_weight(self, checkpoint_file):

        self.print("Loading pretrained models from %s" % checkpoint_file)
        checkpoint = torch.load(checkpoint_file, map_location=lambda storage, loc: storage)
        chkpoint_dict = checkpoint['dicts']

        pretrained_model = build_model(checkpoint['opt'], chkpoint_dict)
        pretrained_model.load_state_dict(checkpoint['model'])

        self.print("Loading pretrained decoder weights ...")
        # first we have to remove the embeddings which probably have difference size ...
        pretrained_word_emb = pretrained_model.decoder.word_lut
        pretrained_model.decoder.word_lut = None
        pretrained_lang_emb = pretrained_model.decoder.language_embeddings
        pretrained_model.decoder.language_embeddings = None

        # actually we assume that two decoders have the same language embeddings...
        untrained_word_emb = self.model.decoder.word_lut
        self.model.decoder.word_lut = None
        untrained_lang_emb = self.model.decoder.language_embeddings
        self.model.decoder.language_embeddings = None

        decoder_state_dict = pretrained_model.decoder.state_dict()
        self.model.decoder.load_state_dict(decoder_state_dict)

        # now we load the embeddings ....
        n_copies = 0
        for token in self.dicts['tgt'].labelToIdx:

            untrained_id = self.dicts['tgt'].labelToIdx[token]

            if token in chkpoint_dict['tgt'].labelToIdx:
                pretrained_id = chkpoint_dict['tgt'].labelToIdx[token]
                untrained_word_emb.weight.data[untrained_id].copy_(pretrained_word_emb.weight.data[pretrained_id])

                self.model.generator[0].linear.bias.data[untrained_id].copy_(pretrained_model
                                                                             .generator[0].linear.bias.data[
                                                                                 pretrained_id])
                n_copies += 1

        self.print("Copied embedding for %d words" % n_copies)
        self.model.decoder.word_lut = untrained_word_emb

        # now we load the language embeddings ...
        if pretrained_lang_emb and untrained_lang_emb and 'langs' in chkpoint_dict:
            for lang in self.dicts['langs']:

                untrained_id = self.dicts['langs'][lang]
                if lang in chkpoint_dict['langs']:
                    pretrained_id = chkpoint_dict['langs'][lang]
                    untrained_lang_emb.weight.data[untrained_id].copy_(pretrained_lang_emb.weight.data[pretrained_id])

        self.model.decoder.language_embeddings = untrained_lang_emb

    def warm_up(self):
        """
        Warmup the memory allocator, by attempting to fit the largest batch
        :return:
        """

        # if self.opt.memory_profiling:
        #     from pytorch_memlab import MemReporter
        #     reporter = MemReporter()
        #
        batch = self.train_data[0].get_largest_batch() if isinstance(self.train_data, list) \
            else self.train_data.get_largest_batch()
        opt = self.opt

        if self.cuda:
            batch.cuda(fp16=False)

        self.model.train()
        self.loss_function.train()
        self.model.zero_grad()
        oom = False

        if self.opt.memory_profiling:
            self.print("Input size: ")
            self.print(batch.size, batch.src_size, batch.tgt_size)

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        try:
            with autocast():
                targets = batch.get('target_output')
                tgt_mask = None
                outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask,
                                     zero_encoder=opt.zero_encoder,
                                     mirror=opt.mirror_loss, streaming_state=streaming_state,
                                     nce=opt.nce)

                outputs['tgt_mask'] = tgt_mask

                loss_dict = self.loss_function(outputs, targets, model=self.model)
                loss_data = loss_dict['data']
                loss = loss_dict['loss']  # a little trick to avoid gradient overflow with fp16
                full_loss = loss

                if opt.ctc_loss > 0.0:
                    ctc_loss = self.ctc_loss_function(outputs, targets)
                    ctc_loss_data = ctc_loss.item()
                    full_loss = full_loss + opt.ctc_loss * ctc_loss

                if opt.mirror_loss:
                    rev_loss = loss_dict['rev_loss']
                    mirror_loss = loss_dict['mirror_loss']
                    full_loss = full_loss + rev_loss + mirror_loss

                # reconstruction loss
                if opt.reconstruct:
                    rec_loss = loss_dict['rec_loss']
                    rec_loss = rec_loss
                    full_loss = full_loss + rec_loss

                optimizer = self.optim.optimizer

                if self.opt.memory_profiling:
                    reporter.report(verbose=True)

                # for obj in gc.get_objects():
                #     try:
                #         if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                #             # print(varname(obj))
                #             # we can rule out parameter cost later
                #             # if 'parameter' not in type(obj):
                #             # if len(obj.shape) == 3:
                #             # if not isinstance(obj, torch.nn.parameter.Parameter):
                #             #     tensor = obj
                #             #     numel = tensor.
                #             print(type(obj), obj.type(), obj.size())
                #     except:
                #         pass

                # print("Memory profiling complete.")
                # print(torch.cuda.memory_summary())
                # exit()

            self.grad_scaler.scale(full_loss).backward()
            # if self.cuda:
            #     with amp.scale_loss(full_loss, optimizer) as scaled_loss:
            #         scaled_loss.backward()
            # else:
            #     loss.div_(batch.tgt_size).backward()

            if self.opt.memory_profiling:
                print('========= after backward =========')
                reporter.report(verbose=True)

            self.model.zero_grad()
            self.optim.zero_grad()
            # self.optim.step()
            # self.optim.reset()

        except RuntimeError as e:
            if 'out of memory' in str(e):
                oom = True
            else:
                raise e

        if oom:
            print("[INFO] Warning: out-of-memory in warming up. "
                  "This is due to the largest batch is too big for the GPU.",
                  flush=True)
        else:
            self.print("[INFO] Warming up successfully.", flush=True)

        if self.opt.memory_profiling:
            if hasattr(torch.cuda, 'memory_summary'):
                print(torch.cuda.memory_summary())
            exit()

    def save(self, epoch, valid_ppl, itr=None):

        opt = self.opt
        model = self.model
        dicts = self.dicts

        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model_state_dict = self.model.module.state_dict()
        else:
            model_state_dict = self.model.state_dict()
        optim_state_dict = self.optim.state_dict()

        if itr:
            itr_state_dict = itr.state_dict()
        else:
            itr_state_dict = None

        #  drop a checkpoint
        checkpoint = {
            'model': model_state_dict,
            'dicts': dicts,
            'opt': opt,
            'epoch': epoch,
            'itr': itr_state_dict,
            'optim': optim_state_dict,
            'scaler': self.grad_scaler.state_dict()
        }

        file_name = '%s_ppl_%.6f_e%.2f.pt' % (opt.save_model, valid_ppl, epoch)
        print('Writing to %s' % file_name)
        torch.save(checkpoint, file_name)

        # check the save directory here
        checkpoint_dir = os.path.dirname(opt.save_model)
        existed_save_files = checkpoint_paths(checkpoint_dir)
        for save_file in existed_save_files[opt.keep_save_files:]:
            print(" * Deleting old save file %s ...." % save_file)
            os.remove(save_file)

    def eval(self, data):
        self.print("[INFO] Running cross-entropy evaluation...", flush=True)
        opt = self.opt
        rank = self.device
        world_size = self.world_size

        # the data iterator creates an epoch iterator
        # for eval, we use false fill_value
        data_iterator = generate_data_iterator(data, rank, world_size, seed=self.opt.seed,
                                               num_workers=opt.num_workers, epoch=1, buffer_size=opt.buffer_size,
                                               fill_value=False)
        epoch_iterator = data_iterator.next_epoch_itr(False, pin_memory=False)

        data_size = len(epoch_iterator)
        i = 0

        self.model.eval()
        self.loss_function.eval()
        # self.model.module.reset_states()

        total_loss = zero_tensor()
        total_words = zero_tensor()

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        with torch.no_grad():
            while not data_iterator.end_of_epoch():
                samples = next(epoch_iterator)

                if samples:
                    with autocast():
                        batch = prepare_sample(samples, device=self.device)
                        targets = batch.get('target_output')
                        tgt_mask = targets.ne(onmt.constants.PAD)
                        outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask,
                                             mirror=opt.mirror_loss, streaming_state=streaming_state, nce=opt.nce)

                        outputs['tgt_mask'] = tgt_mask
                        with autocast(enabled=False):
                            loss_dict = self.loss_function(outputs, targets, model=self.model, eval=True)
                            loss_data = loss_dict['data']

                    total_loss.add_(loss_data)
                    total_words.add_(batch.tgt_size)
                    i = i + 1

        # allreduce the total loss and total words from other processes
        self.all_reduce(total_loss, op=dist.ReduceOp.SUM, group=self.group)
        self.all_reduce(total_words, op=dist.ReduceOp.SUM, group=self.group)

        self.model.train()
        self.loss_function.train()

        return total_loss / total_words

    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        # Clear the gradients of the model
        self.model.zero_grad()
        # self.model.module.reset_states()

        dataset = train_data
        data_iterator = generate_data_iterator(dataset, self.rank, self.world_size,
                                               seed=self.opt.seed, num_workers=opt.num_workers,
                                               epoch=epoch, buffer_size=opt.buffer_size)

        # TODO: fix resume which is currently buggy
        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(not streaming, pin_memory=opt.pin_memory)

        total_tokens, total_loss, total_words = zero_tensor(), zero_tensor(), zero_tensor()
        total_non_pads = zero_tensor()
        report_loss, report_tgt_words = zero_tensor(), zero_tensor()
        report_ctc_loss = zero_tensor()
        report_src_words = zero_tensor()
        report_rec_loss, report_rev_loss, report_mirror_loss = zero_tensor(), zero_tensor(), zero_tensor()
        start = time.time()
        n_samples = len(data_iterator)

        counter = 0
        num_accumulated_words = zero_tensor()
        num_accumulated_sents = zero_tensor()
        grad_div = 1

        nan = False
        nan_counter = zero_tensor()

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        # i = number of yielded items from the dataset at the moment
        i = data_iterator.iterations_in_epoch if not isinstance(train_data, list) else epoch_iterator.n_yielded
        # i = i * self.world_size  #

        while not data_iterator.end_of_epoch():

            # curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            # TODO: move everything to the multiGPU trainer
            samples = next(epoch_iterator)

            batch = prepare_sample(samples, device=self.device)

            if opt.streaming:
                if train_data.is_new_stream():
                    streaming_state = self.model.init_stream()
            else:
                streaming_state = None

            # TODO: dealing with oom during distributed training
            oom = False

            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                counter = counter + 1
                reduction_disabled = False if counter >= opt.update_frequency or i == (n_samples - 1) else True

                def maybe_no_sync():
                    # if not reduction_disabled and isinstance(self.model, DDP_model):
                    #     return self.model.no_sync()
                    # else:
                    #     # when we dont reach the updating step, we do not need to synchronize the gradients
                    #     # thus disabling the backward grad sync to improve speed
                    return contextlib.ExitStack()  # dummy contextmanager

                with maybe_no_sync():
                    with autocast():
                        targets = batch.get('target_output')
                        tgt_mask = targets.ne(onmt.constants.PAD)
                        outputs = self.model(batch, streaming=opt.streaming, target_mask=tgt_mask,
                                             zero_encoder=opt.zero_encoder,
                                             mirror=opt.mirror_loss, streaming_state=streaming_state,
                                             nce=opt.nce)

                        batch_size = batch.size
                        outputs['tgt_mask'] = tgt_mask

                        # Loss functions must be computed in FP32 regions
                        with autocast(enabled=False):
                            loss_dict = self.loss_function(outputs, targets, model=self.model)
                        loss_data = loss_dict['data']
                        loss = loss_dict['loss']  # a little trick to avoid gradient overflow with fp16
                        full_loss = loss

                        if opt.ctc_loss > 0.0:
                            with autocast(enabled=False):
                                ctc_loss = self.ctc_loss_function(outputs, targets)
                            ctc_loss_data = ctc_loss.item()
                            full_loss = full_loss + opt.ctc_loss * ctc_loss

                        if opt.mirror_loss:
                            rev_loss = loss_dict['rev_loss']
                            rev_loss_data = loss_dict['rev_loss_data']
                            mirror_loss = loss_dict['mirror_loss']
                            full_loss = full_loss + rev_loss + mirror_loss
                            mirror_loss_data = loss_dict['mirror_loss'].item()
                        else:
                            rev_loss_data = None
                            mirror_loss_data = 0

                        # reconstruction loss
                        if opt.reconstruct:
                            rec_loss = loss_dict['rec_loss']
                            rec_loss = rec_loss
                            full_loss = full_loss + rec_loss
                            rec_loss_data = loss_dict['rec_loss_data']
                        else:
                            rec_loss_data = None

                        if opt.lfv_multilingual:
                            lid_logits = outputs['lid_logits']
                            lid_labels = batch.get('target_lang')
                            lid_loss_function = self.loss_function.get_loss_function('lid_loss')
                            lid_loss = lid_loss_function(lid_logits, lid_labels)
                            full_loss = full_loss + lid_loss

                        optimizer = self.optim.optimizer

                        # When the batch size is large, each gradient step is very easy to explode on fp16
                        # Normalizing the loss to grad scaler ensures this will not happen
                        full_loss.div_(grad_div)

                    # grad scaler has to be done outside of the autocast
                    self.grad_scaler.scale(full_loss).backward()

                del outputs

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('[WARNING]: ran out of memory on GPU %d' % self.rank, flush=True)
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                    raise e
                else:
                    raise e

            batch_size = batch.size

            src_size = batch.src_size
            tgt_size = batch.tgt_size

            if loss != loss:
                # catching NAN problem
                # oom = True
                self.model.zero_grad()
                self.optim.zero_grad()
                num_accumulated_words.zero_()
                num_accumulated_sents.zero_()

                # print("Warning!!! Loss is Nan")
                # nan_counter.add_(1)
                # if nan_counter >= 15:
                #     raise ValueError("Training stopped because of multiple NaN occurence. "
                #                      "For ASR, using the Relative Transformer is more stable and recommended.")
            else:
                # nan_counter = 0
                num_words = tgt_size
                report_loss.add_(loss_data)
                total_loss.add_(loss_data)
                report_tgt_words.add_(num_words)
                report_src_words.add_(src_size)
                total_words.add_(num_words)

                if opt.ctc_loss > 0.0:
                    report_ctc_loss.add_(ctc_loss_data)
                if opt.reconstruct:
                    report_rec_loss.add_(rec_loss_data)
                if opt.mirror_loss:
                    report_rev_loss.add_(rev_loss_data)
                    report_mirror_loss.add_(mirror_loss_data)

                num_accumulated_words.add_(tgt_size)
                num_accumulated_sents.add_(batch_size)

            # self.all_reduce(nan_counter, op=dist.ReduceOp.SUM, group=self.group)
            # if we have NaN in one process, then restart all
            # if nan_counter.item() > 0:
            #     self.optim.zero_grad()
            #     self.model.zero_grad()
            #     counter = 0
            #     nan_counter.zero_()

            # We only update the parameters after getting gradients from n mini-batches
            update_flag = False
            if counter >= opt.update_frequency:
                update_flag = True
            elif i == n_samples - 1:  # update for the last minibatch
                update_flag = True

            if update_flag:
                # accumulated gradient case, in this case the update frequency
                self.all_reduce(num_accumulated_words, op=dist.ReduceOp.SUM, group=self.group)
                # self.all_reduce(nan_counter, op=dist.ReduceOp.SUM, group=self.group)

                # grad_denom = 1.0 / grad_div
                #
                # if self.opt.normalize_gradient:
                #     grad_denom = num_accumulated_words.item() * grad_denom

                # the gradient is scaled by world size, so in order to match the model without multiGPU
                # we rescale the model parameters w.r.t the world size
                # grad_denom = grad_denom / self.world_size

                # self.grad_scaler.unscale_(self.optim.optimizer)

                # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                # normalize_gradients(self.model.parameters(), grad_denom)

                # Update the parameters.
                if self.opt.max_grad_norm > 0:
                    self.grad_scaler.unscale_(self.optim.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.opt.max_grad_norm)

                self.optim.step(scaler=self.grad_scaler)
                counter = 0

                self.grad_scaler.update()
                self.optim.zero_grad()
                self.model.zero_grad()

                num_accumulated_words.zero_()
                num_accumulated_sents.zero_()
                nan_counter.zero_()

                num_updates = self.optim._step
                if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                    valid_loss = self.eval(self.valid_data)
                    valid_ppl = math.exp(min(valid_loss, 100))

                    if self.is_main():
                        print('Validation perplexity: %g' % valid_ppl)
                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)
                        self.save(ep, valid_ppl, itr=data_iterator)

            num_words = tgt_size

            # report_tgt_words.add_(num_words)
            # report_src_words.add_(src_size)
            # total_words.add_(num_words)
            # total_tokens += batch.get('target_output').nelement()
            # total_non_pads += batch.get('target_output').ne(onmt.constants.PAD).sum().item()
            # batch_efficiency = total_non_pads / total_tokens

            # increase i by world size
            i = i + self.world_size

            # control the index a little bit to ensure the log is always printed
            if i == self.world_size or (i % opt.log_interval < self.world_size):

                self.all_reduce(report_loss, op=dist.ReduceOp.SUM, group=self.group)
                self.all_reduce(report_tgt_words, op=dist.ReduceOp.SUM, group=self.group)
                self.all_reduce(report_src_words, op=dist.ReduceOp.SUM, group=self.group)
                # Looks like this op is bugged for some reason?
                if opt.ctc_loss > 0:
                    self.all_reduce(report_ctc_loss, op=dist.ReduceOp.SUM, group=self.group)

                if self.is_main():
                    log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " %
                                  (epoch, i, len(data_iterator),
                                   math.exp(report_loss.item() / report_tgt_words.item())))

                    if opt.reconstruct:
                        self.all_reduce(report_rec_loss, op=dist.ReduceOp.SUM, group=self.group)
                        rec_ppl = math.exp(report_rec_loss.item() / report_src_words.item())
                        log_string += (" rec_ppl: %6.2f ; " % rec_ppl)

                    if opt.mirror_loss:
                        self.all_reduce(report_rev_loss, op=dist.ReduceOp.SUM, group=self.group)
                        rev_ppl = math.exp(report_rev_loss.item() / report_tgt_words.item())
                        log_string += (" rev_ppl: %6.2f ; " % rev_ppl)
                        log_string += (" mir_loss: %6.2f ; " % (report_mirror_loss / report_tgt_words))

                    if opt.ctc_loss > 0.0:
                        ctc_loss = report_ctc_loss.item() / report_tgt_words.item()
                        log_string += (" ctcloss: %8.2f ; " % ctc_loss)

                    log_string += ("lr: %.7f ; updates: %7d; " %
                                   (self.optim.getLearningRate(),
                                    self.optim._step))

                    log_string += ("%5.0f src tok/s; %5.0f tgt tok/s; " %
                                   (report_src_words.item() / (time.time() - start),
                                    report_tgt_words.item() / (time.time() - start)))

                    log_string += ("%s elapsed" %
                                   str(datetime.timedelta(seconds=int(time.time() - self.start_time))))

                    self.print(log_string, flush=True)

                report_loss.zero_()
                report_tgt_words.zero_()
                report_src_words.zero_()
                report_rec_loss.zero_()
                report_rev_loss.zero_()
                report_mirror_loss.zero_()
                report_ctc_loss.zero_()
                start = time.time()



        return total_loss / total_words

    # def run(self, save_file=None):
    def run(self, checkpoint=None):

        opt = self.opt

        if checkpoint is not None:

            # TODO: have loading checkpoints for each process
            prec_opt = checkpoint['opt'] if 'opt' in checkpoint else None

            if not opt.reset_optim:

                # Only load the progress when we use the same optimizer
                if 'itr' in checkpoint:
                    itr_progress = checkpoint['itr']
                else:
                    itr_progress = None

                resume = True
                start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 1
                if start_epoch is None:
                    start_epoch = 1
            else:
                itr_progress = None
                resume = False
                start_epoch = 1

            # optim_state_dict = checkpoint['optim']
            # # del checkpoint['optim']
            del checkpoint
        else:
            itr_progress = None

            resume = False
            start_epoch = 1

        if opt.load_encoder_from:
            self.load_encoder_weight(opt.load_encoder_from)
        #
        if opt.load_decoder_from:
            self.load_decoder_weight(opt.load_decoder_from)

        # if we are on a GPU: warm up the memory allocator
        if self.cuda:
            self.warm_up()

        valid_loss = self.eval(self.valid_data)
        valid_ppl = math.exp(min(valid_loss, 100))

        if opt.starting_step > 0:
            self.optim.override_starting_step(opt.starting_step)

        if self.is_main():
            print('[INFO] Validation perplexity: %g' % valid_ppl, flush=True)

        self.start_time = time.time()

        for epoch in range(start_epoch, start_epoch + opt.epochs):
            self.print('')

            #  (1) train for one epoch on the training set
            train_loss = self.train_epoch(epoch, resume=resume, itr_progress=itr_progress)
            train_ppl = math.exp(min(train_loss, 100))
            self.print('[INFO] Train perplexity: %g' % train_ppl)

            #  (2) evaluate on the validation set
            valid_loss = self.eval(self.valid_data)
            valid_ppl = math.exp(min(valid_loss, 100))

            if self.is_main():
                print('[INFO] Validation perplexity: %g' % valid_ppl)
                self.save(epoch, valid_ppl)

            itr_progress = None
            resume = False
Esempio n. 4
0
    def __init__(self, device, train_data, valid_data, dicts, opt, setup_optimizer=True):
        """
        :param model:
        :param device: int (GPU id)
        :param loss_function:
        :param train_data:
        :param valid_data:
        :param dicts:
        :param opt:
        """

        self.device = device
        opt.node_rank = 0
        opt.nodes = 1
        self.world_size = len(opt.gpus)

        # in the case of single node distributed, it should equal self.device
        self.rank = self.device

        # make a group to later use with self.all_reduce
        self.group = dist.group.WORLD

        self.print("[INFO] Training Options:", opt)
        if self.world_size > 1:
            dist.init_process_group(backend='nccl', init_method='env://', world_size=self.world_size, rank=self.rank)

        self.model = None

        if self.rank == 0:
            self.train_data = train_data
            self.valid_data = valid_data
        else:
            # Do we really need to deepcopy the data instances (which could cause memory leak easily)
            self.train_data = copy.deepcopy(train_data)
            self.valid_data = copy.deepcopy(valid_data)

        self.dicts = dicts
        self.opt = opt
        self.cuda = (len(opt.gpus) >= 1 and opt.gpus[0] >= 0)

        assert self.cuda, "[ERROR] Training is only available on GPUs."

        self.start_time = 0

        # setting up models and others
        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages, opt.label_smoothing, opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        torch.manual_seed(self.opt.seed)

        # note: we must start creating models after ccreating the processes
        # for some reason passing a pre-created model to a process creates a "pickle" error
        if not opt.fusion:

            if self.is_main():
                print("[INFO] Building models .... ", flush=True)
            model = build_model(opt, dicts)

            """ Building the loss function """
            if opt.ctc_loss > 0.0:
                from onmt.speech.ctc_loss import CTC
                self.ctc_loss_function = CTC(0.0, reduce=True)

            if opt.nce:
                from onmt.modules.nce.nce_loss import NCELoss
                loss_function = NCELoss(opt.model_size, dicts['tgt'].size(), noise_ratio=opt.nce_noise,
                                        logz=9, label_smoothing=opt.label_smoothing)
            else:
                loss_function = NMTLossFunc(opt.model_size, dicts['tgt'].size(),
                                            label_smoothing=opt.label_smoothing,
                                            mirror=opt.mirror_loss,
                                            fast_xentropy=opt.fast_xentropy)

            # This function replaces modules with the more optimized counterparts so that it can run faster
            # Currently exp with LayerNorm
            if not opt.memory_profiling:
                # distributed is required to convert BatchNorm to SyncBatchNorm for DDP
                optimize_model(model, distributed=(self.world_size > 1))

        init_model_parameters(model, opt)
        self.model = model
        self.loss_function = loss_function
        self.grad_scaler = torch.cuda.amp.GradScaler()

        if opt.load_from:
            checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage)
            self.model.load_state_dict(checkpoint['model'])
            if 'scaler' in checkpoint and checkpoint['scaler'] is not None:
                self.grad_scaler.load_state_dict(checkpoint['scaler'])

        if self.cuda:
            torch.cuda.set_device(self.device)

            self.loss_function = self.loss_function.cuda(device=self.device)
            self.model = self.model.cuda(device=self.device)
            if opt.ctc_loss > 0.0:
                self.ctc_loss_function = self.ctc_loss_function.cuda(device=self.device)

            # Ensure that the distributed copies have the same initial parameters
            # Manual seed may not work the same for different GPU models.
            # if self.world_size > 1:
            #     params = [p for p in self.model.parameters()]
            #
            #     with torch.no_grad():
            #         if not self.is_main():
            #             # zero everything except for the main model
            #             for p in params:
            #                 p.zero_()
            #         else:
            #             for p in params:
            #                 p.add_(0)
            #
            # # run all_reduce to ensure that all models have exactly the same parameters
            # if self.world_size > 1:
            #     params = [p for p in self.model.parameters()]
            #     all_reduce_and_rescale_tensors(params, 1)

        if setup_optimizer:

            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if self.is_main():
                print("[INFO] Optimizer: ", self.optim.optimizer)

            if opt.load_from:
                if 'optim' in checkpoint and checkpoint['optim'] is not None and not opt.reset_optim:
                    self.optim.load_state_dict(checkpoint['optim'])

        if self.world_size > 1:
            # find_unused_parameters may be required for dropped layer (parameters that are not connected to
            # any particular graph)
            find_unused_parameters = True

            self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank],
                                                                   output_device=self.rank,
                                                                   find_unused_parameters=find_unused_parameters)

        print("[INFO] Process %d ready." % self.rank, flush=True)
Esempio n. 5
0
    def __init__(self,
                 device,
                 train_data,
                 valid_data,
                 dicts,
                 opt,
                 setup_optimizer=True):
        """
        :param model:
        :param device: int (GPU id)
        :param loss_function:
        :param train_data:
        :param valid_data:
        :param dicts:
        :param opt:
        """

        self.device = device
        opt.node_rank = 0
        opt.nodes = 1
        self.world_size = len(opt.gpus)

        # in the case of single node distributed, it should equal self.device
        self.rank = self.device

        # make a group to later use with self.all_reduce
        self.group = dist.group.WORLD

        self.print("[INFO] Training Options:", opt)
        if self.world_size > 1:
            dist.init_process_group(backend='nccl',
                                    init_method='env://',
                                    world_size=self.world_size,
                                    rank=self.rank)

        self.model = None

        if self.rank == 0:
            self.train_data = train_data
            self.valid_data = valid_data
        else:
            # Do we really need to deepcopy the data instances (which could cause memory leak easily)
            self.train_data = copy.deepcopy(train_data)
            self.valid_data = copy.deepcopy(valid_data)

        self.dicts = dicts
        self.opt = opt
        self.cuda = (len(opt.gpus) >= 1 and opt.gpus[0] >= 0)

        assert self.cuda, "[ERROR] Training is only available on GPUs."

        self.start_time = 0

        # setting up models and others
        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages,
                                           opt.label_smoothing,
                                           opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        torch.manual_seed(self.opt.seed)

        # note: we must start creating models after ccreating the processes
        # for some reason passing a pre-created model to a process creates a "pickle" error
        if not opt.fusion:

            if self.is_main():
                print("[INFO] Building models .... ", flush=True)
            model = build_model(opt, dicts)
            """ Building the loss function """
            if opt.ctc_loss > 0.0:
                from onmt.speech.ctc_loss import CTC
                self.ctc_loss_function = CTC(0.0, reduce=True)

            if opt.nce:
                from onmt.modules.nce.nce_loss import NCELoss
                loss_function = NCELoss(opt.model_size,
                                        dicts['tgt'].size(),
                                        noise_ratio=opt.nce_noise,
                                        logz=9,
                                        label_smoothing=opt.label_smoothing)
            else:
                loss_function = NMTLossFunc(
                    opt.model_size,
                    dicts['tgt'].size(),
                    label_smoothing=opt.label_smoothing,
                    mirror=opt.mirror_loss,
                    fast_xentropy=opt.fast_xentropy)

            # This function replaces modules with the more optimized counterparts so that it can run faster
            # Currently exp with LayerNorm
            if not opt.memory_profiling:
                # distributed is required to convert BatchNorm to SyncBatchNorm for DDP
                optimize_model(model, distributed=(self.world_size > 1))

        init_model_parameters(model, opt)
        self.model = model
        self.loss_function = loss_function
        # self.grad_scaler = torch.cuda.amp.GradScaler()

        if self.cuda:
            torch.cuda.set_device(self.device)

            self.loss_function = self.loss_function.cuda(device=self.device)
            self.model = self.model.cuda(device=self.device)
            if opt.ctc_loss > 0.0:
                self.ctc_loss_function = self.ctc_loss_function.cuda(
                    device=self.device)

        if opt.load_from:
            checkpoint = torch.load(opt.load_from,
                                    map_location=lambda storage, loc: storage)

        if setup_optimizer:

            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if self.is_main():
                print("[INFO] Optimizer: ", self.optim.optimizer)

            if opt.load_from:
                if 'optim' in checkpoint and checkpoint[
                        'optim'] is not None and not opt.reset_optim:
                    self.optim.load_state_dict(checkpoint['optim'])

        if not self.opt.fp16:
            opt_level = "O0"
            keep_batchnorm_fp32 = False
        elif self.opt.fp16_mixed:
            opt_level = "O1"
            keep_batchnorm_fp32 = None
        else:
            opt_level = "O2"
            keep_batchnorm_fp32 = False

        self.opt_level = opt_level

        if self.cuda:
            self.model, self.optim.optimizer = amp.initialize(
                self.model,
                self.optim.optimizer,
                opt_level=opt_level,
                keep_batchnorm_fp32=keep_batchnorm_fp32,
                loss_scale="dynamic",
                verbosity=1 if self.opt.verbose else 1)

        if opt.load_from:
            self.model.load_state_dict(checkpoint['model'])
            if prec_opt is not None and hasattr(prec_opt, "fp16_mixed"):
                # Only load amp information if the mode is the same
                # Maybe its better to change between optimization mode?
                if opt.fp16_mixed == prec_opt.fp16_mixed and opt.fp16 == prec_opt.fp16:
                    if 'amp' in checkpoint:
                        try:
                            amp.load_state_dict(checkpoint['amp'])
                        except Exception:
                            # loading the amp state can fail
                            pass

        if self.world_size > 1:
            # find_unused_parameters may be required for dropped layer (parameters that are not connected to
            # any particular graph)
            # find_unused_parameters = True

            self.model = DDP(self.model,
                             delay_allreduce=True,
                             gradient_average=False)

        print("[INFO] Process %d ready." % self.rank, flush=True)