Example #1
0
def train_triplet_step(data, target, model, device, optimizer, miner, extras = None):
    model.train()
        
    loss_func = losses.MarginLoss()
    acc_calc = AccuracyCalculator()
    
    data = data.float().to(device)
    target = target.long().to(device)
    target = target.view((-1))
    
    if torch.is_tensor(extras):
        extras = extras.float().to(device)

    embedding = model(data, extras)
    triplets = miner.mine(embedding, target, embedding, target)
    
    loss = loss_func(embedding, target, triplets)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    with torch.no_grad():
        acc_dict = acc_calc.get_accuracy(embedding, embedding, target, target, embeddings_come_from_same_source=True)
    
    return loss.item(), acc_dict["precision_at_1"]
Example #2
0
def get_scores(inference_model, gallery_embeddings, query_embeddings,
               gallery_labels, query_labels):
    calculator = AccuracyCalculator()
    scores_dict = calculator.get_accuracy(
        query_embeddings.numpy(),
        gallery_embeddings.numpy(),
        query_labels,
        gallery_labels,
        embeddings_come_from_same_source=False)

    return scores_dict
 def RSAtest(self, test_dl):
     print('Evaluation on RSA test set beginning:')
     trEmbeds, trLabels = self.get_all_embeddings(self.train_dl)
     rEmbeds, rLabels = self.get_all_embeddings(test_dl)
     accuracy_calculator = AccuracyCalculator(include=(), k=10)
     results = accuracy_calculator.get_accuracy(
         rEmbeds.cpu().numpy(),
         trEmbeds.cpu().numpy(),
         rLabels.reshape((-1)).cpu().numpy(),
         trLabels.reshape((-1)).cpu().numpy(), False)
     print('Accuracy on RSA test set: {}'.format(
         round(results['precision_at_1'] * 100.00, 2)))
     print('MAP@R on RSA test set: {}'.format(
         round(results['mean_average_precision'], 4)))
     return results
Example #4
0
    def validation_epoch_end(self, outputs):
        embeddings = (torch.cat([x["embeddings"]
                                 for x in outputs]).cpu().numpy())
        labels = torch.cat([x["labels"] for x in outputs]).cpu().numpy()
        acc_calc = AccuracyCalculator()

        ref_embeddings = (embeddings if not self.hparams.infer_from_train
                          else self.data_index.retrieve(
                              range(len(self.data_index))))
        ref_labels = (labels if not self.hparams.infer_from_train else
                      self.data_index.labels.cpu().numpy())
        same_source = not self.hparams.infer_from_train

        torch.cuda.empty_cache()
        metrics = acc_calc.get_accuracy(embeddings, ref_embeddings, labels,
                                        ref_labels, same_source)

        metrics["n_samples"] = len(labels)

        if self.attacks is not None:
            for epsilon, _ in self.attacks.items():
                tmp_embeddings = (torch.cat([
                    x[f"embeddings_eps={epsilon}"] for x in outputs
                ]).cpu().numpy())

                tmp_metrics = acc_calc.get_accuracy(
                    tmp_embeddings,
                    ref_embeddings,
                    labels,
                    ref_labels,
                    same_source,
                )

                if self.adv_training:
                    for k, v in tmp_metrics.items():
                        metrics[f"{k}_adv"] = v
                else:
                    for k, v in tmp_metrics.items():
                        metrics[f"{k}_eps={epsilon}"] = v

        for k, v in metrics.items():
            self.log(k, v)

        if hasattr(self, "data_index"):
            self.data_index.reset()
Example #5
0
def validation_constructive(valid_loader, train_loader, model, scaler):
    calculator = AccuracyCalculator(k=1)
    model.eval()

    query_embeddings, query_labels = compute_embeddings(
        valid_loader, model, scaler)
    reference_embeddings, reference_labels = compute_embeddings(
        train_loader, model, scaler)

    acc_dict = calculator.get_accuracy(query_embeddings,
                                       reference_embeddings,
                                       query_labels,
                                       reference_labels,
                                       embeddings_come_from_same_source=False)

    del query_embeddings, query_labels, reference_embeddings, reference_labels
    torch.cuda.empty_cache()

    return acc_dict
Example #6
0
def representation(model, device, data_loader):
    model.eval()
    target_list = []
    embedding_list = []
    total_loss = 0.0
        
    loss_func = losses.MarginLoss()
    acc_calc = AccuracyCalculator()
    miner = miners.BatchEasyHardMiner(pos_strategy='all', neg_strategy='all')
    
    acc_dicts = defaultdict(list)
    with torch.no_grad():
        for [data, target, extras] in data_loader:
            data = data.float().to(device) # add channel dimension
            target = target.long().to(device)
            target = target.view((-1,))
            
            if torch.is_tensor(extras):
                extras = extras.float().to(device)
            
            target_list = target_list + list(target.cpu().detach().tolist())

            embedding = model(data, extras)
            triplets = miner.mine(embedding, target, embedding, target)
            
            embedding_list = embedding_list + list(embedding.cpu().detach().tolist())
            
            total_loss += loss_func(embedding, target, triplets)
            
            acc_dict = acc_calc.get_accuracy(embedding, embedding, target, target, embeddings_come_from_same_source=True)
            for key in acc_dict:
                acc_dicts[key].append(acc_dict[key])

    total_loss /= len(data_loader.dataset)
    
    avg_acc_dict = {key: np.mean(acc_dicts[key]) for key in acc_dicts}
    return total_loss, avg_acc_dict, np.array(embedding_list), np.array(target_list)
Example #7
0
    def __init__(self,
                 datamodule: MusicMetricDatamodule,
                 conf: DictConfig = DEFAULT_HPARAMS):
        super().__init__()

        ### Lightning Config ###
        self.save_hyperparameters()
        self.automatic_optimization = False

        ### Setup Dataset ###
        # TODO: Figure out how to get PL's hyperparameter management to play well with MyPy
        # Right now, it's a mess.
        self.dm: MusicMetricDatamodule = self.hparams.datamodule  # type: ignore
        self.dm.epoch_length = self.hparams.conf.epoch_length  # type: ignore
        if not self.dm.is_setup:
            self.dm.setup()
        self.dm.batch_size = self.hparams.conf.batch_size  # type: ignore
        self.dm.m_per_class = self.hparams.conf.m_per_class  # type: ignore
        # For simplicity:
        assert self.hparams.conf.batch_size % self.hparams.conf.batch_size == 0  # type: ignore

        ### Instantiate Model ###
        self.encoder = make_encoder(
            kind=conf.encoder,
            pretrained=self.hparams.conf.encoder_params.
            pretrained,  # type: ignore
            freeze_weights=self.hparams.conf.encoder_params.
            freeze_weights,  # type: ignore
            max_pool=self.hparams.conf.encoder_params.max_pool,  # type: ignore
        )
        self.embedder = EmbeddingMLP(
            category_embedding_dim=self.hparams.conf.
            category_embedding_dim,  # type: ignore
            hidden_dim=self.hparams.conf.hidden_dim,  # type: ignore
            out_dim=self.hparams.conf.embedding_dim,  # type: ignore
            normalize_embeddings=False,  # We'll normalize in the loss function
        )

        ### Instantiate Model Copy for MoCo Track Queue ###
        self.key_encoder = make_encoder(
            kind=self.hparams.conf.encoder,  # type: ignore
            pretrained=self.hparams.conf.encoder_params.
            pretrained,  # type: ignore
            freeze_weights=self.hparams.conf.encoder_params.
            freeze_weights,  # type: ignore
            max_pool=self.hparams.conf.encoder_params.max_pool,  # type: ignore
        )
        self.key_embedder = EmbeddingMLP(
            category_embedding_dim=self.hparams.conf.
            category_embedding_dim,  # type: ignore
            hidden_dim=self.hparams.conf.hidden_dim,  # type: ignore
            out_dim=self.hparams.conf.embedding_dim,  # type: ignore
            normalize_embeddings=False,
        )
        for model, key_model in (
            (self.encoder, self.key_encoder),
            (self.embedder, self.key_embedder),
        ):
            copy_parameters(model, key_model)

        self.key_encoder_momentum = self.hparams.conf.key_encoder_momentum  # type: ignore

        ### Create the Queue ###
        self.register_buffer(
            "queue",
            torch.randn(self.hparams.conf.embedding_dim,
                        self.hparams.conf.queue_size))  # type: ignore
        self.queue: torch.Tensor
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer(
            "queue_labels", torch.zeros((conf.queue_size, ), dtype=torch.long))
        self.queue_labels: torch.Tensor
        self.register_buffer("queue_is_full",
                             torch.tensor([0], dtype=torch.bool))
        self.queue_is_full: torch.Tensor
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
        self.queue_ptr: torch.Tensor

        ### Create Category Label Buffers for the Queue, inited to -1, meaning missing ###
        for i in range(4):
            self.register_buffer(f"queue_{i}_labels",
                                 torch.full_like(self.queue_labels, -1))

        ### Instantiate Loss Function ###
        loss_func = self.hparams.conf.loss_func  # type: ignore
        if loss_func in ["selective", "xent", "bce"]:
            xent_only = True if loss_func in ["xent", "bce"] else False
            bce_all = True if loss_func == "bce" else False
            self.criterion = SelectivelyContrastiveLoss(
                hn_lambda=cast(
                    float,
                    self.hparams.conf.loss_params.hn_lambda),  # type: ignore
                temperature=cast(
                    float,
                    self.hparams.conf.loss_params.temperature),  # type: ignore
                hard_cutoff=cast(
                    float,
                    self.hparams.conf.loss_params.hard_cutoff),  # type: ignore
                xent_only=xent_only,
                bce_all=bce_all,
            )
        elif loss_func == "moco":
            self.criterion = MoCoCrossEntropyLoss(temperature=cast(
                float,
                self.hparams.conf.loss_params.temperature))  # type: ignore
        else:
            raise ValueError(f"Loss function {loss_func} not implemented")

        ### Add utilities for logging ###
        self.visualizer = UMAP(n_neighbors=10, min_dist=0.1, metric="cosine")
        self.accuracy = AccuracyCalculator()
Example #8
0
class MusicMetricLearner(pl.LightningModule):
    def __init__(self,
                 datamodule: MusicMetricDatamodule,
                 conf: DictConfig = DEFAULT_HPARAMS):
        super().__init__()

        ### Lightning Config ###
        self.save_hyperparameters()
        self.automatic_optimization = False

        ### Setup Dataset ###
        # TODO: Figure out how to get PL's hyperparameter management to play well with MyPy
        # Right now, it's a mess.
        self.dm: MusicMetricDatamodule = self.hparams.datamodule  # type: ignore
        self.dm.epoch_length = self.hparams.conf.epoch_length  # type: ignore
        if not self.dm.is_setup:
            self.dm.setup()
        self.dm.batch_size = self.hparams.conf.batch_size  # type: ignore
        self.dm.m_per_class = self.hparams.conf.m_per_class  # type: ignore
        # For simplicity:
        assert self.hparams.conf.batch_size % self.hparams.conf.batch_size == 0  # type: ignore

        ### Instantiate Model ###
        self.encoder = make_encoder(
            kind=conf.encoder,
            pretrained=self.hparams.conf.encoder_params.
            pretrained,  # type: ignore
            freeze_weights=self.hparams.conf.encoder_params.
            freeze_weights,  # type: ignore
            max_pool=self.hparams.conf.encoder_params.max_pool,  # type: ignore
        )
        self.embedder = EmbeddingMLP(
            category_embedding_dim=self.hparams.conf.
            category_embedding_dim,  # type: ignore
            hidden_dim=self.hparams.conf.hidden_dim,  # type: ignore
            out_dim=self.hparams.conf.embedding_dim,  # type: ignore
            normalize_embeddings=False,  # We'll normalize in the loss function
        )

        ### Instantiate Model Copy for MoCo Track Queue ###
        self.key_encoder = make_encoder(
            kind=self.hparams.conf.encoder,  # type: ignore
            pretrained=self.hparams.conf.encoder_params.
            pretrained,  # type: ignore
            freeze_weights=self.hparams.conf.encoder_params.
            freeze_weights,  # type: ignore
            max_pool=self.hparams.conf.encoder_params.max_pool,  # type: ignore
        )
        self.key_embedder = EmbeddingMLP(
            category_embedding_dim=self.hparams.conf.
            category_embedding_dim,  # type: ignore
            hidden_dim=self.hparams.conf.hidden_dim,  # type: ignore
            out_dim=self.hparams.conf.embedding_dim,  # type: ignore
            normalize_embeddings=False,
        )
        for model, key_model in (
            (self.encoder, self.key_encoder),
            (self.embedder, self.key_embedder),
        ):
            copy_parameters(model, key_model)

        self.key_encoder_momentum = self.hparams.conf.key_encoder_momentum  # type: ignore

        ### Create the Queue ###
        self.register_buffer(
            "queue",
            torch.randn(self.hparams.conf.embedding_dim,
                        self.hparams.conf.queue_size))  # type: ignore
        self.queue: torch.Tensor
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer(
            "queue_labels", torch.zeros((conf.queue_size, ), dtype=torch.long))
        self.queue_labels: torch.Tensor
        self.register_buffer("queue_is_full",
                             torch.tensor([0], dtype=torch.bool))
        self.queue_is_full: torch.Tensor
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
        self.queue_ptr: torch.Tensor

        ### Create Category Label Buffers for the Queue, inited to -1, meaning missing ###
        for i in range(4):
            self.register_buffer(f"queue_{i}_labels",
                                 torch.full_like(self.queue_labels, -1))

        ### Instantiate Loss Function ###
        loss_func = self.hparams.conf.loss_func  # type: ignore
        if loss_func in ["selective", "xent", "bce"]:
            xent_only = True if loss_func in ["xent", "bce"] else False
            bce_all = True if loss_func == "bce" else False
            self.criterion = SelectivelyContrastiveLoss(
                hn_lambda=cast(
                    float,
                    self.hparams.conf.loss_params.hn_lambda),  # type: ignore
                temperature=cast(
                    float,
                    self.hparams.conf.loss_params.temperature),  # type: ignore
                hard_cutoff=cast(
                    float,
                    self.hparams.conf.loss_params.hard_cutoff),  # type: ignore
                xent_only=xent_only,
                bce_all=bce_all,
            )
        elif loss_func == "moco":
            self.criterion = MoCoCrossEntropyLoss(temperature=cast(
                float,
                self.hparams.conf.loss_params.temperature))  # type: ignore
        else:
            raise ValueError(f"Loss function {loss_func} not implemented")

        ### Add utilities for logging ###
        self.visualizer = UMAP(n_neighbors=10, min_dist=0.1, metric="cosine")
        self.accuracy = AccuracyCalculator()

    def forward_step(self, images: torch.Tensor,
                     category_n: torch.Tensor) -> torch.Tensor:
        encoded = self.encoder(images)
        embeddings, _ = self.embedder(encoded, category_n)
        return embeddings

    def track_forward(
        self,
        images: torch.Tensor,
        track_category_n: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        images_query, images_key = images[:, 0], images[:, 1]
        query_embeddings = self.forward_step(images_query, track_category_n)
        key_embeddings = self.key_forward_step(images_key, track_category_n)
        return query_embeddings, key_embeddings

    @torch.no_grad()
    def key_forward_step(self, key_images: torch.Tensor,
                         category_n: torch.Tensor) -> torch.Tensor:
        """(1) update the parameters of the reference model, (2) shuffle
        the input images along the batch dimension to simulate
        batch-norm across GPUs, and (3) unshuffle the returned embeddings.

        This prevents the model from learning query-key relationships by leaking
        information across the batch dimension through the batch norm parameters.

        Note that this only makes sense for models where we use a replacement for
        batchnorm that emulates multi-gpu behavior, with separate parameters across
        a split or splits. Not all model architectures have this implemented.
        """
        for query_model, key_model in (
            (self.encoder, self.key_encoder),
            (self.embedder, self.key_embedder),
        ):
            copy_parameters(query_model,
                            key_model,
                            momentum=self.key_encoder_momentum)

        key_images_shuffled, idx_unshuffle = batch_shuffle_single_gpu(
            key_images)
        key_encoded_shuffled = self.key_encoder(key_images_shuffled)
        key_embeddings_shuffled, _ = self.key_embedder(key_encoded_shuffled,
                                                       category_n)
        key_embeddings = batch_unshuffle_single_gpu(key_embeddings_shuffled,
                                                    idx_unshuffle)
        return key_embeddings

    def retrieve_embeddings_labels_from_queue(
            self, category_idx: Optional[int] = None):
        """Return the correctly-formatted embeddings and labels from the queue
        for downstream evaluation. Pass a category label index to retrieve class labels instead of
        track IDs.
        """
        if category_idx is None:
            label_queue = self.queue_labels
        else:
            label_queue = getattr(self, f"queue_{category_idx}_labels")
        if not self.queue_is_full.item():
            ptr = cast(int, self.queue_ptr.item())
            key_embeddings_from_queue = self.queue.T[:ptr].clone().detach()
            key_labels_from_queue = label_queue[:ptr].clone().detach()
        else:
            key_embeddings_from_queue = self.queue.T.clone().detach()
            key_labels_from_queue = label_queue.T.clone().detach()
        if category_idx is None:
            return key_embeddings_from_queue, key_labels_from_queue
        else:
            return self.remove_missing_labels_from_queue_segment(
                key_embeddings_from_queue, key_labels_from_queue)

    def remove_missing_labels_from_queue_segment(
            self, key_embeddings: torch.Tensor,
            key_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        missing_mask = key_labels != -1
        return key_embeddings[missing_mask], key_labels[missing_mask]

    def moco_track_loss(
        self,
        query_embeddings: torch.Tensor,
        track_labels: torch.Tensor,
    ) -> torch.Tensor:
        """Track loss using the queue to supply reference embeddings"""
        (
            key_embeddings_from_queue,
            key_labels_from_queue,
        ) = self.retrieve_embeddings_labels_from_queue()

        # log_callback = self.make_log_callback("train_track") if self.training else None
        log_callback = None

        track_loss, _ = self.criterion.forward(
            query_embeddings,
            track_labels,
            key_embeddings=key_embeddings_from_queue,
            key_labels=key_labels_from_queue,
            log_callback=log_callback,
        )
        return track_loss

    def moco_category_loss(
        self,
        query_category_embeddings: torch.Tensor,
        class_labels: torch.Tensor,
        category_idx: int,
    ):
        (
            key_embeddings_for_category,
            key_labels_for_category,
        ) = self.retrieve_embeddings_labels_from_queue(
            category_idx=category_idx)

        # log_callback = self.make_log_callback(f"train_cat{category_idx}") if self.training else None
        log_callback = None

        category_loss, _ = self.criterion.forward(
            query_category_embeddings,
            class_labels,
            key_embeddings=key_embeddings_for_category,
            key_labels=key_labels_for_category,
            log_callback=log_callback,
        )

        return category_loss

    @torch.no_grad()
    def dequeue_enqueue(
        self,
        new_keys: torch.Tensor,
        new_labels: torch.Tensor,
        category_idx: int,
        new_category_labels: torch.Tensor,
    ) -> None:
        """Add new embeddings and labels to the queue, evicting oldest keys
        if the queue is full
        """
        batch_size = new_keys.shape[0]
        ptr = int(self.queue_ptr)
        self.queue[:, ptr:ptr + batch_size] = new_keys.T  # store batch-last
        self.queue_labels[ptr:ptr + batch_size] = new_labels

        category_label_queue = getattr(self, f"queue_{category_idx}_labels")
        category_label_queue[ptr:ptr + batch_size] = new_category_labels

        for i in range(4):
            if i != category_idx:
                other_category_label_queue = getattr(self, f"queue_{i}_labels")
                other_category_label_queue[ptr:ptr + batch_size] = -1

        if not self.queue_is_full.item(
        ) and ptr + batch_size >= self.hparams.conf.queue_size:  # type: ignore
            self.queue_is_full[0] = torch.tensor(
                [1], device=self.queue_is_full.device, dtype=torch.bool)
        ptr = (ptr + batch_size) % cast(
            int, self.hparams.conf.queue_size)  # type: ignore
        self.queue_ptr[0] = ptr  # move pointer
        return None

    def validation_track_loss(
            self, images, track_labels,
            track_category_n) -> Tuple[torch.Tensor, torch.Tensor]:
        """Don't use this during training, because it doesn't use the queue.
        Returns loss and tensor of (batch_size, 2, embedding_dim), where rank 1
        includes query and key embeddings, in that order
        """
        assert not self.training
        assert images.shape[2] == 1 and len(images.shape) == 5
        query_embeddings, key_embeddings = self.track_forward(
            images=images, track_category_n=track_category_n)
        track_loss, _ = self.criterion.forward(
            query_embeddings,
            labels=track_labels,
            key_embeddings=key_embeddings,
            key_labels=track_labels,
        )
        return track_loss, torch.stack((query_embeddings, key_embeddings),
                                       dim=1)

    def category_loss(
            self, images: torch.Tensor, class_labels: torch.Tensor,
            category_n: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return the loss and embeddings for a given category."""
        assert images.shape[1] == 1 and len(images.shape) == 4
        category_embeddings = self.forward_step(images=images,
                                                category_n=category_n)
        category_idx = cast(int, category_n[0].item())
        category_loss = self.moco_category_loss(
            query_category_embeddings=category_embeddings,
            class_labels=class_labels,
            category_idx=category_idx,
        )
        return category_loss, category_embeddings

    def get_label_maps(
            self, stage: Literal["train", "val",
                                 "test"]) -> List[Dict[int, str]]:
        """Helper function to retrieve names of classes to make logged values
        more easily interpretable.
        """
        if stage == "val":
            loaders = self.val_dataloader()
        elif stage == "test":
            loaders = self.test_dataloader()
        elif stage == "train":
            loaders = self.train_dataloader()
        label_maps: List[Dict[int, str]] = []
        for loader in loaders:
            label_maps.append(loader.dataset.class_label_map)
        return label_maps

    def make_log_callback(
        self,
        prefix: str,
        also_log: Optional[Dict[str, Any]] = None,
    ) -> Callable:
        """Returns a function that, when called, sends key value pairs to the logger,
        with the key prefixed with the supplied prefix. Key-value pairs in `also_log`
        will be logged simultaneously with the other values (for instance, you can supply
        step, category, or epoch information here.)
        """
        def log_callback(
            logging_dict: Dict[str, Union[float, int, torch.Tensor,
                                          np.ndarray]],
            do_print=False,
        ):
            """Pass to loss functions, datasets, etc., to log arbitrary values using available logger"""
            to_log = {
                f"{prefix}_{key}": value
                for key, value in logging_dict.items()
            }
            if also_log is not None:
                to_log.update(also_log)
            if hasattr(self.logger.experiment, "log"):
                self.logger.experiment.log(to_log)
            else:
                for key, value in to_log.items():
                    self.log(key, value)
            if do_print:
                for key, value in to_log.items():
                    self.print(f"{key}: {value}")

        return log_callback

    @torch.no_grad()
    def evaluate_and_log(
        self,
        stage: Literal["train", "val", "test"],
        batch_part: Optional[MusicBatchPart] = None,
        visualize: bool = False,
        compute_accuracies: bool = True,
        batch_idx: Optional[int] = None,
        track_loss: Optional[torch.Tensor] = None,
        track_embeddings: Optional[torch.Tensor] = None,
        category_loss: Optional[torch.Tensor] = None,
        category_embeddings: Optional[torch.Tensor] = None,
        loss: Optional[torch.Tensor] = None,
    ):
        """Prepare and send values to the logger. With the exception of "stage", every argument
        is optional, but some arguments must be passed together to have any effect.

        Arguments:
            stage {"train", "val", "test"} -- Used as prefix to logging keys

        Keyword Arguments:
            batch_part {Optional[MusicBatchPart]} -- Training input; optionally, concat
            multiple batches to log a larger set. (default: {None})
            visualize {bool} -- Whether to log a visulization of the embeddings (default: {False})
            compute_accuracies {bool -- Whether compute and log accuracy metrics (default: {True})
            batch_idx {Optional[int]} -- Current batch index (default: {None})
            track_loss {Optional[torch.Tensor]} -- Loss just for track embeddings (default: {None})
            track_embeddings {Optional[torch.Tensor]} -- The actual track embeddings (default: {None})
            category_loss {Optional[torch.Tensor]} -- Loss just for category embeddings (default: {None})
            category_embeddings {Optional[torch.Tensor]} -- The actual category embeddings (default: {None})
            loss {Optional[torch.Tensor]} -- The final value used for optimization (default: {None})
        """
        if track_loss is not None:
            self.log(f"{stage}_track_loss", track_loss)
        if category_loss is not None:
            self.log(f"{stage}_category_loss", category_loss)
        if loss is not None:
            self.log(f"{stage}_loss", loss, prog_bar=True)
        if batch_idx is not None:
            self.log(f"{stage}_batch_idx", batch_idx)
        if all((compute_accuracies, track_embeddings is not None, batch_part)):
            assert batch_part is not None
            assert track_embeddings is not None
            track_labels = batch_part["track_labels"]
            normalized = F.normalize(track_embeddings, p=2,
                                     dim=-1).clone().contiguous()
            try:
                if stage == "train":
                    (
                        key_embeddings,
                        key_labels,
                    ) = self.retrieve_embeddings_labels_from_queue()
                    accuracy = self.accuracy.get_accuracy(
                        normalized,
                        F.normalize(key_embeddings, p=2,
                                    dim=-1).clone().contiguous(),
                        track_labels,
                        key_labels,
                        embeddings_come_from_same_source=False,
                    )
                else:
                    query_normalized, key_normalized = (
                        normalized[:, 0, :],
                        normalized[:, 1, :],
                    )
                    accuracy = self.accuracy.get_accuracy(
                        query_normalized,
                        key_normalized,
                        track_labels,
                        track_labels,
                        embeddings_come_from_same_source=False,
                    )
            except RuntimeError as e:
                accuracy = {"accuracy_error": 1}
                self.print(f"Error: {e}")
            category = cast(int, batch_part["category_n"][0].item())
            accuracy_log = {
                f"{stage}_cat{category}_track_{k}": v
                for k, v in accuracy.items()
            }
            if hasattr(self.logger.experiment, "log"):
                self.logger.experiment.log(accuracy_log)
        if all((compute_accuracies, category_embeddings
                is not None, batch_part)):
            assert batch_part is not None
            assert category_embeddings is not None
            category = cast(int, batch_part["category_n"][0].item())
            class_labels = batch_part["class_labels"]
            normalized = (F.normalize(category_embeddings, p=2,
                                      dim=1).clone().contiguous())
            try:
                accuracy = self.accuracy.get_accuracy(
                    normalized,
                    normalized,
                    class_labels,
                    class_labels,
                    embeddings_come_from_same_source=True,
                )
            except RuntimeError as e:
                accuracy = {"accuracy_error": 1}
                self.print(f"Error: {e}")
            accuracy_log = {
                f"{stage}_cat{category}_{k}": v
                for k, v in accuracy.items()
            }
            if hasattr(self.logger.experiment, "log"):
                self.logger.experiment.log(accuracy_log)
        if all((visualize, category_embeddings is not None, batch_part)):
            assert batch_part is not None
            assert category_embeddings is not None
            labels = batch_part["class_labels"].clone().detach()
            category = cast(
                int, batch_part["category_n"][0].clone().detach().item())
            label_map = self.get_label_maps(stage)[category]
            try:
                visualization = visualizer_hook(
                    visualizer=self.visualizer,
                    embeddings=F.normalize(category_embeddings),
                    labels=labels,
                    label_map=label_map,
                    split_name=stage,
                    show_plot=False,
                )
                if hasattr(self.logger.experiment, "log"):
                    self.logger.experiment.log(
                        {f"{stage}_{category}": visualization})
            except Exception as e:
                self.print(f"Visualization error: {e}")
        return None

    def training_step(self, batch, batch_idx):
        self.train()
        total_loss = 0
        opt = cast(Optimizer, self.optimizers())
        batch_part: MusicBatchPart
        for batch_part in batch:

            def closure():
                opt.zero_grad()
                images: torch.Tensor = batch_part["images"]
                category_n: torch.Tensor = batch_part["category_n"]
                category_idx: int = category_n[0].item()
                assert torch.all(
                    category_n ==
                    category_idx)  # Mixed-category minibatches won't work
                class_labels: torch.Tensor = batch_part["class_labels"]
                track_category_n: torch.Tensor = batch_part["track_category_n"]
                track_labels: torch.Tensor = batch_part["track_labels"]

                query_embeddings, key_embeddings = self.track_forward(
                    images=images, track_category_n=track_category_n)
                self.dequeue_enqueue(
                    new_keys=key_embeddings,
                    new_labels=track_labels,
                    category_idx=category_idx,
                    new_category_labels=class_labels,
                )
                track_loss = self.moco_track_loss(
                    query_embeddings=query_embeddings,
                    track_labels=track_labels,
                )

                first_images = images[:, 0]
                category_loss, category_embeddings = self.category_loss(
                    images=first_images,
                    class_labels=class_labels,
                    category_n=category_n,
                )
                loss: torch.Tensor = category_loss + (
                    self.hparams.conf.loss_params.track_loss_alpha *
                    track_loss)

                self.manual_backward(loss, opt)

                self.evaluate_and_log(
                    stage="train",
                    compute_accuracies=
                    False,  # This is very slow, but provides useful info while tuning hyperparameters
                    visualize=False,
                    batch_part=batch_part,
                    batch_idx=batch_idx,
                    track_loss=track_loss,
                    track_embeddings=query_embeddings,
                    category_loss=category_loss,
                    category_embeddings=category_embeddings,
                    loss=
                    None,  # Lightning doesn't like logging something called "loss" more than once per (its concept of) step
                )

                return loss

            loss = closure()
            opt.step()
            total_loss += loss.clone().detach().item()
        self.log("train_loss",
                 total_loss,
                 prog_bar=True,
                 on_step=True,
                 on_epoch=True)

    def reset_saved_batches(self) -> None:
        self.saved_batches: Dict[int, List[MusicBatchPart]] = {
            k: []
            for k in range(4)
        }
        self.saved_category_embeddings: Dict[int, List[torch.Tensor]] = {
            k: []
            for k in range(len(self.saved_batches))
        }
        self.saved_track_embeddings: Dict[int, List[torch.Tensor]] = {
            k: []
            for k in range(len(self.saved_batches))
        }
        gc.collect()

    @torch.no_grad()
    def on_validation_epoch_start(self) -> None:
        self.reset_saved_batches()

    @torch.no_grad()
    def validation_step(self, batch, batch_idx, dataloader_idx):
        batch_part = cast(MusicBatchPart, batch)
        track_loss, track_embeddings = self.validation_track_loss(
            images=batch_part["images"],
            track_labels=batch_part["track_labels"],
            track_category_n=batch_part["track_category_n"],
        )
        category_loss, category_embeddings = self.category_loss(
            images=batch_part["images"][:, 0],
            class_labels=batch_part["class_labels"],
            category_n=batch_part["category_n"],
        )
        loss = category_loss + (
            self.hparams.conf.loss_params.track_loss_alpha * track_loss)

        self.log("val_loss", loss, on_epoch=True)

        self.saved_batches[dataloader_idx].append(batch_part)
        self.saved_category_embeddings[dataloader_idx].append(
            category_embeddings)
        self.saved_track_embeddings[dataloader_idx].append(track_embeddings)

    @torch.no_grad()
    def on_validation_epoch_end(self) -> None:
        for category in range(len(self.saved_batches)):
            accumulated_batches: MusicBatchPart = {
                k: torch.cat([
                    batch_part[k]
                    for batch_part in self.saved_batches[category]
                ])
                for k in self.saved_batches[category][0]
            }
            # Pytorch doesn't free tensors with 0 references in containers; you have to clear them manually
            for batch_part in self.saved_batches[category]:
                for _, v in batch_part.items():
                    del v
            accumulated_category_embeddings: torch.Tensor
            accumulated_category_embeddings = torch.cat(
                self.saved_category_embeddings[category])
            for category_embedding in self.saved_category_embeddings[category]:
                del category_embedding
            accumulated_track_embeddings: torch.Tensor
            accumulated_track_embeddings = torch.cat(
                self.saved_track_embeddings[category])
            for track_embedding in self.saved_track_embeddings[category]:
                del track_embedding
            self.evaluate_and_log(
                stage="val",
                batch_part=accumulated_batches,
                compute_accuracies=True,
                visualize=True,
                category_embeddings=accumulated_category_embeddings,
                track_embeddings=accumulated_track_embeddings,
            )
            for _, v in accumulated_batches.items():
                del v
            del accumulated_batches
            del accumulated_category_embeddings
            del accumulated_track_embeddings
            gc.collect()

        self.reset_saved_batches()

    def configure_callbacks(self):
        lr_monitor = pl.callbacks.LearningRateMonitor()
        checkpoint = pl.callbacks.ModelCheckpoint(
            self.hparams.conf.checkpoint_path,
            monitor="train_loss_step",
            save_last=True,
            save_top_k=20,
        )
        return [checkpoint, lr_monitor]

    def configure_optimizers(self):
        lr = self.hparams.conf.learning_rate
        final_lr = self.hparams.conf.final_learning_rate
        momentum = self.hparams.conf.momentum
        weight_decay = self.hparams.conf.weight_decay
        if self.hparams.conf.optimizer == "adabound":
            opt = AdaBound(self.parameters(), lr=lr, final_lr=final_lr)
        elif self.hparams.conf.optimizer == "adam":
            opt = torch.optim.Adam(self.parameters(), lr=lr)
        elif self.hparams.conf.optimizer == "sgd":
            opt = torch.optim.SGD(self.parameters(),
                                  lr=lr,
                                  momentum=momentum,
                                  nesterov=True,
                                  weight_decay=weight_decay)
            max_steps = 60 * (self.hparams.conf.epoch_length //
                              self.hparams.conf.batch_size)
            sched = torch.optim.lr_scheduler.OneCycleLR(
                opt,
                max_lr=(25 * self.hparams.conf.learning_rate),
                total_steps=max_steps,
            )
            return [opt], [{
                "scheduler": sched,
                "interval": "step",
                "frequency": 1,
            }]
        elif self.hparams.conf.optimizer == "asgd":
            opt = torch.optim.ASGD(self.parameters(), lr=lr)
        else:
            raise ValueError("optimizer not implemented")

        return [opt]

    def train_dataloader(self, *args, **kwargs):
        loaders = [
            self.dm.train_dataloader(i, *args, **kwargs) for i in range(4)
        ]
        return loaders

    def val_dataloader(self, *args, **kwargs):
        loaders = [
            self.dm.val_dataloader(i, *args, **kwargs) for i in range(4)
        ]
        return loaders

    def test_dataloader(self, *args, **kwargs):
        loaders = [
            self.dm.test_dataloader(i, *args, **kwargs) for i in range(4)
        ]
        return loaders
Example #9
0
            model, device, validation_loader)
        lr_scheduler.step(total_loss)
        es(total_loss, step, model.state_dict(), output_dir / 'model.pt')

        save_checkpoint(
            model, optimizer, lr_scheduler,
            train_loader.sampler.state_dict(train_loader._infinite_iterator),
            step + 1, es, torch.random.get_rng_state())

_, acc_dict, embedding_list, target_list = representation(
    model, device, test_loader)
_, acc_dict_aug, embedding_list_aug, target_list_aug = representation(
    model, device, test_loader_aug)

results = {}
acc_calc = AccuracyCalculator()
for m, embedding, target in zip(['unaug', 'aug'],
                                [embedding_list, embedding_list_aug],
                                [target_list, target_list_aug]):
    results[m] = {}
    for grp in np.unique(target):
        target_bin = target == grp
        embedding_bin = embedding[target_bin, :]
        results[m][f'metrics_{classes[grp]}'] = acc_calc.get_accuracy(
            embedding_bin,
            embedding_bin,
            target_bin,
            target_bin,
            embeddings_come_from_same_source=True)

    results[m]['targets'] = target
Example #10
0
def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Running on device: {}'.format(device))

    # Data transformations
    trans_train = transforms.Compose([
        transforms.RandomApply(transforms=[
            transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
            # transforms.RandomPerspective(distortion_scale=0.6, p=1.0),
            transforms.RandomRotation(degrees=(0, 180)),
            transforms.RandomHorizontalFlip(),
        ]),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization,
    ])

    trans_val = transforms.Compose([
        # transforms.CenterCrop(120),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization,
    ])

    train_dataset = datasets.ImageFolder(os.path.join(data_dir,
                                                      "train_aligned"),
                                         transform=trans_train)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val_aligned"),
                                       transform=trans_val)

    # Prepare the model
    model = InceptionResnetV1(classify=False,
                              pretrained="vggface2",
                              dropout_prob=0.5).to(device)

    # for param in list(model.parameters())[:-8]:
    #     param.requires_grad = False

    trunk_optimizer = torch.optim.SGD(model.parameters(), lr=LR)

    # Set the loss function
    loss = losses.ArcFaceLoss(len(train_dataset.classes), 512)

    # Package the above stuff into dictionaries.
    models = {"trunk": model}
    optimizers = {"trunk_optimizer": trunk_optimizer}
    loss_funcs = {"metric_loss": loss}
    mining_funcs = {}
    lr_scheduler = {
        "trunk_scheduler_by_plateau":
        torch.optim.lr_scheduler.ReduceLROnPlateau(trunk_optimizer)
    }

    # Create the tester
    record_keeper, _, _ = logging_presets.get_record_keeper(
        "logs", "tensorboard")
    hooks = logging_presets.get_hook_container(record_keeper)

    dataset_dict = {"val": val_dataset, "train": train_dataset}
    model_folder = "training_saved_models"

    def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname,
                        *args):
        logging.info("UMAP plot for the {} split and label set {}".format(
            split_name, keyname))
        label_set = np.unique(labels)
        num_classes = len(label_set)
        fig = plt.figure(figsize=(8, 7))
        plt.gca().set_prop_cycle(
            cycler("color", [
                plt.cm.nipy_spectral(i)
                for i in np.linspace(0, 0.9, num_classes)
            ]))
        for i in range(num_classes):
            idx = labels == label_set[i]
            plt.plot(umap_embeddings[idx, 0],
                     umap_embeddings[idx, 1],
                     ".",
                     markersize=1)
        plt.show()

    tester = testers.GlobalEmbeddingSpaceTester(
        end_of_testing_hook=hooks.end_of_testing_hook,
        dataloader_num_workers=4,
        accuracy_calculator=AccuracyCalculator(
            include=['mean_average_precision_at_r'], k="max_bin_count"))

    end_of_epoch_hook = hooks.end_of_epoch_hook(tester,
                                                dataset_dict,
                                                model_folder,
                                                splits_to_eval=[('val',
                                                                 ['train'])])

    # Create the trainer
    trainer = trainers.MetricLossOnly(
        models,
        optimizers,
        batch_size,
        loss_funcs,
        mining_funcs,
        train_dataset,
        lr_schedulers=lr_scheduler,
        dataloader_num_workers=8,
        end_of_iteration_hook=hooks.end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook)

    trainer.train(num_epochs=num_epochs)
Example #11
0
def train(train_data, test_data, save_model, num_epochs, lr, embedding_size,
          batch_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set trunk model and replace the softmax layer with an identity function
    trunk = torchvision.models.resnet18(pretrained=True)
    trunk_output_size = trunk.fc.in_features
    trunk.fc = common_functions.Identity()
    trunk = torch.nn.DataParallel(trunk.to(device))

    # Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
    embedder = torch.nn.DataParallel(
        MLP([trunk_output_size, embedding_size]).to(device))

    # Set optimizers
    trunk_optimizer = torch.optim.Adam(trunk.parameters(),
                                       lr=lr / 10,
                                       weight_decay=0.0001)
    embedder_optimizer = torch.optim.Adam(embedder.parameters(),
                                          lr=lr,
                                          weight_decay=0.0001)

    # Set the loss function
    loss = losses.TripletMarginLoss(margin=0.1)

    # Set the mining function
    miner = miners.MultiSimilarityMiner(epsilon=0.1)

    # Set the dataloader sampler
    sampler = samplers.MPerClassSampler(train_data.targets,
                                        m=4,
                                        length_before_new_iter=len(train_data))

    save_dir = os.path.join(
        save_model, ''.join(str(lr).split('.')) + '_' + str(batch_size) + '_' +
        str(embedding_size))

    os.makedirs(save_dir, exist_ok=True)

    # Package the above stuff into dictionaries.
    models = {"trunk": trunk, "embedder": embedder}
    optimizers = {
        "trunk_optimizer": trunk_optimizer,
        "embedder_optimizer": embedder_optimizer
    }
    loss_funcs = {"metric_loss": loss}
    mining_funcs = {"tuple_miner": miner}

    record_keeper, _, _ = logging_presets.get_record_keeper(
        os.path.join(save_dir, "example_logs"),
        os.path.join(save_dir, "example_tensorboard"))
    hooks = logging_presets.get_hook_container(record_keeper)

    dataset_dict = {"val": test_data, "train": train_data}
    model_folder = "example_saved_models"

    def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname,
                        *args):
        logging.info("UMAP plot for the {} split and label set {}".format(
            split_name, keyname))
        label_set = np.unique(labels)
        num_classes = len(label_set)
        fig = plt.figure(figsize=(20, 15))
        plt.title(str(split_name) + '_' + str(num_embeddings))
        plt.gca().set_prop_cycle(
            cycler("color", [
                plt.cm.nipy_spectral(i)
                for i in np.linspace(0, 0.9, num_classes)
            ]))
        for i in range(num_classes):
            idx = labels == label_set[i]
            plt.plot(umap_embeddings[idx, 0],
                     umap_embeddings[idx, 1],
                     ".",
                     markersize=1)
        plt.show()

    # Create the tester
    tester = testers.GlobalEmbeddingSpaceTester(
        end_of_testing_hook=hooks.end_of_testing_hook,
        visualizer=umap.UMAP(),
        visualizer_hook=visualizer_hook,
        dataloader_num_workers=32,
        accuracy_calculator=AccuracyCalculator(k="max_bin_count"))

    end_of_epoch_hook = hooks.end_of_epoch_hook(tester,
                                                dataset_dict,
                                                model_folder,
                                                test_interval=1,
                                                patience=1)

    trainer = trainers.MetricLossOnly(
        models,
        optimizers,
        batch_size,
        loss_funcs,
        mining_funcs,
        train_data,
        sampler=sampler,
        dataloader_num_workers=32,
        end_of_iteration_hook=hooks.end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook)

    trainer.train(num_epochs=num_epochs)

    if save_model is not None:

        torch.save(models["trunk"].state_dict(),
                   os.path.join(save_dir, 'trunk.pth'))
        torch.save(models["embedder"].state_dict(),
                   os.path.join(save_dir, 'embedder.pth'))

        print('Model saved in ', save_dir)
#
# train_loader = torch.utils.data.DataLoader(dataset1, batch_size=256, shuffle=True)
# test_loader = torch.utils.data.DataLoader(dataset2, batch_size=256)

output_size = 4
input_size = 768
hidden_size = 200
training_epochs = 30

model = LSTM_model(input_size, output_size, hidden_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 40

### pytorch-metric-learning stuff ###
distance = distances.LpDistance()
reducer = reducers.MeanReducer()
loss_func = losses.ProxyNCALoss(output_size, hidden_size * 2, softmax_scale=1)
mining_func = miners.TripletMarginMiner(margin=0.2,
                                        distance=distance,
                                        type_of_triplets="semihard")
accuracy_calculator = AccuracyCalculator(
    include=("mean_average_precision_at_r", ), k=10)
### pytorch-metric-learning stuff ###

for epoch in range(1, num_epochs + 1):
    train(model, loss_func, mining_func, device, train_loader, optimizer,
          epoch)
    # test(dataset2, model, accuracy_calculator)

torch.save(model.state_dict(),
           './metric_saved_model_' + working_aspect + '.ckpt')