def __init__( self, observation_specs: List[ObservationSpec], network_settings: NetworkSettings, action_spec: ActionSpec, ): super().__init__() self.normalize = network_settings.normalize self.use_lstm = network_settings.memory is not None self.h_size = network_settings.hidden_units self.m_size = (network_settings.memory.memory_size if network_settings.memory is not None else 0) self.action_spec = action_spec self.observation_encoder = ObservationEncoder( observation_specs, self.h_size, network_settings.vis_encode_type, self.normalize, ) self.processors = self.observation_encoder.processors # Modules for multi-agent self-attention obs_only_ent_size = self.observation_encoder.total_enc_size q_ent_size = (obs_only_ent_size + sum(self.action_spec.discrete_branches) + self.action_spec.continuous_size) attention_embeding_size = self.h_size self.obs_encoder = EntityEmbedding(obs_only_ent_size, None, attention_embeding_size) self.obs_action_encoder = EntityEmbedding(q_ent_size, None, attention_embeding_size) self.self_attn = ResidualSelfAttention(attention_embeding_size) self.linear_encoder = LinearEncoder( attention_embeding_size, network_settings.num_layers, self.h_size, kernel_gain=(0.125 / self.h_size)**0.5, ) if self.use_lstm: self.lstm = LSTM(self.h_size, self.m_size) else: self.lstm = None # type: ignore self._current_max_agents = torch.nn.Parameter(torch.as_tensor(1), requires_grad=False)
def test_predict_minimum_training(): # of 5 numbers, predict index of min np.random.seed(1336) torch.manual_seed(1336) n_k = 5 size = n_k + 1 embedding_size = 64 entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self transformer = ResidualSelfAttention(embedding_size) l_layer = LinearEncoder(embedding_size, 2, n_k) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam( list(entity_embedding.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 200 onehots = ModelUtils.actions_to_onehot( torch.range(0, n_k - 1).unsqueeze(1), [n_k])[0] onehots = onehots.expand((batch_size, -1, -1)) losses = [] for _ in range(400): num = np.random.randint(0, n_k) inp = torch.rand((batch_size, num + 1, 1)) with torch.no_grad(): # create the target : The minimum argmin = torch.argmin(inp, dim=1) argmin = argmin.squeeze() argmin = argmin.detach() sliced_oh = onehots[:, :num + 1] inp = torch.cat([inp, sliced_oh], dim=2) embeddings = entity_embedding(inp, inp) masks = get_zero_entities_mask([inp]) prediction = transformer(embeddings, masks) prediction = l_layer(prediction) ce = loss(prediction, argmin) losses.append(ce.item()) print(ce.item()) optimizer.zero_grad() ce.backward() optimizer.step() assert np.array(losses[-20:]).mean() < 0.1
def test_predict_closest_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 200 for _ in range(200): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = get_zero_entities_mask([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.02
def test_all_masking(mask_value): # We make sure that a mask of all zeros or all ones will not trigger an error np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 20 for _ in range(5): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = [torch.ones_like(key[:, :, 0]) * mask_value] prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 optimizer.zero_grad() error.backward() optimizer.step()
def get_encoder_for_obs( obs_spec: ObservationSpec, normalize: bool, h_size: int, attention_embedding_size: int, vis_encode_type: EncoderType, ) -> Tuple[nn.Module, int]: """ Returns the encoder and the size of the appropriate encoder. :param shape: Tuples that represent the observation dimension. :param normalize: Normalize all vector inputs. :param h_size: Number of hidden units per layer excluding attention layers. :param attention_embedding_size: Number of hidden units per attention layer. :param vis_encode_type: Type of visual encoder to use. """ shape = obs_spec.shape dim_prop = obs_spec.dimension_property # VISUAL if dim_prop in ModelUtils.VALID_VISUAL_PROP: visual_encoder_class = ModelUtils.get_encoder_for_type( vis_encode_type) return (visual_encoder_class(shape[0], shape[1], shape[2], h_size), h_size) # VECTOR if dim_prop in ModelUtils.VALID_VECTOR_PROP: return (VectorInput(shape[0], normalize), shape[0]) # VARIABLE LENGTH if dim_prop in ModelUtils.VALID_VAR_LEN_PROP: return ( EntityEmbedding( entity_size=shape[1], entity_num_max_elements=shape[0], embedding_size=attention_embedding_size, ), 0, ) # OTHER raise UnityTrainerException( f"Unsupported Sensor with specs {obs_spec}")