def test_zero_mask_layer(): batch_size, size = 10, 30 def generate_input_helper(pattern): _input = torch.zeros((batch_size, 0, size)) for i in range(len(pattern)): if i % 2 == 0: _input = torch.cat( [_input, torch.rand((batch_size, pattern[i], size))], dim=1) else: _input = torch.cat( [_input, torch.zeros((batch_size, pattern[i], size))], dim=1) return _input masking_pattern_1 = [3, 2, 3, 4] masking_pattern_2 = [5, 7, 8, 2] input_1 = generate_input_helper(masking_pattern_1) input_2 = generate_input_helper(masking_pattern_2) masks = get_zero_entities_mask([input_1, input_2]) assert len(masks) == 2 masks_1 = masks[0] masks_2 = masks[1] assert masks_1.shape == (batch_size, sum(masking_pattern_1)) assert masks_2.shape == (batch_size, sum(masking_pattern_2)) for i in masking_pattern_1: assert masks_1[0, 1] == 0 if i % 2 == 0 else 1 for i in masking_pattern_2: assert masks_2[0, 1] == 0 if i % 2 == 0 else 1
def forward( self, inputs: List[torch.Tensor], actions: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: encodes = [] var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = [] for idx, processor in enumerate(self.processors): if not isinstance(processor, EntityEmbedding): # The input can be encoded without having to process other inputs obs_input = inputs[idx] processed_obs = processor(obs_input) encodes.append(processed_obs) else: var_len_processor_inputs.append((processor, inputs[idx])) if len(encodes) != 0: encoded_self = torch.cat(encodes, dim=1) input_exist = True else: input_exist = False if len(var_len_processor_inputs) > 0: # Some inputs need to be processed with a variable length encoder masks = get_zero_entities_mask( [p_i[1] for p_i in var_len_processor_inputs]) embeddings: List[torch.Tensor] = [] processed_self = self.x_self_encoder( encoded_self) if input_exist else None for processor, var_len_input in var_len_processor_inputs: embeddings.append(processor(processed_self, var_len_input)) qkv = torch.cat(embeddings, dim=1) attention_embedding = self.rsa(qkv, masks) if not input_exist: encoded_self = torch.cat([attention_embedding], dim=1) input_exist = True else: encoded_self = torch.cat([encoded_self, attention_embedding], dim=1) if not input_exist: raise Exception( "The trainer was unable to process any of the provided inputs. " "Make sure the trained agents has at least one sensor attached to them." ) if actions is not None: encoded_self = torch.cat([encoded_self, actions], dim=1) encoding = self.linear_encoder(encoded_self) if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) return encoding, memories
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: """ Encode observations using a list of processors and an RSA. :param inputs: List of Tensors corresponding to a set of obs. :param processors: a ModuleList of the input processors to be applied to these obs. :param rsa: Optionally, an RSA to use for variable length obs. :param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.). """ encodes = [] var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = [] for idx, processor in enumerate(self.processors): if not isinstance(processor, EntityEmbedding): # The input can be encoded without having to process other inputs obs_input = inputs[idx] processed_obs = processor(obs_input) encodes.append(processed_obs) else: var_len_processor_inputs.append((processor, inputs[idx])) if len(encodes) != 0: encoded_self = torch.cat(encodes, dim=1) input_exist = True else: input_exist = False if len(var_len_processor_inputs) > 0 and self.rsa is not None: # Some inputs need to be processed with a variable length encoder masks = get_zero_entities_mask( [p_i[1] for p_i in var_len_processor_inputs]) embeddings: List[torch.Tensor] = [] processed_self = (self.x_self_encoder(encoded_self) if input_exist and self.x_self_encoder is not None else None) for processor, var_len_input in var_len_processor_inputs: embeddings.append(processor(processed_self, var_len_input)) qkv = torch.cat(embeddings, dim=1) attention_embedding = self.rsa(qkv, masks) if not input_exist: encoded_self = torch.cat([attention_embedding], dim=1) input_exist = True else: encoded_self = torch.cat([encoded_self, attention_embedding], dim=1) if not input_exist: raise UnityTrainerException( "The trainer was unable to process any of the provided inputs. " "Make sure the trained agents has at least one sensor attached to them." ) return encoded_self
def test_predict_minimum_training(): # of 5 numbers, predict index of min np.random.seed(1336) torch.manual_seed(1336) n_k = 5 size = n_k + 1 embedding_size = 64 entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self transformer = ResidualSelfAttention(embedding_size) l_layer = LinearEncoder(embedding_size, 2, n_k) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam( list(entity_embedding.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 200 onehots = ModelUtils.actions_to_onehot( torch.range(0, n_k - 1).unsqueeze(1), [n_k])[0] onehots = onehots.expand((batch_size, -1, -1)) losses = [] for _ in range(400): num = np.random.randint(0, n_k) inp = torch.rand((batch_size, num + 1, 1)) with torch.no_grad(): # create the target : The minimum argmin = torch.argmin(inp, dim=1) argmin = argmin.squeeze() argmin = argmin.detach() sliced_oh = onehots[:, :num + 1] inp = torch.cat([inp, sliced_oh], dim=2) embeddings = entity_embedding(inp, inp) masks = get_zero_entities_mask([inp]) prediction = transformer(embeddings, masks) prediction = l_layer(prediction) ce = loss(prediction, argmin) losses.append(ce.item()) print(ce.item()) optimizer.zero_grad() ce.backward() optimizer.step() assert np.array(losses[-20:]).mean() < 0.1
def test_predict_closest_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 200 for _ in range(200): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = get_zero_entities_mask([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.02