def __init__(self, is_train=False, dataset='Multiwoz'): with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: cfg = json.load(f) self.save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['save_dir']) self.save_per_epoch = cfg['save_per_epoch'] self.update_round = cfg['update_round'] self.optim_batchsz = cfg['batchsz'] self.gamma = cfg['gamma'] self.epsilon = cfg['epsilon'] self.tau = cfg['tau'] self.is_train = is_train if is_train: init_logging_handler(os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['log_dir'])) # construct policy and value network if dataset == 'Multiwoz': voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) self.value = Value(self.vector.state_dim, cfg['hv_dim']).to(device=DEVICE) if is_train: self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['lr']) self.value_optim = optim.Adam(self.value.parameters(), lr=cfg['lr'])
def __init__(self, is_train=False): with open( os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: cfg = json.load(f) self.save_dir = cfg['save_dir'] self.save_per_epoch = cfg['save_per_epoch'] self.update_round = cfg['update_round'] self.optim_batchsz = cfg['batchsz'] self.gamma = cfg['gamma'] self.is_train = is_train if is_train: init_logging_handler(cfg['log_dir']) self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) if is_train: self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['lr'])
prec = a_TP / (a_TP + a_FP) rec = a_TP / (a_TP + a_FN) F1 = 2 * prec * rec / (prec + rec) print(a_TP, a_FP, a_FN, F1) print(t_corr, t_tot, t_corr / t_tot) def save(self, directory, epoch): if not os.path.exists(directory): os.makedirs(directory) torch.save(self.user.state_dict(), directory + '/' + str(epoch) + '_simulator.mdl') logging.info( '<<user simulator>> epoch {}: saved network to mdl'.format(epoch)) if __name__ == '__main__': with open('config.json', 'r') as f: cfg = json.load(f) init_logging_handler(cfg['log_dir']) env = VHUS_Trainer(cfg) logging.debug('start training') best = float('inf') for e in range(cfg['epoch']): env.imitating(e) best = env.imit_test(e, best)