def main():
    parser = utils.ArgParser(description=__doc__)
    parser.add_argument("path_to_embeddings",
                        type=str,
                        help="Provide path to h5 embeddings file.")
    args = parser.parse_args()
    path_to_embeddings = Path(args.path_to_embeddings)
    print(f"Testing retrieval on embeddings: {path_to_embeddings}")

    # load embeddings
    with h5py.File(path_to_embeddings, "r") as h5:
        data_collector = dict(
            (key, np.array(h5[key]))
            for key in ["vid_emb", "par_emb", "clip_emb", "sent_emb"])

    # compute retrieval
    print(retrieval.VALHEADER)
    retrieval.compute_retrieval(data_collector, "vid_emb", "par_emb")
    retrieval.compute_retrieval(data_collector, "clip_emb", "sent_emb")
Ejemplo n.º 2
0
def main():
    parser = utils.ArgParser(description=__doc__)
    parser.add_argument("path_to_embeddings",
                        type=str,
                        help="Provide path to h5 embeddings file.")
    args = parser.parse_args()
    path_to_embeddings = Path(args.path_to_embeddings)
    print(f"Testing retrieval on embeddings: {path_to_embeddings}")

    # load embeddings
    with h5py.File(path_to_embeddings, "r") as h5:
        if "vid_emb" not in h5:
            # backwards compatibility
            (f_vid_emb, f_vid_emb_before_norm, f_clip_emb,
             f_clip_emb_before_norm, f_vid_context, f_vid_context_before_norm,
             f_par_emb, f_par_emb_before_norm, f_sent_emb,
             f_sent_emb_before_norm, f_par_context,
             f_par_context_before_norm) = ("vid_norm", "vid", "clip_norm",
                                           "clip", "vid_ctx_norm", "vid_ctx",
                                           "par_norm", "par", "sent_norm",
                                           "sent", "par_ctx_norm", "par_ctx")
            data_collector = dict(
                (key_target, np.array(h5[key_source]))
                for key_target, key_source in zip(
                    ["vid_emb", "par_emb", "clip_emb", "sent_emb"],
                    [f_vid_emb, f_par_emb, f_clip_emb, f_sent_emb]))
        else:
            # new version
            data_collector = dict(
                (key, np.array(h5[key]))
                for key in ["vid_emb", "par_emb", "clip_emb", "sent_emb"])

    # compute retrieval
    print(retrieval.VALHEADER)
    retrieval.compute_retrieval(data_collector, "vid_emb", "par_emb")
    retrieval.compute_retrieval(data_collector, "clip_emb", "sent_emb")
Ejemplo n.º 3
0
    def validate_epoch(
        self,
        data_loader: data.DataLoader,
        val_clips: bool = False,
        save_embs: bool = False
    ) -> (Tuple[float, float, bool, Tuple[Dict[str, float], Optional[Dict[
            str, float]]]]):
        """
        Validate a single epoch.

        Args:
            data_loader: Dataloader for validation
            val_clips: Whether to compute low-level retrieval results.
            save_embs: Save embeddings to file

        Returns:
            Tuple of validation loss, validation score, epoch is best and custom metrics: tuple of
                video-paragraph retrieval, optional clip-sentence retrieval.
        """
        self.hook_pre_val_epoch(
        )  # pre val epoch hook: set models to val and start timers
        forward_time_total = 0
        loss_total: th.Tensor = 0.
        contr_loss_total: th.Tensor = 0.
        cc_loss_total: th.Tensor = 0.
        data_collector = {}

        # decide what to collect
        save_clip_num, save_sent_num, save_key = [], [], []
        collect_keys = ["vid_emb", "par_emb"]
        if val_clips or save_embs:
            # clips can be requested for validation and when saving embeddings
            collect_keys += ["clip_emb", "sent_emb"]
        if save_embs:
            # only need the context when saving embeddings
            collect_keys += ["vid_context", "par_context"]

        # ---------- Dataloader Iteration ----------
        num_steps = 0
        pbar = tqdm(total=len(data_loader),
                    desc=f"Validate epoch {self.state.current_epoch}")
        for _step, batch in enumerate(
                data_loader):  # type: RetrievalDataBatchTuple
            # move data to cuda
            if self.check_cuda():
                batch.to_cuda(non_blocking=self.cfg.cuda_non_blocking)

            if save_embs:
                # collect meta information for saving
                save_clip_num.extend(batch.clip_num.cpu().numpy().tolist())
                save_sent_num.extend(batch.clip_num.cpu().numpy().tolist())
                save_key.extend(batch.key)

            # ---------- forward pass ----------
            self.hook_pre_step_timer()  # hook for step timing

            with autocast(enabled=self.cfg.fp16_val):
                visual_data = self.model_mgr.encode_visual(batch)
                text_data = self.model_mgr.encode_text(batch)
                contr_loss = self.compute_total_constrastive_loss(
                    visual_data, text_data)
                contr_loss_total += contr_loss
                cc_loss = self.compute_cyclecons_loss(visual_data, text_data)
                cc_loss_total += cc_loss
                loss_total += contr_loss + cc_loss

            self.hook_post_forward_step_timer()
            forward_time_total += self.timedelta_step_forward
            num_steps += 1

            # ---------- data collection ----------
            all_data = {**visual_data.dict(), **text_data.dict()}
            for key in collect_keys:
                emb = all_data.get(key)
                # collect embeddings into list, on CPU otherwise the gpu runs OOM
                if data_collector.get(key) is None:
                    data_collector[key] = [emb.data.cpu()]
                else:
                    data_collector[key] += [emb.data.cpu()]
            pbar.update()
        pbar.close()

        # ---------- validation done ----------

        # postprocess collected embeddings
        data_collector_norm = {}
        for key in collect_keys:
            data_collector[key] = th.cat(data_collector[key], dim=0)
            data_collector_norm[key] = F.normalize(data_collector[key])

        if save_embs:
            # save unnormalized embeddings
            os.makedirs(self.exp.path_embeddings, exist_ok=True)
            filename = self.exp.path_embeddings / f"embeddings_{self.state.current_epoch}.h5"
            with h5py.File(filename, mode="w") as h5:
                h5["clip_num"] = save_clip_num
                h5["sent_num"] = save_sent_num
                h5["key"] = save_key
                for key in collect_keys:
                    h5[key] = data_collector_norm[key].numpy()
                    h5[f"{key}_before_norm"] = data_collector[key].numpy()
            self.logger.info(f"Saved embeddings to {filename}\n")

        # calculate total loss and feed meters
        loss_total /= num_steps
        contr_loss_total /= num_steps
        cc_loss_total /= num_steps
        forward_time_total /= num_steps
        self.metrics.update_meter(CMeters.VAL_LOSS_CONTRASTIVE,
                                  contr_loss_total)
        self.metrics.update_meter(CMeters.VAL_LOSS_CC, cc_loss_total)

        # calculate video-paragraph retrieval and print output table
        self.logger.info(retrieval.VALHEADER)
        res_v2p, res_p2v, sum_vp_at_1, str_vp = retrieval.compute_retrieval(
            data_collector_norm,
            "vid_emb",
            "par_emb",
            print_fn=self.logger.info)

        # calculate clip-sentence retrieval and print output table
        res_c2s, res_s2c, sum_cs_at_1, clipsent_results = None, None, None, None
        str_cs = ""
        if val_clips:
            res_c2s, res_s2c, sum_cs_at_1, str_cs = retrieval.compute_retrieval(
                data_collector_norm,
                "clip_emb",
                "sent_emb",
                print_fn=self.logger.info)
            clipsent_results = (res_c2s, res_s2c, sum_cs_at_1)

        # feed retrieval results to meters
        for modality, dict_ret in zip(CMeters.RET_MODALITIES,
                                      [res_v2p, res_p2v, res_c2s, res_s2c]):
            if dict_ret is None:
                continue
            # iterate over result keys
            for metric in CMeters.RET_METRICS:
                # feed averagemeters
                logger_class = "val_ret"
                if metric == "r1":
                    logger_class = "val_base"
                self.metrics.update_meter(
                    f"{logger_class}/{modality}-{metric}", dict_ret[metric])

        # print some more details about the retrieval (time, number of datapoints)
        self.logger.info(
            f"Loss {loss_total:.5f} (Contr: {contr_loss_total:.5f}, CC: {cc_loss_total:.5f}) "
            f"Retrieval: {str_vp}{str_cs}total {timer() - self.timer_val_epoch:.3f}s, "
            f"forward {forward_time_total:.3f}s")

        # find field which determines whether this is a new best epoch
        if self.cfg.val.det_best_field == "val_score_at_1":
            val_score = sum_vp_at_1
        elif self.cfg.val.det_best_field == "val_loss":
            val_score = loss_total
        elif self.cfg.val.det_best_field == "val_clip_sent_score_at_1":
            val_score = sum_cs_at_1
        else:
            raise NotImplementedError(
                f"best field {self.cfg.val.det_best_field} not known")

        # check for a new best epoch and update validation results
        is_best = self.check_is_new_best(val_score)
        self.hook_post_val_epoch(loss_total, is_best)

        if self.is_test:
            # for test runs, save the validation results separately to a file
            self.metrics.feed_metrics(False, self.state.total_step,
                                      self.state.current_epoch)
            metrics_file = self.exp.path_base / f"val_ep_{self.state.current_epoch}.json"
            self.metrics.save_epoch_to_file(metrics_file)
            self.logger.info(f"Saved validation results to {metrics_file}")

        return loss_total, val_score, is_best, ((res_v2p, res_p2v,
                                                 sum_vp_at_1),
                                                clipsent_results)