예제 #1
0
    def __init__(self, 
                 archive_file=DEFAULT_ARCHIVE_FILE, 
                 use_cuda=False,
                 is_user=False,
                 model_file=None):

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for SC-LSTM is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(os.path.join(model_dir, 'resource')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)

        self.USE_CUDA = use_cuda
        self.args, self.config = parse(is_user)
        self.dataset = SimpleDatasetWoz(self.config)

        # get model hyper-parameters
        hidden_size = self.config.getint('MODEL', 'hidden_size')

        # get feat size
        d_size = self.dataset.do_size + self.dataset.da_size + self.dataset.sv_size  # len of 1-hot feat
        vocab_size = len(self.dataset.word2index)

        self.model = LMDeep('sclstm', vocab_size, vocab_size, hidden_size, d_size, n_layer=self.args['n_layer'], use_cuda=use_cuda)
        model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), self.args['model_path'])
        # print(model_path)
        assert os.path.isfile(model_path)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()
        if use_cuda:
            self.model.cuda()
예제 #2
0
 def __init__(self,
              config_file=os.path.join(
                  os.path.dirname(os.path.abspath(__file__)),
                  'config/multiwoz.cfg'),
              model_file=None):
     self.config = configparser.ConfigParser()
     self.config.read(config_file)
     self.c = Classifier.classifier(self.config)
     model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                               self.config.get("train", "output"))
     model_dir = os.path.dirname(model_path)
     if not os.path.exists(model_path):
         if not os.path.exists(model_dir):
             os.makedirs(model_dir)
         if not model_file:
             print('Load from ', os.path.join(model_dir,
                                              'svm_multiwoz.zip'))
             archive = zipfile.ZipFile(
                 os.path.join(model_dir, 'svm_multiwoz.zip'), 'r')
         else:
             print('Load from model_file param')
             archive_file = cached_path(model_file)
             archive = zipfile.ZipFile(archive_file, 'r')
         archive.extractall(os.path.dirname(model_dir))
         archive.close()
     self.c.load(model_path)
예제 #3
0
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 cuda_device=DEFAULT_CUDA_DEVICE,
                 model_file=None,
                 context_size=3):
        """ Constructor for NLU class. """

        self.context_size = context_size

        check_for_gpu(cuda_device)

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for MILU is specified!")

            archive_file = cached_path(model_file)

        archive = load_archive(archive_file, cuda_device=cuda_device)
        self.tokenizer = SpacyWordSplitter(language="en_core_web_sm")
        _special_case = [{ORTH: u"id", LEMMA: u"id"}]
        self.tokenizer.spacy.tokenizer.add_special_case(u"id", _special_case)

        dataset_reader_params = archive.config["dataset_reader"]
        self.dataset_reader = DatasetReader.from_params(dataset_reader_params)
        self.model = archive.model
        self.model.eval()
예제 #4
0
파일: nlu.py 프로젝트: wbj0110/ConvLab
    def __init__(self, mode, config_file, model_file):
        """
        BERT NLU initialization.

        Args:
            mode (str):
                can be either `'usr'`, `'sys'` or `'all'`, representing which side of data the model was trained on.

            model_file (str):
                model path or url

        Example:
            nlu = BERTNLU(mode='all', model_file='https://convlab.blob.core.windows.net/models/bert_multiwoz_all_context.zip')
        """
        assert mode == 'usr' or mode == 'sys' or mode == 'all'
        config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                   'configs/{}'.format(config_file))
        config = json.load(open(config_file))
        DEVICE = config['DEVICE']
        root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        data_dir = os.path.join(root_dir, config['data_dir'])
        output_dir = os.path.join(root_dir, config['output_dir'])

        if not os.path.exists(os.path.join(data_dir, 'intent_vocab.json')):
            preprocess(mode)

        intent_vocab = json.load(
            open(os.path.join(data_dir, 'intent_vocab.json')))
        tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
        dataloader = Dataloader(
            intent_vocab=intent_vocab,
            tag_vocab=tag_vocab,
            pretrained_weights=config['model']['pretrained_weights'])

        print('intent num:', len(intent_vocab))
        print('tag num:', len(tag_vocab))

        bert_config = BertConfig.from_pretrained(
            config['model']['pretrained_weights'])

        best_model_path = os.path.join(output_dir, 'pytorch_model.bin')
        if not os.path.exists(best_model_path):
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            print('Load from model_file param')
            archive_file = cached_path(model_file)
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(root_dir)
            archive.close()
        print('Load from', best_model_path)
        model = JointBERT(bert_config, config['model'], DEVICE,
                          dataloader.tag_dim, dataloader.intent_dim)
        model.load_state_dict(
            torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE))
        model.to(DEVICE)
        model.eval()

        self.model = model
        self.dataloader = dataloader
        print("BERTNLU loaded")
예제 #5
0
파일: predictor.py 프로젝트: smksyj/ConvLab
    def __init__(self, archive_file, model_file=None, use_cuda=False):
        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for DA-predictor is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(os.path.join(model_dir, 'checkpoints')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)

        load_dir = os.path.join(model_dir,
                                "checkpoints/predictor/save_step_15120")
        if not os.path.exists(load_dir):
            archive = zipfile.ZipFile(f'{load_dir}.zip', 'r')
            archive.extractall(os.path.dirname(load_dir))

        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                       do_lower_case=False)
        self.max_seq_length = 256
        self.domain = 'restaurant'
        self.model = BertForSequenceClassification.from_pretrained(
            load_dir,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(-1)),
            num_labels=44)
        self.device = 'cuda' if use_cuda else 'cpu'
        self.model.to(self.device)
예제 #6
0
    def __init__(self, archive_file, model_file=None, use_cuda=False):
        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for HDSA is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(os.path.join(model_dir, 'checkpoints')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)

        with open(os.path.join(model_dir, "data/vocab.json"), 'r') as f:
            vocabulary = json.load(f)

        vocab, ivocab = vocabulary['vocab'], vocabulary['rev']
        self.tokenizer = Tokenizer(vocab, ivocab, False)
        self.max_seq_length = 50

        self.decoder = TableSemanticDecoder(
            vocab_size=self.tokenizer.vocab_len,
            d_word_vec=128,
            n_layers=3,
            d_model=128,
            n_head=4,
            dropout=0.2)
        self.device = 'cuda' if use_cuda else 'cpu'
        self.decoder.to(self.device)
        checkpoint_file = os.path.join(
            model_dir, "checkpoints/generator/BERT_dim128_w_domain")
        self.decoder.load_state_dict(torch.load(checkpoint_file))

        with open(os.path.join(model_dir, 'data/svdic.pkl'), 'rb') as f:
            self.dic = pickle.load(f)
예제 #7
0
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 cuda_device=DEFAULT_CUDA_DEVICE,
                 model_file=None):
        """ Constructor for NLU class. """
        SysPolicy.__init__(self)

        check_for_gpu(cuda_device)

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for MILU is specified!")
            archive_file = cached_path(model_file)

        archive = load_archive(archive_file, cuda_device=cuda_device)
        dataset_reader_params = archive.config["dataset_reader"]
        self.dataset_reader = DatasetReader.from_params(dataset_reader_params)
        self.action_vocab = self.dataset_reader.action_vocab
        self.state_encoder = self.dataset_reader.state_encoder
        self.model = archive.model
        self.model.eval()
예제 #8
0
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 cuda_device=DEFAULT_CUDA_DEVICE,
                 model_file=None):

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for MDRG is specified!")
            archive_file = cached_path(model_file)

        temp_path = tempfile.mkdtemp()
        zip_ref = zipfile.ZipFile(archive_file, 'r')
        zip_ref.extractall(temp_path)
        zip_ref.close()

        self.dic = pickle.load(
            open(os.path.join(temp_path, 'mdrg/svdic.pkl'), 'rb'))
        # Load dictionaries
        with open(os.path.join(temp_path,
                               'mdrg/input_lang.index2word.json')) as f:
            input_lang_index2word = json.load(f)
        with open(os.path.join(temp_path,
                               'mdrg/input_lang.word2index.json')) as f:
            input_lang_word2index = json.load(f)
        with open(os.path.join(temp_path,
                               'mdrg/output_lang.index2word.json')) as f:
            output_lang_index2word = json.load(f)
        with open(os.path.join(temp_path,
                               'mdrg/output_lang.word2index.json')) as f:
            output_lang_word2index = json.load(f)
        self.response_model = Model(args, input_lang_index2word,
                                    output_lang_index2word,
                                    input_lang_word2index,
                                    output_lang_word2index)
        self.response_model.loadModel(os.path.join(temp_path, 'mdrg/mdrg'))

        shutil.rmtree(temp_path)

        self.prev_state = init_state()
        self.prev_active_domain = None
예제 #9
0
    def __init__(self, archive_file=DEFAULT_ARCHIVE_FILE, model_file=None):
        SysPolicy.__init__(self)

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for Sequicity is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(os.path.join(model_dir, 'data')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)

        cfg.init_handler('tsdf-multiwoz')

        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        self.m = Model('multiwoz')
        self.m.count_params()
        self.m.load_model()
        self.reset()
예제 #10
0
    def _read(self, file_path):
        # if `file_path` is a URL, redirect to the cache
        file_path = cached_path(file_path)

        if file_path.endswith("zip"):
            archive = zipfile.ZipFile(file_path, "r")
            data_file = archive.open(os.path.basename(file_path)[:-4])
        else:
            data_file = open(file_path, "r")

        logger.info("Reading instances from lines in file at: %s", file_path)

        dialogs = json.load(data_file)

        for dial_name in dialogs:
            dialog = dialogs[dial_name]["log"]
            self.dst.init_session()
            for i, turn in enumerate(dialog):
                if i % 2 == 0:  # user turn
                    self.dst.update(user_act=turn["dialog_act"])
                else:  # system turn
                    delex_act = {}
                    for domain_act in turn["dialog_act"]:
                        domain, act_type = domain_act.split('-', 1)
                        if act_type in ['NoOffer', 'OfferBook']:
                            delex_act[domain_act] = ['none']
                        elif act_type in ['Select']:
                            for sv in turn["dialog_act"][domain_act]:
                                if sv[0] != "none":
                                    delex_act[domain_act] = [sv[0]]
                                    break
                        else:
                            delex_act[domain_act] = [
                                sv[0] for sv in turn["dialog_act"][domain_act]
                            ]
                    state_vector = self.state_encoder.encode(self.dst.state)
                    action_index = self.find_best_delex_act(delex_act)

                    yield self.text_to_instance(state_vector, action_index)
예제 #11
0
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 cuda_device=DEFAULT_CUDA_DEVICE,
                 model_file=None):
        SysPolicy.__init__(self)

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for LaRL is specified!")
            archive_file = cached_path(model_file)

        temp_path = tempfile.mkdtemp()
        zip_ref = zipfile.ZipFile(archive_file, 'r')
        zip_ref.extractall(temp_path)
        zip_ref.close()

        self.prev_state = init_state()
        self.prev_active_domain = None

        domain_name = 'object_division'
        domain_info = domain.get_domain(domain_name)

        data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
        train_data_path = os.path.join(data_path, 'norm-multi-woz', 'train_dials.json')
        if not os.path.exists(train_data_path):
            zipped_file = os.path.join(data_path, 'norm-multi-woz.zip')
            archive = zipfile.ZipFile(zipped_file, 'r')
            archive.extractall(data_path)

        norm_multiwoz_path = os.path.join(data_path, 'norm-multi-woz')
        with open(os.path.join(norm_multiwoz_path, 'input_lang.index2word.json')) as f:
            self.input_lang_index2word = json.load(f)
        with open(os.path.join(norm_multiwoz_path, 'input_lang.word2index.json')) as f:
            self.input_lang_word2index = json.load(f)
        with open(os.path.join(norm_multiwoz_path, 'output_lang.index2word.json')) as f:
            self.output_lang_index2word = json.load(f)
        with open(os.path.join(norm_multiwoz_path, 'output_lang.word2index.json')) as f:
            self.output_lang_word2index = json.load(f)

        config = Pack(
            seed=10,
            train_path=train_data_path,
            max_vocab_size=1000,
            last_n_model=5,
            max_utt_len=50,
            max_dec_len=50,
            backward_size=2,
            batch_size=1,
            use_gpu=True,
            op='adam',
            init_lr=0.001,
            l2_norm=1e-05,
            momentum=0.0,
            grad_clip=5.0,
            dropout=0.5,
            max_epoch=100,
            embed_size=100,
            num_layers=1,
            utt_rnn_cell='gru',
            utt_cell_size=300,
            bi_utt_cell=True,
            enc_use_attn=True,
            dec_use_attn=True,
            dec_rnn_cell='lstm',
            dec_cell_size=300,
            dec_attn_mode='cat',
            y_size=10,
            k_size=20,
            beta=0.001,
            simple_posterior=True,
            contextual_posterior=True,
            use_mi=False,
            use_pr=True,
            use_diversity=False,
            #
            beam_size=20,
            fix_batch=True,
            fix_train_batch=False,
            avg_type='word',
            print_step=300,
            ckpt_step=1416,
            improve_threshold=0.996,
            patient_increase=2.0,
            save_model=True,
            early_stop=False,
            gen_type='greedy',
            preview_batch_num=None,
            k=domain_info.input_length(),
            init_range=0.1,
            pretrain_folder='2019-09-20-21-43-06-sl_cat',
            forward_only=False
        )

        config.use_gpu = config.use_gpu and torch.cuda.is_available()
        self.corpus = corpora_inference.NormMultiWozCorpus(config)
        self.model = SysPerfectBD2Cat(self.corpus, config)
        self.config = config
        if config.use_gpu:
            self.model.load_state_dict(torch.load(
                os.path.join(temp_path, 'larl_model/best-model')))
            self.model.cuda()
        else:
            self.model.load_state_dict(torch.load(os.path.join(
                temp_path, 'larl_model/best-model'), map_location=lambda storage, loc: storage))
        self.model.eval()
        self.dic = pickle.load(
            open(os.path.join(temp_path, 'larl_model/svdic.pkl'), 'rb'))
    def _read(self, file_path):
        # if `file_path` is a URL, redirect to the cache
        file_path = cached_path(file_path)

        if file_path.endswith("zip"):
            archive = zipfile.ZipFile(file_path, "r")
            data_file = archive.open(os.path.basename(file_path)[:-4])
        else:
            data_file = open(file_path, "r")

        logger.info("Reading instances from lines in file at: %s", file_path)

        dialogs = json.load(data_file)

        for dial_name in dialogs:
            dialog = dialogs[dial_name]["log"]
            context_tokens_list = []
            for i, turn in enumerate(dialog):
                if self._agent and self._agent == "user" and i % 2 != 1:
                    continue
                if self._agent and self._agent == "system" and i % 2 != 0:
                    continue

                tokens = turn["text"].split()

                dialog_act = {}
                for dacts in turn["span_info"]:
                    if dacts[0] not in dialog_act:
                        dialog_act[dacts[0]] = []
                    dialog_act[dacts[0]].append(
                        [dacts[1], " ".join(tokens[dacts[3]:dacts[4] + 1])])

                spans = turn["span_info"]
                tags = []
                for j in range(len(tokens)):
                    for span in spans:
                        if j == span[3]:
                            tags.append("B-" + span[0] + "+" + span[1])
                            break
                        if j > span[3] and j <= span[4]:
                            tags.append("I-" + span[0] + "+" + span[1])
                            break
                    else:
                        tags.append("O")

                intents = []
                for dacts in turn["dialog_act"]:
                    for dact in turn["dialog_act"][dacts]:
                        if dacts not in dialog_act or dact[0] not in [
                                sv[0] for sv in dialog_act[dacts]
                        ]:
                            if dact[1] in [
                                    "none", "?", "yes", "no", "do nt care",
                                    "do n't care"
                            ]:
                                intents.append(dacts + "+" + dact[0] + "*" +
                                               dact[1])

                for dacts in turn["dialog_act"]:
                    for dact in turn["dialog_act"][dacts]:
                        if dacts not in dialog_act:
                            dialog_act[dacts] = turn["dialog_act"][dacts]
                            break
                        elif dact[0] not in [
                                sv[0] for sv in dialog_act[dacts]
                        ]:
                            dialog_act[dacts].append(dact)

                num_context = random.randint(
                    0, self._context_size
                ) if self._random_context_size else self._context_size
                if len(context_tokens_list) > 0 and num_context > 0:
                    wrapped_context_tokens = [
                        Token(token) for context_tokens in
                        context_tokens_list[-num_context:]
                        for token in context_tokens
                    ]
                else:
                    wrapped_context_tokens = [Token("SENT_END")]
                wrapped_tokens = [Token(token) for token in tokens]
                context_tokens_list.append(tokens + ["SENT_END"])

                yield self.text_to_instance(wrapped_context_tokens,
                                            wrapped_tokens, tags, intents,
                                            dialog_act)
예제 #13
0
    def _read(self, file_path):
        # if `file_path` is a URL, redirect to the cache
        file_path = cached_path(file_path)

        if file_path.endswith("zip"):
            archive = zipfile.ZipFile(file_path, "r")
            data_file = archive.open(os.path.basename(file_path)[:-4])
        else:
            data_file = open(file_path, "r")

        logger.info("Reading instances from lines in file at: %s", file_path)

        dialogs = json.load(data_file)

        for dial_name in dialogs:
            dialog = dialogs[dial_name]["log"]
            for turn in dialog:
                tokens = turn["text"].split()
                spans = turn["span_info"]
                tags = []
                domain = "None"
                intent = "None"
                for i in range(len(tokens)):
                    for span in spans:
                        if i == span[3]:
                            new_domain, new_intent = span[0].split("-", 1)
                            if domain == "None":
                                domain = new_domain
                            elif domain != new_domain:
                                continue
                            if intent == "None":
                                intent = new_intent
                            elif intent != new_intent:
                                continue
                            tags.append("B-" + span[1])
                            break
                        if i > span[3] and i <= span[4]:
                            new_domain, new_intent = span[0].split("-", 1)
                            if domain != new_domain:
                                continue
                            if intent != new_intent:
                                continue
                            tags.append("I-" + span[1])
                            break
                    else:
                        tags.append("O")

                if domain != "None":
                    assert intent != "None", "intent must not be None when domain is not None"
                elif turn["dialog_act"] != {}:
                    assert intent == "None", "intent must be None when domain is None"
                    di = list(turn["dialog_act"].keys())[0]
                    dai = turn["dialog_act"][di][0]
                    domain = di.split("-")[0]
                    intent = di.split("-", 1)[-1] + "+" + dai[0] + "*" + dai[1]

                dialog_act = {}
                for dacts in turn["span_info"]:
                    if dacts[0] not in dialog_act:
                        dialog_act[dacts[0]] = []
                    dialog_act[dacts[0]].append(
                        [dacts[1], " ".join(tokens[dacts[3]:dacts[4] + 1])])

                for dacts in turn["dialog_act"]:
                    for dact in turn["dialog_act"][dacts]:
                        if dacts not in dialog_act:
                            dialog_act[dacts] = turn["dialog_act"][dacts]
                            break
                        elif dact[0] not in [
                                sv[0] for sv in dialog_act[dacts]
                        ]:
                            dialog_act[dacts].append(dact)

                tokens = [Token(token) for token in tokens]

                yield self.text_to_instance(tokens, tags, domain, intent,
                                            dialog_act)
예제 #14
0
파일: test.py 프로젝트: smksyj/ConvLab
                                 'rb'))
    for key in data:
        print('{} set size: {}'.format(key, len(data[key])))
    print('intent num:', len(intent_vocab))
    print('tag num:', len(tag_vocab))

    dataloader = Dataloader(data, intent_vocab, tag_vocab,
                            config['model']["pre-trained"])

    best_model_path = best_model_path = os.path.join(output_dir,
                                                     'bestcheckpoint.tar')
    if not os.path.exists(best_model_path):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        print('Load from zipped_model_path param')
        archive_file = cached_path(config['zipped_model_path'])
        archive = zipfile.ZipFile(archive_file, 'r')
        archive.extractall()
        archive.close()
    print('Load from', best_model_path)
    checkpoint = torch.load(best_model_path, map_location=DEVICE)
    print('best_intent_step', checkpoint['best_intent_step'])
    print('best_tag_step', checkpoint['best_tag_step'])

    model = BertNLU(config['model'],
                    dataloader.intent_dim,
                    dataloader.tag_dim,
                    DEVICE=DEVICE,
                    intent_weight=dataloader.intent_weight)
    model_dict = model.state_dict()
    state_dict = {
예제 #15
0
파일: evaluate.py 프로젝트: smksyj/ConvLab
    tag_vocab = pickle.load(open(os.path.join(data_dir, 'tag_vocab.pkl'),
                                 'rb'))
    for key in data:
        print('{} set size: {}'.format(key, len(data[key])))
    print('intent num:', len(intent_vocab))
    print('tag num:', len(tag_vocab))

    dataloader = Dataloader(data, intent_vocab, tag_vocab,
                            config['model']["pre-trained"])

    best_model_path = os.path.join(output_dir, 'bestcheckpoint.tar')
    if not os.path.exists(best_model_path):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        print('Load from zipped_model_path param')
        archive_file = cached_path(
            os.path.join(root_dir, config['zipped_model_path']))
        archive = zipfile.ZipFile(archive_file, 'r')
        archive.extractall(root_dir)
        archive.close()
    print('Load from', best_model_path)
    checkpoint = torch.load(best_model_path, map_location=DEVICE)
    print('best_intent_step', checkpoint['best_intent_step'])
    print('best_tag_step', checkpoint['best_tag_step'])

    model = BertNLU(config['model'],
                    dataloader.intent_dim,
                    dataloader.tag_dim,
                    DEVICE=DEVICE,
                    intent_weight=dataloader.intent_weight)
    model_dict = model.state_dict()
    state_dict = {
예제 #16
0
파일: nlu.py 프로젝트: smksyj/ConvLab
    def __init__(self, mode, model_file):
        """
        BERT NLU initialization.

        Args:
            mode (str):
                can be either `'usr'`, `'sys'` or `'all'`, representing which side of data the model was trained on.

            model_file (str):
                model path or url

        Example:
            nlu = BERTNLU(mode='usr', model_file='https://convlab.blob.core.windows.net/models/bert_multiwoz_usr.zip')
        """
        assert mode == 'usr' or mode == 'sys' or mode == 'all'
        config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                   'configs/multiwoz_{}.json'.format(mode))
        config = json.load(open(config_file))
        DEVICE = config['DEVICE']
        root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        data_dir = os.path.join(root_dir, config['data_dir'])
        output_dir = os.path.join(root_dir, config['output_dir'])

        if not os.path.exists(os.path.join(data_dir, 'data.pkl')):
            preprocess(mode)

        data = pickle.load(open(os.path.join(data_dir, 'data.pkl'), 'rb'))
        intent_vocab = pickle.load(
            open(os.path.join(data_dir, 'intent_vocab.pkl'), 'rb'))
        tag_vocab = pickle.load(
            open(os.path.join(data_dir, 'tag_vocab.pkl'), 'rb'))

        dataloader = Dataloader(data, intent_vocab, tag_vocab,
                                config['model']["pre-trained"])

        best_model_path = os.path.join(output_dir, 'bestcheckpoint.tar')
        if not os.path.exists(best_model_path):
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            print('Load from model_file param')
            archive_file = cached_path(model_file)
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(root_dir)
            archive.close()
        print('Load from', best_model_path)
        checkpoint = torch.load(best_model_path, map_location=DEVICE)
        print('train step', checkpoint['step'])

        model = BertNLU(config['model'],
                        dataloader.intent_dim,
                        dataloader.tag_dim,
                        DEVICE=DEVICE,
                        intent_weight=dataloader.intent_weight)
        model_dict = model.state_dict()
        state_dict = {
            k: v
            for k, v in checkpoint['model_state_dict'].items()
            if k in model_dict.keys()
        }
        model_dict.update(state_dict)
        model.load_state_dict(model_dict)
        model.to(DEVICE)
        model.eval()

        self.model = model
        self.dataloader = dataloader
        print("BERTNLU loaded")