class VAECritic(nn.Module): def __init__(self, vae_weights_path, obs_dim, conv_layer_sizes, hidden_sizes, activation): ''' A Variational Autoencoder Net for the Critic network Args: vae_weights_path (Str): Path to the vae weights file obs_dim (tuple): observation dimension of the environment in the form of (C, H, W) act_dim (int): action dimension of the environment hidden_sizes (list): list of number of neurons in each layer of MLP activation (nn.modules.activation): Activation function for each layer of MLP ''' super().__init__() self.v_vae = VAE() self.v_vae.load_weights(vae_weights_path) self.v_mlp = mlp([self.v_vae.latent_dim] + list(hidden_sizes) + [1], activation) def forward(self, obs): ''' Forward propagation for critic network Args: obs (Tensor [n, obs_dim]): batch of observation from environment ''' obs = self.v_vae(obs) v = self.v_mlp(obs) return torch.squeeze(v, -1) # ensure q has the right shape def dataparallel(self, ngpu): print(f"Critic network using {ngpu} gpus, gpu id: {list(range(ngpu))}") self.v_vae.dataparallel(ngpu) self.v_mlp = nn.DataParallel(self.v_mlp, list(range(ngpu)))
class VAECategoricalActor(Actor): def __init__(self, vae_weights_path, obs_dim, act_dim, hidden_sizes, activation): ''' A Variational Autoencoder Net for the Actor network for discrete outputs Network Architecture: (input) -> VAE -> MLP -> (output) Assume input is in the shape: (3, 128, 128) Args: vae_weights_path (Str): Path to the vae weights file obs_dim (tuple): observation dimension of the environment in the form of (C, H, W) act_dim (int): action dimension of the environment hidden_sizes (list): list of number of neurons in each layer of MLP after output from VAE activation (nn.modules.activation): Activation function for each layer of MLP ''' super().__init__() self.logits_vae = VAE() self.logits_vae.load_weights(vae_weights_path) mlp_sizes = [self.logits_vae.latent_dim ] + list(hidden_sizes) + [act_dim] self.logits_mlp = mlp(mlp_sizes, activation, output_activation=nn.Tanh) # initialise actor network final layer weights to be 1/100 of other weights self.logits_mlp[ -2].weight.data /= 100 # last layer is Identity, so we tweak second last layer weights def _distribution(self, obs): ''' Forward propagation for actor network Args: obs (Tensor [n, obs_dim]): batch of observation from environment Return: Categorical distribution from output of model ''' obs = self.logits_vae(obs) logits = self.logits_mlp(obs) return Categorical(logits=logits) def _log_prob_from_distribution(self, pi, act): ''' Args: pi: distribution from _distribution() function act: log probability of selecting action act from the given distribution pi ''' return pi.log_prob(act) def dataparallel(self, ngpu): print(f"Actor network using {ngpu} gpus, gpu id: {list(range(ngpu))}") self.logits_vae.dataparallel(ngpu) self.logits_mlp = nn.DataParallel(self.logits_mlp, list(range(ngpu)))
class VAEActor(nn.Module): def __init__(self, vae_weights_path, obs_dim, act_dim, hidden_sizes, activation, act_limit): ''' A Variational Autoencoder for the Actor network Network Architecture: (input) -> VAE -> MLP -> (output) The VAE is pretrained on observation images. Assume observation space is in the shape: (3, 128, 128) Args: vae_weights_path (Str): Path to the vae weights file obs_dim (tuple): observation dimension of the environment in the form of (C, H, W) act_dim (int): action dimension of the environment hidden_sizes (list): list of number of neurons in each layer of MLP after output from CNN activation (nn.modules.activation): Activation function for each layer of MLP act_limit (float): the greatest magnitude possible for the action in the environment ''' super().__init__() self.pi_vae = VAE() self.pi_vae.load_weights(vae_weights_path) mlp_sizes = [self.pi_vae.latent_dim] + list(hidden_sizes) + [act_dim] self.pi_mlp = mlp(mlp_sizes, activation, output_activation=nn.Tanh) self.act_limit = act_limit def forward(self, obs): ''' Forward propagation for actor network Args: obs (Tensor [n, obs_dim]): batch of observation from environment Return: output of actor network * act_limit ''' obs = self.pi_vae(obs) obs = self.pi_mlp(obs) return obs * self.act_limit def dataparallel(self, ngpu): print(f"Actor Network using {ngpu} gpus, gpu id: {list(range(ngpu))}") self.pi_vae.dataparallel(ngpu) self.pi_mlp = nn.DataParallel(self.pi_mlp, list(range(ngpu)))