Exemple #1
0
def test_split_dim():
    shape = [4, 6, 8, 10]
    arr = np.zeros(shape)
    new_arr = pt_util.split_dim(arr, 0, 2, 2)
    assert new_arr.shape == (2, 2, 6, 8, 10)

    new_arr = pt_util.split_dim(arr, 1, 2, 3)
    assert new_arr.shape == (4, 2, 3, 8, 10)

    new_arr = pt_util.split_dim(arr, 1, 3, 2)
    assert new_arr.shape == (4, 3, 2, 8, 10)

    new_arr = pt_util.split_dim(arr, 2, 2, -1)
    assert new_arr.shape == (4, 6, 2, 4, 10)

    new_arr = pt_util.split_dim(arr, 2, -1, 2)
    assert new_arr.shape == (4, 6, 4, 2, 10)

    new_arr = pt_util.split_dim(arr, -1, 2, 5)
    assert new_arr.shape == (4, 6, 8, 2, 5)

    new_arr = pt_util.split_dim(arr, -1, 5, -1)
    assert new_arr.shape == (4, 6, 8, 5, 2)

    new_arr = pt_util.split_dim(arr, -2, 2, -1)
    assert new_arr.shape == (4, 6, 2, 4, 10)
    def get_image_output(self, network_outputs):
        with torch.no_grad():
            image_output = {}
            predictions = torch.argmax(network_outputs["outputs"], dim=1)
            labels = network_outputs["labels"]

            batch_size = network_outputs["batch_size"]
            seq_len = network_outputs["num_frames"]

            acc = pt_util.to_numpy(predictions == labels)

            inputs = network_outputs["data"]
            inputs = to_uint8(inputs)
            im_height, im_width = inputs.shape[1:3]

            inputs = pt_util.split_dim(inputs, 0, batch_size, seq_len)

            rand_order = np.random.choice(len(inputs), min(len(inputs), seq_len), replace=False)

            scale_factor = im_width / 320.0
            images = []
            for bb in rand_order:
                correct = acc[bb]
                image_seq = inputs[bb].copy()
                pred_cls = self.ind_to_label_func(predictions[bb])
                gt_cls = self.ind_to_label_func(labels[bb])
                for ii, image in enumerate(image_seq):
                    if correct:
                        image[:10, :, :] = (0, 255, 0)
                        image[-10:, :, :] = (0, 255, 0)
                        image[:, :10, :] = (0, 255, 0)
                        image[:, -10:, :] = (0, 255, 0)
                    else:
                        image[:10, :, :] = (255, 0, 0)
                        image[-10:, :, :] = (255, 0, 0)
                        image[:, :10, :] = (255, 0, 0)
                        image[:, -10:, :] = (255, 0, 0)
                    if ii == 0:
                        image = drawing.draw_contrast_text_cv2(
                            image, "P: " + pred_cls, (10, 10 + int(30 * scale_factor))
                        )
                        if not correct:
                            image = drawing.draw_contrast_text_cv2(
                                image, "GT: " + gt_cls, (10, 10 + int(2 * 30 * scale_factor))
                            )
                    images.append(image)

            n_cols = seq_len
            n_rows = len(images) // n_cols

            subplot = drawing.subplot(images, n_rows, n_cols, im_width, im_height)
            image_output["images/classifier_outputs"] = subplot
            return image_output
Exemple #3
0
 def forward(self, batch):
     if self.freeze_feature_extractor:
         with torch.no_grad():
             feature_extractor_outputs = self.feature_extractor.extract_features(batch["data"])
             extracted_features = feature_extractor_outputs["extracted_features"].detach()
     else:
         feature_extractor_outputs = self.feature_extractor.extract_features(batch["data"])
         extracted_features = feature_extractor_outputs["extracted_features"]
     extracted_features = extracted_features.to(self.model.device)
     extracted_features = pt_util.split_dim(extracted_features, 0, batch["batch_size"], batch["num_frames"])
     output = self.model(extracted_features)
     output = {"outputs": output}
     output.update(batch)
     return output
Exemple #4
0
    def get_image_output(self, network_outputs) -> Dict[str, np.ndarray]:
        with torch.no_grad():
            image_output = {}

            # matching image
            batch_size, _, im_height, im_width = network_outputs["data"].shape

            inputs = network_outputs["data"]
            queue_inputs = network_outputs["queue_data"]
            inputs = to_uint8(inputs, padding=10)
            queue_inputs = to_uint8(queue_inputs, padding=10)
            num_frames = 1 if self.num_frames is None else self.num_frames
            inputs = pt_util.split_dim(inputs, 0, -1, num_frames)
            queue_inputs = pt_util.split_dim(queue_inputs, 0, -1, num_frames)
            images = []
            color = (255, 128, 0)
            for bb in range(
                    min(len(inputs), max(2 * num_frames,
                                         int(32 / num_frames)))):
                for ss in range(num_frames):
                    image = inputs[bb, ss]
                    images.append(image)
                for ss in range(num_frames):
                    image = queue_inputs[bb, ss].copy()
                    image[:10, :, :] = color
                    image[-10:, :, :] = color
                    image[:, :10, :] = color
                    image[:, -10:, :] = color
                    images.append(image)

            n_cols = max(2 * num_frames, 8)
            n_rows = len(images) // n_cols
            subplot = drawing.subplot(images, n_rows, n_cols, im_width,
                                      im_height)
            image_output["images/inputs"] = subplot

            if "vince_similarities" in network_outputs:
                # Nearest neighbor image
                inputs = network_outputs["data"]
                queue_inputs = network_outputs["queue_data"]

                inputs = to_uint8(inputs, padding=10)
                queue_inputs = to_uint8(queue_inputs, padding=10)

                vince_similarities = network_outputs["vince_similarities"]
                logits = vince_similarities / self.args.vince_temperature
                vince_softmax = F.softmax(logits, dim=1)

                queue_images = network_outputs["queue_images"]

                n_neighbors = 9
                topk_val, topk_ind = torch.topk(vince_softmax,
                                                n_neighbors,
                                                dim=1,
                                                largest=True,
                                                sorted=True)
                topk_ind = pt_util.to_numpy(topk_ind)
                topk_val = pt_util.to_numpy(topk_val)

                label = network_outputs["vince_similarities_mask"]

                images = []
                rand_order = np.random.choice(batch_size,
                                              min(batch_size, n_neighbors + 1),
                                              replace=False)
                for bb in rand_order:
                    query_image = inputs[bb].copy()
                    color = (90, 46, 158)
                    if network_outputs["batch_type"] == "images":
                        # Different colors for imagenet vs videos.
                        color = (24, 178, 24)
                    query_image[:10, :, :] = color
                    query_image[-10:, :, :] = color
                    query_image[:, :10, :] = color
                    query_image[:, -10:, :] = color
                    images.append(query_image)
                    found_neighbor = False
                    for nn, neighbor in enumerate(topk_ind[bb]):
                        color = (128, 128, 128)
                        score = topk_val[bb, nn]

                        if self.args.inter_batch_comparison:
                            if neighbor < batch_size:
                                image = queue_inputs[neighbor].copy()
                                data_source = network_outputs["data_source"]
                            else:
                                # Offset by batch_size for the inter-batch negatives
                                offset = batch_size
                                image = to_uint8(queue_images[neighbor -
                                                              offset],
                                                 padding=10)
                                data_source = network_outputs[
                                    "queue_data_sources"][neighbor - offset]
                        else:
                            if neighbor == 0:
                                image = queue_inputs[bb].copy()
                                data_source = network_outputs["data_source"]
                            else:
                                # Offset by 1 for the positive examples
                                image = to_uint8(queue_images[neighbor - 1],
                                                 padding=10)
                                data_source = network_outputs[
                                    "queue_data_sources"][neighbor - 1]

                        if label[bb, neighbor]:
                            if self.args.inter_batch_comparison and neighbor < batch_size:
                                found_neighbor = True
                                color = (255, 128, 0)
                            elif neighbor == 0:
                                found_neighbor = True
                                color = (255, 128, 0)
                            elif data_source == "self":
                                color = (144, 72, 0)
                            else:
                                color = (0, 0, 203)
                        elif data_source == "self":
                            color = (255, 0, 193)
                        if not found_neighbor and nn == n_neighbors - 1:
                            # Last one in row, couldn't match proper, put in just to show what it looks like.
                            image = queue_inputs[bb].copy()
                            color = (255, 0, 0)

                        if color == (128, 128, 128):
                            color = (90, 46, 158)
                            if data_source == "IN":
                                # Different colors for imagenet vs videos.
                                color = (24, 178, 24)
                        image[:10, :, :] = color
                        image[-10:, :, :] = color
                        image[:, :10, :] = color
                        image[:, -10:, :] = color
                        images.append(image)

                n_rows = n_neighbors + 1
                n_cols = n_neighbors + 1
                subplot = drawing.subplot(images, n_rows, n_cols, im_width,
                                          im_height)
                image_output["images/outputs"] = subplot

            if network_outputs["data_source"] == "IN":
                # imagenet image
                predictions = torch.argmax(
                    network_outputs["imagenet_decoder_0"], dim=1)
                labels = network_outputs["imagenet_labels"]
                acc = pt_util.to_numpy(predictions == labels)
                batch_size = acc.shape[0]

                inputs = network_outputs["data"][:batch_size]
                inputs = to_uint8(inputs, padding=10)

                images = []
                rand_order = np.random.choice(len(inputs),
                                              min(len(inputs), 25),
                                              replace=False)
                scale_factor = im_width / 320.0

                for bb in rand_order:
                    correct = acc[bb]
                    image = inputs[bb].copy()
                    pred_cls = util_functions.imagenet_label_to_class(
                        predictions[bb])
                    gt_cls = util_functions.imagenet_label_to_class(labels[bb])
                    if correct:
                        cls_str = pred_cls
                    else:
                        cls_str = "Pred: %s Actual %s" % (pred_cls, gt_cls)

                    if correct:
                        image[:10, :, :] = (0, 255, 0)
                        image[-10:, :, :] = (0, 255, 0)
                        image[:, :10, :] = (0, 255, 0)
                        image[:, -10:, :] = (0, 255, 0)
                    else:
                        image[:10, :, :] = (255, 0, 0)
                        image[-10:, :, :] = (255, 0, 0)
                        image[:, :10, :] = (255, 0, 0)
                        image[:, -10:, :] = (255, 0, 0)
                    image = drawing.draw_contrast_text_cv2(
                        image, "P: " + pred_cls,
                        (10, 10 + int(30 * scale_factor)))
                    if not correct:
                        image = drawing.draw_contrast_text_cv2(
                            image, "GT: " + gt_cls,
                            (10, 10 + int(2 * 30 * scale_factor)))
                    images.append(image)

                n_cols = int(np.sqrt(len(images)))
                n_rows = len(images) // n_cols

                subplot = drawing.subplot(images, n_rows, n_cols, im_width,
                                          im_height)
                image_output["images/imagenet_outputs"] = subplot

            if "attention_masks" in network_outputs:
                # Attention image
                inputs = network_outputs["data"]
                inputs = to_uint8(inputs, padding=10)

                queue_inputs = network_outputs["queue_data"]
                queue_inputs = to_uint8(queue_inputs, padding=10)

                attention_masks = network_outputs["attention_masks"]
                attention_masks = pt_util.to_numpy(
                    F.interpolate(attention_masks, (im_height, im_width),
                                  mode="bilinear",
                                  align_corners=False).permute(0, 2, 3, 1))
                attention_masks = np.pad(attention_masks,
                                         ((0, 0), (10, 10), (10, 10), (0, 0)),
                                         "constant")

                queue_attention_masks = network_outputs[
                    "queue_attention_masks"]
                queue_attention_masks = pt_util.to_numpy(
                    F.interpolate(queue_attention_masks, (im_height, im_width),
                                  mode="bilinear",
                                  align_corners=False).permute(0, 2, 3, 1))
                queue_attention_masks = np.pad(queue_attention_masks,
                                               ((0, 0), (10, 10), (10, 10),
                                                (0, 0)), "constant")

                rand_order = np.random.choice(len(inputs),
                                              min(len(inputs), 25),
                                              replace=False)

                subplots = []
                attention_color = np.array([255, 0, 0], dtype=np.float32)
                for bb in rand_order:
                    images = []
                    for img_src, mask_src in ((inputs, attention_masks),
                                              (queue_inputs,
                                               queue_attention_masks)):
                        image = img_src[bb].copy()
                        attention_mask = mask_src[bb].copy()
                        attention_mask -= attention_mask.min()
                        attention_mask /= attention_mask.max() + 1e-8
                        output = (attention_mask * attention_color
                                  ) + (1 - attention_mask) * image
                        output = output.astype(np.uint8)
                        images.append(image)
                        images.append(output)
                    subplot = drawing.subplot(images, 2, 2, im_width,
                                              im_height)
                    subplots.append(subplot)

                n_cols = int(np.sqrt(len(subplots)))
                n_rows = len(subplots) // n_cols

                subplot = drawing.subplot(subplots,
                                          n_rows,
                                          n_cols,
                                          im_width * 2,
                                          im_height * 2,
                                          border=5)
                image_output["images/attention"] = subplot

        return image_output
Exemple #5
0
    def get_embeddings(self, inputs, jigsaw=False, shuffle=False):
        data = inputs["data"]
        if shuffle:
            # Shuffle
            shuffle_order = torch.randperm(data.shape[0], device=self.device)
            unshuffle_order = torch.zeros(data.shape[0],
                                          dtype=torch.int64,
                                          device=self.device)
            unshuffle_order.index_copy_(
                0, shuffle_order,
                torch.arange(data.shape[0], device=self.device))
            data = data[shuffle_order].contiguous()

        if jigsaw:
            if (data.shape[2] % 3) != 0 or (data.shape[3] % 3) != 0:
                data = F.pad(
                    data, (0, 3 - data.shape[3] % 3, 0, 3 - data.shape[2] % 3))
            # [N, C, H, W]
            data = pt_util.split_dim(data, 2, 3, data.shape[2] // 3)
            # [N, C, 3, H/3, W]
            data = pt_util.split_dim(data, 4, 3, data.shape[4] // 3)
            # [N, C, 3, H/3, 3, W/3]
            data = data.permute(0, 2, 4, 1, 3, 5).contiguous()
            # [N, 3, 3, C, H/3, W/3]
            data = pt_util.remove_dim(data, (1, 2))
            # [N*9, C, H/3, W/3]

        images = data.to(self.feature_extractor_device)
        return_val = self.extract_features(images)
        features = return_val["extracted_features"]

        if jigsaw:
            features = features.to(self.device)
            features = self.jigsaw_linear(features)
            features = pt_util.split_dim(features, 0, -1, 9)
            # Shuffle all permutations independently
            rand_orders = torch.stack([
                torch.randperm(9, device=features.device)
                for _ in range(features.shape[0])
            ])
            features = features[pt_util.expand_new_dim(
                torch.arange(features.shape[0], device=features.device), 1, 9),
                                rand_orders]
            features = pt_util.remove_dim(features, 2)
            features = self.jigsaw_embedding(features)
            return_val["extracted_features"] = features
            output = features
        else:
            features = features.to(self.device)
            return_val["extracted_features"] = features
            output = self.embedding(features)

        return_val["prenorm_features"] = output
        output = F.normalize(output, dim=1)

        return_val["embeddings"] = output

        if shuffle:
            # Unshuffle
            return_val_new = {}
            for key, val in return_val.items():
                if isinstance(val, torch.Tensor):
                    val = val.to(self.device)
                    val = val[unshuffle_order]
                return_val_new[key] = val
            return_val = return_val_new

        if "batch_types" in inputs:
            return_val = self.split_dict_by_type(inputs["batch_types"],
                                                 inputs["batch_sizes"],
                                                 return_val)
        return return_val
Exemple #6
0
def similarity_cross_entropy(similarities,
                             temperature,
                             n_feat,
                             n_rows1,
                             mask=None,
                             n_positives_per_row=None):
    global USE_FLOAT
    similarities = similarities / temperature
    if mask is None:
        assert n_positives_per_row is not None
        # Default identity mask
        mask = (torch.eye(n_feat, device=similarities.device,
                          dtype=torch.bool).repeat_interleave(
                              n_positives_per_row,
                              1).repeat_interleave(n_rows1, 0))

    assert mask.shape == similarities.shape
    similarities = pt_util.split_dim(similarities, 0, n_feat, n_rows1)
    mask = pt_util.split_dim(mask, 0, n_feat, n_rows1)

    # log similarity over (self + all other entries as denom)
    row_maxes = torch.max(similarities, dim=-1, keepdim=True)[0]
    scaled_similarities = similarities - row_maxes

    mask_row_sum = mask.sum(-1)
    if USE_FLOAT is None:
        USE_FLOAT = mask_row_sum.min() != mask_row_sum.max()
    if USE_FLOAT:
        float_mask = mask.float()
        inv_float_mask = 1 - float_mask
        neg_similarities = scaled_similarities * inv_float_mask + -2**20 * float_mask
        pos_similarities = scaled_similarities * float_mask + -2**20 * inv_float_mask
    else:
        # Same number of items per row
        neg_similarities = scaled_similarities[~mask].view(n_feat, n_rows1, -1)
        pos_similarities = scaled_similarities[mask].view(
            n_feat, n_rows1, mask.shape[2] - neg_similarities.shape[2])

    neg_similarities_exp = torch.exp(neg_similarities).sum(-1, keepdim=True)

    pos_similarities_exp = torch.exp(pos_similarities)
    similarity_log_softmax = pos_similarities - torch.log(
        pos_similarities_exp + neg_similarities_exp)
    dists = -similarity_log_softmax
    softmax_weights = torch.exp(similarity_log_softmax.detach())

    if USE_FLOAT:
        dists_mean = dists[mask].mean()
        softmax_weight = softmax_weights[mask].mean()
    else:
        dists_mean = dists.mean()
        softmax_weight = softmax_weights.mean()

    return dict(
        # similarity_log_softmax=similarity_log_softmax.mean(),
        dists=dists,
        dist=dists_mean,
        # similarity_raw_scores=similarity_raw_scores.mean(),
        # similarities=similarities.mean(),
        softmax_weights=softmax_weights,
        softmax_weight=softmax_weight,
    )