Esempio n. 1
0
    def __init__(self, opt):

        super().__init__(opt)

        # self.eos = onmt.constants.EOS
        # self.pad = onmt.constants.PAD
        # self.bos = self.bos_id

        self.src_bos = onmt.constants.SRC_BOS
        self.src_eos = onmt.constants.SRC_EOS
        self.src_pad = onmt.constants.SRC_PAD
        self.src_unk = onmt.constants.SRC_UNK

        self.tgt_bos = self.bos_id
        self.tgt_pad = onmt.constants.TGT_PAD
        self.tgt_eos = onmt.constants.TGT_EOS
        self.tgt_unk = onmt.constants.TGT_UNK

        self.search = BeamSearch(self.tgt_dict)

        self.vocab_size = self.tgt_dict.size()
        self.min_len = 1
        self.normalize_scores = opt.normalize
        self.len_penalty = opt.alpha
        self.buffering = not opt.no_buffering
        # self.buffering = False  # buffering is currently bugged

        if hasattr(opt, 'no_repeat_ngram_size'):
            self.no_repeat_ngram_size = opt.no_repeat_ngram_size
        else:
            self.no_repeat_ngram_size = 0

        if hasattr(opt, 'dynamic_max_len'):
            self.dynamic_max_len = opt.dynamic_max_len
        else:
            self.dynamic_max_len = False

        if hasattr(opt, 'dynamic_max_len_scale'):
            self.dynamic_max_len_scale = opt.dynamic_max_len_scale
        else:
            self.dynamic_max_len_scale = 1.2

        if opt.verbose:
            # print('* Current bos id is: %d, default bos id is: %d' % (self.tgt_bos, onmt.constants.BOS))
            print("src bos id is %d; src eos id is %d;  src pad id is %d; src unk id is %d"
                  % (self.src_bos, self.src_eos, self.src_pad, self.src_unk))
            print("tgt bos id is %d; tgt eos id is %d;  tgt_pad id is %d; tgt unk id is %d"
                  % (self.tgt_bos, self.tgt_eos, self.tgt_pad, self.tgt_unk))
            print('* Using fast beam search implementation')

        if opt.vocab_list:
            word_list = list()
            for line in open(opt.vocab_list).readlines():
                word = line.strip()
                word_list.append(word)

            self.filter = torch.Tensor(self.tgt_dict.size()).zero_()
            for word_idx in [self.tgt_eos, self.tgt_unk]:
                self.filter[word_idx] = 1

            for word in word_list:
                idx = self.tgt_dict.lookup(word)
                if idx is not None:
                    self.filter[idx] = 1

            self.filter = self.filter.bool()
            # print(self.filter)
            if opt.cuda:
                self.filter = self.filter.cuda()

            self.use_filter = True
        else:
            self.use_filter = False

        if opt.sub_model:
            self.sub_models = list()
            self.sub_model_types = list()

            # models are string with | as delimiter
            sub_models = opt.sub_model.split("|")

            print("Loading sub models ... ")
            self.n_sub_models = len(sub_models)
            self.sub_type = 'text'

            for i, model_path in enumerate(sub_models):
                checkpoint = torch.load(model_path,
                                        map_location=lambda storage, loc: storage)

                model_opt = checkpoint['opt']
                model_opt = backward_compatible(model_opt)
                if hasattr(model_opt, "enc_not_load_state"):
                    model_opt.enc_not_load_state = True
                    model_opt.dec_not_load_state = True

                dicts = checkpoint['dicts']

                # update special tokens
                onmt.constants = add_tokenidx(model_opt, onmt.constants, dicts)
                # self.bos_token = model_opt.tgt_bos_word

                """"BE CAREFUL: the sub-models might mismatch with the main models in terms of language dict"""
                """"REQUIRE RE-matching"""

                if i == 0:
                    if "src" in checkpoint['dicts']:
                        self.src_dict = checkpoint['dicts']['src']
                #     else:
                #         self._type = "audio"
                #     self.tgt_dict = checkpoint['dicts']['tgt']
                #
                #     if "langs" in checkpoint["dicts"]:
                #         self.lang_dict = checkpoint['dicts']['langs']
                #
                #     else:
                #         self.lang_dict = {'src': 0, 'tgt': 1}
                #
                #     self.bos_id = self.tgt_dict.labelToIdx[self.bos_token]
                if opt.verbose:
                    print('Loading sub-model from %s' % model_path)

                model = build_model(model_opt, checkpoint['dicts'])
                optimize_model(model)
                model.load_state_dict(checkpoint['model'])

                if model_opt.model in model_list:
                    # if model.decoder.positional_encoder.len_max < self.opt.max_sent_length:
                    #     print("Not enough len to decode. Renewing .. ")
                    #     model.decoder.renew_buffer(self.opt.max_sent_length)
                    model.renew_buffer(self.opt.max_sent_length)

                if opt.fp16:
                    model = model.half()

                if opt.cuda:
                    model = model.cuda()
                else:
                    model = model.cpu()

                if opt.dynamic_quantile == 1:

                    engines = torch.backends.quantized.supported_engines
                    if 'fbgemm' in engines:
                        torch.backends.quantized.engine = 'fbgemm'
                    else:
                        print(
                            "[INFO] fbgemm is not found in the available engines. "
                            " Possibly the CPU does not support AVX2."
                            " It is recommended to disable Quantization (set to 0).")
                        torch.backends.quantized.engine = 'qnnpack'

                    model = torch.quantization.quantize_dynamic(
                        model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8
                    )

                model.eval()

                self.sub_models.append(model)
                self.sub_model_types.append(model_opt.model)
        else:
            self.n_sub_models = 0
            self.sub_models = []

        if opt.ensemble_weight:
            ensemble_weight = [float(item) for item in opt.ensemble_weight.split("|")]
            assert len(ensemble_weight) == self.n_models

            if opt.sub_ensemble_weight:
                sub_ensemble_weight = [float(item) for item in opt.sub_ensemble_weight.split("|")]
                assert len(sub_ensemble_weight) == self.n_sub_models
                ensemble_weight = ensemble_weight + sub_ensemble_weight

            total = sum(ensemble_weight)
            self.ensemble_weight = [ item / total for item in ensemble_weight]
        else:
            self.ensemble_weight = None

        print(self.main_model_opt)
Esempio n. 2
0
def main():

    if not opt.multi_dataset:
        if opt.data_format in ['bin', 'raw']:
            start = time.time()

            if opt.data.endswith(".train.pt"):
                print("Loading data from '%s'" % opt.data)
                dataset = torch.load(opt.data)
            else:
                print("Loading data from %s" % opt.data + ".train.pt")
                dataset = torch.load(opt.data + ".train.pt")

            elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
            print("Done after %s" % elapse)

            dicts = dataset['dicts']
            onmt.constants = add_tokenidx(opt, onmt.constants, dicts)

            # For backward compatibility
            train_dict = defaultdict(lambda: None, dataset['train'])
            valid_dict = defaultdict(lambda: None, dataset['valid'])

            if train_dict['src_lang'] is not None:
                assert 'langs' in dicts
                train_src_langs = train_dict['src_lang']
                train_tgt_langs = train_dict['tgt_lang']
            else:
                # allocate new languages
                dicts['langs'] = {'src': 0, 'tgt': 1}
                train_src_langs = list()
                train_tgt_langs = list()
                # Allocation one for the bilingual case
                train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            if not opt.streaming:
                train_data = onmt.Dataset(
                    numpy_to_torch(train_dict['src']),
                    numpy_to_torch(train_dict['tgt']),
                    train_dict['src_sizes'],
                    train_dict['tgt_sizes'],
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    augment=opt.augment_speech,
                    sa_f=opt.sa_f,
                    sa_t=opt.sa_t,
                    upsampling=opt.upsampling,
                    num_split=1)
            else:
                train_data = onmt.StreamDataset(
                    train_dict['src'],
                    train_dict['tgt'],
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    augment=opt.augment_speech,
                    upsampling=opt.upsampling)

            if valid_dict['src_lang'] is not None:
                assert 'langs' in dicts
                valid_src_langs = valid_dict['src_lang']
                valid_tgt_langs = valid_dict['tgt_lang']
            else:
                # allocate new languages
                valid_src_langs = list()
                valid_tgt_langs = list()

                # Allocation one for the bilingual case
                valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            if not opt.streaming:
                valid_data = onmt.Dataset(
                    numpy_to_torch(valid_dict['src']),
                    numpy_to_torch(valid_dict['tgt']),
                    valid_dict['src_sizes'],
                    valid_dict['tgt_sizes'],
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    cleaning=True,
                    upsampling=opt.upsampling)
            else:
                valid_data = onmt.StreamDataset(
                    numpy_to_torch(valid_dict['src']),
                    numpy_to_torch(valid_dict['tgt']),
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    upsampling=opt.upsampling)

            print(' * number of training sentences. %d' %
                  len(dataset['train']['src']))
            print(' * maximum batch size (words per batch). %d' %
                  opt.batch_size_words)

        elif opt.data_format in ['scp', 'scpmem', 'mmem']:
            print("Loading memory mapped data files ....")
            start = time.time()
            from onmt.data.mmap_indexed_dataset import MMapIndexedDataset
            from onmt.data.scp_dataset import SCPIndexDataset

            dicts = torch.load(opt.data + ".dict.pt")
            onmt.constants = add_tokenidx(opt, onmt.constants, dicts)

            if opt.data_format in ['scp', 'scpmem']:
                audio_data = torch.load(opt.data + ".scp_path.pt")

            # allocate languages if not
            if 'langs' not in dicts:
                dicts['langs'] = {'src': 0, 'tgt': 1}
            else:
                print(dicts['langs'])

            train_path = opt.data + '.train'
            if opt.data_format in ['scp', 'scpmem']:
                train_src = SCPIndexDataset(audio_data['train'],
                                            concat=opt.concat)
            else:
                train_src = MMapIndexedDataset(train_path + '.src')

            train_tgt = MMapIndexedDataset(train_path + '.tgt')

            # check the lang files if they exist (in the case of multi-lingual models)
            if os.path.exists(train_path + '.src_lang.bin'):
                assert 'langs' in dicts
                train_src_langs = MMapIndexedDataset(train_path + '.src_lang')
                train_tgt_langs = MMapIndexedDataset(train_path + '.tgt_lang')
            else:
                train_src_langs = list()
                train_tgt_langs = list()
                # Allocate a Tensor(1) for the bilingual case
                train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            # check the length files if they exist
            if os.path.exists(train_path + '.src_sizes.npy'):
                train_src_sizes = np.load(train_path + '.src_sizes.npy')
                train_tgt_sizes = np.load(train_path + '.tgt_sizes.npy')
            else:
                train_src_sizes, train_tgt_sizes = None, None

            if opt.encoder_type == 'audio':
                data_type = 'audio'
            else:
                data_type = 'text'

            if not opt.streaming:
                train_data = onmt.Dataset(
                    train_src,
                    train_tgt,
                    train_src_sizes,
                    train_tgt_sizes,
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    src_align_right=opt.src_align_right,
                    upsampling=opt.upsampling,
                    augment=opt.augment_speech,
                    sa_f=opt.sa_f,
                    sa_t=opt.sa_t,
                    cleaning=True,
                    verbose=True,
                    num_split=1)
            else:
                train_data = onmt.StreamDataset(
                    train_src,
                    train_tgt,
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=False,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    upsampling=opt.upsampling)

            valid_path = opt.data + '.valid'
            if opt.data_format in ['scp', 'scpmem']:
                valid_src = SCPIndexDataset(audio_data['valid'],
                                            concat=opt.concat)
            else:
                valid_src = MMapIndexedDataset(valid_path + '.src')
            valid_tgt = MMapIndexedDataset(valid_path + '.tgt')

            if os.path.exists(valid_path + '.src_lang.bin'):
                assert 'langs' in dicts
                valid_src_langs = MMapIndexedDataset(valid_path + '.src_lang')
                valid_tgt_langs = MMapIndexedDataset(valid_path + '.tgt_lang')
            else:
                valid_src_langs = list()
                valid_tgt_langs = list()

                # Allocation one for the bilingual case
                valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            # check the length files if they exist
            if os.path.exists(valid_path + '.src_sizes.npy'):
                valid_src_sizes = np.load(valid_path + '.src_sizes.npy')
                valid_tgt_sizes = np.load(valid_path + '.tgt_sizes.npy')
            else:
                valid_src_sizes, valid_tgt_sizes = None, None

            if not opt.streaming:
                valid_data = onmt.Dataset(
                    valid_src,
                    valid_tgt,
                    valid_src_sizes,
                    valid_tgt_sizes,
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    multiplier=opt.batch_size_multiplier,
                    data_type=data_type,
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    src_align_right=opt.src_align_right,
                    cleaning=True,
                    verbose=True,
                    debug=True)
            else:
                # for validation data, we have to go through sentences (very slow but to ensure correctness)
                valid_data = onmt.StreamDataset(
                    valid_src,
                    valid_tgt,
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents)

            elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
            print("Done after %s" % elapse)

        else:
            raise NotImplementedError

        print(' * number of sentences in training data: %d' %
              train_data.size())
        print(' * number of sentences in validation data: %d' %
              valid_data.size())

    else:
        print("[INFO] Reading multiple dataset ...")
        # raise NotImplementedError

        dicts = torch.load(opt.data + ".dict.pt")
        onmt.constants = add_tokenidx(opt, onmt.constants, dicts)

        root_dir = os.path.dirname(opt.data)

        print("Loading training data ...")

        train_dirs, valid_dirs = dict(), dict()

        # scan the data directory to find the training data
        for dir_ in os.listdir(root_dir):
            if os.path.isdir(os.path.join(root_dir, dir_)):
                if str(dir_).startswith("train"):
                    idx = int(dir_.split(".")[1])
                    train_dirs[idx] = dir_
                if dir_.startswith("valid"):
                    idx = int(dir_.split(".")[1])
                    valid_dirs[idx] = dir_

        train_sets, valid_sets = list(), list()

        for (idx_, dir_) in sorted(train_dirs.items()):

            data_dir = os.path.join(root_dir, dir_)
            print("[INFO] Loading training data %i from %s" % (idx_, dir_))

            if opt.data_format in ['bin', 'raw']:
                raise NotImplementedError

            elif opt.data_format in ['scp', 'scpmem', 'mmem']:
                from onmt.data.mmap_indexed_dataset import MMapIndexedDataset
                from onmt.data.scp_dataset import SCPIndexDataset

                if opt.data_format in ['scp', 'scpmem']:
                    audio_data = torch.load(
                        os.path.join(data_dir, "data.scp_path.pt"))
                    src_data = SCPIndexDataset(audio_data, concat=opt.concat)
                else:
                    src_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.src"))

                tgt_data = MMapIndexedDataset(
                    os.path.join(data_dir, "data.tgt"))

                src_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.src_lang'))
                tgt_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.tgt_lang'))

                if os.path.exists(os.path.join(data_dir,
                                               'data.src_sizes.npy')):
                    src_sizes = np.load(
                        os.path.join(data_dir, 'data.src_sizes.npy'))
                    tgt_sizes = np.load(
                        os.path.join(data_dir, 'data.tgt_sizes.npy'))
                else:
                    src_sizes, sizes = None, None

                if opt.encoder_type == 'audio':
                    data_type = 'audio'
                else:
                    data_type = 'text'

                if not opt.streaming:
                    train_data = onmt.Dataset(
                        src_data,
                        tgt_data,
                        src_sizes,
                        tgt_sizes,
                        src_lang_data,
                        tgt_lang_data,
                        batch_size_words=opt.batch_size_words,
                        data_type=data_type,
                        sorting=True,
                        batch_size_sents=opt.batch_size_sents,
                        multiplier=opt.batch_size_multiplier,
                        src_align_right=opt.src_align_right,
                        upsampling=opt.upsampling,
                        cleaning=True,
                        verbose=True,
                        num_split=1)

                    train_sets.append(train_data)

                else:
                    print("Multi-dataset not implemented for Streaming tasks.")
                    raise NotImplementedError

        for (idx_, dir_) in sorted(valid_dirs.items()):

            data_dir = os.path.join(root_dir, dir_)

            print("[INFO] Loading validation data %i from %s" % (idx_, dir_))

            if opt.data_format in ['bin', 'raw']:
                raise NotImplementedError

            elif opt.data_format in ['scp', 'scpmem', 'mmem']:

                if opt.data_format in ['scp', 'scpmem']:
                    audio_data = torch.load(
                        os.path.join(data_dir, "data.scp_path.pt"))
                    src_data = SCPIndexDataset(audio_data, concat=opt.concat)
                else:
                    src_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.src"))

                tgt_data = MMapIndexedDataset(
                    os.path.join(data_dir, "data.tgt"))

                src_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.src_lang'))
                tgt_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.tgt_lang'))

                if os.path.exists(os.path.join(data_dir,
                                               'data.src_sizes.npy')):
                    src_sizes = np.load(
                        os.path.join(data_dir, 'data.src_sizes.npy'))
                    tgt_sizes = np.load(
                        os.path.join(data_dir, 'data.tgt_sizes.npy'))
                else:
                    src_sizes, sizes = None, None

                if opt.encoder_type == 'audio':
                    data_type = 'audio'
                else:
                    data_type = 'text'

                if not opt.streaming:
                    valid_data = onmt.Dataset(
                        src_data,
                        tgt_data,
                        src_sizes,
                        tgt_sizes,
                        src_lang_data,
                        tgt_lang_data,
                        batch_size_words=opt.batch_size_words,
                        multiplier=opt.batch_size_multiplier,
                        data_type=data_type,
                        sorting=True,
                        batch_size_sents=opt.batch_size_sents,
                        src_align_right=opt.src_align_right,
                        cleaning=True,
                        verbose=True,
                        debug=True)

                    valid_sets.append(valid_data)

                else:
                    raise NotImplementedError

        train_data = train_sets
        valid_data = valid_sets

    if opt.load_from:
        checkpoint = torch.load(opt.load_from,
                                map_location=lambda storage, loc: storage)
        print("* Loading dictionaries from the checkpoint")
        del checkpoint['model']
        del checkpoint['optim']
        dicts = checkpoint['dicts']
    else:
        dicts['tgt'].patch(opt.patch_vocab_multiplier)
        checkpoint = None

    if "src" in dicts:
        print(' * vocabulary size. source = %d; target = %d' %
              (dicts['src'].size(), dicts['tgt'].size()))
    else:
        print(' * vocabulary size. target = %d' % (dicts['tgt'].size()))

    os.environ['MASTER_ADDR'] = opt.master_addr  # default 'localhost'
    os.environ['MASTER_PORT'] = opt.master_port  # default '8888'

    # spawn N processes for N gpus
    # each process has a different trainer
    torch.multiprocessing.spawn(run_process,
                                nprocs=len(opt.gpus),
                                args=(train_data, valid_data, dicts, opt,
                                      checkpoint))
parser.add_argument('-model_tgt', required=True,
                    help='Path to model .pt file')
parser.add_argument('-model_out', required=True,
                    help='Path to model .pt file')

opt = parser.parse_args()
# first, we load the model src
print(opt.model_src)
checkpoint = torch.load(opt.model_src, map_location=lambda storage, loc: storage)

model_opt = checkpoint['opt']
model_opt = backward_compatible(model_opt)

src_dicts = checkpoint['dicts']
# update special tokens
onmt.constants = add_tokenidx(model_opt, onmt.constants, src_dicts)

model = build_model(model_opt, checkpoint['dicts'])
model.load_state_dict(checkpoint['model'])

# now load the 2nd model
print(opt.model_tgt)
checkpoint = torch.load(opt.model_tgt, map_location=lambda storage, loc: storage)
# model_opt = checkpoint['opt']
# model_opt = backward_compatible(model_opt)
tgt_dicts = checkpoint['dicts']

# tgt_model = build_model(model_opt, checkpoint['dicts'])

# check the embedding
lang_emb = copy.deepcopy(model.encoder.language_embedding.weight.data)
Esempio n. 4
0
def main():

    if not opt.multi_dataset:
        if opt.data_format in ['bin', 'raw']:
            start = time.time()

            if opt.data.endswith(".train.pt"):
                print("Loading data from '%s'" % opt.data)
                dataset = torch.load(opt.data)
            else:
                print("Loading data from %s" % opt.data + ".train.pt")
                dataset = torch.load(opt.data + ".train.pt")

            elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
            print("Done after %s" % elapse)

            dicts = dataset['dicts']

            # For backward compatibility
            train_dict = defaultdict(lambda: None, dataset['train'])
            valid_dict = defaultdict(lambda: None, dataset['valid'])

            if train_dict['src_lang'] is not None:
                assert 'langs' in dicts
                train_src_langs = train_dict['src_lang']
                train_tgt_langs = train_dict['tgt_lang']
            else:
                # allocate new languages
                dicts['langs'] = {'src': 0, 'tgt': 1}
                train_src_langs = list()
                train_tgt_langs = list()
                # Allocation one for the bilingual case
                train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            if not opt.streaming:
                train_data = onmt.Dataset(
                    numpy_to_torch(train_dict['src']),
                    numpy_to_torch(train_dict['tgt']),
                    train_dict['src_sizes'],
                    train_dict['tgt_sizes'],
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    augment=opt.augment_speech,
                    sa_f=opt.sa_f,
                    sa_t=opt.sa_t,
                    upsampling=opt.upsampling,
                    num_split=len(opt.gpus),
                    cleaning=True)
            else:
                train_data = onmt.StreamDataset(
                    train_dict['src'],
                    train_dict['tgt'],
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    augment=opt.augment_speech,
                    upsampling=opt.upsampling)

            if valid_dict['src_lang'] is not None:
                assert 'langs' in dicts
                valid_src_langs = valid_dict['src_lang']
                valid_tgt_langs = valid_dict['tgt_lang']
            else:
                # allocate new languages
                valid_src_langs = list()
                valid_tgt_langs = list()

                # Allocation one for the bilingual case
                valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            if not opt.streaming:
                valid_data = onmt.Dataset(
                    numpy_to_torch(valid_dict['src']),
                    numpy_to_torch(valid_dict['tgt']),
                    valid_dict['src_sizes'],
                    valid_dict['tgt_sizes'],
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    upsampling=opt.upsampling,
                    cleaning=True)
            else:
                valid_data = onmt.StreamDataset(
                    numpy_to_torch(valid_dict['src']),
                    numpy_to_torch(valid_dict['tgt']),
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=dataset.get("type", "text"),
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    upsampling=opt.upsampling)

            print(' * number of training sentences. %d' %
                  len(dataset['train']['src']))
            print(' * maximum batch size (words per batch). %d' %
                  opt.batch_size_words)

        elif opt.data_format in ['scp', 'scpmem', 'mmem']:
            print("Loading memory mapped data files ....")
            start = time.time()
            from onmt.data.mmap_indexed_dataset import MMapIndexedDataset
            from onmt.data.scp_dataset import SCPIndexDataset

            dicts = torch.load(opt.data + ".dict.pt")
            if opt.data_format in ['scp', 'scpmem']:
                audio_data = torch.load(opt.data + ".scp_path.pt")

            # allocate languages if not
            if 'langs' not in dicts:
                dicts['langs'] = {'src': 0, 'tgt': 1}
            else:
                print(dicts['langs'])

            train_path = opt.data + '.train'
            if opt.data_format in ['scp', 'scpmem']:
                train_src = SCPIndexDataset(audio_data['train'],
                                            concat=opt.concat)
            else:
                train_src = MMapIndexedDataset(train_path + '.src')

            train_tgt = MMapIndexedDataset(train_path + '.tgt')

            # check the lang files if they exist (in the case of multi-lingual models)
            if os.path.exists(train_path + '.src_lang.bin'):
                assert 'langs' in dicts
                train_src_langs = MMapIndexedDataset(train_path + '.src_lang')
                train_tgt_langs = MMapIndexedDataset(train_path + '.tgt_lang')
            else:
                train_src_langs = list()
                train_tgt_langs = list()
                # Allocate a Tensor(1) for the bilingual case
                train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            # check the length files if they exist
            if os.path.exists(train_path + '.src_sizes.npy'):
                train_src_sizes = np.load(train_path + '.src_sizes.npy')
                train_tgt_sizes = np.load(train_path + '.tgt_sizes.npy')
            else:
                train_src_sizes, train_tgt_sizes = None, None

            if opt.encoder_type == 'audio':
                data_type = 'audio'
            else:
                data_type = 'text'

            if not opt.streaming:
                train_data = onmt.Dataset(
                    train_src,
                    train_tgt,
                    train_src_sizes,
                    train_tgt_sizes,
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    src_align_right=opt.src_align_right,
                    augment=opt.augment_speech,
                    sa_f=opt.sa_f,
                    sa_t=opt.sa_t,
                    upsampling=opt.upsampling,
                    cleaning=True,
                    verbose=True)
            else:
                train_data = onmt.StreamDataset(
                    train_src,
                    train_tgt,
                    train_src_langs,
                    train_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=False,
                    batch_size_sents=opt.batch_size_sents,
                    multiplier=opt.batch_size_multiplier,
                    upsampling=opt.upsampling)

            valid_path = opt.data + '.valid'
            if opt.data_format in ['scp', 'scpmem']:
                valid_src = SCPIndexDataset(audio_data['valid'],
                                            concat=opt.concat)
            else:
                valid_src = MMapIndexedDataset(valid_path + '.src')
            valid_tgt = MMapIndexedDataset(valid_path + '.tgt')

            if os.path.exists(valid_path + '.src_lang.bin'):
                assert 'langs' in dicts
                valid_src_langs = MMapIndexedDataset(valid_path + '.src_lang')
                valid_tgt_langs = MMapIndexedDataset(valid_path + '.tgt_lang')
            else:
                valid_src_langs = list()
                valid_tgt_langs = list()

                # Allocation one for the bilingual case
                valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
                valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))

            # check the length files if they exist
            if os.path.exists(valid_path + '.src_sizes.npy'):
                valid_src_sizes = np.load(valid_path + '.src_sizes.npy')
                valid_tgt_sizes = np.load(valid_path + '.tgt_sizes.npy')
            else:
                valid_src_sizes, valid_tgt_sizes = None, None

            if not opt.streaming:
                valid_data = onmt.Dataset(
                    valid_src,
                    valid_tgt,
                    valid_src_sizes,
                    valid_tgt_sizes,
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents,
                    src_align_right=opt.src_align_right,
                    cleaning=True,
                    verbose=True,
                    debug=True,
                    num_split=len(opt.gpus))
            else:
                # for validation data, we have to go through sentences (very slow but to ensure correctness)
                valid_data = onmt.StreamDataset(
                    valid_src,
                    valid_tgt,
                    valid_src_langs,
                    valid_tgt_langs,
                    batch_size_words=opt.batch_size_words,
                    data_type=data_type,
                    sorting=True,
                    batch_size_sents=opt.batch_size_sents)

            elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
            print("Done after %s" % elapse)

        else:
            raise NotImplementedError

        print(' * number of sentences in training data: %d' %
              train_data.size())
        print(' * number of sentences in validation data: %d' %
              valid_data.size())

    else:
        print("[INFO] Reading multiple dataset ...")
        # raise NotImplementedError

        dicts = torch.load(opt.data + ".dict.pt")

        root_dir = os.path.dirname(opt.data)

        print("Loading training data ...")

        train_dirs, valid_dirs = dict(), dict()

        # scan the data directory to find the training data
        for dir_ in os.listdir(root_dir):
            if os.path.isdir(os.path.join(root_dir, dir_)):
                if str(dir_).startswith("train"):
                    idx = int(dir_.split(".")[1])
                    train_dirs[idx] = dir_
                if dir_.startswith("valid"):
                    idx = int(dir_.split(".")[1])
                    valid_dirs[idx] = dir_

        train_sets, valid_sets = list(), list()

        for (idx_, dir_) in sorted(train_dirs.items()):

            data_dir = os.path.join(root_dir, dir_)
            print("[INFO] Loading training data %i from %s" % (idx_, dir_))

            if opt.data_format in ['bin', 'raw']:
                raise NotImplementedError

            elif opt.data_format in ['scp', 'scpmem', 'mmem']:
                from onmt.data.mmap_indexed_dataset import MMapIndexedDataset
                from onmt.data.scp_dataset import SCPIndexDataset

                if opt.data_format in ['scp', 'scpmem']:
                    audio_data = torch.load(
                        os.path.join(data_dir, "data.scp_path.pt"))
                    src_data = SCPIndexDataset(audio_data, concat=opt.concat)
                else:
                    src_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.src"))

                tgt_data = MMapIndexedDataset(
                    os.path.join(data_dir, "data.tgt"))

                src_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.src_lang'))
                tgt_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.tgt_lang'))

                if os.path.exists(os.path.join(data_dir,
                                               'data.src_sizes.npy')):
                    src_sizes = np.load(
                        os.path.join(data_dir, 'data.src_sizes.npy'))
                    tgt_sizes = np.load(
                        os.path.join(data_dir, 'data.tgt_sizes.npy'))
                else:
                    src_sizes, sizes = None, None

                if opt.encoder_type == 'audio':
                    data_type = 'audio'
                else:
                    data_type = 'text'

                if not opt.streaming:
                    train_data = onmt.Dataset(
                        src_data,
                        tgt_data,
                        src_sizes,
                        tgt_sizes,
                        src_lang_data,
                        tgt_lang_data,
                        batch_size_words=opt.batch_size_words,
                        data_type=data_type,
                        sorting=True,
                        batch_size_sents=opt.batch_size_sents,
                        multiplier=opt.batch_size_multiplier,
                        src_align_right=opt.src_align_right,
                        augment=opt.augment_speech,
                        sa_f=opt.sa_f,
                        sa_t=opt.sa_t,
                        upsampling=opt.upsampling,
                        cleaning=True,
                        verbose=True,
                        num_split=len(opt.gpus))

                    train_sets.append(train_data)

                else:
                    print("Multi-dataset not implemented for Streaming tasks.")
                    raise NotImplementedError

        for (idx_, dir_) in sorted(valid_dirs.items()):

            data_dir = os.path.join(root_dir, dir_)

            print("[INFO] Loading validation data %i from %s" % (idx_, dir_))

            if opt.data_format in ['bin', 'raw']:
                raise NotImplementedError

            elif opt.data_format in ['scp', 'scpmem', 'mmem']:

                if opt.data_format in ['scp', 'scpmem']:
                    audio_data = torch.load(
                        os.path.join(data_dir, "data.scp_path.pt"))
                    src_data = SCPIndexDataset(audio_data, concat=opt.concat)
                else:
                    src_data = MMapIndexedDataset(
                        os.path.join(data_dir, "data.src"))

                tgt_data = MMapIndexedDataset(
                    os.path.join(data_dir, "data.tgt"))

                src_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.src_lang'))
                tgt_lang_data = MMapIndexedDataset(
                    os.path.join(data_dir, 'data.tgt_lang'))

                if os.path.exists(os.path.join(data_dir,
                                               'data.src_sizes.npy')):
                    src_sizes = np.load(
                        os.path.join(data_dir, 'data.src_sizes.npy'))
                    tgt_sizes = np.load(
                        os.path.join(data_dir, 'data.tgt_sizes.npy'))
                else:
                    src_sizes, sizes = None, None

                if opt.encoder_type == 'audio':
                    data_type = 'audio'
                else:
                    data_type = 'text'

                if not opt.streaming:
                    valid_data = onmt.Dataset(
                        src_data,
                        tgt_data,
                        src_sizes,
                        tgt_sizes,
                        src_lang_data,
                        tgt_lang_data,
                        batch_size_words=opt.batch_size_words,
                        data_type=data_type,
                        sorting=True,
                        batch_size_sents=opt.batch_size_sents,
                        src_align_right=opt.src_align_right,
                        cleaning=True,
                        verbose=True,
                        debug=True,
                        num_split=len(opt.gpus))

                    valid_sets.append(valid_data)

                else:
                    raise NotImplementedError

        train_data = train_sets
        valid_data = valid_sets

    if opt.load_from:
        checkpoint = torch.load(opt.load_from,
                                map_location=lambda storage, loc: storage)
        print("* Loading dictionaries from the checkpoint")
        dicts = checkpoint['dicts']
    else:
        dicts['tgt'].patch(opt.patch_vocab_multiplier)
        checkpoint = None

    # Put the vocab mask from dicts to the datasets
    for data in [train_data, valid_data]:
        if isinstance(data, list):
            for i, data_ in enumerate(data):
                data_.set_mask(dicts['tgt'].vocab_mask)
                data[i] = data_
        else:
            data.set_mask(dicts['tgt'].vocab_mask)

    if "src" in dicts:
        print(' * vocabulary size. source = %d; target = %d' %
              (dicts['src'].size(), dicts['tgt'].size()))
    else:
        print('[INFO] vocabulary size. target = %d' % (dicts['tgt'].size()))

    print('* Building model...')

    # update special tokens
    onmt.constants = add_tokenidx(opt, onmt.constants, dicts)

    if not opt.fusion:
        if opt.bayes_by_backprop:
            model = build_bayesian_model(opt, dicts)
        else:
            model = build_model(opt, dicts)
        """ Building the loss function """
        if opt.nce:
            from onmt.modules.nce.nce_loss import NCELoss
            loss_function = NCELoss(opt.model_size,
                                    dicts['tgt'].size(),
                                    noise_ratio=opt.nce_noise,
                                    logz=9,
                                    label_smoothing=opt.label_smoothing)
        else:
            loss_function = NMTLossFunc(opt.model_size,
                                        dicts['tgt'].size(),
                                        label_smoothing=opt.label_smoothing,
                                        mirror=opt.mirror_loss,
                                        fast_xentropy=opt.fast_xentropy)

        # This function replaces modules with the more optimized counterparts so that it can run faster
        # Currently exp with LayerNorm
        if not opt.memory_profiling:
            optimize_model(model, fp16=opt.fp16)

    else:
        from onmt.model_factory import build_fusion
        from onmt.modules.loss import FusionLoss

        model = build_fusion(opt, dicts)

        loss_function = FusionLoss(dicts['tgt'].size(),
                                   label_smoothing=opt.label_smoothing)

    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)

    if not opt.debugging and len(opt.gpus) == 1:
        if opt.bayes_by_backprop:

            from onmt.train_utils.bayes_by_backprop_trainer import BayesianTrainer
            trainer = BayesianTrainer(model, loss_function, train_data,
                                      valid_data, dicts, opt)

        else:
            trainer = XETrainer(model, loss_function, train_data, valid_data,
                                dicts, opt)
    else:
        print(
            "MultiGPU is not supported by this train.py. Use train_distributed.py with the same arguments "
            "for MultiGPU training")
        raise NotImplementedError

    trainer.run(checkpoint=checkpoint)
Esempio n. 5
0
def build_tm_model(opt, dicts):
    onmt.constants = add_tokenidx(opt, onmt.constants, dicts)

    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
    else:
        raise NotImplementedError

    if opt.reconstruct:
        # reconstruction is only compatible
        assert opt.model == 'relative_transformer'
        assert opt.encoder_type == 'text'

    # BUILD GENERATOR
    if opt.copy_generator:
        if opt.nce_noise > 0:
            print("[INFO] Copy generator overrides NCE.")
            opt.nce = False
            opt.nce_noise = 0
        generators = [
            CopyGenerator(opt.model_size,
                          dicts['tgt'].size(),
                          fix_norm=opt.fix_norm_output_embedding)
        ]
    elif opt.nce_noise > 0:
        from onmt.modules.nce.nce_linear import NCELinear
        from onmt.modules.nce.nce_utils import build_unigram_noise
        noise_distribution = build_unigram_noise(
            torch.FloatTensor(list(dicts['tgt'].frequencies.values())))

        generator = NCELinear(opt.model_size,
                              dicts['tgt'].size(),
                              fix_norm=opt.fix_norm_output_embedding,
                              noise_distribution=noise_distribution,
                              noise_ratio=opt.nce_noise)
        generators = [generator]
    else:
        generators = [
            onmt.modules.base_seq2seq.Generator(
                opt.model_size,
                dicts['tgt'].size(),
                fix_norm=opt.fix_norm_output_embedding)
        ]

    # BUILD EMBEDDINGS
    if 'src' in dicts:
        if (not hasattr(opt, "enc_pretrained_model")) or (
                not opt.enc_pretrained_model):
            embedding_src = nn.Embedding(dicts['src'].size(),
                                         opt.model_size,
                                         padding_idx=onmt.constants.SRC_PAD)
    else:
        embedding_src = None

    if opt.join_embedding and embedding_src is not None:
        embedding_tgt = embedding_src
        print("* Joining the weights of encoder and decoder word embeddings")
    elif not opt.dec_pretrained_model:
        embedding_tgt = nn.Embedding(dicts['tgt'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.TGT_PAD)
    else:
        assert opt.model == "pretrain_transformer"
        embedding_tgt = None

    if opt.use_language_embedding:
        print("* Create language embeddings with %d languages" %
              len(dicts['langs']))
        language_embeddings = nn.Embedding(len(dicts['langs']), opt.model_size)
    else:
        language_embeddings = None

    if opt.ctc_loss != 0:
        generators.append(
            onmt.modules.base_seq2seq.Generator(opt.model_size,
                                                dicts['tgt'].size() + 1))

    if opt.model in ['conformer', 'speech_transformer', 'hybrid_transformer']:
        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_recognizer.relative_transformer import \
            SpeechTransformerEncoder, SpeechTransformerDecoder

        if opt.model == 'conformer':
            from onmt.models.speech_recognizer.conformer import ConformerEncoder, Conformer
            from onmt.models.speech_recognizer.lstm import SpeechLSTMDecoder
            opt.cnn_downsampling = True  # force this bool to have masking at decoder to be corrected
            encoder = ConformerEncoder(opt, None, None, 'audio')

            # decoder = SpeechLSTMDecoder(opt, embedding_tgt, language_embeddings=language_embeddings)
            decoder = SpeechTransformerDecoder(
                opt,
                embedding_tgt,
                positional_encoder,
                language_embeddings=language_embeddings)

            # model = Conformer(encoder, decoder, nn.ModuleList(generators), ctc=opt.ctc_loss > 0.0)
            model = RelativeTransformer(encoder,
                                        decoder,
                                        nn.ModuleList(generators),
                                        None,
                                        None,
                                        mirror=opt.mirror_loss,
                                        ctc=opt.ctc_loss > 0.0)
        elif opt.model == 'hybrid_transformer':
            from onmt.models.speech_recognizer.lstm import SpeechLSTMDecoder, SpeechLSTMEncoder, SpeechLSTMSeq2Seq
            encoder = SpeechTransformerEncoder(opt, None, positional_encoder,
                                               opt.encoder_type)

            decoder = SpeechLSTMDecoder(
                opt, embedding_tgt, language_embeddings=language_embeddings)

            model = SpeechLSTMSeq2Seq(encoder,
                                      decoder,
                                      nn.ModuleList(generators),
                                      ctc=opt.ctc_loss > 0.0)
        else:
            encoder = SpeechTransformerEncoder(opt, None, positional_encoder,
                                               opt.encoder_type)

            decoder = SpeechTransformerDecoder(
                opt,
                embedding_tgt,
                positional_encoder,
                language_embeddings=language_embeddings)
            model = RelativeTransformer(encoder,
                                        decoder,
                                        nn.ModuleList(generators),
                                        None,
                                        None,
                                        mirror=opt.mirror_loss,
                                        ctc=opt.ctc_loss > 0.0)

        # If we use the multilingual model and weights are partitioned:
        if opt.multilingual_partitioned_weights:
            # this is basically the language embeddings
            factor_embeddings = nn.Embedding(len(dicts['langs']),
                                             opt.mpw_factor_size)

            encoder.factor_embeddings = factor_embeddings
            decoder.factor_embeddings = factor_embeddings

    elif opt.model in ["LSTM", 'lstm']:
        # print("LSTM")
        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_recognizer.lstm import SpeechLSTMDecoder, SpeechLSTMEncoder, SpeechLSTMSeq2Seq

        encoder = SpeechLSTMEncoder(opt, None, opt.encoder_type)

        decoder = SpeechLSTMDecoder(opt,
                                    embedding_tgt,
                                    language_embeddings=language_embeddings)

        model = SpeechLSTMSeq2Seq(encoder,
                                  decoder,
                                  nn.ModuleList(generators),
                                  ctc=opt.ctc_loss > 0.0)

    elif opt.model in ['multilingual_translator', 'translator']:
        onmt.constants.init_value = opt.param_init
        from onmt.models.multilingual_translator.relative_transformer import \
            RelativeTransformerEncoder, RelativeTransformerDecoder

        encoder = RelativeTransformerEncoder(
            opt,
            embedding_src,
            None,
            opt.encoder_type,
            language_embeddings=language_embeddings)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        model = RelativeTransformer(encoder,
                                    decoder,
                                    nn.ModuleList(generators),
                                    None,
                                    None,
                                    mirror=opt.mirror_loss)

    elif opt.model in ['transformer', 'stochastic_transformer']:
        onmt.constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = TransformerEncoder(opt, None, positional_encoder,
                                         opt.encoder_type)
        elif opt.encoder_type == "mix":
            text_encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                "text",
                language_embeddings=language_embeddings)
            audio_encoder = TransformerEncoder(opt, None, positional_encoder,
                                               "audio")
            encoder = MixedEncoder(text_encoder, audio_encoder)
        else:
            print("Unknown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = TransformerDecoder(opt,
                                     embedding_tgt,
                                     positional_encoder,
                                     language_embeddings=language_embeddings)

        model = Transformer(encoder,
                            decoder,
                            nn.ModuleList(generators),
                            mirror=opt.mirror_loss)

    elif opt.model == 'relative_transformer':
        from onmt.models.relative_transformer import \
            RelativeTransformerEncoder, RelativeTransformerDecoder

        if opt.encoder_type == "text":
            encoder = RelativeTransformerEncoder(
                opt,
                embedding_src,
                None,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        if opt.encoder_type == "audio":
            # raise NotImplementedError
            encoder = RelativeTransformerEncoder(
                opt,
                None,
                None,
                encoder_type=opt.encoder_type,
                language_embeddings=language_embeddings)

        generator = nn.ModuleList(generators)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        if opt.reconstruct:
            rev_decoder = RelativeTransformerDecoder(
                opt,
                embedding_src,
                None,
                language_embeddings=language_embeddings)
            rev_generator = [
                onmt.modules.base_seq2seq.Generator(
                    opt.model_size,
                    dicts['src'].size(),
                    fix_norm=opt.fix_norm_output_embedding)
            ]
            rev_generator = nn.ModuleList(rev_generator)
        else:
            rev_decoder = None
            rev_generator = None

        model = RelativeTransformer(encoder,
                                    decoder,
                                    generator,
                                    rev_decoder,
                                    rev_generator,
                                    mirror=opt.mirror_loss)

    elif opt.model == 'universal_transformer':
        from onmt.legacy.old_models.universal_transformer import UniversalTransformerDecoder, UniversalTransformerEncoder

        generator = nn.ModuleList(generators)

        if opt.encoder_type == "text":
            encoder = UniversalTransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = UniversalTransformerEncoder(opt, None,
                                                  positional_encoder,
                                                  opt.encoder_type)

        decoder = UniversalTransformerDecoder(
            opt,
            embedding_tgt,
            positional_encoder,
            language_embeddings=language_embeddings)

        model = Transformer(encoder,
                            decoder,
                            generator,
                            mirror=opt.mirror_loss)

    elif opt.model == 'pretrain_transformer':
        assert (opt.enc_pretrained_model or opt.dec_pretrained_model)
        from onmt.models.pretrain_transformer import PretrainTransformer
        print(f"pos_emb_type: {opt.pos_emb_type}")
        print(f"max_pos_length: {opt.max_pos_length }")
        print(
            f"Share position embeddings cross heads: {not opt.diff_head_pos}")
        print()
        if opt.enc_pretrained_model:
            print("* Build encoder with enc_pretrained_model: {}".format(
                opt.enc_pretrained_model))
        if opt.enc_pretrained_model == "bert":
            from pretrain_module.configuration_bert import BertConfig
            from pretrain_module.modeling_bert import BertModel

            enc_bert_config = BertConfig.from_json_file(opt.enc_config_file)
            encoder = BertModel(
                enc_bert_config,
                bert_word_dropout=opt.enc_pretrain_word_dropout,
                bert_emb_dropout=opt.enc_pretrain_emb_dropout,
                bert_atten_dropout=opt.enc_pretrain_attn_dropout,
                bert_hidden_dropout=opt.enc_pretrain_hidden_dropout,
                bert_hidden_size=opt.enc_pretrain_hidden_size,
                is_decoder=False,
                before_plm_output_ln=opt.before_enc_output_ln,
                gradient_checkpointing=opt.enc_gradient_checkpointing,
                max_pos_len=opt.max_pos_length,
                diff_head_pos=opt.diff_head_pos,
                pos_emb_type=opt.pos_emb_type,
            )

        elif opt.enc_pretrained_model == "roberta":
            from pretrain_module.configuration_roberta import RobertaConfig
            from pretrain_module.modeling_roberta import RobertaModel
            enc_roberta_config = RobertaConfig.from_json_file(
                opt.enc_config_file)

            encoder = RobertaModel(
                enc_roberta_config,
                bert_word_dropout=opt.enc_pretrain_word_dropout,
                bert_emb_dropout=opt.enc_pretrain_emb_dropout,
                bert_atten_dropout=opt.enc_pretrain_attn_dropout,
                bert_hidden_dropout=opt.enc_pretrain_hidden_dropout,
                bert_hidden_size=opt.enc_pretrain_hidden_size,
                is_decoder=False,
                before_plm_output_ln=opt.before_enc_output_ln,
                gradient_checkpointing=opt.enc_gradient_checkpointing,
                max_pos_len=opt.max_pos_length,
                diff_head_pos=opt.diff_head_pos,
                pos_emb_type=opt.pos_emb_type,
            )
        elif not opt.enc_pretrained_model:
            print(" Encoder is not from pretrained model")
            encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        else:
            print("Warning: only bert and roberta are implemented for encoder")
            exit(-1)

        if opt.load_from or not opt.enc_state_dict:
            if opt.verbose:
                print("  No weights loading from {} for encoder".format(
                    opt.enc_pretrained_model))
        elif opt.enc_pretrained_model:
            print("  Loading weights for encoder from: \n", opt.enc_state_dict)

            enc_model_state_dict = torch.load(opt.enc_state_dict,
                                              map_location="cpu")

            encoder.from_pretrained(state_dict=enc_model_state_dict,
                                    model=encoder,
                                    output_loading_info=opt.verbose,
                                    model_prefix=opt.enc_pretrained_model)

        if opt.dec_pretrained_model:
            print("* Build decoder with dec_pretrained_model: {}".format(
                opt.dec_pretrained_model))

        if opt.dec_pretrained_model == "bert":
            if opt.enc_pretrained_model != "bert":
                from pretrain_module.configuration_bert import BertConfig
                from pretrain_module.modeling_bert import BertModel
            dec_bert_config = BertConfig.from_json_file(opt.dec_config_file)
            decoder = BertModel(
                dec_bert_config,
                bert_word_dropout=opt.dec_pretrain_word_dropout,
                bert_emb_dropout=opt.dec_pretrain_emb_dropout,
                bert_atten_dropout=opt.dec_pretrain_attn_dropout,
                bert_hidden_dropout=opt.dec_pretrain_hidden_dropout,
                bert_hidden_size=opt.dec_pretrain_hidden_size,
                is_decoder=True,
                gradient_checkpointing=opt.dec_gradient_checkpointing,
                max_pos_len=opt.max_pos_length,
                diff_head_pos=opt.diff_head_pos,
                pos_emb_type=opt.pos_emb_type,
            )

        elif opt.dec_pretrained_model == "roberta":
            if opt.enc_pretrained_model != "roberta":
                from pretrain_module.configuration_roberta import RobertaConfig
                from pretrain_module.modeling_roberta import RobertaModel

            dec_roberta_config = RobertaConfig.from_json_file(
                opt.dec_config_file)

            decoder = RobertaModel(
                dec_roberta_config,
                bert_word_dropout=opt.dec_pretrain_word_dropout,
                bert_emb_dropout=opt.dec_pretrain_emb_dropout,
                bert_atten_dropout=opt.dec_pretrain_attn_dropout,
                bert_hidden_dropout=opt.dec_pretrain_hidden_dropout,
                bert_hidden_size=opt.dec_pretrain_hidden_size,
                is_decoder=True,
                gradient_checkpointing=opt.dec_gradient_checkpointing,
                max_pos_len=opt.max_pos_length,
                diff_head_pos=opt.diff_head_pos,
                pos_emb_type=opt.pos_emb_type,
            )

        elif not opt.dec_pretrained_model:
            print(" Decoder is not from pretrained model")
            decoder = TransformerDecoder(
                opt,
                embedding_tgt,
                positional_encoder,
                language_embeddings=language_embeddings)
        else:
            print("Warning: only bert and roberta are implemented for decoder")
            exit(-1)

        if opt.load_from or not opt.dec_state_dict:
            if opt.verbose:
                print("  No weights loading from {} for decoder".format(
                    opt.dec_pretrained_model))
        elif opt.enc_pretrained_model:
            print("  Loading weights for decoder from: \n", opt.dec_state_dict)
            dec_model_state_dict = torch.load(opt.dec_state_dict,
                                              map_location="cpu")

            decoder.from_pretrained(state_dict=dec_model_state_dict,
                                    model=decoder,
                                    output_loading_info=opt.verbose,
                                    model_prefix=opt.dec_pretrained_model)

        encoder.enc_pretrained_model = opt.enc_pretrained_model
        decoder.dec_pretrained_model = opt.dec_pretrained_model

        encoder.input_type = opt.encoder_type

        model = PretrainTransformer(encoder, decoder,
                                    nn.ModuleList(generators))
    else:
        raise NotImplementedError

    if opt.tie_weights:
        print("* Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    return model
Esempio n. 6
0
    def __init__(self, opt):
        self.opt = opt
        self.tt = torch.cuda if opt.cuda else torch
        self.beam_accum = None
        self.beta = opt.beta
        self.alpha = opt.alpha
        self.start_with_bos = opt.start_with_bos
        self.fp16 = opt.fp16
        self.attributes = opt.attributes  # attributes split by |. for example: de|domain1
        # self.bos_token = opt.bos_token
        self.sampling = opt.sampling
        self.src_lang = opt.src_lang
        self.tgt_lang = opt.tgt_lang

        if self.attributes:
            self.attributes = self.attributes.split("|")

        self.models = list()
        self.model_types = list()

        # models are string with | as delimiter
        models = opt.model.split("|")

        print(models)
        self.n_models = len(models)
        self._type = 'text'

        for i, model_path in enumerate(models):
            checkpoint = torch.load(model_path,
                                    map_location=lambda storage, loc: storage)

            model_opt = checkpoint['opt']
            model_opt = backward_compatible(model_opt)
            if hasattr(model_opt, "enc_state_dict"):
                model_opt.enc_state_dict = None
                model_opt.dec_state_dict = None

            self.main_model_opt = model_opt
            dicts = checkpoint['dicts']

            # update special tokens
            onmt.constants = add_tokenidx(model_opt, onmt.constants, dicts)
            self.bos_token = model_opt.tgt_bos_word

            if i == 0:
                if "src" in checkpoint['dicts']:
                    self.src_dict = checkpoint['dicts']['src']
                else:
                    self._type = "audio"
                    # self.src_dict = self.tgt_dict

                self.tgt_dict = checkpoint['dicts']['tgt']

                if "langs" in checkpoint["dicts"]:
                    self.lang_dict = checkpoint['dicts']['langs']

                else:
                    self.lang_dict = {'src': 0, 'tgt': 1}

                self.bos_id = self.tgt_dict.labelToIdx[self.bos_token]

            model = build_model(model_opt, checkpoint['dicts'])
            optimize_model(model)
            if opt.verbose:
                print('Loading model from %s' % model_path)
            model.load_state_dict(checkpoint['model'])

            if model_opt.model in model_list:
                # if model.decoder.positional_encoder.len_max < self.opt.max_sent_length:
                #     print("Not enough len to decode. Renewing .. ")
                #     model.decoder.renew_buffer(self.opt.max_sent_length)
                model.renew_buffer(self.opt.max_sent_length)

            # model.convert_autograd()

            if opt.fp16:
                model = model.half()

            if opt.cuda:
                model = model.cuda()
            else:
                model = model.cpu()

            if opt.dynamic_quantile == 1:

                engines = torch.backends.quantized.supported_engines
                if 'fbgemm' in engines:
                    torch.backends.quantized.engine = 'fbgemm'
                else:
                    print(
                        "[INFO] fbgemm is not found in the available engines. Possibly the CPU does not support AVX2."
                        " It is recommended to disable Quantization (set to 0)."
                    )
                    torch.backends.quantized.engine = 'qnnpack'

                # convert the custom functions to their autograd equivalent first
                model.convert_autograd()

                model = torch.quantization.quantize_dynamic(
                    model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)

            model.eval()

            self.models.append(model)
            self.model_types.append(model_opt.model)

        # language model
        if opt.lm is not None:
            if opt.verbose:
                print('Loading language model from %s' % opt.lm)

            lm_chkpoint = torch.load(opt.lm,
                                     map_location=lambda storage, loc: storage)

            lm_opt = lm_chkpoint['opt']

            lm_model = build_language_model(lm_opt, checkpoint['dicts'])

            if opt.fp16:
                lm_model = lm_model.half()

            if opt.cuda:
                lm_model = lm_model.cuda()
            else:
                lm_model = lm_model.cpu()

            self.lm_model = lm_model

        self.cuda = opt.cuda
        self.ensemble_op = opt.ensemble_op

        if opt.autoencoder is not None:
            if opt.verbose:
                print('Loading autoencoder from %s' % opt.autoencoder)
            checkpoint = torch.load(opt.autoencoder,
                                    map_location=lambda storage, loc: storage)
            model_opt = checkpoint['opt']

            # posSize= checkpoint['autoencoder']['nmt.decoder.positional_encoder.pos_emb'].size(0)
            # self.models[0].decoder.renew_buffer(posSize)
            # self.models[0].decoder.renew_buffer(posSize)

            # Build model from the saved option
            self.autoencoder = Autoencoder(self.models[0], model_opt)

            self.autoencoder.load_state_dict(checkpoint['autoencoder'])

            if opt.cuda:
                self.autoencoder = self.autoencoder.cuda()
                self.models[0] = self.models[0].cuda()
            else:
                self.autoencoder = self.autoencoder.cpu()
                self.models[0] = self.models[0].cpu()

            self.models[0].autoencoder = self.autoencoder
        if opt.verbose:
            print('Done')