def init_weights( main: nn.Sequential, critic_linear: nn.Linear) -> Tuple[nn.Sequential, nn.Linear]: """ Runs initializers on arguments. """ init_ = lambda m: init( m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), nn.init.calculate_gain("relu"), ) layers: List[nn.Module] = [] for module in main.modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): layers.append(init_(module)) elif not isinstance(module, nn.Sequential): layers.append(module) new_main = nn.Sequential(*layers) init_critic = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. constant_(x, 0)) new_critic_linear = init_critic(critic_linear) return new_main, new_critic_linear
def init_weights( actor: nn.Sequential, critic: nn.Sequential, critic_linear: nn.Linear ) -> Tuple[nn.Sequential, nn.Sequential, nn.Linear]: """ Runs initializers on arguments. """ init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. constant_(x, 0), np.sqrt(2)) layers: List[nn.Module] = [] for module in actor.modules(): if isinstance(module, nn.Linear): layers.append(init_(module)) elif not isinstance(module, nn.Sequential): layers.append(module) new_actor = nn.Sequential(*layers) layers = [] for module in critic.modules(): if isinstance(module, nn.Linear): layers.append(init_(module)) elif not isinstance(module, nn.Sequential): layers.append(module) new_critic = nn.Sequential(*layers) new_critic_linear = init_(critic_linear) return new_actor, new_critic, new_critic_linear
def __init__(self, num_inputs: int, num_outputs: int): super(Bernoulli, self).__init__() init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. constant_(x, 0)) self.linear = init_(nn.Linear(num_inputs, num_outputs))
def __init__(self, num_inputs: int, num_outputs: int): super(DiagGaussian, self).__init__() init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. constant_(x, 0)) self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) self.logstd = AddBias(torch.zeros(num_outputs))
def __init__(self, num_inputs: int, num_outputs: int): super(Categorical, self).__init__() init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=0.01) self.linear = init_(nn.Linear(num_inputs, num_outputs))
def __init__(self, num_inputs: int, num_outputs_list: List[int]): super(CategoricalProduct, self).__init__() self.num_inputs = num_inputs self.num_outputs_list = num_outputs_list self.distributions = [ Categorical(num_inputs, outputs) for outputs in num_outputs_list ] init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=0.01) self.linears = nn.ModuleList([ init_(nn.Linear(num_inputs, outputs)) for outputs in num_outputs_list ])