def main(): parser = utils.ArgParser(description=__doc__) parser.add_argument("path_to_embeddings", type=str, help="Provide path to h5 embeddings file.") args = parser.parse_args() path_to_embeddings = Path(args.path_to_embeddings) print(f"Testing retrieval on embeddings: {path_to_embeddings}") # load embeddings with h5py.File(path_to_embeddings, "r") as h5: data_collector = dict( (key, np.array(h5[key])) for key in ["vid_emb", "par_emb", "clip_emb", "sent_emb"]) # compute retrieval print(retrieval.VALHEADER) retrieval.compute_retrieval(data_collector, "vid_emb", "par_emb") retrieval.compute_retrieval(data_collector, "clip_emb", "sent_emb")
def main(): parser = utils.ArgParser(description=__doc__) parser.add_argument("path_to_embeddings", type=str, help="Provide path to h5 embeddings file.") args = parser.parse_args() path_to_embeddings = Path(args.path_to_embeddings) print(f"Testing retrieval on embeddings: {path_to_embeddings}") # load embeddings with h5py.File(path_to_embeddings, "r") as h5: if "vid_emb" not in h5: # backwards compatibility (f_vid_emb, f_vid_emb_before_norm, f_clip_emb, f_clip_emb_before_norm, f_vid_context, f_vid_context_before_norm, f_par_emb, f_par_emb_before_norm, f_sent_emb, f_sent_emb_before_norm, f_par_context, f_par_context_before_norm) = ("vid_norm", "vid", "clip_norm", "clip", "vid_ctx_norm", "vid_ctx", "par_norm", "par", "sent_norm", "sent", "par_ctx_norm", "par_ctx") data_collector = dict( (key_target, np.array(h5[key_source])) for key_target, key_source in zip( ["vid_emb", "par_emb", "clip_emb", "sent_emb"], [f_vid_emb, f_par_emb, f_clip_emb, f_sent_emb])) else: # new version data_collector = dict( (key, np.array(h5[key])) for key in ["vid_emb", "par_emb", "clip_emb", "sent_emb"]) # compute retrieval print(retrieval.VALHEADER) retrieval.compute_retrieval(data_collector, "vid_emb", "par_emb") retrieval.compute_retrieval(data_collector, "clip_emb", "sent_emb")
def validate_epoch( self, data_loader: data.DataLoader, val_clips: bool = False, save_embs: bool = False ) -> (Tuple[float, float, bool, Tuple[Dict[str, float], Optional[Dict[ str, float]]]]): """ Validate a single epoch. Args: data_loader: Dataloader for validation val_clips: Whether to compute low-level retrieval results. save_embs: Save embeddings to file Returns: Tuple of validation loss, validation score, epoch is best and custom metrics: tuple of video-paragraph retrieval, optional clip-sentence retrieval. """ self.hook_pre_val_epoch( ) # pre val epoch hook: set models to val and start timers forward_time_total = 0 loss_total: th.Tensor = 0. contr_loss_total: th.Tensor = 0. cc_loss_total: th.Tensor = 0. data_collector = {} # decide what to collect save_clip_num, save_sent_num, save_key = [], [], [] collect_keys = ["vid_emb", "par_emb"] if val_clips or save_embs: # clips can be requested for validation and when saving embeddings collect_keys += ["clip_emb", "sent_emb"] if save_embs: # only need the context when saving embeddings collect_keys += ["vid_context", "par_context"] # ---------- Dataloader Iteration ---------- num_steps = 0 pbar = tqdm(total=len(data_loader), desc=f"Validate epoch {self.state.current_epoch}") for _step, batch in enumerate( data_loader): # type: RetrievalDataBatchTuple # move data to cuda if self.check_cuda(): batch.to_cuda(non_blocking=self.cfg.cuda_non_blocking) if save_embs: # collect meta information for saving save_clip_num.extend(batch.clip_num.cpu().numpy().tolist()) save_sent_num.extend(batch.clip_num.cpu().numpy().tolist()) save_key.extend(batch.key) # ---------- forward pass ---------- self.hook_pre_step_timer() # hook for step timing with autocast(enabled=self.cfg.fp16_val): visual_data = self.model_mgr.encode_visual(batch) text_data = self.model_mgr.encode_text(batch) contr_loss = self.compute_total_constrastive_loss( visual_data, text_data) contr_loss_total += contr_loss cc_loss = self.compute_cyclecons_loss(visual_data, text_data) cc_loss_total += cc_loss loss_total += contr_loss + cc_loss self.hook_post_forward_step_timer() forward_time_total += self.timedelta_step_forward num_steps += 1 # ---------- data collection ---------- all_data = {**visual_data.dict(), **text_data.dict()} for key in collect_keys: emb = all_data.get(key) # collect embeddings into list, on CPU otherwise the gpu runs OOM if data_collector.get(key) is None: data_collector[key] = [emb.data.cpu()] else: data_collector[key] += [emb.data.cpu()] pbar.update() pbar.close() # ---------- validation done ---------- # postprocess collected embeddings data_collector_norm = {} for key in collect_keys: data_collector[key] = th.cat(data_collector[key], dim=0) data_collector_norm[key] = F.normalize(data_collector[key]) if save_embs: # save unnormalized embeddings os.makedirs(self.exp.path_embeddings, exist_ok=True) filename = self.exp.path_embeddings / f"embeddings_{self.state.current_epoch}.h5" with h5py.File(filename, mode="w") as h5: h5["clip_num"] = save_clip_num h5["sent_num"] = save_sent_num h5["key"] = save_key for key in collect_keys: h5[key] = data_collector_norm[key].numpy() h5[f"{key}_before_norm"] = data_collector[key].numpy() self.logger.info(f"Saved embeddings to {filename}\n") # calculate total loss and feed meters loss_total /= num_steps contr_loss_total /= num_steps cc_loss_total /= num_steps forward_time_total /= num_steps self.metrics.update_meter(CMeters.VAL_LOSS_CONTRASTIVE, contr_loss_total) self.metrics.update_meter(CMeters.VAL_LOSS_CC, cc_loss_total) # calculate video-paragraph retrieval and print output table self.logger.info(retrieval.VALHEADER) res_v2p, res_p2v, sum_vp_at_1, str_vp = retrieval.compute_retrieval( data_collector_norm, "vid_emb", "par_emb", print_fn=self.logger.info) # calculate clip-sentence retrieval and print output table res_c2s, res_s2c, sum_cs_at_1, clipsent_results = None, None, None, None str_cs = "" if val_clips: res_c2s, res_s2c, sum_cs_at_1, str_cs = retrieval.compute_retrieval( data_collector_norm, "clip_emb", "sent_emb", print_fn=self.logger.info) clipsent_results = (res_c2s, res_s2c, sum_cs_at_1) # feed retrieval results to meters for modality, dict_ret in zip(CMeters.RET_MODALITIES, [res_v2p, res_p2v, res_c2s, res_s2c]): if dict_ret is None: continue # iterate over result keys for metric in CMeters.RET_METRICS: # feed averagemeters logger_class = "val_ret" if metric == "r1": logger_class = "val_base" self.metrics.update_meter( f"{logger_class}/{modality}-{metric}", dict_ret[metric]) # print some more details about the retrieval (time, number of datapoints) self.logger.info( f"Loss {loss_total:.5f} (Contr: {contr_loss_total:.5f}, CC: {cc_loss_total:.5f}) " f"Retrieval: {str_vp}{str_cs}total {timer() - self.timer_val_epoch:.3f}s, " f"forward {forward_time_total:.3f}s") # find field which determines whether this is a new best epoch if self.cfg.val.det_best_field == "val_score_at_1": val_score = sum_vp_at_1 elif self.cfg.val.det_best_field == "val_loss": val_score = loss_total elif self.cfg.val.det_best_field == "val_clip_sent_score_at_1": val_score = sum_cs_at_1 else: raise NotImplementedError( f"best field {self.cfg.val.det_best_field} not known") # check for a new best epoch and update validation results is_best = self.check_is_new_best(val_score) self.hook_post_val_epoch(loss_total, is_best) if self.is_test: # for test runs, save the validation results separately to a file self.metrics.feed_metrics(False, self.state.total_step, self.state.current_epoch) metrics_file = self.exp.path_base / f"val_ep_{self.state.current_epoch}.json" self.metrics.save_epoch_to_file(metrics_file) self.logger.info(f"Saved validation results to {metrics_file}") return loss_total, val_score, is_best, ((res_v2p, res_p2v, sum_vp_at_1), clipsent_results)