def __init__(self, num_inputs, obs_space, recurrent=False, hidden_size=32):
        super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size)

        if recurrent:
            num_inputs = hidden_size

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        # self.actor = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(),
        #     init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())

        # self.critic = nn.Sequential(
        #     init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(),
        #     init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())

        self.actor = MLP(
            num_inputs,
            32,
            [16, 32, 64]
        )

        self.critic = MLP(
            num_inputs,
            32,
            [16, 32, 64]
        )

        self.critic_linear = init_(nn.Linear(32, 1))

        self.train()
    def __init__(self, environment):
        super().__init__()

        observation_space = environment.observation_space.shape
        image_size = CHW(observation_space[2], observation_space[0],
                         observation_space[1])

        self.conv = SimpleConvNet(image_size.channels, 1, [], [3],
                                  {"USE_BATCH_NORM": True})
        self.mlp = MLP(
            self.conv.output_size((image_size.width, image_size.height)), 32,
            [32, 32])
    def __init__(self, environment):
        super().__init__()

        self.mlp = MLP(np.prod(environment.observation_space.shape), 32,
                       [16, 32, 64])
 def create_empty(self, environment, policy):
     return datastructures.Ensemble([
         MLP(32 + get_action_space_size(environment.action_space), 32,
             [32, 32]) for i in range(self.NUM_MODELS)
     ], environment).to(DefaultDevice.current())
 def create_empty(self, environment, policy):
     return MLP(32, 32, [32, 32]).to(DefaultDevice.current())
 def create_empty(self, environment, policy):
     return MLP(get_action_space_size(environment.action_space), 32,
                [32, 32]).to(DefaultDevice.current())