Ejemplo n.º 1
0
def test_predict_with_condition(num_cond_layers):
    np.random.seed(1336)
    torch.manual_seed(1336)
    input_size, goal_size, h, num_normal_layers = 10, 1, 16, 1

    conditional_enc = ConditionalEncoder(
        input_size, goal_size, h, num_normal_layers + num_cond_layers, num_cond_layers
    )
    l_layer = linear_layer(h, 1)

    optimizer = torch.optim.Adam(
        list(conditional_enc.parameters()) + list(l_layer.parameters()), lr=0.001
    )
    batch_size = 200
    for _ in range(300):
        input_tensor = torch.rand((batch_size, input_size))
        goal_tensor = (torch.rand((batch_size, goal_size)) > 0.5).float()
        # If the goal is 1: do the sum of the inputs, else, return 0
        target = torch.sum(input_tensor, dim=1, keepdim=True) * goal_tensor
        target.detach()
        prediction = l_layer(conditional_enc(input_tensor, goal_tensor))
        error = torch.mean((prediction - target) ** 2, dim=1)
        error = torch.mean(error) / 2

        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.02
Ejemplo n.º 2
0
 def compute_gradient_magnitude(self, policy_batch: AgentBuffer,
                                expert_batch: AgentBuffer) -> torch.Tensor:
     """
     Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
     for off-policy. Compute gradients w.r.t randomly interpolated input.
     """
     policy_inputs = self.get_state_inputs(policy_batch)
     expert_inputs = self.get_state_inputs(expert_batch)
     interp_inputs = []
     for policy_input, expert_input in zip(policy_inputs, expert_inputs):
         obs_epsilon = torch.rand(policy_input.shape)
         interp_input = obs_epsilon * policy_input + (
             1 - obs_epsilon) * expert_input
         interp_input.requires_grad = True  # For gradient calculation
         interp_inputs.append(interp_input)
     if self._settings.use_actions:
         policy_action = self.get_action_input(policy_batch)
         expert_action = self.get_action_input(expert_batch)
         action_epsilon = torch.rand(policy_action.shape)
         policy_dones = torch.as_tensor(policy_batch[BufferKey.DONE],
                                        dtype=torch.float).unsqueeze(1)
         expert_dones = torch.as_tensor(expert_batch[BufferKey.DONE],
                                        dtype=torch.float).unsqueeze(1)
         dones_epsilon = torch.rand(policy_dones.shape)
         action_inputs = torch.cat(
             [
                 action_epsilon * policy_action +
                 (1 - action_epsilon) * expert_action,
                 dones_epsilon * policy_dones +
                 (1 - dones_epsilon) * expert_dones,
             ],
             dim=1,
         )
         action_inputs.requires_grad = True
         hidden, _ = self.encoder(interp_inputs, action_inputs)
         encoder_input = tuple(interp_inputs + [action_inputs])
     else:
         hidden, _ = self.encoder(interp_inputs)
         encoder_input = tuple(interp_inputs)
     if self._settings.use_vail:
         use_vail_noise = True
         z_mu = self._z_mu_layer(hidden)
         hidden = z_mu + torch.randn_like(
             z_mu) * self._z_sigma * use_vail_noise
     estimate = self._estimator(hidden).squeeze(1).sum()
     gradient = torch.autograd.grad(estimate,
                                    encoder_input,
                                    create_graph=True)[0]
     # Norm's gradient could be NaN at 0. Use our own safe_norm
     safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
     gradient_mag = torch.mean((safe_norm - 1)**2)
     return gradient_mag
Ejemplo n.º 3
0
def test_simple_transformer_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size)
    transformer = ResidualSelfAttention(embedding_size, [n_k])
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(list(transformer.parameters()) +
                                 list(l_layer.parameters()),
                                 lr=0.001)
    batch_size = 200
    point_range = 3
    init_error = -1.0
    for _ in range(250):
        center = torch.rand((batch_size, size)) * point_range * 2 - point_range
        key = torch.rand(
            (batch_size, n_k, size)) * point_range * 2 - point_range
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, [key])
        masks = EntityEmbeddings.get_masks([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        if init_error == -1.0:
            init_error = error.item()
        else:
            assert error.item() < init_error
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.3
Ejemplo n.º 4
0
    def compute_gradient_magnitude(self, policy_batch: AgentBuffer,
                                   expert_batch: AgentBuffer) -> torch.Tensor:
        """
        Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
        for off-policy. Compute gradients w.r.t randomly interpolated input.
        """
        policy_obs = self.get_state_encoding(policy_batch)
        expert_obs = self.get_state_encoding(expert_batch)
        obs_epsilon = torch.rand(policy_obs.shape)
        encoder_input = obs_epsilon * policy_obs + (1 -
                                                    obs_epsilon) * expert_obs
        if self._settings.use_actions:
            policy_action = self.get_action_input(policy_batch)
            expert_action = self.get_action_input(expert_batch)
            action_epsilon = torch.rand(policy_action.shape)
            policy_dones = torch.as_tensor(policy_batch["done"],
                                           dtype=torch.float).unsqueeze(1)
            expert_dones = torch.as_tensor(expert_batch["done"],
                                           dtype=torch.float).unsqueeze(1)
            dones_epsilon = torch.rand(policy_dones.shape)
            encoder_input = torch.cat(
                [
                    encoder_input,
                    action_epsilon * policy_action +
                    (1 - action_epsilon) * expert_action,
                    dones_epsilon * policy_dones +
                    (1 - dones_epsilon) * expert_dones,
                ],
                dim=1,
            )
        hidden = self.encoder(encoder_input)
        if self._settings.use_vail:
            use_vail_noise = True
            z_mu = self._z_mu_layer(hidden)
            hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
        estimate = self._estimator(hidden).squeeze(1).sum()

        gradient = torch.autograd.grad(estimate,
                                       encoder_input,
                                       create_graph=True)[0]
        # Norm's gradient could be NaN at 0. Use our own safe_norm
        safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
        gradient_mag = torch.mean((safe_norm - 1)**2)
        return gradient_mag
Ejemplo n.º 5
0
def test_predict_closest_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 200
    for _ in range(200):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = get_zero_entities_mask([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.02
Ejemplo n.º 6
0
def test_all_masking(mask_value):
    # We make sure that a mask of all zeros or all ones will not trigger an error
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 20
    for _ in range(5):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = [torch.ones_like(key[:, :, 0]) * mask_value]
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
Ejemplo n.º 7
0
def test_multi_head_attention_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_h, n_k, n_q = 3, 10, 5, 1
    embedding_size = 64
    mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size)
    optimizer = torch.optim.Adam(mha.parameters(), lr=0.001)
    batch_size = 200
    point_range = 3
    init_error = -1.0
    for _ in range(50):
        query = torch.rand(
            (batch_size, n_q, size)) * point_range * 2 - point_range
        key = torch.rand(
            (batch_size, n_k, size)) * point_range * 2 - point_range
        value = key
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((query - key)**2, dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        prediction, _ = mha.forward(query, key, value)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        if init_error == -1.0:
            init_error = error.item()
        else:
            assert error.item() < init_error
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.5
Ejemplo n.º 8
0
 def generate_input_helper(pattern):
     _input = torch.zeros((batch_size, 0, size))
     for i in range(len(pattern)):
         if i % 2 == 0:
             _input = torch.cat(
                 [_input,
                  torch.rand((batch_size, pattern[i], size))],
                 dim=1)
         else:
             _input = torch.cat(
                 [_input,
                  torch.zeros((batch_size, pattern[i], size))],
                 dim=1)
     return _input
Ejemplo n.º 9
0
def test_layer_norm():
    torch.manual_seed(0)
    torch_ln = torch.nn.LayerNorm(10, elementwise_affine=False)
    cust_ln = LayerNorm()

    sample_input = torch.rand(10)
    assert torch.all(
        torch.isclose(
            torch_ln(sample_input), cust_ln(sample_input), atol=1e-5, rtol=0.0
        )
    )
    sample_input = torch.rand((4, 10))
    assert torch.all(
        torch.isclose(
            torch_ln(sample_input), cust_ln(sample_input), atol=1e-5, rtol=0.0
        )
    )
    sample_input = torch.rand((7, 6, 10))
    assert torch.all(
        torch.isclose(
            torch_ln(sample_input), cust_ln(sample_input), atol=1e-5, rtol=0.0
        )
    )
Ejemplo n.º 10
0
def test_predict_minimum_training():
    # of 5 numbers, predict index of min
    np.random.seed(1336)
    torch.manual_seed(1336)
    n_k = 5
    size = n_k + 1
    embedding_size = 64
    entity_embeddings = EntityEmbeddings(size, [size],
                                         embedding_size, [n_k],
                                         concat_self=False)
    transformer = ResidualSelfAttention(embedding_size)
    l_layer = LinearEncoder(embedding_size, 2, n_k)
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )

    batch_size = 200
    onehots = ModelUtils.actions_to_onehot(
        torch.range(0, n_k - 1).unsqueeze(1), [n_k])[0]
    onehots = onehots.expand((batch_size, -1, -1))
    losses = []
    for _ in range(400):
        num = np.random.randint(0, n_k)
        inp = torch.rand((batch_size, num + 1, 1))
        with torch.no_grad():
            # create the target : The minimum
            argmin = torch.argmin(inp, dim=1)
            argmin = argmin.squeeze()
            argmin = argmin.detach()
        sliced_oh = onehots[:, :num + 1]
        inp = torch.cat([inp, sliced_oh], dim=2)

        embeddings = entity_embeddings(inp, [inp])
        masks = EntityEmbeddings.get_masks([inp])
        prediction = transformer(embeddings, masks)
        prediction = l_layer(prediction)
        ce = loss(prediction, argmin)
        losses.append(ce.item())
        print(ce.item())
        optimizer.zero_grad()
        ce.backward()
        optimizer.step()
    assert np.array(losses[-20:]).mean() < 0.1