コード例 #1
0
ファイル: dqn.py プロジェクト: zz-jacob/ConvLab-2
    def __init__(self, is_train=False, dataset='Multiwoz'):

        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
            cfg = json.load(f)
        self.save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['save_dir'])
        self.save_per_epoch = cfg['save_per_epoch']
        self.training_iter = cfg['training_iter']
        self.training_batch_iter = cfg['training_batch_iter']
        self.batch_size = cfg['batch_size']
        self.epsilon = cfg['epsilon_spec']['start']
        self.rule_bot = RuleBasedMultiwozBot()
        self.gamma = cfg['gamma']
        self.is_train = is_train
        if is_train:
            init_logging_handler(os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['log_dir']))

        # construct multiwoz vector
        if dataset == 'Multiwoz':
            voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
            self.vector = MultiWozVector(voc_file, voc_opp_file, composite_actions=True, vocab_size=cfg['vocab_size'])

        #replay memory
        self.memory = MemoryReplay(cfg['memory_size'])

        self.net = EpsilonGreedyPolicy(self.vector.state_dim, cfg['hv_dim'], self.vector.da_dim, cfg['epsilon_spec']).to(device=DEVICE)
        self.target_net = copy.deepcopy(self.net)

        self.online_net = self.target_net
        self.eval_net = self.target_net

        if is_train:
            self.net_optim = optim.Adam(self.net.parameters(), lr=cfg['lr'])

        self.loss_fn = nn.MSELoss()
コード例 #2
0
    def __init__(self, is_train=False, dataset='Multiwoz'):
        with open("/home/raliegh/图片/ConvLab-2/convlab2/policy/pg/config.json",
                  'r') as f:
            cfg = json.load(f)
        self.save_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), cfg['save_dir'])
        self.save_per_epoch = cfg['save_per_epoch']
        self.update_round = cfg['update_round']
        self.optim_batchsz = cfg['batchsz']
        self.gamma = cfg['gamma']
        self.is_train = is_train
        if is_train:
            init_logging_handler(cfg['log_dir'])
        # load vocabulary
        if dataset == 'Multiwoz':
            voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir,
                                        'data/multiwoz/usr_da_voc.txt')
            self.vector = MultiWozVector(voc_file, voc_opp_file)
            self.policy = MultiDiscretePolicy(
                self.vector.state_dim, cfg['h_dim'],
                self.vector.da_dim).to(device=DEVICE)

        # self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
        if is_train:
            self.policy_optim = optim.RMSprop(self.policy.parameters(),
                                              lr=cfg['lr'])
        # load_best model from the web.
        self.load(
            "/home/raliegh/图片/ConvLab-2/convlab2/policy/pg/save/best/best_pg_from_web.pol.mdl"
        )
コード例 #3
0
    def __init__(self, is_train=False, dataset='Multiwoz'):

        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
            cfg = json.load(f)
        self.save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['save_dir'])
        self.save_per_epoch = cfg['save_per_epoch']
        self.update_round = cfg['update_round']
        self.optim_batchsz = cfg['batchsz']
        self.gamma = cfg['gamma']
        self.epsilon = cfg['epsilon']
        self.tau = cfg['tau']
        self.is_train = is_train
        if is_train:
            init_logging_handler(os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['log_dir']))

        # construct policy and value network
        if dataset == 'Multiwoz':
            voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
            self.vector = MultiWozVector(voc_file, voc_opp_file)
            self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)

        self.value = Value(self.vector.state_dim, cfg['hv_dim']).to(device=DEVICE)
        if is_train:
            self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['policy_lr'])
            self.value_optim = optim.Adam(self.value.parameters(), lr=cfg['value_lr'])
コード例 #4
0
ファイル: DQfD.py プロジェクト: JQWang-77/MSc-Project
 def __init__(self, train=True):
     # load configuration file
     with open(
             os.path.join(os.path.dirname(os.path.abspath(__file__)),
                          'config.json'), 'r') as f:
         cfg = json.load(f)
     self.gamma = cfg['gamma']
     self.epsilon_init = cfg['epsilon_init']
     self.epsilon_final = cfg['epsilon_final']
     self.istrain = train
     if self.istrain:
         self.epsilon = self.epsilon_init
     else:
         self.epsilon = self.epsilon_final
     self.epsilon_degrade_period = cfg['epsilon_degrade_period']
     self.tau = cfg['tau']
     self.action_number = cfg[
         'action_number']  # total number of actions considered
     init_logging_handler(
         os.path.join(os.path.dirname(os.path.abspath(__file__)),
                      cfg['log_dir']))
     # load action mapping file
     action_map_file = os.path.join(root_dir,
                                    'convlab2/policy/act_500_list.txt')
     _, self.ind2act_dict = read_action_map(action_map_file)
     # load vector for MultiWoz 2.1
     voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
     voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
     self.vector = MultiWozVector(voc_file, voc_opp_file)
     # build Q network
     # current Q network to be trained
     self.Q = DuelDQN(self.vector.state_dim, cfg['h_dim'],
                      self.action_number).to(device=DEVICE)
     # target Q network
     self.target_Q = DuelDQN(self.vector.state_dim, cfg['h_dim'],
                             self.action_number).to(device=DEVICE)
     self.target_Q.load_state_dict(self.Q.state_dict())
     # define optimizer
     # self.optimizer = RAdam(self.Q.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
     self.optimizer = optim.Adam(self.Q.parameters(),
                                 lr=cfg['lr'],
                                 weight_decay=cfg['weight_decay'])
     self.scheduler = StepLR(self.optimizer,
                             step_size=cfg['lr_decay_step'],
                             gamma=cfg['lr_decay'])
     self.min_lr = cfg['min_lr']
     # loss function
     self.criterion = torch.nn.MSELoss()
コード例 #5
0
import os
import json
import logging
import sys
root_dir = os.path.dirname(
    os.path.dirname(
        os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
sys.path.append(root_dir)

from convlab2.util.train_util import init_logging_handler
from convlab2.task.camrest.goal_generator import GoalGenerator
from convlab2.policy.vhus.camrest.usermanager import UserDataManager
from convlab2.policy.vhus.train import VHUS_Trainer

if __name__ == '__main__':
    with open('config.json', 'r') as f:
        cfg = json.load(f)

    init_logging_handler(cfg['log_dir'])
    manager = UserDataManager()
    goal_gen = GoalGenerator()
    env = VHUS_Trainer(cfg, manager, goal_gen)

    logging.debug('start training')

    best = float('inf')
    for e in range(cfg['epoch']):
        env.imitating(e)
        best = env.imit_test(e, best)