Esempio n. 1
0
File: ppo.py Progetto: zqwerty/tatk
    def __init__(self, is_train=False, dataset='Multiwoz'):

        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
            cfg = json.load(f)
        self.save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['save_dir'])
        self.save_per_epoch = cfg['save_per_epoch']
        self.update_round = cfg['update_round']
        self.optim_batchsz = cfg['batchsz']
        self.gamma = cfg['gamma']
        self.epsilon = cfg['epsilon']
        self.tau = cfg['tau']
        self.is_train = is_train
        if is_train:
            init_logging_handler(os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['log_dir']))

        # construct policy and value network
        if dataset == 'Multiwoz':
            voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
            self.vector = MultiWozVector(voc_file, voc_opp_file)
            self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)

        self.value = Value(self.vector.state_dim, cfg['hv_dim']).to(device=DEVICE)
        if is_train:
            self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['lr'])
            self.value_optim = optim.Adam(self.value.parameters(), lr=cfg['lr'])
Esempio n. 2
0
File: pg.py Progetto: zyq0104/tatk
    def __init__(self, is_train=False):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            cfg = json.load(f)
        self.save_dir = cfg['save_dir']
        self.save_per_epoch = cfg['save_per_epoch']
        self.update_round = cfg['update_round']
        self.optim_batchsz = cfg['batchsz']
        self.gamma = cfg['gamma']
        self.is_train = is_train
        if is_train:
            init_logging_handler(cfg['log_dir'])

        self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'],
                                          self.vector.da_dim).to(device=DEVICE)
        if is_train:
            self.policy_optim = optim.RMSprop(self.policy.parameters(),
                                              lr=cfg['lr'])
Esempio n. 3
0
        prec = a_TP / (a_TP + a_FP)
        rec = a_TP / (a_TP + a_FN)
        F1 = 2 * prec * rec / (prec + rec)
        print(a_TP, a_FP, a_FN, F1)
        print(t_corr, t_tot, t_corr / t_tot)

    def save(self, directory, epoch):
        if not os.path.exists(directory):
            os.makedirs(directory)

        torch.save(self.user.state_dict(),
                   directory + '/' + str(epoch) + '_simulator.mdl')
        logging.info(
            '<<user simulator>> epoch {}: saved network to mdl'.format(epoch))


if __name__ == '__main__':
    with open('config.json', 'r') as f:
        cfg = json.load(f)

    init_logging_handler(cfg['log_dir'])
    env = VHUS_Trainer(cfg)

    logging.debug('start training')

    best = float('inf')
    for e in range(cfg['epoch']):
        env.imitating(e)
        best = env.imit_test(e, best)