Example #1
0
 def get_tensor_dict_list(self, tensor_inserter: TensorInserter,
                          data_dicts: List[DataDict],
                          model_dicts: List[ModelDict],
                          batch_idx) -> List[TensorDict]:
     tensor_dicts = [
         tensor_inserter.insert_tensors(TensorDict(), data_dict, model_dict,
                                        batch_idx)
         for data_dict, model_dict in zip(data_dicts, model_dicts)
     ]
     # add origin labels
     for i, tensor_dict in enumerate(tensor_dicts):
         origins_tensor = torch.full([len(batch_idx)], i).long().to(device)
         tensor_dict.set(TensorKey.origins_tensor, origins_tensor)
     return tensor_dicts
Example #2
0
def visualize_with_paths(data_dict_paths, model_dicts_path):
    data_dicts = [
        torch.load(dataset_path, map_location="cpu")
        for dataset_path in data_dict_paths
    ]
    data_dict1, data_dict2 = data_dicts

    points1 = data_dict1.get(DataKey.states)
    points2 = data_dict2.get(DataKey.states)

    n, d = points1.shape
    sample_idx = np.random.choice(range(n), 1000, replace=False)
    points1 = points1[sample_idx]
    points2 = points2[sample_idx]

    model_dicts = torch.load(model_dicts_path, map_location="cpu")
    model_dict1, model_dict2 = model_dicts
    points1_decoded, points1_decoded_into_points2, points1_encoded = convert(
        model_dict1, model_dict2, points1)
    points2_decoded, points2_decoded_into_points1, points2_encoded = convert(
        model_dict2, model_dict1, points2)

    state_scaler1 = model_dict1.get(ModelKey.state_scaler)
    state_scaler2 = model_dict2.get(ModelKey.state_scaler)

    nn_loss_calculator = LossCalculatorNearestNeighborL2(
        TensorKey.encoded_states_tensor, TensorKey.origins_tensor, 1.)
    mse_loss = nn.MSELoss()

    tensor_dict1 = TensorDict()
    points1_tensor = torch.as_tensor(points1).float()
    points1_scaled_tensor = state_scaler1.forward(points1_tensor)
    points1_decoded_tensor = torch.as_tensor(points1_decoded).float()
    points1_decoded_scaled_tensor = state_scaler1.forward(
        points1_decoded_tensor)
    tensor_dict1.set(TensorKey.encoded_states_tensor,
                     torch.as_tensor(points1_encoded).float())
    tensor_dict1.set(TensorKey.origins_tensor, torch.zeros(n))
    tensor_dict1.set(TensorKey.states_tensor, points1_scaled_tensor)

    tensor_dict2 = TensorDict()
    points2_tensor = torch.as_tensor(points2).float()
    points2_scaled_tensor = state_scaler2.forward(points2_tensor)
    tensor_dict2.set(TensorKey.encoded_states_tensor,
                     torch.as_tensor(points2_encoded).float())
    tensor_dict2.set(TensorKey.origins_tensor, torch.ones(n))
    tensor_dict2.set(TensorKey.states_tensor, points2_scaled_tensor)

    print(
        mse_loss.forward(points1_decoded_scaled_tensor, points1_scaled_tensor))

    nn_loss = nn_loss_calculator.get_loss([tensor_dict1, tensor_dict2])
    print(nn_loss)

    plt.figure()
    plt.xlim(-2, 6)
    plt.ylim(-4, 4)
    plt.scatter(points1[:, 0], points1[:, 1], alpha=0.5)
    plt.scatter(points2[:, 0], points2[:, 1], alpha=0.5)

    diffs = points1_decoded_into_points2 - points1
    n, _ = points1.shape
    for i in range(n):
        eps = np.random.random()
        if eps < 0.10:
            plt.arrow(points1[i, 0],
                      points1[i, 1],
                      diffs[i, 0],
                      diffs[i, 1],
                      alpha=0.2,
                      width=0.05,
                      length_includes_head=True)

    plt.figure()
    plt.scatter(points1_encoded[:, 0], points1_encoded[:, 1], alpha=0.5)
    plt.scatter(points2_encoded[:, 0],
                points2_encoded[:, 1],
                alpha=0.5,
                c="C1")
    plt.show()