示例#1
0
 def __init__(self, module, name=None):
     super(SplitBias, self).__init__()
     self.module = module
     self.add_bias = AddBias(module.bias.data)
     self.module.bias = None
     self.name = name
     self.module.split_bias = True
示例#2
0
    def __init__(self,
                 num_inputs,
                 num_actions,
                 recurrent=False,
                 hidden_size=64):
        super(MLPBase_av, self).__init__(recurrent, num_inputs, hidden_size)

        if recurrent:
            raise NotImplementedError
            num_inputs = hidden_size

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.actor1 = nn.Sequential(init_(nn.Linear(num_inputs, hidden_size)),
                                    nn.Tanh(),
                                    init_(nn.Linear(hidden_size, hidden_size)),
                                    nn.Tanh())

        self.critic1 = nn.Sequential(
            init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(),
            init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())

        self.actor2 = nn.Sequential(init_(nn.Linear(num_inputs, hidden_size)),
                                    nn.Tanh(),
                                    init_(nn.Linear(hidden_size, hidden_size)),
                                    nn.Tanh())

        self.critic2 = nn.Sequential(
            init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(),
            init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())

        self.critic_linear1 = init_(nn.Linear(hidden_size, 1))

        self.critic_linear2 = init_(nn.Linear(hidden_size, 1))

        # Action distribution
        init_dist_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                                    constant_(x, 0))

        self.fc_mean1 = init_dist_(nn.Linear(hidden_size, num_actions))
        self.logstd1 = AddBias(torch.zeros(num_actions))

        self.fc_mean2 = init_dist_(nn.Linear(hidden_size, num_actions))
        self.logstd2 = AddBias(torch.zeros(num_actions))

        self.train()
示例#3
0
    def __init__(self, num_inputs, num_outputs):
        super(DiagGaussian, self).__init__()

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))

        self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
        self.logstd = AddBias(torch.zeros(num_outputs))
示例#4
0
    def __init__(self, num_inputs, num_outputs, init_log_std=0):
        super(DiagGaussian, self).__init__()

        #init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
        #                       constant_(x, 0))

        self.fc_mean = nn.Linear(num_inputs, num_outputs)
        self.tanh = nn.Tanh()
        self.logstd = AddBias(torch.zeros(num_outputs) + init_log_std)
示例#5
0
    def __init__(self,
                 obs_shape,
                 action_space,
                 base=None,
                 base_kwargs=None,
                 dimh=2):
        super(Policy, self).__init__()

        net = None

        dim_input = obs_shape[0]
        if len(obs_shape) == 1:
            net = [
                # actor base
                SpLinear(dim_input, dimh, actv_fn='tanh'),
                SpLinear(dimh, dimh, actv_fn='tanh'),
                # critic base
                SpLinear(dim_input, dimh, actv_fn='tanh'),
                SpLinear(dimh, dimh, actv_fn='tanh'),
                SpLinear(dimh, 1, actv_fn='none'),
            ]
        else:
            raise NotImplementedError

        if action_space.__class__.__name__ == "Discrete":
            self.action_type = 'discrete'
            num_outputs = action_space.n
            net.append(SpLinear(dimh, num_outputs, actv_fn='none',
                                init_type=1))
        elif action_space.__class__.__name__ == "Box":
            self.action_type = 'continuous'
            num_outputs = action_space.shape[0]
            net.append(SpLinear(dimh, num_outputs, actv_fn='none',
                                init_type=2))
            self.logstd = AddBias(torch.zeros(num_outputs))
        else:
            raise NotImplementedError

        self.net = nn.ModuleList(net)
        self.next_layer = {
            0: [1],
            1: [5],
            2: [3],
            3: [4],
        }
        '''
        self.next_layer = {
            0: [3],
            1: [2],
        }
        '''
        self.layers_to_split = list(self.next_layer.keys())
        self.n_elites = 64
示例#6
0
    def __init__(self, num_inputs, num_outputs, zll=False):
        super(DiagGaussian, self).__init__()

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))

        init_zeros = lambda m: init(
            m, lambda x, **kwargs: nn.init.constant_(x, 0), lambda x: nn.init.
            constant_(x, 0))

        init_last_layer = init_zeros if zll else init_

        self.fc_mean = init_last_layer(nn.Linear(num_inputs, num_outputs))
        self.logstd = AddBias(torch.zeros(num_outputs))
示例#7
0
    def __init__(self, num_inputs, num_outputs, activation=None):
        super(DiagGaussian, self).__init__()

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))

        if activation is None:
            self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
        elif activation == "tanh":
            # print("!!!")
            self.fc_mean = nn.Sequential(
                init_(nn.Linear(num_inputs, num_outputs)), nn.Tanh())
        else:
            raise NotImplementedError
        self.logstd = AddBias(torch.zeros(num_outputs))
示例#8
0
文件: kfac.py 项目: youngleox/nero
 def __init__(self, module):
     super(SplitBias, self).__init__()
     self.module = module
     self.add_bias = AddBias(module.bias.data)
     self.module.bias = None
示例#9
0
    def __init__(self, num_outputs):
        super(DiagGaussianDist, self).__init__()

        self.logstd = AddBias(torch.zeros(num_outputs))