예제 #1
0
 def __init__(self,
              dictionary,
              corpus,
              np2var,
              bad_tokens=['<disconnect>', '<disagree>']):
     super(CombinedNLLEntropy4CLF, self).__init__()
     self.dictionary = dictionary
     self.domain = domain.get_domain('object_division')
     self.corpus = corpus
     self.np2var = np2var
     self.bad_tokens = bad_tokens
예제 #2
0
import json
import torch as th
import logging
import sys
sys.path.append("../")
from convlab2.policy.larl.multiwoz.latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
import convlab2.policy.larl.multiwoz.latent_dialog.corpora as corpora
from convlab2.policy.larl.multiwoz.latent_dialog.data_loaders import BeliefDbDataLoaders
from convlab2.policy.larl.multiwoz.latent_dialog.evaluators import MultiWozEvaluator
from convlab2.policy.larl.multiwoz.latent_dialog.models_task import SysPerfectBD2Cat
from convlab2.policy.larl.multiwoz.latent_dialog.main import train, validate
import convlab2.policy.larl.multiwoz.latent_dialog.domain as domain
from convlab2.policy.larl.multiwoz.experiments_woz.dialog_utils import task_generate

domain_name = 'object_division'
domain_info = domain.get_domain(domain_name)
config = Pack(
    seed=10,
    train_path='../data/norm-multi-woz/train_dials.json',
    valid_path='../data/norm-multi-woz/val_dials.json',
    test_path='../data/norm-multi-woz/test_dials.json',
    max_vocab_size=1000,
    last_n_model=5,
    max_utt_len=50,
    max_dec_len=50,
    backward_size=2,
    batch_size=32,
    use_gpu=True,
    op='adam',
    init_lr=0.001,
    l2_norm=1e-05,
예제 #3
0
    def __init__(
            self,
            archive_file=DEFAULT_ARCHIVE_FILE,
            cuda_device=DEFAULT_CUDA_DEVICE,
            model_file="https://convlab.blob.core.windows.net/convlab-2/larl.zip"
    ):

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for LaRL is specified!")
            archive_file = cached_path(model_file)

        temp_path = os.path.dirname(os.path.abspath(__file__))
        #print(temp_path)
        zip_ref = zipfile.ZipFile(archive_file, 'r')
        zip_ref.extractall(temp_path)
        zip_ref.close()

        self.prev_state = default_state()
        self.prev_active_domain = None

        domain_name = 'object_division'
        domain_info = domain.get_domain(domain_name)
        self.db = Database()
        data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                 'data')
        train_data_path = os.path.join(data_path, 'train_dials.json')
        if not os.path.exists(train_data_path):
            zipped_file = os.path.join(data_path, 'norm-multi-woz.zip')
            archive = zipfile.ZipFile(zipped_file, 'r')
            archive.extractall(data_path)

        norm_multiwoz_path = data_path
        with open(
                os.path.join(norm_multiwoz_path,
                             'input_lang.index2word.json')) as f:
            self.input_lang_index2word = json.load(f)
        with open(
                os.path.join(norm_multiwoz_path,
                             'input_lang.word2index.json')) as f:
            self.input_lang_word2index = json.load(f)
        with open(
                os.path.join(norm_multiwoz_path,
                             'output_lang.index2word.json')) as f:
            self.output_lang_index2word = json.load(f)
        with open(
                os.path.join(norm_multiwoz_path,
                             'output_lang.word2index.json')) as f:
            self.output_lang_word2index = json.load(f)

        config = Pack(
            seed=10,
            train_path=train_data_path,
            max_vocab_size=1000,
            last_n_model=5,
            max_utt_len=50,
            max_dec_len=50,
            backward_size=2,
            batch_size=1,
            use_gpu=True,
            op='adam',
            init_lr=0.001,
            l2_norm=1e-05,
            momentum=0.0,
            grad_clip=5.0,
            dropout=0.5,
            max_epoch=100,
            embed_size=100,
            num_layers=1,
            utt_rnn_cell='gru',
            utt_cell_size=300,
            bi_utt_cell=True,
            enc_use_attn=True,
            dec_use_attn=True,
            dec_rnn_cell='lstm',
            dec_cell_size=300,
            dec_attn_mode='cat',
            y_size=10,
            k_size=20,
            beta=0.001,
            simple_posterior=True,
            contextual_posterior=True,
            use_mi=False,
            use_pr=True,
            use_diversity=False,
            #
            beam_size=20,
            fix_batch=True,
            fix_train_batch=False,
            avg_type='word',
            print_step=300,
            ckpt_step=1416,
            improve_threshold=0.996,
            patient_increase=2.0,
            save_model=True,
            early_stop=False,
            gen_type='greedy',
            preview_batch_num=None,
            k=domain_info.input_length(),
            init_range=0.1,
            pretrain_folder='2019-09-20-21-43-06-sl_cat',
            forward_only=False)

        config.use_gpu = config.use_gpu and torch.cuda.is_available()
        self.corpus = corpora_inference.NormMultiWozCorpus(config)
        self.model = SysPerfectBD2Cat(self.corpus, config)
        self.config = config
        if config.use_gpu:
            self.model.load_state_dict(
                torch.load(os.path.join(temp_path, 'larl_model/best-model')))
            self.model.cuda()
        else:
            self.model.load_state_dict(
                torch.load(os.path.join(temp_path, 'larl_model/best-model'),
                           map_location=lambda storage, loc: storage))
        self.model.eval()
        self.dic = pickle.load(
            open(os.path.join(temp_path, 'larl_model/svdic.pkl'), 'rb'))
예제 #4
0
 def __init__(self, model, device_id):
     self.model = model
     self.domain = domain.get_domain('object_division')
     self.device_id = device_id