Example #1
0
class DQN(Policy):
    def __init__(self, is_train=False, dataset='Multiwoz'):

        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            cfg = json.load(f)
        self.save_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), cfg['save_dir'])
        self.save_per_epoch = cfg['save_per_epoch']
        self.training_iter = cfg['training_iter']
        self.training_batch_iter = cfg['training_batch_iter']
        self.batch_size = cfg['batch_size']
        self.gamma = cfg['gamma']
        self.is_train = is_train
        if is_train:
            init_logging_handler(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             cfg['log_dir']))

        # construct multiwoz vector
        if dataset == 'Multiwoz':
            voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir,
                                        'data/multiwoz/usr_da_voc.txt')
            self.vector = MultiWozVector(voc_file,
                                         voc_opp_file,
                                         composite_actions=True,
                                         vocab_size=cfg['vocab_size'])

        #replay memory
        self.memory = MemoryReplay(cfg['memory_size'])

        self.net = EpsilonGreedyPolicy(self.vector.state_dim, cfg['hv_dim'],
                                       self.vector.da_dim,
                                       cfg['epsilon_spec']).to(device=DEVICE)
        self.target_net = copy.deepcopy(self.net)

        self.online_net = self.target_net
        self.eval_net = self.target_net

        if is_train:
            self.net_optim = optim.Adam(self.net.parameters(), lr=cfg['lr'])

        self.loss_fn = nn.MSELoss()

    def update_memory(self, sample):
        self.memory.append(sample)

    def predict(self, state):
        """
        Predict an system action given state.
        Args:
            state (dict): Dialog state. Please refer to util/state.py
        Returns:
            action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
        """
        s_vec = torch.Tensor(self.vector.state_vectorize(state))
        a = self.net.select_action(s_vec.to(device=DEVICE))

        action = self.vector.action_devectorize(a.numpy())

        state['system_action'] = action
        return action

    def init_session(self):
        """
        Restore after one session
        """
        self.memory.reset()

    def calc_q_loss(self, batch):
        '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
        s = torch.from_numpy(np.stack(batch.state)).to(device=DEVICE)
        a = torch.from_numpy(np.stack(batch.action)).to(device=DEVICE)
        r = torch.from_numpy(np.stack(batch.reward)).to(device=DEVICE)
        next_s = torch.from_numpy(np.stack(batch.next_state)).to(device=DEVICE)
        mask = torch.Tensor(np.stack(batch.mask)).to(device=DEVICE)

        q_preds = self.net(s)
        with torch.no_grad():
            # Use online_net to select actions in next state
            online_next_q_preds = self.online_net(next_s)
            # Use eval_net to calculate next_q_preds for actions chosen by online_net
            next_q_preds = self.eval_net(next_s)
        act_q_preds = q_preds.gather(
            -1,
            a.argmax(-1).long().unsqueeze(-1)).squeeze(-1)
        online_actions = online_next_q_preds.argmax(dim=-1, keepdim=True)
        max_next_q_preds = next_q_preds.gather(-1, online_actions).squeeze(-1)
        max_q_targets = r + self.gamma * mask * max_next_q_preds

        q_loss = self.loss_fn(act_q_preds, max_q_targets)

        return q_loss

    def update(self, epoch):
        total_loss = 0.
        for i in range(self.training_iter):
            round_loss = 0.
            # 1. batch a sample from memory
            batch = self.memory.get_batch(batch_size=self.batch_size)

            for _ in range(self.training_batch_iter):
                # 2. calculate the Q loss
                loss = self.calc_q_loss(batch)

                # 3. make a optimization step
                self.net_optim.zero_grad()
                loss.backward()
                self.net_optim.step()

                round_loss += loss.item()

            logging.debug(
                '<<dialog policy dqn>> epoch {}, iteration {}, loss {}'.format(
                    epoch, i, round_loss / self.training_batch_iter))
            total_loss += round_loss
        total_loss /= (self.training_batch_iter * self.training_iter)
        logging.debug('<<dialog policy dqn>> epoch {}, total_loss {}'.format(
            epoch, total_loss))

        # update the epsilon value
        self.net.update_epsilon(epoch)

        # update the target network
        self.target_net.load_state_dict(self.net.state_dict())

        if (epoch + 1) % self.save_per_epoch == 0:
            self.save(self.save_dir, epoch)

    def save(self, directory, epoch):
        if not os.path.exists(directory):
            os.makedirs(directory)

        torch.save(self.net.state_dict(),
                   directory + '/' + str(epoch) + '_dqn.pol.mdl')

        logging.info(
            '<<dialog policy>> epoch {}: saved network to mdl'.format(epoch))

    def load(self, filename):
        dqn_mdl_candidates = [
            filename + '.dqn.mdl',
            os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         filename + '.dqn.mdl'),
        ]
        for dqn_mdl in dqn_mdl_candidates:
            if os.path.exists(dqn_mdl):
                self.net.load_state_dict(
                    torch.load(dqn_mdl, map_location=DEVICE))
                self.target_net.load_state_dict(
                    torch.load(dqn_mdl, map_location=DEVICE))
                logging.info(
                    '<<dialog policy>> loaded checkpoint from file: {}'.format(
                        dqn_mdl))
                break
class PG_generator(Policy):
    def __init__(self, is_train=False, dataset='Multiwoz'):
        with open("/home/raliegh/图片/ConvLab-2/convlab2/policy/pg/config.json",
                  'r') as f:
            cfg = json.load(f)
        self.save_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), cfg['save_dir'])
        self.save_per_epoch = cfg['save_per_epoch']
        self.update_round = cfg['update_round']
        self.optim_batchsz = cfg['batchsz']
        self.gamma = cfg['gamma']
        self.is_train = is_train
        if is_train:
            init_logging_handler(cfg['log_dir'])
        # load vocabulary
        if dataset == 'Multiwoz':
            voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir,
                                        'data/multiwoz/usr_da_voc.txt')
            self.vector = MultiWozVector(voc_file, voc_opp_file)
            self.policy = MultiDiscretePolicy(
                self.vector.state_dim, cfg['h_dim'],
                self.vector.da_dim).to(device=DEVICE)

        # self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
        if is_train:
            self.policy_optim = optim.RMSprop(self.policy.parameters(),
                                              lr=cfg['lr'])
        # load_best model from the web.
        self.load(
            "/home/raliegh/图片/ConvLab-2/convlab2/policy/pg/save/best/best_pg_from_web.pol.mdl"
        )

    def predict(self, state):
        """
        Predict an system action given state.
        Args:
            state (tensor): Dialog state. Please refer to util/state.py
        Returns:
            action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
        """
        with torch.no_grad():
            s_vec = torch.Tensor(self.vector.state_vectorize(state))
            a = self.policy.select_action(s_vec.to(device=DEVICE),
                                          self.is_train).cpu()
            # print(a)
            # action = self.vector.action_devectorize(a.detach().numpy())
            # state['system_action'] = action
        return a

    def load(self, filename):
        policy_mdl_candidates = [
            filename, filename + '.pol.mdl', filename + '_pg.pol.mdl',
            os.path.join(os.path.dirname(os.path.abspath(__file__)), filename),
            os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         filename + '.pol.mdl'),
            os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         filename + '_pg.pol.mdl')
        ]
        for policy_mdl in policy_mdl_candidates:
            if os.path.exists(policy_mdl):
                self.policy.load_state_dict(
                    torch.load(policy_mdl, map_location=DEVICE))
                logging.info(
                    '<<dialog policy>> loaded checkpoint from file: {}'.format(
                        policy_mdl))
                break
Example #3
0
class DQfD(Policy):
    def __init__(self, train=True):
        # load configuration file
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            cfg = json.load(f)
        self.gamma = cfg['gamma']
        self.epsilon_init = cfg['epsilon_init']
        self.epsilon_final = cfg['epsilon_final']
        self.istrain = train
        if self.istrain:
            self.epsilon = self.epsilon_init
        else:
            self.epsilon = self.epsilon_final
        self.epsilon_degrade_period = cfg['epsilon_degrade_period']
        self.tau = cfg['tau']
        self.action_number = cfg[
            'action_number']  # total number of actions considered
        init_logging_handler(
            os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         cfg['log_dir']))
        # load action mapping file
        action_map_file = os.path.join(root_dir,
                                       'convlab2/policy/act_500_list.txt')
        _, self.ind2act_dict = read_action_map(action_map_file)
        # load vector for MultiWoz 2.1
        voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
        voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
        self.vector = MultiWozVector(voc_file, voc_opp_file)
        # build Q network
        # current Q network to be trained
        self.Q = DuelDQN(self.vector.state_dim, cfg['h_dim'],
                         self.action_number).to(device=DEVICE)
        # target Q network
        self.target_Q = DuelDQN(self.vector.state_dim, cfg['h_dim'],
                                self.action_number).to(device=DEVICE)
        self.target_Q.load_state_dict(self.Q.state_dict())
        # define optimizer
        # self.optimizer = RAdam(self.Q.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
        self.optimizer = optim.Adam(self.Q.parameters(),
                                    lr=cfg['lr'],
                                    weight_decay=cfg['weight_decay'])
        self.scheduler = StepLR(self.optimizer,
                                step_size=cfg['lr_decay_step'],
                                gamma=cfg['lr_decay'])
        self.min_lr = cfg['min_lr']
        # loss function
        self.criterion = torch.nn.MSELoss()

    def predict(self, state):
        """Predict an system action and its index given state."""
        s_vec = torch.Tensor(self.vector.state_vectorize(state))
        if state['user_action'] == [['bye', 'general', 'none', 'none']]:
            action = [['bye', 'general', 'none', 'none']]
        else:
            a, a_ind = self.Q.select_action(s_vec.to(device=DEVICE),
                                            self.epsilon, self.ind2act_dict,
                                            self.istrain)
            action = self.vector.action_devectorize(a)
        state['system_action'] = action
        return action

    def predict_ind(self, state):
        """Predict an system action and its index given state."""
        s_vec = torch.Tensor(self.vector.state_vectorize(state))
        if state['user_action'] == [['bye', 'general', 'none', 'none']]:
            action = [['bye', 'general', 'none', 'none']]
            a_ind = 489
        else:
            a, a_ind = self.Q.select_action(s_vec.to(device=DEVICE),
                                            self.epsilon, self.ind2act_dict,
                                            self.istrain)
            action = self.vector.action_devectorize(a)
        state['system_action'] = action
        return action, a_ind

    def init_session(self):
        """Restore after one session"""
        pass

    def aux_loss(self, s, a, candidate_a_ind, expert_label):
        """compute auxiliary loss given batch of states, actions and expert labels"""
        # only keep those expert demonstrations by setting expert label to 1
        s_exp = s[np.where(expert_label == 1)[0]]
        a_exp = a[np.where(expert_label == 1)[0]]
        candidate_a_ind_exp = candidate_a_ind[np.where(expert_label == 1)[0]]
        # if there exist expert demonstration in current batch
        if s_exp.size(0) > 0:
            # compute q value predictions for states for each action
            q_all = self.Q(s_exp)
            # only when agent take the same action as the expert does, the act_diff term(i.e. l(a_e,a)) is 0
            act_diff = q_all.new_full(q_all.size(), self.tau)
            for row, exp_ind in enumerate(candidate_a_ind_exp):
                act_diff[row, exp_ind] = 0
            # compute aux_loss = max(Q(s, a) + l(a_e, a)) - Q(s, a_e)
            q_max_act = (q_all + act_diff).max(dim=1)[0]
            q_exp_act = q_all.gather(-1, a_exp.unsqueeze(1)).squeeze(-1)
            aux_loss = (q_max_act - q_exp_act).sum() / s_exp.size(0)
        else:
            aux_loss = 0
        return aux_loss

    def update_net(self):
        """update target network by copying parameters from online network"""
        self.target_Q.load_state_dict(self.Q.state_dict())

    def compute_loss(self, s, a, r, s_next, mask, expert_label,
                     candidate_a_ind):
        """compute loss for batch"""
        # q value predictions for current state for each action
        q_preds = self.Q(s)
        with torch.no_grad():
            # online net for action selection in next state
            online_next_q_preds = self.Q(s_next)
            # target net for q value predicting in the next state
            next_q_preds = self.target_Q(s_next)
        # select q value predictions for corresponding actions
        act_q_preds = q_preds.gather(-1, a.unsqueeze(1)).squeeze(-1)
        # use online net to choose action for the next state
        online_actions = online_next_q_preds.argmax(dim=-1, keepdim=True)
        # use target net to predict the corresponding value
        max_next_q_preds = next_q_preds.gather(-1, online_actions).squeeze(-1)
        # compute target q values
        max_q_targets = r + self.gamma * max_next_q_preds * mask
        # q loss
        q_loss = self.criterion(act_q_preds, max_q_targets)
        # auxiliary loss
        aux_loss_term = self.aux_loss(s, a, candidate_a_ind, expert_label)
        # total loss
        loss = q_loss + aux_loss_term
        return loss

    def update(self, loss):
        """update online network"""
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def save(self, directory, epoch):
        """save model to directory"""
        if not os.path.exists(directory):
            os.makedirs(directory)

        torch.save(self.Q.state_dict(),
                   directory + '/' + str(epoch) + '_dqn.mdl')

        logging.info(
            '<<dialog policy>> epoch {}: saved network to mdl'.format(epoch))

    def load(self, filename):
        """load model"""
        dqn_mdl_candidates = [
            filename + '_dqn.mdl',
            os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         filename + '_dqn.mdl')
        ]
        for dqn_mdl in dqn_mdl_candidates:
            if os.path.exists(dqn_mdl):
                self.Q.load_state_dict(torch.load(dqn_mdl,
                                                  map_location=DEVICE))
                self.target_Q.load_state_dict(
                    torch.load(dqn_mdl, map_location=DEVICE))
                logging.info(
                    '<<dialog policy>> loaded checkpoint from file: {}'.format(
                        dqn_mdl))
                break
Example #4
0
class GDPL(Policy):
    def __init__(self, is_train=False, dataset='Multiwoz'):

        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            cfg = json.load(f)
        self.save_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), cfg['save_dir'])
        self.save_per_epoch = cfg['save_per_epoch']
        self.update_round = cfg['update_round']
        self.optim_batchsz = cfg['batchsz']
        self.gamma = cfg['gamma']
        self.epsilon = cfg['epsilon']
        self.tau = cfg['tau']
        self.is_train = is_train
        if is_train:
            init_logging_handler(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             cfg['log_dir']))

        # construct policy and value network
        if dataset == 'Multiwoz':
            voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir,
                                        'data/multiwoz/usr_da_voc.txt')
            self.vector = MultiWozVector(voc_file, voc_opp_file)
            self.policy = MultiDiscretePolicy(
                self.vector.state_dim, cfg['h_dim'],
                self.vector.da_dim).to(device=DEVICE)

        self.value = Value(self.vector.state_dim,
                           cfg['hv_dim']).to(device=DEVICE)
        if is_train:
            self.policy_optim = optim.RMSprop(self.policy.parameters(),
                                              lr=cfg['lr'])
            self.value_optim = optim.Adam(self.value.parameters(),
                                          lr=cfg['lr'])

    def predict(self, state):
        """
        Predict an system action given state.
        Args:
            state (dict): Dialog state. Please refer to util/state.py
        Returns:
            action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
        """
        s_vec = torch.Tensor(self.vector.state_vectorize(state))
        a = self.policy.select_action(s_vec.to(device=DEVICE),
                                      self.is_train).cpu()
        return self.vector.action_devectorize(a.numpy())

    def init_session(self):
        """
        Restore after one session
        """
        pass

    def est_adv(self, r, v, mask):
        """
        we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0.
        :param r: reward, Tensor, [b]
        :param v: estimated value, Tensor, [b]
        :param mask: indicates ending for 0 otherwise 1, Tensor, [b]
        :return: A(s, a), V-target(s), both Tensor
        """
        batchsz = v.size(0)

        # v_target is worked out by Bellman equation.
        v_target = torch.Tensor(batchsz).to(device=DEVICE)
        delta = torch.Tensor(batchsz).to(device=DEVICE)
        A_sa = torch.Tensor(batchsz).to(device=DEVICE)

        prev_v_target = 0
        prev_v = 0
        prev_A_sa = 0
        for t in reversed(range(batchsz)):
            # mask here indicates a end of trajectory
            # this value will be treated as the target value of value network.
            # mask = 0 means the immediate reward is the real V(s) since it's end of trajectory.
            # formula: V(s_t) = r_t + gamma * V(s_t+1)
            v_target[t] = r[t] + self.gamma * prev_v_target * mask[t]

            # please refer to : https://arxiv.org/abs/1506.02438
            # for generalized adavantage estimation
            # formula: delta(s_t) = r_t + gamma * V(s_t+1) - V(s_t)
            delta[t] = r[t] + self.gamma * prev_v * mask[t] - v[t]

            # formula: A(s, a) = delta(s_t) + gamma * lamda * A(s_t+1, a_t+1)
            # here use symbol tau as lambda, but original paper uses symbol lambda.
            A_sa[t] = delta[t] + self.gamma * self.tau * prev_A_sa * mask[t]

            # update previous
            prev_v_target = v_target[t]
            prev_v = v[t]
            prev_A_sa = A_sa[t]

        # normalize A_sa
        A_sa = (A_sa - A_sa.mean()) / A_sa.std()

        return A_sa, v_target

    def update(self, epoch, batchsz, s, a, next_s, mask, rewarder):
        # update reward estimator
        rewarder.update_irl((s, a, next_s), batchsz, epoch)

        # get estimated V(s) and PI_old(s, a)
        # actually, PI_old(s, a) can be saved when interacting with env, so as to save the time of one forward elapsed
        # v: [b, 1] => [b]
        v = self.value(s).squeeze(-1).detach()
        log_pi_old_sa = self.policy.get_log_prob(s, a).detach()

        # estimate advantage and v_target according to GAE and Bellman Equation
        r = rewarder.estimate(s, a, next_s, log_pi_old_sa).detach()
        A_sa, v_target = self.est_adv(r, v, mask)

        for i in range(self.update_round):

            # 1. shuffle current batch
            perm = torch.randperm(batchsz)
            # shuffle the variable for mutliple optimize
            v_target_shuf, A_sa_shuf, s_shuf, a_shuf, log_pi_old_sa_shuf = v_target[perm], A_sa[perm], s[perm], a[perm], \
                                                                           log_pi_old_sa[perm]

            # 2. get mini-batch for optimizing
            optim_chunk_num = int(np.ceil(batchsz / self.optim_batchsz))
            # chunk the optim_batch for total batch
            v_target_shuf, A_sa_shuf, s_shuf, a_shuf, log_pi_old_sa_shuf = torch.chunk(v_target_shuf, optim_chunk_num), \
                                                                           torch.chunk(A_sa_shuf, optim_chunk_num), \
                                                                           torch.chunk(s_shuf, optim_chunk_num), \
                                                                           torch.chunk(a_shuf, optim_chunk_num), \
                                                                           torch.chunk(log_pi_old_sa_shuf,
                                                                                       optim_chunk_num)
            # 3. iterate all mini-batch to optimize
            policy_loss, value_loss = 0., 0.
            for v_target_b, A_sa_b, s_b, a_b, log_pi_old_sa_b in zip(
                    v_target_shuf, A_sa_shuf, s_shuf, a_shuf,
                    log_pi_old_sa_shuf):
                # print('optim:', batchsz, v_target_b.size(), A_sa_b.size(), s_b.size(), a_b.size(), log_pi_old_sa_b.size())
                # 1. update value network
                self.value_optim.zero_grad()
                v_b = self.value(s_b).squeeze(-1)
                loss = (v_b - v_target_b).pow(2).mean()
                value_loss += loss.item()

                # backprop
                loss.backward()
                # nn.utils.clip_grad_norm(self.value.parameters(), 4)
                self.value_optim.step()

                # 2. update policy network by clipping
                self.policy_optim.zero_grad()
                # [b, 1]
                log_pi_sa = self.policy.get_log_prob(s_b, a_b)
                # ratio = exp(log_Pi(a|s) - log_Pi_old(a|s)) = Pi(a|s) / Pi_old(a|s)
                # we use log_pi for stability of numerical operation
                # [b, 1] => [b]
                ratio = (log_pi_sa - log_pi_old_sa_b).exp().squeeze(-1)
                surrogate1 = ratio * A_sa_b
                surrogate2 = torch.clamp(ratio, 1 - self.epsilon,
                                         1 + self.epsilon) * A_sa_b
                # this is element-wise comparing.
                # we add negative symbol to convert gradient ascent to gradient descent
                surrogate = -torch.min(surrogate1, surrogate2).mean()
                policy_loss += surrogate.item()

                # backprop
                surrogate.backward()
                # gradient clipping, for stability
                torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
                # self.lock.acquire() # retain lock to update weights
                self.policy_optim.step()
                # self.lock.release() # release lock

            value_loss /= optim_chunk_num
            policy_loss /= optim_chunk_num
            logging.debug(
                '<<dialog policy ppo>> epoch {}, iteration {}, value, loss {}'.
                format(epoch, i, value_loss))
            logging.debug(
                '<<dialog policy ppo>> epoch {}, iteration {}, policy, loss {}'
                .format(epoch, i, policy_loss))

        if (epoch + 1) % self.save_per_epoch == 0:
            self.save(self.save_dir, epoch)

    def save(self, directory, epoch):
        if not os.path.exists(directory):
            os.makedirs(directory)

        torch.save(self.value.state_dict(),
                   directory + '/' + str(epoch) + '_ppo.val.mdl')
        torch.save(self.policy.state_dict(),
                   directory + '/' + str(epoch) + '_ppo.pol.mdl')

        logging.info(
            '<<dialog policy>> epoch {}: saved network to mdl'.format(epoch))

    def load(self, filename):
        value_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 filename + '_ppo.val.mdl')
        if os.path.exists(value_mdl):
            self.value.load_state_dict(torch.load(value_mdl))
            print('<<dialog policy>> loaded checkpoint from file: {}'.format(
                value_mdl))

        policy_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                  filename + '_ppo.pol.mdl')
        if os.path.exists(policy_mdl):
            self.policy.load_state_dict(torch.load(policy_mdl))
            print('<<dialog policy>> loaded checkpoint from file: {}'.format(
                policy_mdl))
Example #5
0
class PG(Policy):

    def __init__(self, is_train=False, dataset='Multiwoz'):
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
            cfg = json.load(f)
        self.save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['save_dir'])
        self.save_per_epoch = cfg['save_per_epoch']
        self.update_round = cfg['update_round']
        self.optim_batchsz = cfg['batchsz']
        self.gamma = cfg['gamma']
        self.is_train = is_train
        if is_train:
            init_logging_handler(cfg['log_dir'])

        if dataset == 'Multiwoz':
            voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
            self.vector = MultiWozVector(voc_file, voc_opp_file)
            self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)

        # self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
        if is_train:
            self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['lr'])

    def predict(self, state):
        """
        Predict an system action given state.
        Args:
            state (dict): Dialog state. Please refer to util/state.py
        Returns:
            action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
        """
        s_vec = torch.Tensor(self.vector.state_vectorize(state))
        a = self.policy.select_action(s_vec.to(device=DEVICE), self.is_train).cpu()
        action = self.vector.action_devectorize(a.numpy())
        state['system_action'] = action

        return action

    def init_session(self):
        """
        Restore after one session
        """
        pass

    def est_return(self, r, mask):
        """
        we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0.
        :param r: reward, Tensor, [b]
        :param mask: indicates ending for 0 otherwise 1, Tensor, [b]
        :return: V-target(s), Tensor
        """
        batchsz = r.size(0)

        # v_target is worked out by Bellman equation.
        v_target = torch.Tensor(batchsz).to(device=DEVICE)

        prev_v_target = 0
        for t in reversed(range(batchsz)):
            # mask here indicates a end of trajectory
            # this value will be treated as the target value of value network.
            # mask = 0 means the immediate reward is the real V(s) since it's end of trajectory.
            # formula: V(s_t) = r_t + gamma * V(s_t+1)
            v_target[t] = r[t] + self.gamma * prev_v_target * mask[t]

            # update previous
            prev_v_target = v_target[t]

        return v_target

    def update(self, epoch, batchsz, s, a, r, mask):

        v_target = self.est_return(r, mask)

        for i in range(self.update_round):

            # 1. shuffle current batch
            perm = torch.randperm(batchsz)
            # shuffle the variable for mutliple optimize
            v_target_shuf, s_shuf, a_shuf = v_target[perm], s[perm], a[perm]

            # 2. get mini-batch for optimizing
            optim_chunk_num = int(np.ceil(batchsz / self.optim_batchsz))
            # chunk the optim_batch for total batch
            v_target_shuf, s_shuf, a_shuf = torch.chunk(v_target_shuf, optim_chunk_num), \
                                            torch.chunk(s_shuf, optim_chunk_num), \
                                            torch.chunk(a_shuf, optim_chunk_num)

            # 3. iterate all mini-batch to optimize
            policy_loss = 0.
            for v_target_b, s_b, a_b in zip(v_target_shuf, s_shuf, a_shuf):
                # print('optim:', batchsz, v_target_b.size(), A_sa_b.size(), s_b.size(), a_b.size(), log_pi_old_sa_b.size())

                # update policy network by clipping
                self.policy_optim.zero_grad()
                # [b, 1]
                log_pi_sa = self.policy.get_log_prob(s_b, a_b)
                # ratio = exp(log_Pi(a|s) - log_Pi_old(a|s)) = Pi(a|s) / Pi_old(a|s)
                # we use log_pi for stability of numerical operation
                # [b, 1] => [b]
                # this is element-wise comparing.
                # we add negative symbol to convert gradient ascent to gradient descent
                surrogate = - (log_pi_sa * v_target_b).mean()
                policy_loss += surrogate.item()

                # backprop
                surrogate.backward()
                for p in self.policy.parameters():
                    p.grad[p.grad != p.grad] = 0.0
                # gradient clipping, for stability
                torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
                # self.lock.acquire() # retain lock to update weights
                self.policy_optim.step()
                # self.lock.release() # release lock

            policy_loss /= optim_chunk_num
            logging.debug('<<dialog policy pg>> epoch {}, iteration {}, policy, loss {}'.format(epoch, i, policy_loss))

        if (epoch + 1) % self.save_per_epoch == 0:
            self.save(self.save_dir, epoch)

    def save(self, directory, epoch):
        if not os.path.exists(directory):
            os.makedirs(directory)

        torch.save(self.policy.state_dict(), directory + '/' + str(epoch) + '_pg.pol.mdl')

        logging.info('<<dialog policy>> epoch {}: saved network to mdl'.format(epoch))

    def load(self, filename):
        policy_mdl_candidates = [
            filename,
            filename + '.pol.mdl',
            filename + '_pg.pol.mdl',
            os.path.join(os.path.dirname(os.path.abspath(__file__)), filename),
            os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'),
            os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_pg.pol.mdl')
        ]
        for policy_mdl in policy_mdl_candidates:
            if os.path.exists(policy_mdl):
                self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
                logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
                break

    def load_from_pretrained(self, archive_file, model_file, filename):
        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for PG Policy is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'save')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        if not os.path.exists(os.path.join(model_dir, 'best_pg.pol.mdl')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)

        policy_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_pg.pol.mdl')
        if os.path.exists(policy_mdl):
            self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
            logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))

    @classmethod
    def from_pretrained(cls,
                        archive_file="",
                        model_file="https://convlab.blob.core.windows.net/convlab-2/pg_policy_multiwoz.zip"):
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
            cfg = json.load(f)
        model = cls()
        model.load_from_pretrained(archive_file, model_file, cfg['load'])
        return model