예제 #1
0
    def __init__(
        self,
        archive_file=DEFAULT_ARCHIVE_FILE,
        model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/mle_policy_camrest.zip'
    ):
        root_dir = os.path.dirname(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(os.path.dirname(
                        os.path.abspath(__file__))))))

        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            cfg = json.load(f)

        voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
        voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
        self.vector = CamrestVector(voc_file, voc_opp_file)

        self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'],
                                          self.vector.da_dim).to(device=DEVICE)

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for MLE Policy is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'save')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)
        self.load(cfg['load'])
예제 #2
0
파일: loader.py 프로젝트: zqwerty/tatk
class ActPolicyDataLoaderCamrest(ActPolicyDataloader):
    def __init__(self):
        root_dir = os.path.dirname(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(os.path.dirname(
                        os.path.abspath(__file__))))))
        voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
        voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
        self.vector = CamrestVector(voc_file, voc_opp_file)

        processed_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), 'processed_data')
        if os.path.exists(processed_dir):
            print('Load processed data file')
            self._load_data(processed_dir)
        else:
            print('Start preprocessing the dataset')
            self._build_data(root_dir, processed_dir)

    def _build_data(self, root_dir, processed_dir):  # TODO
        raw_data = {}
        for part in ['train', 'val', 'test']:
            archive = zipfile.ZipFile(
                os.path.join(root_dir,
                             'data/camrest/{}.json.zip'.format(part)), 'r')
            with archive.open('{}.json'.format(part), 'r') as f:
                raw_data[part] = json.load(f)

        self.data = {}
        for part in ['train', 'val', 'test']:
            self.data[part] = []

            for key in raw_data[part]:
                sess = key['dial']
                state = default_state()
                action = {}
                for i, turn in enumerate(sess):
                    state['user_action'] = turn['usr']['dialog_act']
                    if i + 1 == len(sess):
                        state['terminated'] = True
                    for da in turn['usr']['slu']:
                        if da['slots'][0][0] != 'slot':
                            state['belief_state'][da['slots'][0]
                                                  [0]] = da['slots'][0][1]
                    action = turn['sys']['dialog_act']
                    self.data[part].append([
                        self.vector.state_vectorize(state),
                        self.vector.action_vectorize(action)
                    ])
                    state['system_action'] = turn['sys']['dialog_act']

        os.makedirs(processed_dir)
        for part in ['train', 'val', 'test']:
            with open(os.path.join(processed_dir, '{}.pkl'.format(part)),
                      'wb') as f:
                pickle.dump(self.data[part], f)
예제 #3
0
파일: mle.py 프로젝트: luweishuang/tatk
class MLP(Policy):
    def __init__(self, is_train=False):
        root_dir = os.path.dirname(
            os.path.dirname(
                os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
        self.is_train = is_train

        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            cfg = json.load(f)

        voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
        voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
        self.vector = CamrestVector(voc_file, voc_opp_file)

        self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'],
                                          self.vector.da_dim).to(device=DEVICE)

        self.load(cfg['load'])

    def predict(self, state):
        """
        Predict an system action given state.
        Args:
            state (dict): Dialog state. Please refer to util/state.py
        Returns:
            action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
        """
        s_vec = torch.Tensor(self.vector.state_vectorize(state))
        a = self.policy.select_action(s_vec.to(device=DEVICE)).cpu()
        action = self.vector.action_devectorize(a.numpy())

        return action

    def init_session(self):
        """
        Restore after one session
        """
        pass

    def load(self, filename):
        policy_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                  filename + '_mle.pol.mdl')
        if os.path.exists(policy_mdl):
            self.policy.load_state_dict(torch.load(policy_mdl))
            print('<<dialog policy>> loaded checkpoint from file: {}'.format(
                policy_mdl))
예제 #4
0
파일: loader.py 프로젝트: zqwerty/tatk
    def __init__(self):
        root_dir = os.path.dirname(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(os.path.dirname(
                        os.path.abspath(__file__))))))
        voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
        voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
        self.vector = CamrestVector(voc_file, voc_opp_file)

        processed_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), 'processed_data')
        if os.path.exists(processed_dir):
            print('Load processed data file')
            self._load_data(processed_dir)
        else:
            print('Start preprocessing the dataset')
            self._build_data(root_dir, processed_dir)
예제 #5
0
파일: mle.py 프로젝트: luweishuang/tatk
    def __init__(self, is_train=False):
        root_dir = os.path.dirname(
            os.path.dirname(
                os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
        self.is_train = is_train

        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            cfg = json.load(f)

        voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
        voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
        self.vector = CamrestVector(voc_file, voc_opp_file)

        self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'],
                                          self.vector.da_dim).to(device=DEVICE)

        self.load(cfg['load'])
예제 #6
0
파일: train.py 프로젝트: zqwerty/tatk
 def __init__(self, manager, cfg):
     self._init_data(manager, cfg)
     voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
     voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
     vector = CamrestVector(voc_file, voc_opp_file)
     self.policy = MultiDiscretePolicy(vector.state_dim, cfg['h_dim'],
                                       vector.da_dim).to(device=DEVICE)
     self.policy.eval()
     self.policy_optim = torch.optim.Adam(self.policy.parameters(),
                                          lr=cfg['lr'])
예제 #7
0
파일: mle.py 프로젝트: zqwerty/tatk
 def __init__(self,
              archive_file=DEFAULT_ARCHIVE_FILE,
              model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/mle_policy_camrest.zip'):
     root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
     
     with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
         cfg = json.load(f)
     
     voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
     voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
     self.vector = CamrestVector(voc_file, voc_opp_file)
            
     self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
     
     self.load(archive_file, model_file, cfg['load'])
예제 #8
0
파일: train.py 프로젝트: luweishuang/tatk
    def __init__(self, manager, cfg):
        self.data_train = manager.create_dataset('train', cfg['batchsz'])
        self.data_valid = manager.create_dataset('val', cfg['batchsz'])
        self.data_test = manager.create_dataset('test', cfg['batchsz'])
        self.save_dir = cfg['save_dir']
        self.print_per_batch = cfg['print_per_batch']
        self.save_per_epoch = cfg['save_per_epoch']

        voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
        voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
        vector = CamrestVector(voc_file, voc_opp_file)
        self.policy = MultiDiscretePolicy(vector.state_dim, cfg['h_dim'],
                                          vector.da_dim).to(device=DEVICE)
        self.policy.eval()
        self.policy_optim = torch.optim.Adam(self.policy.parameters(),
                                             lr=cfg['lr'])
        self.multi_entropy_loss = nn.MultiLabelSoftMarginLoss()
예제 #9
0
class MLE(Policy):
    def __init__(
        self,
        archive_file=DEFAULT_ARCHIVE_FILE,
        model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/mle_policy_camrest.zip'
    ):
        root_dir = os.path.dirname(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(os.path.dirname(
                        os.path.abspath(__file__))))))

        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            cfg = json.load(f)

        voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
        voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
        self.vector = CamrestVector(voc_file, voc_opp_file)

        self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'],
                                          self.vector.da_dim).to(device=DEVICE)

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for MLE Policy is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'save')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)
        self.load(cfg['load'])

    def predict(self, state):
        """
        Predict an system action given state.
        Args:
            state (dict): Dialog state. Please refer to util/state.py
        Returns:
            action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
        """
        s_vec = torch.Tensor(self.vector.state_vectorize(state))
        a = self.policy.select_action(s_vec.to(device=DEVICE)).cpu()
        action = self.vector.action_devectorize(a.numpy())

        return action

    def init_session(self):
        """
        Restore after one session
        """
        pass

    def load(self, filename):
        policy_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                  filename + '_mle.pol.mdl')
        if os.path.exists(policy_mdl):
            self.policy.load_state_dict(torch.load(policy_mdl))
            print('<<dialog policy>> loaded checkpoint from file: {}'.format(
                policy_mdl))