예제 #1
파일: NN.py 프로젝트: xuhuazhe/PYTORCH_RL
class StochasticActorCriticNN(nn.Module) :
	def __init__(self,state_dim=3,action_dim=2,action_scaler=2.0,CNN={'use_cnn':False,'input_size':3},dueling=True,HER=True) :
		self.state_dim = state_dim
		self.action_dim = action_dim
		self.action_scaler = action_scaler

		self.dueling = dueling

		self.CNN = CNN
		# dictionnary with :
		# - 'input_size' : int
		# - 'use_cnn' : bool
		if self.CNN['use_cnn'] :
			self.state_dim = self.CNN['input_size']

		self.HER = HER
		if self.HER :
			self.state_dim *= 2

		#Features :
		if self.CNN['use_cnn'] :
			self.conv1 = nn.Conv2d(self.state_dim,16, kernel_size=5, stride=2)
			self.bn1 = nn.BatchNorm2d(16)
			self.conv2 = nn.Conv2d(16,32, kernel_size=5, stride=2)
			self.bn2 = nn.BatchNorm2d(32)
			self.conv3 = nn.Conv2d(32,32, kernel_size=5, stride=2)
			self.bn3 = nn.BatchNorm2d(32)
			#self.featx = nn.Linear(448,self.nbr_actions)
			self.featx = nn.Linear(192,128)
		else :
			self.fc1 = nn.Linear(self.state_dim,512)
			self.bn1 = nn.BatchNorm1d(512)
			self.fc2 = nn.Linear(512,256)
			self.bn2 = nn.BatchNorm1d(256)
			#self.featx = nn.Linear(448,self.nbr_actions)
			self.featx = nn.Linear(256,128)


		# Critic network :
		## state value path :
		if self.dueling :
			self.critic_Vhead = nn.Linear(128,1)
		else :
			self.critic_Vhead = nn.Linear(128,64)

		## action value path :
		self.critic_afc1 = nn.Linear(self.action_dim,256)
		self.critic_afc2 = nn.Linear(256,128)

		if self.dueling :
			self.critic_ahead = nn.Linear(256,128)
		else :
			self.critic_ahead = nn.Linear(256,64)
			#linear layer, after the concatenation of ahead and vhead :
			self.critic_final = nn.Linear(128,1)

		# Actor network :
		self.actor_final = nn.Linear(128,2*self.action_dim)

	def features(self,x) :
		if self.CNN['use_cnn'] :
			x1 = F.relu( self.bn1(self.conv1(x) ) )
			x2 = F.relu( self.bn2(self.conv2(x1) ) )
			x3 = F.relu( self.bn3(self.conv3(x2) ) )
			x4 = x3.view( x3.size(0), -1)
			fx = F.relu( self.featx( x4) )
			# batch x 128 
		else :
			x1 = F.relu( self.bn1(self.fc1(x) ) )
			x2 = F.relu( self.bn2(self.fc2(x1) ) )
			fx = F.relu( self.featx( x2) )
			# batch x 128
		return fx

	def evaluate(self, x,a) :
		fx = self.features(x)

		#V value :
		self.Vvalue = v = self.critic_Vhead( fx )

		a1 = F.relu( self.critic_afc1(a) )
		a2 = F.relu( self.critic_afc2(a1) )
		# batch x 128
		afx = torch.cat([ fx, a2], dim=1)
		# batch x 256

		if self.dueling :
			self.Advantage = advantage = self.critic_ahead(afx)
			out = advantage + v
		else :
			advantage = self.critic_ahead(afx)
			concat = torch.cat( [ v,advantage], dim=1)
			out = self.critic_final(concat)

		return out

	def evaluateV(self, x) :
		fx = self.features(x)
		#V value :
		self.Vvalue = v = self.critic_Vhead( fx )
		return v

	def act(self, x) :
		fx = self.features(x)

		xx = self.actor_final( fx )
		meanxx, log_varxx = torch.chunk( xx, 2, dim=1)  
		# scale the actions mean:
		unscaled = F.tanh(meanxx)
		self.mean = unscaled * self.action_scaler

		# log_var ;
		self.log_var = F.relu( log_varxx)

		#dist = Normal(self.mean, std=torch.sqrt(torch.exp(self.log_var)) )
		self.dist = Normal(self.mean, log_var=self.log_var )
		action = self.dist.sample()

		return action
예제 #2
파일: NN.py 프로젝트: xuhuazhe/PYTORCH_RL
class StochasticActorNN(nn.Module) :
	def __init__(self,state_dim=3,action_dim=2,action_scaler=2.0,CNN={'use_cnn':False,'input_size':3},HER=True,actfn= lambda x : F.leaky_relu(x, 0.1) ) :
		self.state_dim = state_dim
		self.action_dim = action_dim
		self.action_scaler = action_scaler

		self.CNN = CNN
		# dictionnary with :
		# - 'input_size' : int
		# - 'use_cnn' : bool
		if self.CNN['use_cnn'] :
			self.state_dim = self.CNN['input_size']

		self.HER = HER
		if self.HER :
			self.state_dim *= 2

		self.actfn = actfn
		#Features :
		if self.CNN['use_cnn'] :
			self.conv1 = nn.Conv2d(self.state_dim,16, kernel_size=5, stride=2)
			self.bn1 = nn.BatchNorm2d(16)
			self.conv2 = nn.Conv2d(16,32, kernel_size=5, stride=2)
			self.bn2 = nn.BatchNorm2d(32)
			self.conv3 = nn.Conv2d(32,32, kernel_size=5, stride=2)
			self.bn3 = nn.BatchNorm2d(32)
			#self.featx = nn.Linear(448,self.nbr_actions)
			#self.featx = nn.Linear(192,128)
			self.featx = nn.Linear(2592,128)
		else :
			self.fc1 = nn.Linear(self.state_dim,400)
			self.fc1.weight.data = init_weights(self.fc1.weight.data.size())
			#self.bn1 = nn.BatchNorm1d(400)
			self.fc2 = nn.Linear(400,300)
			self.fc2.weight.data = init_weights(self.fc2.weight.data.size())	
			#self.bn2 = nn.BatchNorm1d(300)
			#self.fc3 = nn.Linear(256,128)
			#self.fc3.weight.data = init_weights(self.fc3.weight.data.size())	
			#self.bn3 = nn.BatchNorm1d(128)
			#self.featx = nn.Linear(448,self.nbr_actions)
			self.featx = nn.Linear(300,200)

		self.featx.weight.data = init_weights(self.featx.weight.data.size())

		# Actor network :
		# normal distribution output : mean log_var
		self.actor_final = nn.Linear(200,2*self.action_dim)

	def features(self,x) :
		if self.CNN['use_cnn'] :
			x1 = F.relu( self.bn1(self.conv1(x) ) )
			x2 = F.relu( self.bn2(self.conv2(x1) ) )
			x3 = F.relu( self.bn3(self.conv3(x2) ) )
			x4 = x3.view( x3.size(0), -1)
			fx = F.relu( self.featx( x4) )
			# batch x 128 
		else :
			#x1 = F.relu( self.bn1(self.fc1(x) ) )
			#x1 = F.leaky_relu( self.fc1(x), 0.1 )
			x1 = self.actfn( self.fc1(x) )
			#x2 = F.relu( self.bn2(self.fc2(x1) ) )
			#x2 = F.leaky_relu( self.fc2(x1), 0.1  )
			x2 = self.actfn( self.fc2(x1) )
			#x3 = F.relu( self.fc3(x2)  )
			#x3 = F.relu( self.bn3(self.fc3(x2) ) )
			#fx = F.relu( self.featx(x3) )
			#fx = F.leaky_relu( self.featx(x2), 0.1  )
			fx = self.actfn( self.featx( x2) )
			#fx = F.relu( self.featx( x1) )
			# batch x 128

		return fx

	def forward(self, x) :
		fx = self.features(x)

		xx = self.actor_final( fx )
		meanxx, log_varxx = torch.chunk( xx, 2, dim=1)  
		# scale the actions mean:
		unscaled = F.tanh(meanxx)
		self.mean = unscaled * self.action_scaler

		# log_var ;
		self.log_var = F.relu( log_varxx)

		#dist = Normal(self.mean, std=torch.sqrt(torch.exp(self.log_var)) )
		self.dist = Normal(self.mean, log_var=self.log_var )
		action = self.dist.sample()

		return action