Пример #1
0
	def __init__(self,phi_layers,rho_layers,activation,env_name):
		super(DeepSetObstacles, self).__init__()
		
		self.phi = FeedForward(phi_layers,activation)
		self.rho = FeedForward(rho_layers,activation)
		self.env_name = env_name
		self.device = torch.device('cpu')
Пример #2
0
	def __init__(self,param,learning_module):
		super(Empty_Net, self).__init__()

		if learning_module is "DeepSet":

			self.model_neighbors = DeepSet(
				param.il_phi_network_architecture,
				param.il_rho_network_architecture,
				param.il_network_activation,
				param.env_name
				)
			self.model_obstacles = DeepSetObstacles(
				param.il_phi_obs_network_architecture,
				param.il_rho_obs_network_architecture,
				param.il_network_activation,
				param.env_name
				)
			self.psi = FeedForward(
				param.il_psi_network_architecture,
				param.il_network_activation)

		self.param = param
		self.device = torch.device('cpu')

		self.dim_neighbor = param.il_phi_network_architecture[0].in_features
		self.dim_action = param.il_psi_network_architecture[-1].out_features
		self.dim_state = param.il_psi_network_architecture[0].in_features - \
						param.il_rho_network_architecture[-1].out_features - \
						param.il_rho_obs_network_architecture[-1].out_features
Пример #3
0
class DeepSet(nn.Module):

	def __init__(self,phi_layers,rho_layers,activation,env_name):
		super(DeepSet, self).__init__()
		
		self.phi = FeedForward(phi_layers,activation)
		self.rho = FeedForward(rho_layers,activation)
		self.env_name = env_name
		self.device = torch.device('cpu')

	def to(self, device):
		self.device = device
		self.phi.to(device)
		self.rho.to(device)
		return super().to(device)

	def export_to_onnx(self, filename):
		self.phi.export_to_onnx("{}_phi".format(filename))
		self.rho.export_to_onnx("{}_rho".format(filename))

	def forward(self,x):
		X = torch.zeros((len(x),self.rho.in_dim), device=self.device)
		num_elements = int(x.size()[1] / self.phi.in_dim)
		for i in range(num_elements):
			X += self.phi(x[:,i*self.phi.in_dim:(i+1)*self.phi.in_dim])
		return self.rho(X)
Пример #4
0
class DeepSet(nn.Module):

	def __init__(self,phi_layers,rho_layers,activation,env_name):
		super(DeepSet, self).__init__()
		
		self.phi = FeedForward(phi_layers,activation)
		self.rho = FeedForward(rho_layers,activation)
		self.env_name = env_name
		self.device = torch.device('cpu')

	def to(self, device):
		self.device = device
		self.phi.to(device)
		self.rho.to(device)
		return super().to(device)

	def export_to_onnx(self, filename):
		self.phi.export_to_onnx("{}_phi".format(filename))
		self.rho.export_to_onnx("{}_rho".format(filename))

	def forward(self,x):

		if self.env_name == 'Consensus':
			return self.consensus_forward(x)
		elif self.env_name in ['SingleIntegrator','DoubleIntegrator','SingleIntegratorVelSensing']:
			return self.si_forward(x)

	def consensus_forward(self,x):

		# x is a relative neighbor histories 
		# RHO_IN = torch.zeros((1,self.rho.in_dim))

		summ = torch.zeros((self.phi.out_dim))
		for step_rnh, rnh in enumerate(x):

			if step_rnh == 0:
				self_history = np.array(rnh, ndmin=1)
				self_history = torch.from_numpy(self_history).float()
			else:
				rnh = np.array(rnh, ndmin=1)
				rnh = torch.from_numpy(rnh).float()
				summ += self.phi(rnh)

		# print(self_history.shape)
		# print(summ.shape)
		# print(torch.cat((self_history,summ)))
		# exit()

		RHO_IN = torch.cat((self_history,summ))
		RHO_OUT = self.rho(RHO_IN)
		return RHO_OUT

	def si_forward(self,x):
		# print(x)
		X = torch.zeros((len(x),self.rho.in_dim), device=self.device)
		num_elements = int(x.size()[1] / self.phi.in_dim)
		for i in range(num_elements):
			X += self.phi(x[:,i*self.phi.in_dim:(i+1)*self.phi.in_dim])
		return self.rho(X)
Пример #5
0
    def __init__(self, param):
        super(Barrier_Net, self).__init__()
        self.model_neighbors = DeepSet(param.il_phi_network_architecture,
                                       param.il_rho_network_architecture,
                                       param.il_network_activation,
                                       param.env_name)
        self.model_obstacles = DeepSetObstacles(
            param.il_phi_obs_network_architecture,
            param.il_rho_obs_network_architecture, param.il_network_activation,
            param.env_name)
        self.psi = FeedForward(param.il_psi_network_architecture,
                               param.il_network_activation)

        self.param = param
        self.bf = Barrier_Fncs(param)
        self.layers = param.il_psi_network_architecture
        self.activation = param.il_network_activation
        self.device = torch.device('cpu')

        self.dim_neighbor = param.il_phi_network_architecture[0].in_features
        self.dim_action = param.il_psi_network_architecture[-1].out_features
        self.dim_state = param.il_psi_network_architecture[0].in_features - \
            param.il_rho_network_architecture[-1].out_features - \
            param.il_rho_obs_network_architecture[-1].out_features
Пример #6
0
class DeepSetObstacles(nn.Module):

	def __init__(self,phi_layers,rho_layers,activation,env_name):
		super(DeepSetObstacles, self).__init__()
		
		self.phi = FeedForward(phi_layers,activation)
		self.rho = FeedForward(rho_layers,activation)
		self.env_name = env_name
		self.device = torch.device('cpu')

	def to(self, device):
		self.device = device
		self.phi.to(device)
		self.rho.to(device)
		return super().to(device)

	def export_to_onnx(self, filename):
		self.phi.export_to_onnx("{}_phi".format(filename))
		self.rho.export_to_onnx("{}_rho".format(filename))

	def forward(self, x, vel):
		# print(x)
		X = torch.zeros((len(x),self.rho.in_dim), device=self.device)
		if self.phi.in_dim == 4:
			# In this case, we also add our own velocity information
			num_elements = int(x.size()[1] / 2)
			for i in range(num_elements):
				X += self.phi(torch.cat((x[:,i*2:(i+1)*2], vel), dim=1))
			return self.rho(X)
		elif self.phi.in_dim == 2:
			# regular case: only relative positions
			num_elements = int(x.size()[1] / self.phi.in_dim)
			for i in range(num_elements):
				X += self.phi(x[:,i*self.phi.in_dim:(i+1)*self.phi.in_dim])
			return self.rho(X)
		else:
			print('unknown phi.in_dim!', self.phi.in_dim)
			exit()
Пример #7
0
class Empty_Net(nn.Module):
    def __init__(self, param):
        super(Empty_Net, self).__init__()

        self.model_neighbors = DeepSet(param.il_phi_network_architecture,
                                       param.il_rho_network_architecture,
                                       param.il_network_activation,
                                       param.env_name)
        self.model_obstacles = DeepSetObstacles(
            param.il_phi_obs_network_architecture,
            param.il_rho_obs_network_architecture, param.il_network_activation,
            param.env_name)
        self.psi = FeedForward(param.il_psi_network_architecture,
                               param.il_network_activation)

        self.param = param
        self.device = torch.device('cpu')

        self.dim_neighbor = param.il_phi_network_architecture[0].in_features
        self.dim_action = param.il_psi_network_architecture[-1].out_features
        self.dim_state = param.il_psi_network_architecture[0].in_features - \
            param.il_rho_network_architecture[-1].out_features - \
            param.il_rho_obs_network_architecture[-1].out_features

    def to(self, device):
        self.device = device
        self.model_neighbors.to(device)
        self.model_obstacles.to(device)
        self.psi.to(device)
        return super().to(device)

    def save_weights(self, filename):
        torch.save(
            {
                'neighbors_phi_state_dict':
                self.model_neighbors.phi.state_dict(),
                'neighbors_rho_state_dict':
                self.model_neighbors.rho.state_dict(),
                'obstacles_phi_state_dict':
                self.model_obstacles.phi.state_dict(),
                'obstacles_rho_state_dict':
                self.model_obstacles.rho.state_dict(),
                'psi_state_dict': self.psi.state_dict(),
            }, filename)

    def load_weights(self, filename):
        checkpoint = torch.load(filename)
        self.model_neighbors.phi.load_state_dict(
            checkpoint['neighbors_phi_state_dict'])
        self.model_neighbors.rho.load_state_dict(
            checkpoint['neighbors_rho_state_dict'])
        self.model_obstacles.phi.load_state_dict(
            checkpoint['obstacles_phi_state_dict'])
        self.model_obstacles.rho.load_state_dict(
            checkpoint['obstacles_rho_state_dict'])
        self.psi.load_state_dict(checkpoint['psi_state_dict'])

    def policy(self, x):

        # inputs observation from all agents...
        # outputs policy for all agents
        grouping = dict()
        for i, x_i in enumerate(x):
            key = (int(x_i[0][0]), x_i.shape[1])
            if key in grouping:
                grouping[key].append(i)
            else:
                grouping[key] = [i]

        A = np.empty((len(x), self.dim_action))
        for key, idxs in grouping.items():
            batch = torch.Tensor([x[idx][0] for idx in idxs])
            a = self(batch)
            a = a.detach().numpy()
            for i, idx in enumerate(idxs):
                A[idx, :] = a[i]

        return A

    def export_to_onnx(self, filename):
        self.model_neighbors.export_to_onnx("{}_neighbors".format(filename))
        self.model_obstacles.export_to_onnx("{}_obstacles".format(filename))
        self.psi.export_to_onnx("{}_psi".format(filename))

    def get_num_neighbors(self, x):
        return int(x[0, 0])

    def get_num_obstacles(self, x):
        nn = self.get_num_neighbors(x)
        return int((x.shape[1] - 1 - self.dim_state - nn * self.dim_neighbor) /
                   2)  # number of obstacles

    def get_agent_idx_all(self, x):
        nn = self.get_num_neighbors(x)
        idx = np.arange(1 + self.dim_state,
                        1 + self.dim_state + self.dim_neighbor * nn,
                        dtype=int)
        return idx

    def get_obstacle_idx_all(self, x):
        nn = self.get_num_neighbors(x)
        idx = np.arange((1 + self.dim_state) + self.dim_neighbor * nn,
                        x.size()[1],
                        dtype=int)
        return idx

    def get_goal_idx(self, x):
        idx = np.arange(1, 1 + self.dim_state, dtype=int)
        return idx

    def __call__(self, x):
        # batches are grouped by number of neighbors (i.e., each batch has data with the same number of neighbors)
        # x is a 2D tensor, where the columns are: relative_goal, relative_neighbors, ...

        num_neighbors = int(x[0, 0])  #int((x.size()[1]-4)/4)
        num_obstacles = int(
            (x.size()[1] -
             (1 + self.dim_state + self.dim_neighbor * num_neighbors)) / 2)

        rho_neighbors = self.model_neighbors.forward(
            x[:, self.get_agent_idx_all(x)])
        # rho_obstacles = self.model_obstacles.forward(x[:,self.get_obstacle_idx_all(x)])
        vel = -x[:, 3:5]
        rho_obstacles = self.model_obstacles.forward(
            x[:, self.get_obstacle_idx_all(x)], vel)

        g = x[:, self.get_goal_idx(x)]

        x = torch.cat((rho_neighbors, rho_obstacles, g), 1)
        x = self.psi(x)

        return x
Пример #8
0
class Barrier_Net(nn.Module):
    def __init__(self, param, learning_module):
        super(Barrier_Net, self).__init__()
        self.model_neighbors = DeepSet(param.il_phi_network_architecture,
                                       param.il_rho_network_architecture,
                                       param.il_network_activation,
                                       param.env_name)
        self.model_obstacles = DeepSetObstacles(
            param.il_phi_obs_network_architecture,
            param.il_rho_obs_network_architecture, param.il_network_activation,
            param.env_name)
        self.psi = FeedForward(param.il_psi_network_architecture,
                               param.il_network_activation)

        self.param = param
        self.bf = Barrier_Fncs(param)
        self.layers = param.il_psi_network_architecture
        self.activation = param.il_network_activation
        self.device = torch.device('cpu')

        self.dim_neighbor = param.il_phi_network_architecture[0].in_features
        self.dim_action = param.il_psi_network_architecture[-1].out_features
        self.dim_state = param.il_psi_network_architecture[0].in_features - \
            param.il_rho_network_architecture[-1].out_features - \
            param.il_rho_obs_network_architecture[-1].out_features

    def to(self, device):
        self.device = device
        self.model_neighbors.to(device)
        self.model_obstacles.to(device)
        self.psi.to(device)
        self.bf.to(device)
        return super().to(device)

    def save_weights(self, filename):
        torch.save(
            {
                'neighbors_phi_state_dict':
                self.model_neighbors.phi.state_dict(),
                'neighbors_rho_state_dict':
                self.model_neighbors.rho.state_dict(),
                'obstacles_phi_state_dict':
                self.model_obstacles.phi.state_dict(),
                'obstacles_rho_state_dict':
                self.model_obstacles.rho.state_dict(),
                'psi_state_dict': self.psi.state_dict(),
            }, filename)

    def load_weights(self, filename):
        checkpoint = torch.load(filename)
        self.model_neighbors.phi.load_state_dict(
            checkpoint['neighbors_phi_state_dict'])
        self.model_neighbors.rho.load_state_dict(
            checkpoint['neighbors_rho_state_dict'])
        self.model_obstacles.phi.load_state_dict(
            checkpoint['obstacles_phi_state_dict'])
        self.model_obstacles.rho.load_state_dict(
            checkpoint['obstacles_rho_state_dict'])
        self.psi.load_state_dict(checkpoint['psi_state_dict'])

    def policy(self, x):

        if self.param.rollout_batch_on:
            grouping = dict()
            for i, x_i in enumerate(x):
                key = (int(x_i[0][0]), x_i.shape[1])
                if key in grouping:
                    grouping[key].append(i)
                else:
                    grouping[key] = [i]

            if len(grouping) < len(x):
                A = np.empty((len(x), self.dim_action))
                for key, idxs in grouping.items():
                    batch = torch.Tensor([x[idx][0] for idx in idxs])
                    a = self(batch)
                    a = a.detach().numpy()
                    for i, idx in enumerate(idxs):
                        A[idx, :] = a[i]

                return A
            else:
                A = np.empty((len(x), self.dim_action))
                for i, x_i in enumerate(x):
                    a_i = self(x_i)
                    A[i, :] = a_i
                return A

        else:
            A = np.empty((len(x), self.dim_action))
            for i, x_i in enumerate(x):
                a_i = self(x_i)
                A[i, :] = a_i
            return A

    def export_to_onnx(self, filename):
        self.model_neighbors.export_to_onnx("{}_neighbors".format(filename))
        self.model_obstacles.export_to_onnx("{}_obstacles".format(filename))
        self.psi.export_to_onnx("{}_psi".format(filename))

    def __call__(self, x):

        if type(x) == torch.Tensor:

            if self.param.safety == "potential":
                P, H = self.bf.torch_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = -1 * self.param.kp * self.bf.torch_get_grad_phi(
                    x, P, H)

                empty_action = self.empty(x)
                empty_action = self.bf.torch_scale(empty_action,
                                                   self.param.pi_max)

                adaptive_scaling = self.bf.torch_get_adaptive_scaling_si(
                    x, empty_action, barrier_action, P, H)
                action = torch.mul(adaptive_scaling,
                                   empty_action) + barrier_action
                action = self.bf.torch_scale(action, self.param.a_max)

            elif self.param.safety == "fdbk_si":
                P, H = self.bf.torch_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.torch_fdbk_si(x, P, H)

                empty_action = self.empty(x)
                empty_action = self.bf.torch_scale(empty_action,
                                                   self.param.pi_max)

                adaptive_scaling = self.bf.torch_get_adaptive_scaling_si(
                    x, empty_action, barrier_action, P, H)
                action = torch.mul(adaptive_scaling,
                                   empty_action) + barrier_action
                action = self.bf.torch_scale(action, self.param.a_max)

            elif self.param.safety == "fdbk_di":

                P, H = self.bf.torch_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.torch_fdbk_di(x, P, H)

                empty_action = self.empty(x)
                empty_action = self.bf.torch_scale(empty_action,
                                                   self.param.pi_max)

                adaptive_scaling = self.bf.torch_get_adaptive_scaling_di(
                    x, empty_action, barrier_action, P, H)
                action = torch.mul(adaptive_scaling,
                                   empty_action) + barrier_action
                action = self.bf.torch_scale(action, self.param.a_max)

            elif self.param.safety == "cf_si":

                P, H = self.bf.torch_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.torch_fdbk_si(x, P, H)

                empty_action = self.empty(x)
                empty_action = self.bf.torch_scale(empty_action,
                                                   self.param.pi_max)

                cf_alpha = self.bf.torch_get_cf_si(x, P, H, empty_action,
                                                   barrier_action)
                action = torch.mul(cf_alpha, empty_action) + torch.mul(
                    1 - cf_alpha, barrier_action)
                action = self.bf.torch_scale(action, self.param.a_max)

            elif self.param.safety == "cf_si_2":

                P, H = self.bf.torch_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.torch_fdbk_si(x, P, H)

                empty_action = self.empty(x)
                empty_action = self.bf.torch_scale(empty_action,
                                                   self.param.pi_max)

                cf_alpha = self.bf.torch_get_cf_si_2(x, empty_action,
                                                     barrier_action, P, H)
                action = torch.mul(cf_alpha, empty_action) + torch.mul(
                    1 - cf_alpha, barrier_action)
                action = self.bf.torch_scale(action, self.param.a_max)

            elif self.param.safety == "cf_di":

                P, H = self.bf.torch_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.torch_fdbk_di(x, P, H)

                empty_action = self.empty(x)
                empty_action = self.bf.torch_scale(empty_action,
                                                   self.param.pi_max)

                cf_alpha = self.bf.torch_get_cf_di(x, P, H, empty_action,
                                                   barrier_action)
                action = torch.mul(cf_alpha, empty_action) + torch.mul(
                    1 - cf_alpha, barrier_action)
                action = self.bf.torch_scale(action, self.param.a_max)

            elif self.param.safety == "cf_di_2":

                P, H = self.bf.torch_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.torch_fdbk_di(x, P, H)

                empty_action = self.empty(x)
                empty_action = self.bf.torch_scale(empty_action,
                                                   self.param.pi_max)

                cf_alpha = self.bf.torch_get_cf_di_2(x, empty_action,
                                                     barrier_action, P, H)
                action = torch.mul(cf_alpha, empty_action) + torch.mul(
                    1 - cf_alpha, barrier_action)
                action = self.bf.torch_scale(action, self.param.a_max)

            else:
                exit('self.param.safety: {} not recognized'.format(
                    self.param.safety))

        elif type(x) is np.ndarray:

            if self.param.safety == "potential":
                P, H = self.bf.numpy_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = -1 * self.param.b_gamma * self.bf.numpy_get_grad_phi(
                    x, P, H)

                empty_action = self.empty(
                    torch.tensor(x).float()).detach().numpy()
                empty_action = self.bf.numpy_scale(empty_action,
                                                   self.param.pi_max)

                adaptive_scaling = self.bf.numpy_get_adaptive_scaling_si(
                    x, empty_action, barrier_action, P, H)
                action = adaptive_scaling * empty_action + barrier_action
                action = self.bf.numpy_scale(action, self.param.a_max)

            elif self.param.safety == "fdbk_si":

                P, H = self.bf.numpy_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.numpy_fdbk_si(x, P, H)

                empty_action = self.empty(
                    torch.tensor(x).float()).detach().numpy()
                empty_action = self.bf.numpy_scale(empty_action,
                                                   self.param.pi_max)

                adaptive_scaling = self.bf.numpy_get_adaptive_scaling_si(
                    x, empty_action, barrier_action, P, H)
                action = adaptive_scaling * empty_action + barrier_action
                action = self.bf.numpy_scale(action, self.param.a_max)

            elif self.param.safety == "fdbk_di":

                P, H = self.bf.numpy_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.numpy_fdbk_di(x, P, H)

                empty_action = self.empty(
                    torch.tensor(x).float()).detach().numpy()
                empty_action = self.bf.numpy_scale(empty_action,
                                                   self.param.pi_max)

                adaptive_scaling = self.bf.numpy_get_adaptive_scaling_di(
                    x, empty_action, barrier_action, P, H)
                action = adaptive_scaling * empty_action + barrier_action
                action = self.bf.numpy_scale(action, self.param.a_max)

            elif self.param.safety == "cf_si":

                P, H = self.bf.numpy_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.numpy_fdbk_si(x, P, H)

                empty_action = self.empty(
                    torch.tensor(x).float()).detach().numpy()
                empty_action = self.bf.numpy_scale(empty_action,
                                                   self.param.pi_max)

                cf_alpha = self.bf.numpy_get_cf_si(x, P, H, empty_action,
                                                   barrier_action)
                action = cf_alpha * empty_action + (1 -
                                                    cf_alpha) * barrier_action
                action = self.bf.numpy_scale(action, self.param.a_max)

            elif self.param.safety == "cf_si_2":

                P, H = self.bf.numpy_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.numpy_fdbk_si(x, P, H)

                empty_action = self.empty(
                    torch.tensor(x).float()).detach().numpy()
                empty_action = self.bf.numpy_scale(empty_action,
                                                   self.param.pi_max)

                cf_alpha = self.bf.numpy_get_cf_si_2(x, P, H, empty_action,
                                                     barrier_action)
                action = cf_alpha * empty_action + (1 -
                                                    cf_alpha) * barrier_action
                action = self.bf.numpy_scale(action, self.param.a_max)

            elif self.param.safety == "cf_di":

                P, H = self.bf.numpy_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.numpy_fdbk_di(x, P, H)

                empty_action = self.empty(
                    torch.tensor(x).float()).detach().numpy()
                empty_action = self.bf.numpy_scale(empty_action,
                                                   self.param.pi_max)

                cf_alpha = self.bf.numpy_get_cf_di(x, P, H, empty_action,
                                                   barrier_action)
                action = cf_alpha * empty_action + (1 -
                                                    cf_alpha) * barrier_action
                action = self.bf.numpy_scale(action, self.param.a_max)

            elif self.param.safety == "cf_di_2":

                P, H = self.bf.numpy_get_relative_positions_and_safety_functions(
                    x)
                barrier_action = self.bf.numpy_fdbk_di(x, P, H)

                empty_action = self.empty(
                    torch.tensor(x).float()).detach().numpy()
                empty_action = self.bf.numpy_scale(empty_action,
                                                   self.param.pi_max)

                cf_alpha = self.bf.numpy_get_cf_di_2(x, P, H, empty_action,
                                                     barrier_action)
                action = cf_alpha * empty_action + (1 -
                                                    cf_alpha) * barrier_action
                action = self.bf.numpy_scale(action, self.param.a_max)

            else:
                exit('self.param.safety: {} not recognized'.format(
                    self.param.safety))

        else:
            exit('type(x) not recognized: ', type(x))

        return action

    def empty(self, x):
        # batches are grouped by number of neighbors (i.e., each batch has data with the same number of neighbors)
        # x is a 2D tensor, where the columns are: relative_goal, relative_neighbors, ...

        num_neighbors = int(x[0, 0])  #int((x.size()[1]-4)/4)
        num_obstacles = int(
            (x.size()[1] -
             (1 + self.dim_state + self.dim_neighbor * num_neighbors)) / 2)

        rho_neighbors = self.model_neighbors.forward(
            x[:, self.bf.get_agent_idx_all(x)])
        # rho_obstacles = self.model_obstacles.forward(x[:,self.bf.get_obstacle_idx_all(x)])

        vel = -x[:, 3:5]
        rho_obstacles = self.model_obstacles.forward(
            x[:, self.bf.get_obstacle_idx_all(x)], vel)

        g = x[:, self.bf.get_goal_idx(x)]

        x = torch.cat((rho_neighbors, rho_obstacles, g), 1)
        x = self.psi(x)
        return x