Exemplo n.º 1
0
    def __init__(self, config):

        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager

        self.print_per_batch = config['print_per_batch']
        self.save_dir = config['save_dir']
        self.save_per_epoch = config['save_per_epoch']
        seq_goals, seq_usr_dass, seq_sys_dass = manager.data_loader_seg()
        train_goals, train_usrdas, train_sysdas, \
        test_goals, test_usrdas, test_sysdas, \
        val_goals, val_usrdas, val_sysdas = manager.train_test_val_split_seg(
            seq_goals, seq_usr_dass, seq_sys_dass)
        self.data_train = (train_goals, train_usrdas, train_sysdas,
                           config['batchsz'])
        self.data_valid = (val_goals, val_usrdas, val_sysdas,
                           config['batchsz'])
        self.data_test = (test_goals, test_usrdas, test_sysdas,
                          config['batchsz'])
        self.alpha = config['alpha']
        self.optim = torch.optim.Adam(self.user.parameters(), lr=config['lr'])
        self.nll_loss = nn.NLLLoss(ignore_index=0)  # PAD=0
        self.bce_loss = nn.BCEWithLogitsLoss()
Exemplo n.º 2
0
    def __init__(
        self,
        archive_file=DEFAULT_ARCHIVE_FILE,
        model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_camrest.zip'
    ):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for VHUS Policy is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'save')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        if not os.path.exists(os.path.join(model_dir, 'best_simulator.mdl')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)
        self.load(config['load'])
Exemplo n.º 3
0
Arquivo: vhus.py Projeto: zqwerty/tatk
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_camrest.zip'):
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size, voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(archive_file, model_file, config['load'])
Exemplo n.º 4
0
    def __init__(self):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(config['load'])
Exemplo n.º 5
0
    def __init__(self):
        """
        Constructor for User_Policy_Agenda class.
        """
        self.max_turn = 40
        self.max_initiative = 4

        self.goal_generator = GoalGenerator(corpus_path=os.path.join(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(os.path.dirname(os.path.abspath(
                        __file__))))), 'data/camrest/CamRest676_v2.json'))

        self.__turn = 0
        self.goal = None
        self.agenda = None

        Policy.__init__(self)
Exemplo n.º 6
0
    def __init__(self, max_goal_num=100, seed=2019):
        """
        Constructor for User_Policy_Agenda class.
        """
        self.max_turn = 40
        self.max_initiative = 4

        self.goal_generator = GoalGenerator(corpus_path=os.path.join(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(os.path.dirname(os.path.abspath(
                        __file__))))), 'data/camrest/CamRest676_v2.json'))

        self.__turn = 0
        self.goal = None
        self.agenda = None

        random.seed(seed)
        self.goal_seeds = [random.randint(1, 1e7) for i in range(max_goal_num)]

        Policy.__init__(self)
Exemplo n.º 7
0
import json
import logging
import sys

root_dir = os.path.dirname(
    os.path.dirname(
        os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
sys.path.append(root_dir)

from tatk.util.train_util import init_logging_handler
from tatk.task.camrest.goal_generator import GoalGenerator
from tatk.policy.vhus.camrest.usermanager import UserDataManager
from tatk.policy.vhus.train import VHUS_Trainer

if __name__ == '__main__':
    with open('config.json', 'r') as f:
        cfg = json.load(f)

    init_logging_handler(cfg['log_dir'])
    manager = UserDataManager()
    goal_gen = GoalGenerator()
    env = VHUS_Trainer(cfg, manager, goal_gen)

    logging.debug('start training')

    best = float('inf')
    for e in range(cfg['epoch']):
        env.imitating(e)
        best = env.imit_test(e, best)