コード例 #1
0
def generate_necessary_file(root_dir):
    voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
    voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
    vector = MultiWozVector(voc_file, voc_opp_file)
    action_map_file = os.path.join(root_dir, 'convlab2/policy/act_500_list.txt')
    act2ind_dict, ind2act_dict = read_action_map(action_map_file)
    return vector, act2ind_dict, ind2act_dict
コード例 #2
0
ファイル: DQfD.py プロジェクト: JQWang-77/MSc-Project
 def __init__(self, train=True):
     # load configuration file
     with open(
             os.path.join(os.path.dirname(os.path.abspath(__file__)),
                          'config.json'), 'r') as f:
         cfg = json.load(f)
     self.gamma = cfg['gamma']
     self.epsilon_init = cfg['epsilon_init']
     self.epsilon_final = cfg['epsilon_final']
     self.istrain = train
     if self.istrain:
         self.epsilon = self.epsilon_init
     else:
         self.epsilon = self.epsilon_final
     self.epsilon_degrade_period = cfg['epsilon_degrade_period']
     self.tau = cfg['tau']
     self.action_number = cfg[
         'action_number']  # total number of actions considered
     init_logging_handler(
         os.path.join(os.path.dirname(os.path.abspath(__file__)),
                      cfg['log_dir']))
     # load action mapping file
     action_map_file = os.path.join(root_dir,
                                    'convlab2/policy/act_500_list.txt')
     _, self.ind2act_dict = read_action_map(action_map_file)
     # load vector for MultiWoz 2.1
     voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
     voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
     self.vector = MultiWozVector(voc_file, voc_opp_file)
     # build Q network
     # current Q network to be trained
     self.Q = DuelDQN(self.vector.state_dim, cfg['h_dim'],
                      self.action_number).to(device=DEVICE)
     # target Q network
     self.target_Q = DuelDQN(self.vector.state_dim, cfg['h_dim'],
                             self.action_number).to(device=DEVICE)
     self.target_Q.load_state_dict(self.Q.state_dict())
     # define optimizer
     # self.optimizer = RAdam(self.Q.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
     self.optimizer = optim.Adam(self.Q.parameters(),
                                 lr=cfg['lr'],
                                 weight_decay=cfg['weight_decay'])
     self.scheduler = StepLR(self.optimizer,
                             step_size=cfg['lr_decay_step'],
                             gamma=cfg['lr_decay'])
     self.min_lr = cfg['min_lr']
     # loss function
     self.criterion = torch.nn.MSELoss()