예제 #1
0
    def __init__(self, model, loss_function, train_data, valid_data, dicts, opt, setup_optimizer=True):
        super().__init__(model, loss_function, train_data, valid_data, dicts, opt)

        self.n_gpus = len(self.opt.gpus)

        if opt.ctc_loss != 0:
            from onmt.speech.ctc_loss import CTC
            self.ctc_loss_function = CTC(dicts['tgt'].size(), opt.model_size, 0.0, reduce=True)

        init_model_parameters(model, opt)
        
        if self.cuda:
            torch.cuda.set_device(self.opt.gpus[0])
            if self.opt.seed >= 0:
                torch.manual_seed(self.opt.seed)
            self.loss_function = self.loss_function.cuda()
            self.model = self.model.cuda()
            if opt.ctc_loss > 0.0:
                self.ctc_loss_function = self.ctc_loss_function.cuda()

        if setup_optimizer:
            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if not self.opt.fp16:
                opt_level = "O0"
                keep_batchnorm_fp32 = False
            elif self.opt.fp16_mixed:
                opt_level = "O1"
                keep_batchnorm_fp32 = None
            else:
                opt_level = "O2"
                keep_batchnorm_fp32 = False

            if self.cuda:
                self.model, self.optim.optimizer = amp.initialize(self.model,
                                                                  self.optim.optimizer,
                                                                  opt_level=opt_level,
                                                                  keep_batchnorm_fp32=keep_batchnorm_fp32,
                                                                  loss_scale="dynamic",
                                                                  verbosity=1 if self.opt.verbose else 0)

            print(self.optim.optimizer)

        # An ugly hack to switch between align right and align left
        if hasattr(self.model, 'relative'):
            if self.model.relative:
                self.train_data.src_align_right = True
                self.train_data.tgt_align_right = False
                self.valid_data.src_align_right = True
                self.valid_data.tgt_align_right = False
예제 #2
0
    def __init__(self, model, loss_function, train_data, valid_data, dicts,
                 opt):

        # self.model = model
        init_model_parameters(model, opt)
        self.model = model

        if opt.load_encoder_from:
            self.load_encoder_weight(opt.load_encoder_from)

        if opt.load_decoder_from:
            self.load_decoder_weight(opt.load_decoder_from)

        self.model = None

        self.train_data = train_data
        self.valid_data = valid_data

        self.dicts = dicts
        self.opt = opt
        self.cuda = (len(opt.gpus) >= 1 and opt.gpus[0] >= 0)

        self.loss_function = loss_function
        self.start_time = 0

        # setting up models and others

        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages,
                                           opt.label_smoothing,
                                           opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        self.model_wrapper = MultiprocessingRunner(opt,
                                                   model,
                                                   loss_function,
                                                   device_ids=opt.gpus)
예제 #3
0
    def run(self, checkpoint=None):

        opt = self.opt
        model = self.model
        optim = self.optim

        if checkpoint is not None:
            self.model.load_state_dict(checkpoint['model'])
            prec_opt = checkpoint['opt'] if 'opt' in checkpoint else None

            if not opt.reset_optim:
                print("* Loading optimizer states ... ")
                self.optim.load_state_dict(checkpoint['optim'])
                if prec_opt is not None and hasattr(prec_opt, "fp16_mixed"):
                    # Only load amp information if the mode is the same
                    # Maybe its better to change between optimization mode?
                    if opt.fp16_mixed == prec_opt.fp16_mixed and opt.fp16 == prec_opt.fp16:
                        if 'amp' in checkpoint:
                            amp.load_state_dict(checkpoint['amp'])

                # Only load the progress when we use the same optimizer
                if 'itr' in checkpoint:
                    itr_progress = checkpoint['itr']
                else:
                    itr_progress = None

                resume = True
                start_epoch = checkpoint[
                    'epoch'] if 'epoch' in checkpoint else 1
                if start_epoch is None:
                    start_epoch = 1
            else:
                itr_progress = None
                resume = False
                start_epoch = 1

            del checkpoint['model']
            del checkpoint['optim']
            del checkpoint
        else:
            itr_progress = None
            print('Initializing model parameters')
            init_model_parameters(model, opt)
            resume = False
            start_epoch = 1

        if opt.load_encoder_from:
            self.load_encoder_weight(opt.load_encoder_from)

        if opt.load_decoder_from:
            self.load_decoder_weight(opt.load_decoder_from)

        # if we are on a GPU: warm up the memory allocator
        if self.cuda:
            self.warm_up()

            valid_loss = self.eval(self.valid_data)
            valid_ppl = math.exp(min(valid_loss, 100))
            print('Validation perplexity: %g' % valid_ppl)

        self.start_time = time.time()

        if opt.starting_step > 0:
            self.optim.override_starting_step(opt.starting_step)

        if opt.override_ctc_loss >= 0:
            opt.ctc_loss = opt.override_ctc_loss

        for epoch in range(start_epoch, start_epoch + opt.epochs):
            print('')

            #  (1) train for one epoch on the training set
            train_loss = self.train_epoch(epoch,
                                          resume=resume,
                                          itr_progress=itr_progress)
            train_ppl = math.exp(min(train_loss, 100))
            print('Train perplexity: %g' % train_ppl)

            #  (2) evaluate on the validation set
            valid_loss = self.eval(self.valid_data)
            valid_ppl = math.exp(min(valid_loss, 100))
            print('Validation perplexity: %g' % valid_ppl)

            self.save(epoch, valid_ppl)
            itr_progress = None
            resume = False
예제 #4
0
    def run(self, checkpoint=None):

        opt = self.opt
        model_ae = self.model_ae
        optim_ae = self.optim_ae

        if checkpoint is not None:
            self.model_ae.load_state_dict(checkpoint['model'])
            prec_opt = checkpoint['opt'] if 'opt' in checkpoint else None

            if not opt.reset_optim:
                print("* Loading optimizer states ... ")
                self.optim_ae.load_state_dict(checkpoint['optim'])
                if prec_opt is not None and hasattr(prec_opt, "fp16_mixed"):
                    # Only load amp information if the mode is the same
                    # Maybe its better to change between optimization mode?
                    if opt.fp16_mixed == prec_opt.fp16_mixed and opt.fp16 == prec_opt.fp16:
                        if 'amp' in checkpoint:
                            amp.load_state_dict(checkpoint['amp'])

                # Only load the progress when we use the same optimizer
                if 'itr' in checkpoint:
                    itr_progress = checkpoint['itr']
                else:
                    itr_progress = None

                resume = True
                start_epoch = checkpoint[
                    'epoch'] if 'epoch' in checkpoint else 1
                if start_epoch is None:
                    start_epoch = 1
            else:
                itr_progress = None
                resume = False
                start_epoch = 1

            del checkpoint['model']
            del checkpoint['optim']
            del checkpoint
        else:
            itr_progress = None
            print('Initializing model parameters')
            init_model_parameters(model_ae, opt)
            resume = False
            start_epoch = 1

        # if we are on a GPU: warm up the memory allocator
        # if self.cuda:
        # self.warm_up()
        #
        valid_loss_ae, valid_loss_lat_dis = self.eval(self.valid_data)
        #
        print('Validation loss ae: %g' % valid_loss_ae)
        print('Validation loss latent discriminator: %g' % valid_loss_lat_dis)
        #
        self.start_time = time.time()

        for epoch in range(start_epoch, start_epoch + opt.epochs):
            print('')

            #  (1) train for one epoch on the training set
            train_loss_ae, train_loss_lat_dis, train_loss_adv = self.train_epoch(
                epoch, resume=resume, itr_progress=itr_progress)

            print('Train loss ae: %g' % train_loss_ae)
            print('Train loss latent discriminator: %g' % train_loss_lat_dis)
            print('Train loss adversarial : %g' % train_loss_adv)

            # #  (2) evaluate on the validation set
            valid_loss_ae, valid_loss_lat_dis = self.eval(self.valid_data)
            print('Validation loss ae: %g' % valid_loss_ae)
            print('Validation loss latent discriminator: %g' %
                  valid_loss_lat_dis)
            #
            self.save(epoch, valid_loss_ae)
            itr_progress = None
            resume = False
예제 #5
0
    def __init__(self, device, train_data, valid_data, dicts, opt, setup_optimizer=True):
        """
        :param model:
        :param device: int (GPU id)
        :param loss_function:
        :param train_data:
        :param valid_data:
        :param dicts:
        :param opt:
        """

        self.device = device
        opt.node_rank = 0
        opt.nodes = 1
        self.world_size = len(opt.gpus)

        # in the case of single node distributed, it should equal self.device
        self.rank = self.device

        # make a group to later use with self.all_reduce
        self.group = dist.group.WORLD

        self.print("[INFO] Training Options:", opt)
        if self.world_size > 1:
            dist.init_process_group(backend='nccl', init_method='env://', world_size=self.world_size, rank=self.rank)

        self.model = None

        if self.rank == 0:
            self.train_data = train_data
            self.valid_data = valid_data
        else:
            # Do we really need to deepcopy the data instances (which could cause memory leak easily)
            self.train_data = copy.deepcopy(train_data)
            self.valid_data = copy.deepcopy(valid_data)

        self.dicts = dicts
        self.opt = opt
        self.cuda = (len(opt.gpus) >= 1 and opt.gpus[0] >= 0)

        assert self.cuda, "[ERROR] Training is only available on GPUs."

        self.start_time = 0

        # setting up models and others
        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages, opt.label_smoothing, opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        torch.manual_seed(self.opt.seed)

        # note: we must start creating models after ccreating the processes
        # for some reason passing a pre-created model to a process creates a "pickle" error
        if not opt.fusion:

            if self.is_main():
                print("[INFO] Building models .... ", flush=True)
            model = build_model(opt, dicts)

            """ Building the loss function """
            if opt.ctc_loss > 0.0:
                from onmt.speech.ctc_loss import CTC
                self.ctc_loss_function = CTC(0.0, reduce=True)

            if opt.nce:
                from onmt.modules.nce.nce_loss import NCELoss
                loss_function = NCELoss(opt.model_size, dicts['tgt'].size(), noise_ratio=opt.nce_noise,
                                        logz=9, label_smoothing=opt.label_smoothing)
            else:
                loss_function = NMTLossFunc(opt.model_size, dicts['tgt'].size(),
                                            label_smoothing=opt.label_smoothing,
                                            mirror=opt.mirror_loss,
                                            fast_xentropy=opt.fast_xentropy)

            # This function replaces modules with the more optimized counterparts so that it can run faster
            # Currently exp with LayerNorm
            if not opt.memory_profiling:
                # distributed is required to convert BatchNorm to SyncBatchNorm for DDP
                optimize_model(model, distributed=(self.world_size > 1))

        init_model_parameters(model, opt)
        self.model = model
        self.loss_function = loss_function
        self.grad_scaler = torch.cuda.amp.GradScaler()

        if opt.load_from:
            checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage)
            self.model.load_state_dict(checkpoint['model'])
            if 'scaler' in checkpoint and checkpoint['scaler'] is not None:
                self.grad_scaler.load_state_dict(checkpoint['scaler'])

        if self.cuda:
            torch.cuda.set_device(self.device)

            self.loss_function = self.loss_function.cuda(device=self.device)
            self.model = self.model.cuda(device=self.device)
            if opt.ctc_loss > 0.0:
                self.ctc_loss_function = self.ctc_loss_function.cuda(device=self.device)

            # Ensure that the distributed copies have the same initial parameters
            # Manual seed may not work the same for different GPU models.
            # if self.world_size > 1:
            #     params = [p for p in self.model.parameters()]
            #
            #     with torch.no_grad():
            #         if not self.is_main():
            #             # zero everything except for the main model
            #             for p in params:
            #                 p.zero_()
            #         else:
            #             for p in params:
            #                 p.add_(0)
            #
            # # run all_reduce to ensure that all models have exactly the same parameters
            # if self.world_size > 1:
            #     params = [p for p in self.model.parameters()]
            #     all_reduce_and_rescale_tensors(params, 1)

        if setup_optimizer:

            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if self.is_main():
                print("[INFO] Optimizer: ", self.optim.optimizer)

            if opt.load_from:
                if 'optim' in checkpoint and checkpoint['optim'] is not None and not opt.reset_optim:
                    self.optim.load_state_dict(checkpoint['optim'])

        if self.world_size > 1:
            # find_unused_parameters may be required for dropped layer (parameters that are not connected to
            # any particular graph)
            find_unused_parameters = True

            self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank],
                                                                   output_device=self.rank,
                                                                   find_unused_parameters=find_unused_parameters)

        print("[INFO] Process %d ready." % self.rank, flush=True)
예제 #6
0
    def run(self, checkpoint=None):

        opt = self.opt
        model = self.model
        optim = self.optim

        # Try to load the save_file
        # checkpoint = None
        # if save_file:
        #     checkpoint = torch.load(save_file, map_location=lambda storage, loc: storage)

        if checkpoint is not None:
            self.model.load_state_dict(checkpoint['model'])
            prev_opt = checkpoint['opt'] if 'opt' in checkpoint else None

            if not opt.reset_optim:
                self.optim.load_state_dict(checkpoint['optim'])
                if prev_opt is not None and hasattr(prev_opt, "fp16_mixed"):
                    # Only load amp information if the mode is the same
                    # Maybe its better to change between optimization mode?
                    if opt.fp16_mixed == prev_opt.fp16_mixed and opt.fp16 == prev_opt.fp16:
                        if 'amp' in checkpoint:
                            amp.load_state_dict(checkpoint['amp'])

                if 'batch_order' in checkpoint:
                    batch_order = checkpoint['batch_order']
                    iteration = checkpoint['iteration'] + 1
                else:
                    batch_order = None
                    iteration = 0
                opt.start_epoch = int(
                    math.floor(float(checkpoint['epoch'] + 1)))

                resume = True
                if len(self.additional_data) > 0:
                    if 'additional_batch_order' in checkpoint:
                        self.additional_batch_order = checkpoint[
                            'additional_batch_order']
                        self.additional_data_iteration = checkpoint[
                            'additional_data_iteration']
                    else:
                        self.init_additional_data()
            else:
                batch_order = None
                iteration = 0
                resume = False
                self.init_additional_data()

            del checkpoint['model']
            del checkpoint['optim']
            del checkpoint
        else:
            batch_order = None
            iteration = 0
            print('Initializing model parameters')
            init_model_parameters(model, opt)
            resume = False
            self.init_additional_data()

        if opt.load_encoder_from:
            self.load_encoder_weight(opt.load_encoder_from)

        if opt.load_decoder_from:
            self.load_decoder_weight(opt.load_decoder_from)

        valid_loss = self.eval(self.valid_data)
        valid_ppl = math.exp(min(valid_loss, 100))
        print('Validation perplexity: %g' % valid_ppl)

        self.start_time = time.time()

        for epoch in range(opt.start_epoch, opt.start_epoch + opt.epochs):
            print('')

            #  (1) train for one epoch on the training set
            train_loss = self.train_epoch(epoch,
                                          resume=resume,
                                          batch_order=batch_order,
                                          iteration=iteration)
            train_ppl = math.exp(min(train_loss, 100))
            print('Train perplexity: %g' % train_ppl)

            #  (2) evaluate on the validation set
            valid_loss = self.eval(self.valid_data)
            valid_ppl = math.exp(min(valid_loss, 100))
            print('Validation perplexity: %g' % valid_ppl)

            self.save(epoch, valid_ppl)
            batch_order = None
            iteration = None
            resume = False
예제 #7
0
    def __init__(self,
                 device,
                 train_data,
                 valid_data,
                 dicts,
                 opt,
                 setup_optimizer=True):
        """
        :param model:
        :param device: int (GPU id)
        :param loss_function:
        :param train_data:
        :param valid_data:
        :param dicts:
        :param opt:
        """

        # self.model = model

        # self.model = model
        # self.loss_function = loss_function
        self.device = device
        opt.node_rank = 0
        opt.nodes = 1
        self.world_size = len(opt.gpus)

        # in the case of single node distributed, it should equal self.device
        self.rank = self.device

        # make a group to later use with dist.all_reduce
        self.group = dist.group.WORLD

        self.print("[INFO] Training Options:", opt)
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=self.world_size,
                                rank=self.rank)

        self.model = None

        if self.rank == 0:
            self.train_data = train_data
            self.valid_data = valid_data
        else:
            self.train_data = copy.deepcopy(train_data)
            self.valid_data = copy.deepcopy(valid_data)

        self.dicts = dicts
        self.opt = opt
        self.cuda = (len(opt.gpus) >= 1 and opt.gpus[0] >= 0)

        assert self.cuda, "[ERROR] Training is only available on GPUs."

        self.start_time = 0

        # setting up models and others
        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages,
                                           opt.label_smoothing,
                                           opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        torch.manual_seed(self.opt.seed)

        # note: we must start creating models after ccreating the processes
        # for some reason passing a pre-created model to a process creates a "pickle" error
        if not opt.fusion:

            if self.is_main():
                print("BUILDING MODEL .... ", flush=True)
            model = build_model(opt, dicts)
            """ Building the loss function """
            if opt.ctc_loss != 0:
                loss_function = NMTAndCTCLossFunc(
                    dicts['tgt'].size(),
                    label_smoothing=opt.label_smoothing,
                    ctc_weight=opt.ctc_loss)
            elif opt.nce:
                from onmt.modules.nce.nce_loss import NCELoss
                loss_function = NCELoss(opt.model_size,
                                        dicts['tgt'].size(),
                                        noise_ratio=opt.nce_noise,
                                        logz=9,
                                        label_smoothing=opt.label_smoothing)
            else:
                loss_function = NMTLossFunc(
                    opt.model_size,
                    dicts['tgt'].size(),
                    label_smoothing=opt.label_smoothing,
                    mirror=opt.mirror_loss,
                    fast_xentropy=opt.fast_xentropy)

            # This function replaces modules with the more optimized counterparts so that it can run faster
            # Currently exp with LayerNorm
            if not opt.memory_profiling:
                # distributed is required to convert BatchNorm to SyncBatchNorm for DDP
                optimize_model(model, distributed=(self.world_size > 1))
                # optimize_model(model)

        init_model_parameters(model, opt)
        self.model = model
        self.loss_function = loss_function

        if self.cuda:
            torch.cuda.set_device(self.device)

            self.loss_function = self.loss_function.cuda(device=self.device)
            self.model = self.model.cuda(device=self.device)

            # Ensure that the distributed copies have the same initial parameters
            # Manual seed may not work the same for different GPU models.
            if self.world_size > 1:
                params = [p for p in self.model.parameters()]

                with torch.no_grad():
                    if not self.is_main():
                        for p in params:
                            p.zero_()
                    else:
                        for p in params:
                            p.add_(0)

            if self.world_size > 1:
                params = [p for p in self.model.parameters()]
                all_reduce_and_rescale_tensors(params, 1)

        if setup_optimizer:

            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if self.is_main():
                print("[INFO] Optimizer: ", self.optim.optimizer)

            if not self.opt.fp16:
                opt_level = "O0"
                keep_batchnorm_fp32 = False
            elif self.opt.fp16_mixed:
                opt_level = "O1"
                keep_batchnorm_fp32 = None
            else:
                opt_level = "O2"
                keep_batchnorm_fp32 = False

            if self.cuda:
                self.model, self.optim.optimizer = amp.initialize(
                    self.model,
                    self.optim.optimizer,
                    opt_level=opt_level,
                    keep_batchnorm_fp32=keep_batchnorm_fp32,
                    loss_scale="dynamic",
                    verbosity=1 if self.opt.verbose else 0)

            # wrap the model into DDP after initializing by amp
            if self.world_size > 1:
                """
                delay_allreduce is required to avoid allreduce error during backward pass
                """
                self.model = DDP(self.model,
                                 delay_allreduce=True,
                                 gradient_average=False)

                # torch DDP is more likely to work with the official amp autocast
                # self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank],
                #                                                        output_device=self.rank,
                #                                                        find_unused_parameters=True)

        print("[INFO] Process %d ready." % self.rank, flush=True)
예제 #8
0
def main():

    if not opt.multi_dataset:
        if opt.data_format in ['bin', 'raw']:
            start = time.time()

            if opt.data.endswith(".train.pt"):
                print("Loading data from '%s'" % opt.data)
                dataset = torch.load(opt.data)
            else:
                print("Loading data from %s" % opt.data + ".train.pt")
                dataset = torch.load(opt.data + ".train.pt")

            elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
            print("Done after %s" % elapse)

            dicts = dataset['dicts']

            # For backward compatibility
            train_dict = defaultdict(lambda: None, dataset['train'])
            valid_dict = defaultdict(lambda: None, dataset['valid'])

            if train_dict['src_lang'] is not None:
                assert 'langs' in dicts
                train_src_langs = train_dict['src_lang']
                train_tgt_langs = train_dict['tgt_lang']
            else:
                # allocate new languages
                dicts['langs'] = {'src': 0, 'tgt': 1}
                train_src_langs = list()
                train_tgt_langs = list()
                # Allocation one for the bilingual case
                train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            train_data = onmt.Dataset(numpy_to_torch(train_dict['src']),
                                      numpy_to_torch(train_dict['tgt']),
                                      train_dict['src_sizes'],
                                      train_dict['tgt_sizes'],
                                      train_src_langs,
                                      train_tgt_langs,
                                      batch_size_words=opt.batch_size_words,
                                      data_type=dataset.get("type", "text"),
                                      sorting=True,
                                      batch_size_sents=opt.batch_size_sents,
                                      multiplier=opt.batch_size_multiplier,
                                      augment=opt.augment_speech,
                                      upsampling=opt.upsampling,
                                      num_split=len(opt.gpus))

            if valid_dict['src_lang'] is not None:
                assert 'langs' in dicts
                valid_src_langs = valid_dict['src_lang']
                valid_tgt_langs = valid_dict['tgt_lang']
            else:
                # allocate new languages
                valid_src_langs = list()
                valid_tgt_langs = list()

                # Allocation one for the bilingual case
                valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            valid_data = onmt.Dataset(numpy_to_torch(valid_dict['src']),
                                      numpy_to_torch(valid_dict['tgt']),
                                      valid_dict['src_sizes'],
                                      valid_dict['tgt_sizes'],
                                      valid_src_langs,
                                      valid_tgt_langs,
                                      batch_size_words=opt.batch_size_words,
                                      data_type=dataset.get("type", "text"),
                                      sorting=True,
                                      batch_size_sents=opt.batch_size_sents,
                                      upsampling=opt.upsampling)

            print(' * number of training sentences. %d' %
                  len(dataset['train']['src']))
            print(' * maximum batch size (words per batch). %d' %
                  opt.batch_size_words)

        elif opt.data_format in ['scp', 'scpmem', 'mmem']:
            print("Loading memory mapped data files ....")
            start = time.time()
            from onmt.data.mmap_indexed_dataset import MMapIndexedDataset
            from onmt.data.scp_dataset import SCPIndexDataset

            dicts = torch.load(opt.data + ".dict.pt")
            if opt.data_format in ['scp', 'scpmem']:
                audio_data = torch.load(opt.data + ".scp_path.pt")

            # allocate languages if not
            if 'langs' not in dicts:
                dicts['langs'] = {'src': 0, 'tgt': 1}
            else:
                print(dicts['langs'])

            train_path = opt.data + '.train'
            if opt.data_format in ['scp', 'scpmem']:
                train_src = SCPIndexDataset(audio_data['train'],
                                            concat=opt.concat)
            else:
                train_src = MMapIndexedDataset(train_path + '.src')

            if os.path.exists(train_path + '.tgt.bin'):
                train_tgt = MMapIndexedDataset(train_path + '.tgt')
            else:
                train_tgt = None

            # check the lang files if they exist (in the case of multi-lingual models)
            if os.path.exists(train_path + '.src_lang.bin'):
                assert 'langs' in dicts
                train_src_langs = MMapIndexedDataset(train_path + '.src_lang')
                if os.path.exists(train_path + '.tgt_lang.bin'):
                    train_tgt_langs = MMapIndexedDataset(train_path +
                                                         '.tgt_lang')
                else:
                    train_tgt_langs = None
            else:
                train_src_langs = list()
                train_tgt_langs = list()
                # Allocate a Tensor(1) for the bilingual case
                train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            # check the length files if they exist
            if os.path.exists(train_path + '.src_sizes.npy'):
                train_src_sizes = np.load(train_path + '.src_sizes.npy')

            else:
                train_src_sizes = None

            if os.path.exists(train_path + '.tgt_sizes.npy'):
                train_tgt_sizes = np.load(train_path + '.tgt_sizes.npy')

            else:
                train_tgt_sizes = None

            if opt.encoder_type == 'audio':
                data_type = 'audio'
            else:
                data_type = 'text'

            train_data = onmt.Dataset(
                train_src,
                train_tgt,
                train_src_sizes,
                train_tgt_sizes,
                train_src_langs,
                train_tgt_langs,
                batch_size_words=opt.batch_size_words,
                data_type=data_type,
                sorting=True,
                batch_size_sents=opt.batch_size_sents,
                multiplier=opt.batch_size_multiplier,
                max_length_multiplier=opt.n_frames_per_step,
                augment=opt.augment_speech,
                src_align_right=opt.src_align_right,
                upsampling=opt.upsampling,
                cleaning=True,
                verbose=True,
                num_split=len(opt.gpus))

            valid_path = opt.data + '.valid'
            if opt.data_format in ['scp', 'scpmem']:
                valid_src = SCPIndexDataset(audio_data['valid'],
                                            concat=opt.concat)
            else:
                valid_src = MMapIndexedDataset(valid_path + '.src')

            if os.path.exists(valid_path + '.tgt.bin'):
                valid_tgt = MMapIndexedDataset(valid_path + '.tgt')
            else:
                valid_tgt = None

            if os.path.exists(train_path + '.src_lang.bin'):
                assert 'langs' in dicts
                valid_src_langs = MMapIndexedDataset(valid_path + '.src_lang')
                if os.path.exists(train_path + '.tgt_lang.bin'):
                    valid_tgt_langs = MMapIndexedDataset(valid_path +
                                                         '.tgt_lang')
                else:
                    valid_tgt_langs = None
            else:
                valid_src_langs = list()
                valid_tgt_langs = list()
                # Allocate a Tensor(1) for the bilingual case
                valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            # check the length files if they exist
            if os.path.exists(valid_path + '.src_sizes.npy'):
                valid_src_sizes = np.load(valid_path + '.src_sizes.npy')
            else:
                valid_src_sizes = None

            if os.path.exists(valid_path + '.tgt_sizes.npy'):
                valid_tgt_sizes = np.load(valid_path + '.tgt_sizes.npy')

            else:
                valid_tgt_sizes = None

            valid_data = onmt.Dataset(
                valid_src,
                valid_tgt,
                valid_src_sizes,
                valid_tgt_sizes,
                valid_src_langs,
                valid_tgt_langs,
                batch_size_words=opt.batch_size_words,
                data_type=data_type,
                sorting=True,
                batch_size_sents=opt.batch_size_sents,
                max_length_multiplier=opt.n_frames_per_step,
                src_align_right=opt.src_align_right,
                cleaning=True,
                verbose=True,
                debug=True)

            elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
            print("Done after %s" % elapse)

        else:
            raise NotImplementedError

        print(' * number of sentences in training data: %d' %
              train_data.size())
        print(' * number of sentences in validation data: %d' %
              valid_data.size())

    else:
        print("[INFO] Reading multiple dataset ...")
        # raise NotImplementedError

        dicts = torch.load(opt.data + ".dict.pt")

        root_dir = os.path.dirname(opt.data)

        print("Loading training data ...")

        train_dirs, valid_dirs = dict(), dict()

        # scan the data directory to find the training data
        for dir_ in os.listdir(root_dir):
            if os.path.isdir(os.path.join(root_dir, dir_)):
                if str(dir_).startswith("train"):
                    idx = int(dir_.split(".")[1])
                    train_dirs[idx] = dir_
                if dir_.startswith("valid"):
                    idx = int(dir_.split(".")[1])
                    valid_dirs[idx] = dir_

        train_sets, valid_sets = list(), list()

        for (idx_, dir_) in sorted(train_dirs.items()):

            data_dir = os.path.join(root_dir, dir_)
            print("[INFO] Loading training data %i from %s" % (idx_, dir_))

            if opt.data_format in ['bin', 'raw']:
                raise NotImplementedError

            elif opt.data_format in ['scp', 'scpmem', 'mmem']:
                from onmt.data.mmap_indexed_dataset import MMapIndexedDataset
                from onmt.data.scp_dataset import SCPIndexDataset

                if opt.data_format in ['scp', 'scpmem']:
                    audio_data = torch.load(
                        os.path.join(data_dir, "data.scp_path.pt"))
                    src_data = SCPIndexDataset(audio_data, concat=opt.concat)
                else:
                    src_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.src"))

                if os.path.exists(data_dir + '.tgt.bin'):
                    tgt_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.tgt"))
                else:
                    tgt_data = None

                src_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.src_lang'))

                if os.path.exists(data_dir + '.data.tgt_lang'):
                    tgt_lang_data = MMapIndexedDataset(
                        os.path.join(data_dir, 'data.tgt_lang'))
                else:
                    tgt_lang_data = None

                if os.path.exists(os.path.join(data_dir,
                                               'data.src_sizes.npy')):
                    src_sizes = np.load(
                        os.path.join(data_dir, 'data.src_sizes.npy'))
                    if os.path.exists(data_dir + 'data.tgt_sizes.npy'):
                        tgt_sizes = np.load(
                            os.path.join(data_dir, 'data.tgt_sizes.npy'))
                    else:
                        tgt_sizes = None
                else:
                    src_sizes, sizes = None, None

                if opt.encoder_type == 'audio':
                    data_type = 'audio'
                else:
                    data_type = 'text'

                if not opt.streaming:

                    train_data = onmt.Dataset(
                        src_data,
                        tgt_data,
                        src_sizes,
                        tgt_sizes,
                        src_lang_data,
                        tgt_lang_data,
                        batch_size_words=opt.batch_size_words,
                        data_type=data_type,
                        sorting=True,
                        batch_size_sents=opt.batch_size_sents,
                        multiplier=opt.batch_size_multiplier,
                        max_length_multiplier=opt.n_frames_per_step,
                        src_align_right=opt.src_align_right,
                        augment=opt.augment_speech,
                        upsampling=opt.upsampling,
                        cleaning=True,
                        verbose=True,
                        num_split=len(opt.gpus))

                    train_sets.append(train_data)

                else:
                    print("Multi-dataset not implemented for Streaming tasks.")
                    raise NotImplementedError

        for (idx_, dir_) in sorted(valid_dirs.items()):

            data_dir = os.path.join(root_dir, dir_)

            print("[INFO] Loading validation data %i from %s" % (idx_, dir_))

            if opt.data_format in ['bin', 'raw']:
                raise NotImplementedError

            elif opt.data_format in ['scp', 'scpmem', 'mmem']:

                if opt.data_format in ['scp', 'scpmem']:
                    audio_data = torch.load(
                        os.path.join(data_dir, "data.scp_path.pt"))
                    src_data = SCPIndexDataset(audio_data, concat=opt.concat)
                else:
                    src_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.src"))

                if os.path.exists(data_dir + '.tgt.bin'):
                    tgt_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.tgt"))
                else:
                    tgt_data = None

                src_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.src_lang'))
                if os.path.exists(data_dir + '.data.tgt_lang'):
                    tgt_lang_data = MMapIndexedDataset(
                        os.path.join(data_dir, 'data.tgt_lang'))
                else:
                    tgt_lang_data = None

                if os.path.exists(os.path.join(data_dir,
                                               'data.src_sizes.npy')):
                    src_sizes = np.load(
                        os.path.join(data_dir, 'data.src_sizes.npy'))
                    if os.path.exists(data_dir + 'data.tgt_sizes.npy'):
                        tgt_sizes = np.load(
                            os.path.join(data_dir, 'data.tgt_sizes.npy'))
                    else:
                        tgt_sizes = None
                else:
                    src_sizes, sizes = None, None

                if opt.encoder_type == 'audio':
                    data_type = 'audio'
                else:
                    data_type = 'text'

                if not opt.streaming:
                    valid_data = onmt.Dataset(
                        src_data,
                        tgt_data,
                        src_sizes,
                        tgt_sizes,
                        src_lang_data,
                        tgt_lang_data,
                        batch_size_words=opt.batch_size_words,
                        data_type=data_type,
                        sorting=True,
                        multiplier=opt.batch_size_multiplier,
                        max_length_multiplier=opt.n_frames_per_step,
                        batch_size_sents=opt.batch_size_sents,
                        src_align_right=opt.src_align_right,
                        cleaning=True,
                        verbose=True,
                        debug=True)

                    valid_sets.append(valid_data)

                else:
                    raise NotImplementedError

        train_data = train_sets
        valid_data = valid_sets

    if opt.load_from:
        checkpoint = torch.load(opt.load_from,
                                map_location=lambda storage, loc: storage)
        print("* Loading dictionaries from the checkpoint")
        dicts = checkpoint['dicts']
    else:
        if "tgt" in dicts:
            dicts['tgt'].patch(opt.patch_vocab_multiplier)
        checkpoint = None

    if "src" in dicts:
        print(' * vocabulary size. source = %d; target = %d' %
              (dicts['src'].size(), dicts['tgt'].size()))
    elif "tgt" in dicts:
        print(' * vocabulary size. target = %d' % (dicts['tgt'].size()))

    print('* Building model...')

    if not opt.fusion:
        if opt.bayes_by_backprop:
            model = build_bayesian_model(opt, dicts)
        else:
            model, lat_dis = build_model(opt, dicts)
        """ Building the loss function """
        if opt.ctc_loss != 0:
            loss_function = NMTAndCTCLossFunc(
                dicts['tgt'].size(),
                label_smoothing=opt.label_smoothing,
                ctc_weight=opt.ctc_loss)
        elif opt.model == "speech_ae":

            loss_function = Tacotron2Loss()
        elif opt.model == "speech_FN":
            loss_function_ae = Tacotron2Loss()
            loss_function_lat_dis = AttributeLoss()
            loss_function = (loss_function_ae, loss_function_lat_dis)
        elif opt.nce:
            from onmt.modules.nce.nce_loss import NCELoss
            loss_function = NCELoss(opt.model_size,
                                    dicts['tgt'].size(),
                                    noise_ratio=opt.nce_noise,
                                    logz=9,
                                    label_smoothing=opt.label_smoothing)
        else:
            loss_function = NMTLossFunc(opt.model_size,
                                        dicts['tgt'].size(),
                                        label_smoothing=opt.label_smoothing,
                                        mirror=opt.mirror_loss,
                                        fast_xentropy=opt.fast_xentropy)

        # This function replaces modules with the more optimized counterparts so that it can run faster
        # Currently exp with LayerNorm
        if not opt.memory_profiling:
            optimize_model(model)

    else:
        from onmt.model_factory import build_fusion
        from onmt.modules.loss import FusionLoss

        model = build_fusion(opt, dicts)

        loss_function = FusionLoss(dicts['tgt'].size(),
                                   label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    # We need to initialize the model parameters before sending out to distributed
    print('Initializing model parameters')
    init_model_parameters(model, opt)

    if not opt.debugging and len(opt.gpus) == 1:
        if opt.bayes_by_backprop:

            from onmt.train_utils.bayes_by_backprop_trainer import BayesianTrainer
            trainer = BayesianTrainer(model, loss_function, train_data,
                                      valid_data, dicts, opt)
        elif opt.model == "speech_ae":
            raise NotImplementedError
            # trainer = SpeechAETrainer(model, loss_function, train_data, valid_data, dicts, opt)
            print(" TacotronTrainer successfully")
        elif opt.model == "speech_FN":

            trainer = SpeechFNTrainer(model, lat_dis, loss_function,
                                      train_data, valid_data, dicts, opt)

        else:
            raise NotImplementedError
            # trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt)

        trainer.run(checkpoint=checkpoint)
    else:
        raise NotImplementedError
예제 #9
0
    def __init__(self,
                 device,
                 train_data,
                 valid_data,
                 dicts,
                 opt,
                 setup_optimizer=True):
        """
        :param model:
        :param device: int (GPU id)
        :param loss_function:
        :param train_data:
        :param valid_data:
        :param dicts:
        :param opt:
        """

        self.device = device
        opt.node_rank = 0
        opt.nodes = 1
        self.world_size = len(opt.gpus)

        # in the case of single node distributed, it should equal self.device
        self.rank = self.device

        # make a group to later use with self.all_reduce
        self.group = dist.group.WORLD

        self.print("[INFO] Training Options:", opt)
        if self.world_size > 1:
            dist.init_process_group(backend='nccl',
                                    init_method='env://',
                                    world_size=self.world_size,
                                    rank=self.rank)

        self.model = None

        if self.rank == 0:
            self.train_data = train_data
            self.valid_data = valid_data
        else:
            # Do we really need to deepcopy the data instances (which could cause memory leak easily)
            self.train_data = copy.deepcopy(train_data)
            self.valid_data = copy.deepcopy(valid_data)

        self.dicts = dicts
        self.opt = opt
        self.cuda = (len(opt.gpus) >= 1 and opt.gpus[0] >= 0)

        assert self.cuda, "[ERROR] Training is only available on GPUs."

        self.start_time = 0

        # setting up models and others
        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages,
                                           opt.label_smoothing,
                                           opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        torch.manual_seed(self.opt.seed)

        # note: we must start creating models after ccreating the processes
        # for some reason passing a pre-created model to a process creates a "pickle" error
        if not opt.fusion:

            if self.is_main():
                print("[INFO] Building models .... ", flush=True)
            model = build_model(opt, dicts)
            """ Building the loss function """
            if opt.ctc_loss > 0.0:
                from onmt.speech.ctc_loss import CTC
                self.ctc_loss_function = CTC(0.0, reduce=True)

            if opt.nce:
                from onmt.modules.nce.nce_loss import NCELoss
                loss_function = NCELoss(opt.model_size,
                                        dicts['tgt'].size(),
                                        noise_ratio=opt.nce_noise,
                                        logz=9,
                                        label_smoothing=opt.label_smoothing)
            else:
                loss_function = NMTLossFunc(
                    opt.model_size,
                    dicts['tgt'].size(),
                    label_smoothing=opt.label_smoothing,
                    mirror=opt.mirror_loss,
                    fast_xentropy=opt.fast_xentropy)

            # This function replaces modules with the more optimized counterparts so that it can run faster
            # Currently exp with LayerNorm
            if not opt.memory_profiling:
                # distributed is required to convert BatchNorm to SyncBatchNorm for DDP
                optimize_model(model, distributed=(self.world_size > 1))

        init_model_parameters(model, opt)
        self.model = model
        self.loss_function = loss_function
        # self.grad_scaler = torch.cuda.amp.GradScaler()

        if self.cuda:
            torch.cuda.set_device(self.device)

            self.loss_function = self.loss_function.cuda(device=self.device)
            self.model = self.model.cuda(device=self.device)
            if opt.ctc_loss > 0.0:
                self.ctc_loss_function = self.ctc_loss_function.cuda(
                    device=self.device)

        if opt.load_from:
            checkpoint = torch.load(opt.load_from,
                                    map_location=lambda storage, loc: storage)

        if setup_optimizer:

            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if self.is_main():
                print("[INFO] Optimizer: ", self.optim.optimizer)

            if opt.load_from:
                if 'optim' in checkpoint and checkpoint[
                        'optim'] is not None and not opt.reset_optim:
                    self.optim.load_state_dict(checkpoint['optim'])

        if not self.opt.fp16:
            opt_level = "O0"
            keep_batchnorm_fp32 = False
        elif self.opt.fp16_mixed:
            opt_level = "O1"
            keep_batchnorm_fp32 = None
        else:
            opt_level = "O2"
            keep_batchnorm_fp32 = False

        self.opt_level = opt_level

        if self.cuda:
            self.model, self.optim.optimizer = amp.initialize(
                self.model,
                self.optim.optimizer,
                opt_level=opt_level,
                keep_batchnorm_fp32=keep_batchnorm_fp32,
                loss_scale="dynamic",
                verbosity=1 if self.opt.verbose else 1)

        if opt.load_from:
            self.model.load_state_dict(checkpoint['model'])
            if prec_opt is not None and hasattr(prec_opt, "fp16_mixed"):
                # Only load amp information if the mode is the same
                # Maybe its better to change between optimization mode?
                if opt.fp16_mixed == prec_opt.fp16_mixed and opt.fp16 == prec_opt.fp16:
                    if 'amp' in checkpoint:
                        try:
                            amp.load_state_dict(checkpoint['amp'])
                        except Exception:
                            # loading the amp state can fail
                            pass

        if self.world_size > 1:
            # find_unused_parameters may be required for dropped layer (parameters that are not connected to
            # any particular graph)
            # find_unused_parameters = True

            self.model = DDP(self.model,
                             delay_allreduce=True,
                             gradient_average=False)

        print("[INFO] Process %d ready." % self.rank, flush=True)