def generate_necessary_file(root_dir): voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt') vector = MultiWozVector(voc_file, voc_opp_file) action_map_file = os.path.join(root_dir, 'convlab2/policy/act_500_list.txt') act2ind_dict, ind2act_dict = read_action_map(action_map_file) return vector, act2ind_dict, ind2act_dict
def __init__(self, train=True): # load configuration file with open( os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: cfg = json.load(f) self.gamma = cfg['gamma'] self.epsilon_init = cfg['epsilon_init'] self.epsilon_final = cfg['epsilon_final'] self.istrain = train if self.istrain: self.epsilon = self.epsilon_init else: self.epsilon = self.epsilon_final self.epsilon_degrade_period = cfg['epsilon_degrade_period'] self.tau = cfg['tau'] self.action_number = cfg[ 'action_number'] # total number of actions considered init_logging_handler( os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['log_dir'])) # load action mapping file action_map_file = os.path.join(root_dir, 'convlab2/policy/act_500_list.txt') _, self.ind2act_dict = read_action_map(action_map_file) # load vector for MultiWoz 2.1 voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt') voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt') self.vector = MultiWozVector(voc_file, voc_opp_file) # build Q network # current Q network to be trained self.Q = DuelDQN(self.vector.state_dim, cfg['h_dim'], self.action_number).to(device=DEVICE) # target Q network self.target_Q = DuelDQN(self.vector.state_dim, cfg['h_dim'], self.action_number).to(device=DEVICE) self.target_Q.load_state_dict(self.Q.state_dict()) # define optimizer # self.optimizer = RAdam(self.Q.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) self.optimizer = optim.Adam(self.Q.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) self.scheduler = StepLR(self.optimizer, step_size=cfg['lr_decay_step'], gamma=cfg['lr_decay']) self.min_lr = cfg['min_lr'] # loss function self.criterion = torch.nn.MSELoss()