def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict, model_dict: ModelDict, batch_idx: np.ndarray) -> TensorDict: model = model_dict.get(self.model_key) attribute = self.attribute_lambda(model) tensor_dict.set(self.target_tensor_key, attribute) return tensor_dict
def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict, model_dict: ModelDict, batch_idx: np.ndarray) -> TensorDict: source_tensor = tensor_dict.get(self.source_tensor_key) transformed = self.transform_lambda(source_tensor).to(device) tensor_dict.set(self.target_tensor_key, transformed) return tensor_dict
def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict, model_dict: ModelDict, batch_idx: np.ndarray) -> TensorDict: data = self.get_data(data_dict, model_dict, batch_idx) tensor = torch.as_tensor(data, dtype=self.dtype).to(device).reshape( len(batch_idx), -1) tensor_dict.set(self.tensor_key, tensor) return tensor_dict
def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict, model_dict: ModelDict, batch_idx: np.ndarray) -> TensorDict: tensors = torch.stack([ tensor_dict.get(tensor_key) for tensor_key in self.source_tensor_keys ]) summed = torch.sum(tensors, dim=0) tensor_dict.set(self.target_tensor_key, summed) return tensor_dict
def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict, model_dict: ModelDict, batch_idx: np.ndarray) -> TensorDict: model = model_dict.get(self.model_key) source_tensor = self.get_model_input(tensor_dict) if self.noise_scale > 1e-9: source_tensor = self.add_noise(source_tensor) target_tensor, _ = model.forward(source_tensor) tensor_dict.set(self.target_tensor_key, target_tensor) return tensor_dict
def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict, model_dict: ModelDict, batch_idx: np.ndarray) -> TensorDict: model = model_dict.get(self.model_key) source_tensor = self.get_model_input(tensor_dict) if self.noise_scale > 1e-9: source_tensor = self.add_noise(source_tensor) samples, log_probs, _, _ = model.sample(source_tensor) tensor_dict.set(self.samples_tensor_key, samples) tensor_dict.set(self.log_probs_tensor_key, log_probs) return tensor_dict
def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict, model_dict: ModelDict, batch_idx: np.ndarray) -> TensorDict: model = model_dict.get(self.model_key) source_tensors = [ tensor_dict.get(key) for key in self.source_tensor_keys ] result_tensor_list = self.transform_lambda(model, source_tensors) for target_tensor_key, result_tensor in zip(self.target_tensor_keys, result_tensor_list): tensor_dict.set(target_tensor_key, result_tensor) return tensor_dict
def get_tensor_dict_list(self, tensor_inserter: TensorInserter, data_dicts: List[DataDict], model_dicts: List[ModelDict], batch_idx) -> List[TensorDict]: tensor_dicts = [ tensor_inserter.insert_tensors(TensorDict(), data_dict, model_dict, batch_idx) for data_dict, model_dict in zip(data_dicts, model_dicts) ] # add origin labels for i, tensor_dict in enumerate(tensor_dicts): origins_tensor = torch.full([len(batch_idx)], i).long().to(device) tensor_dict.set(TensorKey.origins_tensor, origins_tensor) return tensor_dicts
def visualize_with_paths(data_dict_paths, model_dicts_path): data_dicts = [ torch.load(dataset_path, map_location="cpu") for dataset_path in data_dict_paths ] data_dict1, data_dict2 = data_dicts points1 = data_dict1.get(DataKey.states) points2 = data_dict2.get(DataKey.states) n, d = points1.shape sample_idx = np.random.choice(range(n), 1000, replace=False) points1 = points1[sample_idx] points2 = points2[sample_idx] model_dicts = torch.load(model_dicts_path, map_location="cpu") model_dict1, model_dict2 = model_dicts points1_decoded, points1_decoded_into_points2, points1_encoded = convert( model_dict1, model_dict2, points1) points2_decoded, points2_decoded_into_points1, points2_encoded = convert( model_dict2, model_dict1, points2) state_scaler1 = model_dict1.get(ModelKey.state_scaler) state_scaler2 = model_dict2.get(ModelKey.state_scaler) nn_loss_calculator = LossCalculatorNearestNeighborL2( TensorKey.encoded_states_tensor, TensorKey.origins_tensor, 1.) mse_loss = nn.MSELoss() tensor_dict1 = TensorDict() points1_tensor = torch.as_tensor(points1).float() points1_scaled_tensor = state_scaler1.forward(points1_tensor) points1_decoded_tensor = torch.as_tensor(points1_decoded).float() points1_decoded_scaled_tensor = state_scaler1.forward( points1_decoded_tensor) tensor_dict1.set(TensorKey.encoded_states_tensor, torch.as_tensor(points1_encoded).float()) tensor_dict1.set(TensorKey.origins_tensor, torch.zeros(n)) tensor_dict1.set(TensorKey.states_tensor, points1_scaled_tensor) tensor_dict2 = TensorDict() points2_tensor = torch.as_tensor(points2).float() points2_scaled_tensor = state_scaler2.forward(points2_tensor) tensor_dict2.set(TensorKey.encoded_states_tensor, torch.as_tensor(points2_encoded).float()) tensor_dict2.set(TensorKey.origins_tensor, torch.ones(n)) tensor_dict2.set(TensorKey.states_tensor, points2_scaled_tensor) print( mse_loss.forward(points1_decoded_scaled_tensor, points1_scaled_tensor)) nn_loss = nn_loss_calculator.get_loss([tensor_dict1, tensor_dict2]) print(nn_loss) plt.figure() plt.xlim(-2, 6) plt.ylim(-4, 4) plt.scatter(points1[:, 0], points1[:, 1], alpha=0.5) plt.scatter(points2[:, 0], points2[:, 1], alpha=0.5) diffs = points1_decoded_into_points2 - points1 n, _ = points1.shape for i in range(n): eps = np.random.random() if eps < 0.10: plt.arrow(points1[i, 0], points1[i, 1], diffs[i, 0], diffs[i, 1], alpha=0.2, width=0.05, length_includes_head=True) plt.figure() plt.scatter(points1_encoded[:, 0], points1_encoded[:, 1], alpha=0.5) plt.scatter(points2_encoded[:, 0], points2_encoded[:, 1], alpha=0.5, c="C1") plt.show()
def get_loss(self, tensor_dicts: TensorDict): return self.weight * torch.mean( tensor_dicts.get(self.input_tensor_key)**2)
def insert_tensors(self, tensor_dict: TensorDict, data_dict: DataDict, model_dict: ModelDict, batch_idx: np.ndarray) -> TensorDict: tensor = self.generate_lambda() tensor_dict.set(self.target_tensor_key, tensor) return tensor_dict