示例#1
0
 def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict,
                    model_dict: ModelDict,
                    batch_idx: np.ndarray) -> TensorDict:
     model = model_dict.get(self.model_key)
     attribute = self.attribute_lambda(model)
     tensor_dict.set(self.target_tensor_key, attribute)
     return tensor_dict
示例#2
0
 def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict,
                    model_dict: ModelDict,
                    batch_idx: np.ndarray) -> TensorDict:
     source_tensor = tensor_dict.get(self.source_tensor_key)
     transformed = self.transform_lambda(source_tensor).to(device)
     tensor_dict.set(self.target_tensor_key, transformed)
     return tensor_dict
示例#3
0
 def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict,
                    model_dict: ModelDict,
                    batch_idx: np.ndarray) -> TensorDict:
     data = self.get_data(data_dict, model_dict, batch_idx)
     tensor = torch.as_tensor(data, dtype=self.dtype).to(device).reshape(
         len(batch_idx), -1)
     tensor_dict.set(self.tensor_key, tensor)
     return tensor_dict
示例#4
0
 def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict,
                    model_dict: ModelDict,
                    batch_idx: np.ndarray) -> TensorDict:
     tensors = torch.stack([
         tensor_dict.get(tensor_key)
         for tensor_key in self.source_tensor_keys
     ])
     summed = torch.sum(tensors, dim=0)
     tensor_dict.set(self.target_tensor_key, summed)
     return tensor_dict
示例#5
0
 def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict,
                    model_dict: ModelDict,
                    batch_idx: np.ndarray) -> TensorDict:
     model = model_dict.get(self.model_key)
     source_tensor = self.get_model_input(tensor_dict)
     if self.noise_scale > 1e-9:
         source_tensor = self.add_noise(source_tensor)
     target_tensor, _ = model.forward(source_tensor)
     tensor_dict.set(self.target_tensor_key, target_tensor)
     return tensor_dict
示例#6
0
 def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict,
                    model_dict: ModelDict,
                    batch_idx: np.ndarray) -> TensorDict:
     model = model_dict.get(self.model_key)
     source_tensor = self.get_model_input(tensor_dict)
     if self.noise_scale > 1e-9:
         source_tensor = self.add_noise(source_tensor)
     samples, log_probs, _, _ = model.sample(source_tensor)
     tensor_dict.set(self.samples_tensor_key, samples)
     tensor_dict.set(self.log_probs_tensor_key, log_probs)
     return tensor_dict
示例#7
0
 def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict,
                    model_dict: ModelDict,
                    batch_idx: np.ndarray) -> TensorDict:
     model = model_dict.get(self.model_key)
     source_tensors = [
         tensor_dict.get(key) for key in self.source_tensor_keys
     ]
     result_tensor_list = self.transform_lambda(model, source_tensors)
     for target_tensor_key, result_tensor in zip(self.target_tensor_keys,
                                                 result_tensor_list):
         tensor_dict.set(target_tensor_key, result_tensor)
     return tensor_dict
示例#8
0
 def get_tensor_dict_list(self, tensor_inserter: TensorInserter,
                          data_dicts: List[DataDict],
                          model_dicts: List[ModelDict],
                          batch_idx) -> List[TensorDict]:
     tensor_dicts = [
         tensor_inserter.insert_tensors(TensorDict(), data_dict, model_dict,
                                        batch_idx)
         for data_dict, model_dict in zip(data_dicts, model_dicts)
     ]
     # add origin labels
     for i, tensor_dict in enumerate(tensor_dicts):
         origins_tensor = torch.full([len(batch_idx)], i).long().to(device)
         tensor_dict.set(TensorKey.origins_tensor, origins_tensor)
     return tensor_dicts
示例#9
0
def visualize_with_paths(data_dict_paths, model_dicts_path):
    data_dicts = [
        torch.load(dataset_path, map_location="cpu")
        for dataset_path in data_dict_paths
    ]
    data_dict1, data_dict2 = data_dicts

    points1 = data_dict1.get(DataKey.states)
    points2 = data_dict2.get(DataKey.states)

    n, d = points1.shape
    sample_idx = np.random.choice(range(n), 1000, replace=False)
    points1 = points1[sample_idx]
    points2 = points2[sample_idx]

    model_dicts = torch.load(model_dicts_path, map_location="cpu")
    model_dict1, model_dict2 = model_dicts
    points1_decoded, points1_decoded_into_points2, points1_encoded = convert(
        model_dict1, model_dict2, points1)
    points2_decoded, points2_decoded_into_points1, points2_encoded = convert(
        model_dict2, model_dict1, points2)

    state_scaler1 = model_dict1.get(ModelKey.state_scaler)
    state_scaler2 = model_dict2.get(ModelKey.state_scaler)

    nn_loss_calculator = LossCalculatorNearestNeighborL2(
        TensorKey.encoded_states_tensor, TensorKey.origins_tensor, 1.)
    mse_loss = nn.MSELoss()

    tensor_dict1 = TensorDict()
    points1_tensor = torch.as_tensor(points1).float()
    points1_scaled_tensor = state_scaler1.forward(points1_tensor)
    points1_decoded_tensor = torch.as_tensor(points1_decoded).float()
    points1_decoded_scaled_tensor = state_scaler1.forward(
        points1_decoded_tensor)
    tensor_dict1.set(TensorKey.encoded_states_tensor,
                     torch.as_tensor(points1_encoded).float())
    tensor_dict1.set(TensorKey.origins_tensor, torch.zeros(n))
    tensor_dict1.set(TensorKey.states_tensor, points1_scaled_tensor)

    tensor_dict2 = TensorDict()
    points2_tensor = torch.as_tensor(points2).float()
    points2_scaled_tensor = state_scaler2.forward(points2_tensor)
    tensor_dict2.set(TensorKey.encoded_states_tensor,
                     torch.as_tensor(points2_encoded).float())
    tensor_dict2.set(TensorKey.origins_tensor, torch.ones(n))
    tensor_dict2.set(TensorKey.states_tensor, points2_scaled_tensor)

    print(
        mse_loss.forward(points1_decoded_scaled_tensor, points1_scaled_tensor))

    nn_loss = nn_loss_calculator.get_loss([tensor_dict1, tensor_dict2])
    print(nn_loss)

    plt.figure()
    plt.xlim(-2, 6)
    plt.ylim(-4, 4)
    plt.scatter(points1[:, 0], points1[:, 1], alpha=0.5)
    plt.scatter(points2[:, 0], points2[:, 1], alpha=0.5)

    diffs = points1_decoded_into_points2 - points1
    n, _ = points1.shape
    for i in range(n):
        eps = np.random.random()
        if eps < 0.10:
            plt.arrow(points1[i, 0],
                      points1[i, 1],
                      diffs[i, 0],
                      diffs[i, 1],
                      alpha=0.2,
                      width=0.05,
                      length_includes_head=True)

    plt.figure()
    plt.scatter(points1_encoded[:, 0], points1_encoded[:, 1], alpha=0.5)
    plt.scatter(points2_encoded[:, 0],
                points2_encoded[:, 1],
                alpha=0.5,
                c="C1")
    plt.show()
示例#10
0
 def get_loss(self, tensor_dicts: TensorDict):
     return self.weight * torch.mean(
         tensor_dicts.get(self.input_tensor_key)**2)
示例#11
0
 def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict,
                    model_dict: ModelDict,
                    batch_idx: np.ndarray) -> TensorDict:
     tensor = self.generate_lambda()
     tensor_dict.set(self.target_tensor_key, tensor)
     return tensor_dict