class StochasticActorCriticNN(nn.Module) : def __init__(self,state_dim=3,action_dim=2,action_scaler=2.0,CNN={'use_cnn':False,'input_size':3},dueling=True,HER=True) : super(StochasticActorCriticNN,self).__init__() self.state_dim = state_dim self.action_dim = action_dim self.action_scaler = action_scaler self.dueling = dueling self.CNN = CNN # dictionnary with : # - 'input_size' : int # - 'use_cnn' : bool # if self.CNN['use_cnn'] : self.state_dim = self.CNN['input_size'] self.HER = HER if self.HER : self.state_dim *= 2 #Features : if self.CNN['use_cnn'] : self.conv1 = nn.Conv2d(self.state_dim,16, kernel_size=5, stride=2) self.bn1 = nn.BatchNorm2d(16) self.conv2 = nn.Conv2d(16,32, kernel_size=5, stride=2) self.bn2 = nn.BatchNorm2d(32) self.conv3 = nn.Conv2d(32,32, kernel_size=5, stride=2) self.bn3 = nn.BatchNorm2d(32) #self.featx = nn.Linear(448,self.nbr_actions) self.featx = nn.Linear(192,128) else : self.fc1 = nn.Linear(self.state_dim,512) self.bn1 = nn.BatchNorm1d(512) self.fc2 = nn.Linear(512,256) self.bn2 = nn.BatchNorm1d(256) #self.featx = nn.Linear(448,self.nbr_actions) self.featx = nn.Linear(256,128) self.featx.weight.data.uniform_(-EPS,EPS) # Critic network : ## state value path : if self.dueling : self.critic_Vhead = nn.Linear(128,1) else : self.critic_Vhead = nn.Linear(128,64) self.critic_Vhead.weight.data.uniform_(-EPS,EPS) ## action value path : self.critic_afc1 = nn.Linear(self.action_dim,256) self.critic_afc1.weight.data.uniform_(-EPS,EPS) self.critic_afc2 = nn.Linear(256,128) self.critic_afc2.weight.data.uniform_(-EPS,EPS) if self.dueling : self.critic_ahead = nn.Linear(256,128) self.critic_ahead.weight.data.uniform_(-EPS,EPS) else : self.critic_ahead = nn.Linear(256,64) self.critic_ahead.weight.data.uniform_(-EPS,EPS) #linear layer, after the concatenation of ahead and vhead : self.critic_final = nn.Linear(128,1) self.critic_final.weight.data.uniform_(-EPS,EPS) # Actor network : self.actor_final = nn.Linear(128,2*self.action_dim) self.actor_final.weight.data.uniform_(-EPS,EPS) def features(self,x) : if self.CNN['use_cnn'] : x1 = F.relu( self.bn1(self.conv1(x) ) ) x2 = F.relu( self.bn2(self.conv2(x1) ) ) x3 = F.relu( self.bn3(self.conv3(x2) ) ) x4 = x3.view( x3.size(0), -1) fx = F.relu( self.featx( x4) ) # batch x 128 else : x1 = F.relu( self.bn1(self.fc1(x) ) ) x2 = F.relu( self.bn2(self.fc2(x1) ) ) fx = F.relu( self.featx( x2) ) # batch x 128 return fx def evaluate(self, x,a) : fx = self.features(x) #V value : self.Vvalue = v = self.critic_Vhead( fx ) a1 = F.relu( self.critic_afc1(a) ) a2 = F.relu( self.critic_afc2(a1) ) # batch x 128 afx = torch.cat([ fx, a2], dim=1) # batch x 256 if self.dueling : self.Advantage = advantage = self.critic_ahead(afx) out = advantage + v else : advantage = self.critic_ahead(afx) concat = torch.cat( [ v,advantage], dim=1) out = self.critic_final(concat) return out def evaluateV(self, x) : fx = self.features(x) #V value : self.Vvalue = v = self.critic_Vhead( fx ) return v def act(self, x) : fx = self.features(x) xx = self.actor_final( fx ) meanxx, log_varxx = torch.chunk( xx, 2, dim=1) # scale the actions mean: unscaled = F.tanh(meanxx) self.mean = unscaled * self.action_scaler # log_var ; self.log_var = F.relu( log_varxx) #dist = Normal(self.mean, std=torch.sqrt(torch.exp(self.log_var)) ) self.dist = Normal(self.mean, log_var=self.log_var ) action = self.dist.sample() return action
class StochasticActorNN(nn.Module) : def __init__(self,state_dim=3,action_dim=2,action_scaler=2.0,CNN={'use_cnn':False,'input_size':3},HER=True,actfn= lambda x : F.leaky_relu(x, 0.1) ) : super(StochasticActorNN,self).__init__() self.state_dim = state_dim self.action_dim = action_dim self.action_scaler = action_scaler self.CNN = CNN # dictionnary with : # - 'input_size' : int # - 'use_cnn' : bool # if self.CNN['use_cnn'] : self.state_dim = self.CNN['input_size'] self.HER = HER if self.HER : self.state_dim *= 2 self.actfn = actfn #Features : if self.CNN['use_cnn'] : self.conv1 = nn.Conv2d(self.state_dim,16, kernel_size=5, stride=2) self.bn1 = nn.BatchNorm2d(16) self.conv2 = nn.Conv2d(16,32, kernel_size=5, stride=2) self.bn2 = nn.BatchNorm2d(32) self.conv3 = nn.Conv2d(32,32, kernel_size=5, stride=2) self.bn3 = nn.BatchNorm2d(32) #self.featx = nn.Linear(448,self.nbr_actions) #self.featx = nn.Linear(192,128) self.featx = nn.Linear(2592,128) else : self.fc1 = nn.Linear(self.state_dim,400) self.fc1.weight.data = init_weights(self.fc1.weight.data.size()) #self.bn1 = nn.BatchNorm1d(400) self.fc2 = nn.Linear(400,300) self.fc2.weight.data = init_weights(self.fc2.weight.data.size()) #self.bn2 = nn.BatchNorm1d(300) #self.fc3 = nn.Linear(256,128) #self.fc3.weight.data = init_weights(self.fc3.weight.data.size()) #self.bn3 = nn.BatchNorm1d(128) #self.featx = nn.Linear(448,self.nbr_actions) self.featx = nn.Linear(300,200) self.featx.weight.data = init_weights(self.featx.weight.data.size()) # Actor network : # normal distribution output : mean log_var self.actor_final = nn.Linear(200,2*self.action_dim) self.actor_final.weight.data.uniform_(-EPS,EPS) def features(self,x) : if self.CNN['use_cnn'] : x1 = F.relu( self.bn1(self.conv1(x) ) ) x2 = F.relu( self.bn2(self.conv2(x1) ) ) x3 = F.relu( self.bn3(self.conv3(x2) ) ) x4 = x3.view( x3.size(0), -1) #print(x4.size()) fx = F.relu( self.featx( x4) ) # batch x 128 else : #x1 = F.relu( self.bn1(self.fc1(x) ) ) #x1 = F.leaky_relu( self.fc1(x), 0.1 ) x1 = self.actfn( self.fc1(x) ) #x2 = F.relu( self.bn2(self.fc2(x1) ) ) #x2 = F.leaky_relu( self.fc2(x1), 0.1 ) x2 = self.actfn( self.fc2(x1) ) #x3 = F.relu( self.fc3(x2) ) #x3 = F.relu( self.bn3(self.fc3(x2) ) ) #fx = F.relu( self.featx(x3) ) #fx = F.leaky_relu( self.featx(x2), 0.1 ) fx = self.actfn( self.featx( x2) ) #fx = F.relu( self.featx( x1) ) # batch x 128 return fx def forward(self, x) : fx = self.features(x) xx = self.actor_final( fx ) meanxx, log_varxx = torch.chunk( xx, 2, dim=1) # scale the actions mean: unscaled = F.tanh(meanxx) self.mean = unscaled * self.action_scaler # log_var ; self.log_var = F.relu( log_varxx) #dist = Normal(self.mean, std=torch.sqrt(torch.exp(self.log_var)) ) self.dist = Normal(self.mean, log_var=self.log_var ) action = self.dist.sample() return action