示例#1
0
    def calc_q_loss(self, batch):
        '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
        states = batch['states']
        next_states = batch['next_states']
        q_preds = self.net(states)
        with torch.no_grad():
            # Use online_net to select actions in next state
            online_next_q_preds = self.online_net(next_states)
            # Use eval_net to calculate next_q_preds for actions chosen by online_net
            next_q_preds = self.eval_net(next_states)
        act_q_preds = q_preds.gather(
            -1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
        online_actions = online_next_q_preds.argmax(dim=-1, keepdim=True)
        max_next_q_preds = next_q_preds.gather(-1, online_actions).squeeze(-1)
        max_q_targets = batch['rewards'] + self.gamma * (
            1 - batch['dones']) * max_next_q_preds
        logger.debug(
            f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}')
        q_loss = self.net.loss_fn(act_q_preds, max_q_targets)

        # TODO use the same loss_fn but do not reduce yet
        if 'Prioritized' in util.get_class_name(self.body.memory):  # PER
            errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy()
            self.body.memory.update_priorities(errors)
        return q_loss
示例#2
0
def init_global_nets(algorithm):
    '''
    Initialize global_nets for Hogwild using an identical instance of an algorithm from an isolated Session
    in spec.meta.distributed, specify either:
    - 'shared': global network parameter is shared all the time. In this mode, algorithm local network will be replaced directly by global_net via overriding by identify attribute name
    - 'synced': global network parameter is periodically synced to local network after each gradient push. In this mode, algorithm will keep a separate reference to `global_{net}` for each of its network
    '''
    dist_mode = algorithm.agent.spec['meta']['distributed']
    assert dist_mode in ('shared', 'synced'), f'Unrecognized distributed mode'
    global_nets = {}
    for net_name in algorithm.net_names:
        optim_name = net_name.replace('net', 'optim')
        if not hasattr(
                algorithm,
                optim_name):  # only for trainable network, i.e. has an optim
            continue
        g_net = getattr(algorithm, net_name)
        g_net.share_memory()  # make net global
        if dist_mode == 'shared':  # use the same name to override the local net
            global_nets[net_name] = g_net
        else:  # keep a separate reference for syncing
            global_nets[f'global_{net_name}'] = g_net
        # if optim is Global, set to override the local optim and its scheduler
        optim = getattr(algorithm, optim_name)
        if 'Global' in util.get_class_name(optim):
            optim.share_memory()  # make optim global
            global_nets[optim_name] = optim
            lr_scheduler_name = net_name.replace('net', 'lr_scheduler')
            lr_scheduler = getattr(algorithm, lr_scheduler_name)
            global_nets[lr_scheduler_name] = lr_scheduler
    logger.info(
        f'Initialized global_nets attr {list(global_nets.keys())} for Hogwild')
    return global_nets
示例#3
0
    def calc_q_loss(self, batch):
        '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
        batch_rewards_ori, batch_rewards_log, batch_rewards_log_double, batch_reward_log_minus_one = self.fetch_irl_reward(
            batch)
        self.reward_count += batch_rewards_ori.mean().item()
        self.batch_count += 1
        # batch_rewards = batch_rewards_ori + batch['rewards']
        # batch_rewards = batch_reward_log_minus_one
        # batch_rewards = batch_rewards_log
        # batch_rewards = batch_rewards_log_double + batch['rewards']
        """
        here to change the reward function. From two to choose one. For me, baseline is not running over here.
        Specify the method of surgery
        change VAE function in the other place.
        """
        batch_rewards = batch_rewards_log.to("cpu") + batch['rewards']
        # batch_rewards = batch['rewards']
        # batch_rewards = batch_rewards_ori.to("cpu") + batch['rewards']
        # flag = copy.deepcopy(batch['rewards'])
        # flag[flag<=0]=0
        # flag[flag>0]=1
        # batch_rewards = batch_rewards_log + flag * batch['rewards']

        states = batch['states']
        next_states = batch['next_states']
        q_preds = self.net(states)
        with torch.no_grad():
            # Use online_net to select actions in next state
            online_next_q_preds = self.online_net(next_states)
            # Use eval_net to calculate next_q_preds for actions chosen by online_net
            next_q_preds = self.eval_net(next_states)
        act_q_preds = q_preds.gather(
            -1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
        online_actions = online_next_q_preds.argmax(dim=-1, keepdim=True)
        max_next_q_preds = next_q_preds.gather(-1, online_actions).squeeze(-1)
        max_q_targets = batch_rewards + self.gamma * (
            1 - batch['dones']) * max_next_q_preds
        logger.debug(
            f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}')
        q_loss = self.net.loss_fn(act_q_preds, max_q_targets)

        # TODO use the same loss_fn but do not reduce yet
        if 'Prioritized' in util.get_class_name(self.body.memory):  # PER
            errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy()
            self.body.memory.update_priorities(errors)
        return q_loss
示例#4
0
def init_params(module, init_fn):
    '''Initialize module's weights using init_fn, and biases to 0.0'''
    bias_init = 0.0
    classname = util.get_class_name(module)
    if 'Net' in classname:  # skip if it's a net, not pytorch layer
        pass
    elif any(k in classname for k in ('BatchNorm', 'Conv', 'Linear')):
        init_fn(module.weight)
        nn.init.constant_(module.bias, bias_init)
    elif 'GRU' in classname:
        for name, param in module.named_parameters():
            if 'weight' in name:
                init_fn(param)
            elif 'bias' in name:
                nn.init.constant_(param, bias_init)
    else:
        pass
示例#5
0
    def calc_q_loss(self, batch):
        '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
        states = batch['states']
        next_states = batch['next_states']
        q_preds = self.net(states)
        with torch.no_grad():
            next_q_preds = self.net(next_states)
        act_q_preds = q_preds.gather(-1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
        # Bellman equation: compute max_q_targets using reward and max estimated Q values (0 if no next_state)
        max_next_q_preds, _ = next_q_preds.max(dim=-1, keepdim=True)
        max_q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * max_next_q_preds
        logger.debug(f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}')
        q_loss = self.net.loss_fn(act_q_preds, max_q_targets)

        # TODO use the same loss_fn but do not reduce yet
        if 'Prioritized' in util.get_class_name(self.body.memory):  # PER
            errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy()
            self.body.memory.update_priorities(errors)
        return q_loss
示例#6
0
文件: base.py 项目: temporaer/ConvLab
 def _is_discrete(self, action_space):
     '''Check if an action space is discrete'''
     return util.get_class_name(action_space) != 'Box'