Ejemplo n.º 1
0
def test_remove_dim_np():
    shape = [2, 3, 4, 5]
    arr = np.zeros(shape)
    new_arr = pt_util.remove_dim(arr, 1)
    assert new_arr.shape == (6, 4, 5)

    new_arr = pt_util.remove_dim(arr, 2)
    assert new_arr.shape == (2, 12, 5)

    new_arr = pt_util.remove_dim(arr, 3)
    assert new_arr.shape == (2, 3, 20)
Ejemplo n.º 2
0
def process_video_data(batch):
    data = pt_util.remove_dim(batch["data"], 1)
    queue_data = pt_util.remove_dim(batch["queue_data"], 1)
    batch = {
        "data": data,
        "queue_data": queue_data,
        "data_source": "YT",
        "batch_type": "video",
        "batch_size": len(data),
        "num_frames": args.num_frames,
        "imagenet_labels": torch.full((len(data),), -1, dtype=torch.int64),
    }
    return batch
Ejemplo n.º 3
0
def test_remove_dim_np_multi_dim():
    shape = [2, 3, 4, 5]
    arr = np.zeros(shape)
    new_arr = pt_util.remove_dim(arr, (1, 2))
    assert new_arr.shape == (24, 5)

    new_arr = pt_util.remove_dim(arr, (1, 3))
    assert new_arr.shape == (6, 20)

    new_arr = pt_util.remove_dim(arr, (3, 1))
    assert new_arr.shape == (6, 20)

    new_arr = pt_util.remove_dim(arr, (2, 3))
    assert new_arr.shape == (2, 60)
Ejemplo n.º 4
0
 def convert_batch(self, batch, batch_type: str = "train") -> Dict:
     batch["data_source"] = "Kinetics400"
     data = batch["data"]
     batch_size, seq_len = data.shape[:2]
     data = pt_util.remove_dim(data, 1)
     batch["data"] = data
     batch["batch_type"] = ("images", len(batch["data"]))
     batch["batch_size"] = batch_size
     batch["num_frames"] = seq_len
     return super(EndTaskKinetics400Solver, self).convert_batch(batch)
Ejemplo n.º 5
0
 def process_imagenet_data(self, data):
     images, labels = data
     data = images[:self.num_frames]
     queue_data = images[self.num_frames:]
     if self.num_frames > 1:
         data = pt_util.remove_dim(torch.stack(data, dim=1), 1)
         queue_data = pt_util.remove_dim(torch.stack(queue_data, dim=1), 1)
         labels = labels.repeat_interleave(self.num_frames)
     else:
         data = data[0]
         queue_data = queue_data[0]
     batch = {
         "data": data,
         "queue_data": queue_data,
         "imagenet_labels": labels,
         "data_source": "IN",
         "num_frames": self.num_frames,
         "batch_type": "images",
         "batch_size": len(data),
     }
     return batch
Ejemplo n.º 6
0
def get_edges(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    # if len(image.shape) == 3:
    # image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    start_t = time.time()
    if image.shape[-1] == 3:
        image = (
            np.float32(0.299) * image[..., 0] + np.float32(0.587) * image[..., 1] + np.float32(0.114) * image[..., 2]
        )
    # edge = cv2.Canny(image, 0, 50)
    image_16 = image.astype(np.int16)
    x_edge_array = EDGE_ARRAY
    y_edge_array = EDGE_ARRAY.T
    while len(x_edge_array.shape) < len(image_16.shape):
        x_edge_array = x_edge_array[np.newaxis, ...]
        y_edge_array = y_edge_array[np.newaxis, ...]

    # Some annoying opencv bug
    try:
        edge1 = cv2.filter2D(image_16, -1, x_edge_array, borderType=cv2.BORDER_REFLECT)
    except:
        edge1 = scipy.ndimage.filters.correlate(image_16, x_edge_array, mode="reflect")
    try:
        edge2 = cv2.filter2D(image_16, -1, y_edge_array, borderType=cv2.BORDER_REFLECT)
    except:
        edge2 = scipy.ndimage.filters.correlate(image_16, y_edge_array, mode="reflect")

    edge1 = np.abs(edge1)
    edge2 = np.abs(edge2)
    edge = (edge1 > 10) | (edge2 > 10)
    edge = edge * np.uint8(255)
    dilation_size = int(min(edge.shape[-2], edge.shape[-1]) * 0.01)
    if dilation_size > 1:
        dilation_kernel = np.ones((dilation_size, dilation_size), dtype=np.uint8)
        original_edge_shape = edge.shape
        while len(edge.shape) > 2:
            edge = pt_util.remove_dim(edge, 1)
        dilated = cv2.dilate(edge, dilation_kernel)
        edge = edge.reshape(original_edge_shape)
        dilated = dilated.reshape(original_edge_shape)
    else:
        dilated = edge
    inverted = 255 - dilated
    print("edges time", time.time() - start_t)
    return edge, inverted
Ejemplo n.º 7
0
def test_remove_dim_np_neg_dim():
    shape = [2, 3, 4, 5]
    arr = np.zeros(shape)
    new_arr = pt_util.remove_dim(arr, -1)
    assert new_arr.shape == (2, 3, 20)

    new_arr = pt_util.remove_dim(arr, -2)
    assert new_arr.shape == (2, 12, 5)

    new_arr = pt_util.remove_dim(arr, (-2, -1))
    assert new_arr.shape == (2, 60)

    new_arr = pt_util.remove_dim(arr, (-1, -2))
    assert new_arr.shape == (2, 60)

    new_arr = pt_util.remove_dim(arr, (1, -1))
    assert new_arr.shape == (6, 20)

    new_arr = pt_util.remove_dim(arr, (-1, 1))
    assert new_arr.shape == (6, 20)
Ejemplo n.º 8
0
    def get_embeddings(self, inputs, jigsaw=False, shuffle=False):
        data = inputs["data"]
        if shuffle:
            # Shuffle
            shuffle_order = torch.randperm(data.shape[0], device=self.device)
            unshuffle_order = torch.zeros(data.shape[0],
                                          dtype=torch.int64,
                                          device=self.device)
            unshuffle_order.index_copy_(
                0, shuffle_order,
                torch.arange(data.shape[0], device=self.device))
            data = data[shuffle_order].contiguous()

        if jigsaw:
            if (data.shape[2] % 3) != 0 or (data.shape[3] % 3) != 0:
                data = F.pad(
                    data, (0, 3 - data.shape[3] % 3, 0, 3 - data.shape[2] % 3))
            # [N, C, H, W]
            data = pt_util.split_dim(data, 2, 3, data.shape[2] // 3)
            # [N, C, 3, H/3, W]
            data = pt_util.split_dim(data, 4, 3, data.shape[4] // 3)
            # [N, C, 3, H/3, 3, W/3]
            data = data.permute(0, 2, 4, 1, 3, 5).contiguous()
            # [N, 3, 3, C, H/3, W/3]
            data = pt_util.remove_dim(data, (1, 2))
            # [N*9, C, H/3, W/3]

        images = data.to(self.feature_extractor_device)
        return_val = self.extract_features(images)
        features = return_val["extracted_features"]

        if jigsaw:
            features = features.to(self.device)
            features = self.jigsaw_linear(features)
            features = pt_util.split_dim(features, 0, -1, 9)
            # Shuffle all permutations independently
            rand_orders = torch.stack([
                torch.randperm(9, device=features.device)
                for _ in range(features.shape[0])
            ])
            features = features[pt_util.expand_new_dim(
                torch.arange(features.shape[0], device=features.device), 1, 9),
                                rand_orders]
            features = pt_util.remove_dim(features, 2)
            features = self.jigsaw_embedding(features)
            return_val["extracted_features"] = features
            output = features
        else:
            features = features.to(self.device)
            return_val["extracted_features"] = features
            output = self.embedding(features)

        return_val["prenorm_features"] = output
        output = F.normalize(output, dim=1)

        return_val["embeddings"] = output

        if shuffle:
            # Unshuffle
            return_val_new = {}
            for key, val in return_val.items():
                if isinstance(val, torch.Tensor):
                    val = val.to(self.device)
                    val = val[unshuffle_order]
                return_val_new[key] = val
            return_val = return_val_new

        if "batch_types" in inputs:
            return_val = self.split_dict_by_type(inputs["batch_types"],
                                                 inputs["batch_sizes"],
                                                 return_val)
        return return_val
Ejemplo n.º 9
0
    def recurrent_generator(self, advantages, num_mini_batch):
        num_processes = self.rewards.size(1)
        assert num_processes >= num_mini_batch, (
            "PPO requires the number of processes ({}) "
            "to be greater than or equal to the number of "
            "PPO mini batches ({}).".format(num_processes, num_mini_batch)
        )
        num_envs_per_batch = num_processes // num_mini_batch
        perm = torch.randperm(num_processes)
        for start_ind in range(0, num_processes, num_envs_per_batch):
            obs_batch = []
            recurrent_hidden_states_batch = []
            actions_batch = []
            value_preds_batch = []
            return_batch = []
            masks_batch = []
            old_action_log_probs_batch = []
            adv_targ = []
            additional_obs_batch = []

            for offset in range(num_envs_per_batch):
                ind = perm[start_ind + offset]
                obs_batch.append(self.obs[:-1, ind])
                additional_obs_batch.append(
                    {key: val[:-1, ind] for key, val in self.additional_observations_dict.items()}
                )
                recurrent_hidden_states_batch.append(self.recurrent_hidden_states[0:1, ind])
                actions_batch.append(self.actions[:, ind])
                value_preds_batch.append(self.value_preds[:-1, ind])
                return_batch.append(self.returns[:-1, ind])
                masks_batch.append(self.masks[:-1, ind])
                old_action_log_probs_batch.append(self.action_log_probs[:, ind])
                adv_targ.append(advantages[:, ind])

            T, N = self.num_forward_rollout_steps, num_envs_per_batch
            # These are all tensors of size (T, N, -1)
            obs_batch = torch.stack(obs_batch, 1)
            actions_batch = torch.stack(actions_batch, 1)
            value_preds_batch = torch.stack(value_preds_batch, 1)
            return_batch = torch.stack(return_batch, 1)
            masks_batch = torch.stack(masks_batch, 1)
            old_action_log_probs_batch = torch.stack(old_action_log_probs_batch, 1)
            adv_targ = torch.stack(adv_targ, 1)
            additional_obs_batch = {
                key: torch.stack([additional_obs_batch[ii][key] for ii in range(num_envs_per_batch)], 1)
                for key in additional_obs_batch[0]
            }

            # States is just a (N, -1) tensor
            recurrent_hidden_states_batch = torch.stack(recurrent_hidden_states_batch, 1).view(N, -1)

            # Flatten the (T, N, ...) tensors to (T * N, ...)
            obs_batch = pt_util.remove_dim(obs_batch, 1)
            additional_obs_batch = {key: pt_util.remove_dim(val, 1) for key, val in additional_obs_batch.items()}
            actions_batch = pt_util.remove_dim(actions_batch, 1)
            value_preds_batch = pt_util.remove_dim(value_preds_batch, 1)
            return_batch = pt_util.remove_dim(return_batch, 1)
            masks_batch = pt_util.remove_dim(masks_batch, 1)
            old_action_log_probs_batch = pt_util.remove_dim(old_action_log_probs_batch, 1)
            adv_targ = pt_util.remove_dim(adv_targ, 1)

            yield (obs_batch, recurrent_hidden_states_batch, actions_batch, value_preds_batch, return_batch,
                   masks_batch, old_action_log_probs_batch, adv_targ, additional_obs_batch)