def make_critic_encoder(self, sdim, adim, hidden_dim):
        PolicyBase = make_base_policy(self.base_policy_type)
        if self.base_policy_type == 'mlp':

            class MlpCritEncoder(PolicyBase):
                def __init__(self,
                             sdim_h,
                             adim_h,
                             hidden_dim,
                             nonlin=F.leaky_relu):
                    dim_cat = [sdim_h + adim_h]
                    super(MlpCritEncoder, self).__init__(dim_cat,
                                                         hidden_dim,
                                                         nonlin=F.leaky_relu)
                    #self.bn = nn.BatchNorm1d(sdim_h, affine=False)
                def forward(self, o, a):
                    o = o.flatten(1)
                    x = torch.cat([o, a], dim=1)
                    x = self.fc1(x)
                    x = self.nonlin(x)
                    return x

            sdim = list(sdim)
            hldr = 1
            for i in sdim:  #flatten
                hldr *= i
            sdim_h = hldr
            #dim_cat = [hldr + adim]
            base = MlpCritEncoder(sdim_h,
                                  adim,
                                  hidden_dim,
                                  nonlin=F.leaky_relu)
        else:
            base = PolicyBase(sdim, hidden_dim=hidden_dim, cat_end=adim)
        return base
Exemple #2
0
 def make_state_encoder(self, observation_space):
     if type(observation_space) == tuple:
         raise NotImplementedError
     elif isinstance(observation_space, spaces.Box):
         BasePolicy = make_base_policy(self.base_policy_type)
         self.enc = BasePolicy(observation_space.shape,
                               self.hid_size,
                               nonlin=F.leaky_relu).double()
Exemple #3
0
def make_policy(policy_type):
    BasePolicy = make_base_policy(policy_type)

    class ActorBase(BasePolicy):
        def __init__(self, *args, nonlin = F.leaky_relu, **kwargs):
            (num_in_pol, num_out_pol) = args
            self.nonlin = nonlin
            del kwargs["onehot_dim"]
            super(ActorBase, self).__init__(num_in_pol,nonlin=nonlin, **kwargs)
            h_dim = kwargs["hidden_dim"]
            self.start_layer = nn.Linear(h_dim, h_dim)
            self.out_layer = nn.Linear(h_dim, num_out_pol)
        def forward(self, obs):
            out1 = super(ActorBase, self).forward(obs.float())
            out2 = self.out_layer(out1)
            return out2

    class ActorPolicy(ActorBase):
        def __init__(self, *args, **kwargs):
            super(ActorPolicy, self).__init__(*args, **kwargs)

        def forward(self, obs, sample=True, return_all_probs=False,
                    return_log_pi=False, regularize=False,
                    return_entropy=False):
            out = super(ActorPolicy, self).forward(obs)
            probs = F.softmax(out, dim=1)
            on_gpu = next(self.parameters()).is_cuda
            if sample:
                int_act, act = categorical_sample(probs, use_cuda=on_gpu)
            else:
                act = onehot_from_logits(probs)
            rets = [act]
            if return_log_pi or return_entropy:
                log_probs = F.log_softmax(out, dim=1)
            if return_all_probs:
                rets.append(probs)
            if return_log_pi:
                # return log probability of selected action
                rets.append(log_probs.gather(1, int_act))
            if regularize:
                rets.append([(out**2).mean()])
            if return_entropy:
                rets.append(-(log_probs * probs).sum(1).mean())
            if len(rets) == 1:
                return rets[0]
            return rets

    return ActorPolicy
def make_value_net(base_policy_type,double_obs_space = False, recurrent = False):
    BasePoliy = make_base_policy(base_policy_type, double_obs_space)
    class Critic(BasePoliy):
        def __init__(self, observation_space, hidden_dim, lr = 0.001, nonlin = F.leaky_relu):
            if double_obs_space:
                dim1_shape = observation_space[0].shape
                dim2 = observation_space[1].shape[0]
                if "mlp" in base_policy_type:
                    super(Critic, self).__init__(dim1_shape, dim2, hidden_dim = hidden_dim,nonlin= nonlin)
                else:
                    super(Critic, self).__init__(dim1_shape,hidden_dim, nonlin= nonlin, cat_end = dim2)
            else:
                super(Critic, self).__init__(observation_space, hidden_dim, nonlin= nonlin)
            self.recurrent = recurrent
            self.nonlin = nonlin
            self.fc_mid = nn.Linear(self.hidden_dim, hidden_dim)
            self.fc_out = nn.Linear(hidden_dim, 1)
            if recurrent:
                self.lstm = nn.LSTMCell(self.hidden_dim, hidden_dim)
            self.optimizer = optim.Adam(self.parameters(), lr =lr)
        def forward(self, x, hid_cell_state = None):
            x = super(Critic, self).forward(x)
            hx, cx = None,None
            if self.recurrent:
                x, cx = self.lstm(x, hid_cell_state)
                hx = x
                hx, cx = hx.clone(), cx.clone()
                x = self.nonlin(x)
            else:
                x = self.nonlin(self.fc_mid(x))
            x = self.fc_out(x)
            return x, (hx, cx)
        def update(self, loss):
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    return Critic
def make_actor_net(base_policy_type,double_obs_space = False, recurrent = False, blocking = False):
    BasePoliy = make_base_policy(base_policy_type,double_obs_space)
    class Actor(BasePoliy):
        def __init__(self, observation_space, hidden_dim, action_dim, lr = 0.001, nonlin = F.leaky_relu):
            if double_obs_space:
                dim1_shape = observation_space[0].shape
                dim2 = observation_space[1].shape[0]
                if "mlp" in base_policy_type:
                    super(Actor, self).__init__(dim1_shape,hidden_dim, dim2, nonlin= nonlin)
                else:
                    super(Actor, self).__init__(dim1_shape,hidden_dim, nonlin= nonlin, cat_end = dim2)
            else:
                super(Actor, self).__init__(observation_space, hidden_dim, nonlin= nonlin)
            self.recurrent = recurrent
            self.blocking = blocking
            self.nonlin = nonlin
            
            self.fc_out = nn.Linear(hidden_dim, action_dim)
            if recurrent:
                self.lstm = nn.LSTMCell(self.hidden_dim, hidden_dim)
            else:
                self.fc_mid = nn.Linear(self.hidden_dim, hidden_dim)

            if self.blocking:
                self.fc_block_mid = nn.Linear(hidden_dim, hidden_dim)
                self.fc_block_out = nn.Linear(hidden_dim, 1)
            self.optimizer = optim.Adam(self.parameters(), lr =lr)
            #self.optimizer = Nadam(self.parameters(), lr = lr)
            
        def forward(self, x, hid_cell_state = None):
            x = super(Actor, self).forward(x)
            hx, cx = None,None
            if self.recurrent:
                x, cx = self.lstm(x, hid_cell_state)
                hx = x
                hx, cx = hx.clone(), cx.clone()
                x = self.nonlin(x)
            else:
                x = self.nonlin(self.fc_mid(x))
            x_act = self.fc_out(x)
            if self.blocking:
                x_block = self.nonlin(self.fc_block_mid(x))
                x_block = F.sigmoid(self.fc_block_out(x_block)) #softmax removed since incuded in loss f
            else:
                x_block = None
            return (x_act, hx, cx, x_block)

        def take_action(self, obs, greedy = False, hid_cell_state = None, valid_act = None):
            (a_probs, hx, cx, x_block) = self.forward(obs, hid_cell_state)
            if greedy:
                a_select = torch.argmax(F.softmax(a_probs), dim=-1, keepdim= True)
            else:
                a_probs_new = F.softmax(a_probs.clone().detach())
                if not valid_act is None:
                    assert a_probs.size(0) == 1
                    non_valid = set([i for i in range(a_probs.size(-1))]) - set(valid_act)
                    if len(non_valid) != 0:
                        idx = torch.tensor(list(non_valid))
                        a_probs_new[0,idx] = 0
                if a_probs_new.sum() == 0:
                    print("a_prob_new is zero")
                    a_probs_new = F.softmax(a_probs.clone())
                a_select = torch.multinomial(a_probs_new, 1)
            return a_probs, a_select, (hx, cx), x_block

        def loss(self, a_prob, a_taken, adv, entropy_coeff):
            a_prob_select = torch.gather(a_prob, dim= -1, index = a_taken)
            entropy = torch.distributions.Categorical(a_prob).entropy() * entropy_coeff
            loss = -torch.log(a_prob_select) * adv
            loss -= entropy.reshape(-1,1)
            loss = loss.mean()
            return loss
        def update(self, loss):
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    return Actor
def make_primal_net(base_policy_type, double_obs_space = False,recurrent=True, blocking = False):
    print("Recurrent is: {}".format(recurrent))
    BasePoliy = make_base_policy(base_policy_type,double_obs_space)
    class Actor(BasePoliy):
        def __init__(self, observation_space, hidden_dim, action_dim, lr = 0.001, nonlin = F.relu):
            if double_obs_space:
                dim1_shape = observation_space[0].shape
                dim2 = observation_space[1].shape[0]
                if "mlp" in base_policy_type:
                    super(Actor, self).__init__(dim1_shape,hidden_dim, dim2, nonlin= nonlin)
                else:
                    super(Actor, self).__init__(dim1_shape,hidden_dim, nonlin= nonlin, cat_end = dim2)
            else:
                super(Actor, self).__init__(observation_space, hidden_dim, nonlin= nonlin)
            self.recurrent = recurrent
            self.blocking = blocking
            self.nonlin = nonlin
            
            self.act_mid = nn.Linear(hidden_dim, hidden_dim)
            self.act_out = nn.Linear(hidden_dim, action_dim)

            self.value_mid = nn.Linear(hidden_dim, hidden_dim)
            self.value_out = nn.Linear(hidden_dim, 1)

            if recurrent:
                self.lstm = nn.LSTMCell(self.hidden_dim, hidden_dim)
            else:
                self.fc_mid = nn.Linear(self.hidden_dim, hidden_dim)
            
            #middle layers before LSTM:
            self.fc_mid1 = nn.Linear(hidden_dim, hidden_dim)
            self.fc_mid2 = nn.Linear(hidden_dim, hidden_dim)

            if self.blocking:
                self.fc_block_mid = nn.Linear(hidden_dim, hidden_dim)
                self.fc_block_out = nn.Linear(hidden_dim, 1)
            self.optimizer = optim.Adam(self.parameters(), lr =lr)
            #self.optimizer = Nadam(self.parameters(), lr = lr)
            
        def forward(self, x, hid_cell_state = None):
            x_dec = super(Actor, self).forward(x)

            x_mid = self.fc_mid1(x_dec)
            x_mid = F.relu(x_mid)
            x_mid = self.fc_mid2(x_mid)
            x_mid = x_mid + x_dec
            x_mid = F.relu(x_mid)

            hx, cx = None,None
            if self.recurrent:
                x, cx = self.lstm(x_mid, hid_cell_state)
                hx = x
                hx, cx = hx.clone(), cx.clone()
            else:
                x = self.nonlin(self.fc_mid(x_mid))
            x_act = self.act_mid(x)
            x_act = self.act_out(x_act)

            x_val = self.value_mid(x)
            x_val = self.value_out(x)
            if self.blocking:
                x_block = self.nonlin(self.fc_block_mid(x))
                x_block = F.sigmoid(self.fc_block_out(x_block)) #softmax removed since incuded in loss f
            else:
                x_block = None
            return (x_act, hx, cx, x_block, x_val)

        def take_action(self, obs, greedy = False, hid_cell_state = None, valid_act = None):
            (a_probs, hx, cx, x_block, x_val) = self.forward(obs, hid_cell_state)
            
            if greedy:
                a_select = torch.argmax(F.softmax(a_probs), dim=-1, keepdim= True)
            else:
                a_probs_new = F.softmax(a_probs.clone().detach())
                a_probs_new = torch.clamp(a_probs_new, 1e-15, 1.0)
                if not valid_act is None:
                    assert a_probs.size(0) == 1
                    non_valid = set([i for i in range(a_probs.size(-1))]) - set(valid_act)
                    if len(non_valid) != 0:
                        idx = torch.tensor(list(non_valid))
                        a_probs_new[0,idx] = 0
                if a_probs_new.sum() == 0:
                    print("a_prob_new is zero")
                    a_probs_new = F.softmax(a_probs.clone())
                a_select = torch.multinomial(a_probs_new, 1)
            return a_probs, a_select, (hx, cx), x_block, x_val
    return Actor
 def make_state_encoder(self, sdim, hidden_dim):
     PolicyBase = make_base_policy(self.base_policy_type)
     base = PolicyBase(sdim, hidden_dim, nonlin=F.leaky_relu)
     return base