def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ stats = { "_type": "val_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "time_diff": self.iter_timer.seconds(), "gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), "RAM": "{:.2f}/{:.2f}G".format(*misc.cpu_mem_usage()), } if self._cfg.DATA.MULTI_LABEL: stats["map"] = get_map( torch.cat(self.all_preds).cpu().numpy(), torch.cat(self.all_labels).cpu().numpy(), ) else: top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples self.min_top1_err = min(self.min_top1_err, top1_err) self.min_top5_err = min(self.min_top5_err, top5_err) stats["top1_err"] = top1_err stats["top5_err"] = top5_err stats["min_top1_err"] = self.min_top1_err stats["min_top5_err"] = self.min_top5_err logging.log_json_stats(stats)
def log_epoch_stats(self, cur_epoch, preds=[], labels=[]): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ if self.num_samples <= 0: self.logger.warning( "TrainMeter log_epoch_stats numSample {}".format( self.num_samples)) return None eta_sec = self.iter_timer.seconds() * ( self.MAX_EPOCH - (cur_epoch + 1) * self.epoch_iters) eta = str(datetime.timedelta(seconds=int(eta_sec))) stats = { "_type": "train_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "time_diff": self.iter_timer.seconds(), "eta": eta, "lr": self.lr, "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), "RAM": "{:.2f}/{:.2f} GB".format(*misc.cpu_mem_usage()), } if not self._cfg.DATA.MULTI_LABEL: top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples avg_loss = self.loss_total / self.num_samples stats["top1_err"] = top1_err stats["top5_err"] = top5_err stats["loss"] = avg_loss if len(preds) > 0 and len(labels) > 0: calc_binary_stats(preds, labels, stats, self._cfg) self.logger.info(stats) return stats
def log_iter_stats(self, cur_epoch, cur_iter): """ log the stats of the current iteration. Args: cur_epoch (int): the number of current epoch. cur_iter (int): the number of current iteration. """ if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: return eta_sec = self.iter_timer.seconds() * (self.max_iter - cur_iter - 1) eta = str(datetime.timedelta(seconds=int(eta_sec))) mem_usage = misc.gpu_mem_usage() stats = { "_type": "val_iter", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.max_iter), "time_diff": self.iter_timer.seconds(), "eta": eta, "verb_top1_acc": self.mb_verb_top1_acc.get_win_median(), "verb_top5_acc": self.mb_verb_top5_acc.get_win_median(), "noun_top1_acc": self.mb_noun_top1_acc.get_win_median(), "noun_top5_acc": self.mb_noun_top5_acc.get_win_median(), "top1_acc": self.mb_top1_acc.get_win_median(), "top5_acc": self.mb_top5_acc.get_win_median(), "mem": int(np.ceil(mem_usage)), } log_to_tensorboard(self.tb_writer, stats) logging.log_json_stats(stats)
def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ eta_sec = self.iter_timer.seconds() * ( self.MAX_EPOCH - (cur_epoch + 1) * self.epoch_iters ) eta = str(datetime.timedelta(seconds=int(eta_sec))) mem_usage = misc.gpu_mem_usage() top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples avg_loss = self.loss_total / self.num_samples stats = { "_type": "train_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "time_diff": self.iter_timer.seconds(), "eta": eta, "top1_err": top1_err, "top5_err": top5_err, "loss": avg_loss, "lr": self.lr, "mem": int(np.ceil(mem_usage)), } logging.log_json_stats(stats)
def log_iter_stats(self, cur_epoch, cur_iter): """ log the stats of the current iteration. Args: cur_epoch (int): the number of current epoch. cur_iter (int): the number of current iteration. """ if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: return eta_sec = self.iter_timer.seconds() * ( self.MAX_EPOCH - (cur_epoch * self.epoch_iters + cur_iter + 1) ) eta = str(datetime.timedelta(seconds=int(eta_sec))) mem_usage = misc.gpu_mem_usage() stats = { "_type": "train_iter", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), "time_diff": self.iter_timer.seconds(), "eta": eta, "top1_err": self.mb_top1_err.get_win_median(), "top5_err": self.mb_top5_err.get_win_median(), "loss": self.loss.get_win_median(), "lr": self.lr, "mem": int(np.ceil(mem_usage)), } logging.log_json_stats(stats)
def log_iter_stats(self, cur_epoch, cur_iter): """ log the stats of the current iteration. Args: cur_epoch (int): the number of current epoch. cur_iter (int): the number of current iteration. """ if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: return eta_sec = self.iter_timer.seconds() * (self.max_iter - cur_iter - 1) eta = str(datetime.timedelta(seconds=int(eta_sec))) mem_usage = misc.gpu_mem_usage() stats = { "_type": "val_iter", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.max_iter), "time_diff": self.iter_timer.seconds(), "time_left": eta, # "top1_err": self.mb_top1_err.get_win_median(), # "top5_err": self.mb_top5_err.get_win_median(), "mem": int(np.ceil(mem_usage)), } for k, v in self.stats.items(): stats[k] = v.get_win_median() logging.log_json_stats(stats)
def log_iter_stats(self, cur_epoch, cur_iter, preds=[], labels=[]): """ log the stats of the current iteration. Args: cur_epoch (int): the number of current epoch. cur_iter (int): the number of current iteration. """ if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: return None eta_sec = self.iter_timer.seconds() * (self.max_iter - cur_iter - 1) eta = str(datetime.timedelta(seconds=int(eta_sec))) stats = { "_type": "val_iter", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.max_iter), "time_diff": self.iter_timer.seconds(), "eta": eta, "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), } if not self._cfg.DATA.MULTI_LABEL: stats["top1_err"] = self.mb_top1_err.get_win_median() stats["top5_err"] = self.mb_top5_err.get_win_median() if len(preds) > 0 and len(labels) > 0: calc_binary_stats(preds, labels, stats, self._cfg) self.logger.info(stats) return stats
def log_iter_stats(self, cur_epoch, cur_iter): """ log the stats of the current iteration. Args: cur_epoch (int): the number of current epoch. cur_iter (int): the number of current iteration. """ if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: return eta_sec = self.iter_timer.seconds() * (self.max_iter - cur_iter - 1) eta = str(datetime.timedelta(seconds=int(eta_sec))) stats = { "_type": "val_iter{}".format( "_ssl" if self._cfg.TASK == "ssl" else "" ), "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.max_iter), "time_diff": self.iter_timer.seconds(), "eta": eta, "gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), } if not self._cfg.DATA.MULTI_LABEL: stats["top1_err"] = self.mb_top1_err.get_win_median() stats["top5_err"] = self.mb_top5_err.get_win_median() logging.log_json_stats(stats)
def log_iter_stats(self, cur_epoch, cur_iter): """ log the stats of the current iteration. Args: cur_epoch (int): the number of current epoch. cur_iter (int): the number of current iteration. """ if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: return eta_sec = self.iter_timer.seconds() * ( self.MAX_EPOCH - (cur_epoch * self.epoch_iters + cur_iter + 1)) eta = str(datetime.timedelta(seconds=int(eta_sec))) stats = { "_type": "train_iter", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), "dt": self.iter_timer.seconds(), "dt_data": self.data_timer.seconds(), "dt_net": self.net_timer.seconds(), "eta": eta, "loss": self.loss.get_win_median(), "lr": self.lr, "gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), } if not self._cfg.DATA.MULTI_LABEL: stats["top1_err"] = self.mb_top1_err.get_win_median() stats["top5_err"] = self.mb_top5_err.get_win_median() logging.log_json_stats(stats)
def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ eta_sec = self.iter_timer.seconds() * ( self.MAX_EPOCH - (cur_epoch + 1) * self.epoch_iters) eta = str(datetime.timedelta(seconds=int(eta_sec))) stats = { "_type": "train_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "dt": self.iter_timer.seconds(), "dt_data": self.data_timer.seconds(), "dt_net": self.net_timer.seconds(), "eta": eta, "lr": self.lr, "gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), "RAM": "{:.2f}/{:.2f}G".format(*misc.cpu_mem_usage()), } if not self._cfg.DATA.MULTI_LABEL: top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples avg_loss = self.loss_total / self.num_samples stats["top1_err"] = top1_err stats["top5_err"] = top5_err stats["loss"] = avg_loss logging.log_json_stats(stats)
def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples is_best_epoch = top1_err < self.min_top1_err self.min_top1_err = min(self.min_top1_err, top1_err) self.min_top5_err = min(self.min_top5_err, top5_err) mem_usage = misc.gpu_mem_usage() stats = { "_type": "val_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "time_diff": self.iter_timer.seconds(), "top1_err": top1_err, "top5_err": top5_err, "min_top1_err": self.min_top1_err, "min_top5_err": self.min_top5_err, "mem": int(np.ceil(mem_usage)), } logging.log_json_stats(stats) return is_best_epoch
def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ if self.mode in ["val", "test"]: self.finalize_metrics(log=False) stats = { "_type": "{}_epoch".format(self.mode), "cur_epoch": "{}".format(cur_epoch + 1), "mode": self.mode, "map": self.full_map, "gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), "RAM": "{:.2f}/{:.2f}G".format(*misc.cpu_mem_usage()), } logging.log_json_stats(stats)
def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ mem_usage = misc.gpu_mem_usage() stats = { "_type": "val_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "time_diff": self.iter_timer.seconds(), "mem": int(np.ceil(mem_usage)), } for k, v in self.stats.items(): stats[k] = v.get_global_avg() logging.log_json_stats(stats)
def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ verb_top1_acc = self.num_verb_top1_cor / self.num_samples verb_top5_acc = self.num_verb_top5_cor / self.num_samples noun_top1_acc = self.num_noun_top1_cor / self.num_samples noun_top5_acc = self.num_noun_top5_cor / self.num_samples top1_acc = self.num_top1_cor / self.num_samples top5_acc = self.num_top5_cor / self.num_samples self.max_verb_top1_acc = max(self.max_verb_top1_acc, verb_top1_acc) self.max_verb_top5_acc = max(self.max_verb_top5_acc, verb_top5_acc) self.max_noun_top1_acc = max(self.max_noun_top1_acc, noun_top1_acc) self.max_noun_top5_acc = max(self.max_noun_top5_acc, noun_top5_acc) is_best_epoch = top1_acc > self.max_top1_acc self.max_top1_acc = max(self.max_top1_acc, top1_acc) self.max_top5_acc = max(self.max_top5_acc, top5_acc) mem_usage = misc.gpu_mem_usage() stats = { "_type": "val_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "time_diff": self.iter_timer.seconds(), "verb_top1_acc": verb_top1_acc, "verb_top5_acc": verb_top5_acc, "noun_top1_acc": noun_top1_acc, "noun_top5_acc": noun_top5_acc, "top1_acc": top1_acc, "top5_acc": top5_acc, "max_verb_top1_acc": self.max_verb_top1_acc, "max_verb_top5_acc": self.max_verb_top5_acc, "max_noun_top1_acc": self.max_noun_top1_acc, "max_noun_top5_acc": self.max_noun_top5_acc, "max_top1_acc": self.max_top1_acc, "max_top5_acc": self.max_top5_acc, "mem": int(np.ceil(mem_usage)), } log_to_tensorboard(self.tb_writer, stats, False) logging.log_json_stats(stats) return is_best_epoch
def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ verb_top1_acc = self.num_verb_top1_cor / self.num_samples verb_top5_acc = self.num_verb_top5_cor / self.num_samples noun_top1_acc = self.num_noun_top1_cor / self.num_samples noun_top5_acc = self.num_noun_top5_cor / self.num_samples top1_acc = self.num_top1_cor / self.num_samples top5_acc = self.num_top5_cor / self.num_samples self.max_verb_top1_acc = max(self.max_verb_top1_acc, verb_top1_acc) self.max_verb_top5_acc = max(self.max_verb_top5_acc, verb_top5_acc) self.max_noun_top1_acc = max(self.max_noun_top1_acc, noun_top1_acc) self.max_noun_top5_acc = max(self.max_noun_top5_acc, noun_top5_acc) is_best_epoch = top1_acc > self.max_top1_acc self.max_top1_acc = max(self.max_top1_acc, top1_acc) self.max_top5_acc = max(self.max_top5_acc, top5_acc) stats = { "_type": "val_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "time_diff": self.iter_timer.seconds(), "verb_top1_acc": verb_top1_acc, "verb_top5_acc": verb_top5_acc, "noun_top1_acc": noun_top1_acc, "noun_top5_acc": noun_top5_acc, "top1_acc": top1_acc, "top5_acc": top5_acc, "max_verb_top1_acc": self.max_verb_top1_acc, "max_verb_top5_acc": self.max_verb_top5_acc, "max_noun_top1_acc": self.max_noun_top1_acc, "max_noun_top5_acc": self.max_noun_top5_acc, "max_top1_acc": self.max_top1_acc, "max_top5_acc": self.max_top5_acc, "gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), "RAM": "{:.2f}/{:.2f}G".format(*misc.cpu_mem_usage()), } logging.log_json_stats(stats) return is_best_epoch, {"top1_acc": top1_acc, "verb_top1_acc": verb_top1_acc, "noun_top1_acc": noun_top1_acc}
def log_iter_stats(self, cur_epoch, cur_iter): """ log the stats of the current iteration. Args: cur_epoch (int): the number of current epoch. cur_iter (int): the number of current iteration. """ if (cur_iter + 1) % self._cfg.LOGS.PERIOD != 0: return eta_sec = self.iter_timer.seconds() * ( self.MAX_EPOCH - (cur_epoch * self.epoch_iters + cur_iter + 1)) eta = str(datetime.timedelta(seconds=int(eta_sec))) stats = { "_type": "train_iter", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), "time_diff": self.iter_timer.seconds(), "eta": eta, "loss": self.loss.get_win_median(), "lr": self.lr, "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), } if not self._cfg.DATA.MULTI_LABEL: stats["top1_err"] = self.mb_top1_err.get_win_median() stats["top5_err"] = self.mb_top5_err.get_win_median() logging.log_json_stats(stats) if du.is_master_proc(): iters = cur_iter + 1 + self.epoch_iters * cur_epoch for k, v in stats.items(): if 'err' in k or 'loss' in k: self.tblogger.add_scalar('train/{}'.format(k), v, iters) elif k == 'eta': self.tblogger.add_scalar('other/eta', eta_sec, iters) elif k == 'epoch': self.tblogger.add_scalar('other/epoch', cur_epoch + 1, iters) elif k == 'lr': self.tblogger.add_scalar('other/lr', v, iters) else: continue
def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ eta_sec = self.iter_timer.seconds() * ( self.MAX_EPOCH - (cur_epoch + 1) * self.epoch_iters ) eta = str(datetime.timedelta(seconds=int(eta_sec))) verb_top1_acc = self.num_verb_top1_cor / self.num_samples verb_top5_acc = self.num_verb_top5_cor / self.num_samples noun_top1_acc = self.num_noun_top1_cor / self.num_samples noun_top5_acc = self.num_noun_top5_cor / self.num_samples top1_acc = self.num_top1_cor / self.num_samples top5_acc = self.num_top5_cor / self.num_samples avg_loss_verb = self.loss_verb_total / self.num_samples avg_loss_noun = self.loss_noun_total / self.num_samples avg_loss = self.loss_total / self.num_samples stats = { "_type": "train_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "dt": self.iter_timer.seconds(), "dt_data": self.data_timer.seconds(), "dt_net": self.net_timer.seconds(), "eta": eta, "verb_top1_acc": verb_top1_acc, "verb_top5_acc": verb_top5_acc, "noun_top1_acc": noun_top1_acc, "noun_top5_acc": noun_top5_acc, "top1_acc": top1_acc, "top5_acc": top5_acc, "verb_loss": avg_loss_verb, "noun_loss": avg_loss_noun, "loss": avg_loss, "lr": self.lr, "gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), "RAM": "{:.2f}/{:.2f}G".format(*misc.cpu_mem_usage()), } logging.log_json_stats(stats)
def log_epoch_stats(self, cur_epoch, preds=[], labels=[]): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ if self.num_samples <= 0: self.logger.warning("ValMeter log_epoch_stats numSample {}".format( self.num_samples)) return None stats = { "_type": "val_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "time_diff": self.iter_timer.seconds(), "gpu_mem": "{:.2f} GB".format(misc.gpu_mem_usage()), "RAM": "{:.2f}/{:.2f} GB".format(*misc.cpu_mem_usage()), } if self._cfg.DATA.MULTI_LABEL: stats["map"] = get_map( torch.cat(self.all_preds).cpu().numpy(), torch.cat(self.all_labels).cpu().numpy(), ) else: top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples self.min_top1_err = min(self.min_top1_err, top1_err) self.min_top5_err = min(self.min_top5_err, top5_err) stats["top1_err"] = top1_err stats["top5_err"] = top5_err stats["min_top1_err"] = self.min_top1_err stats["min_top5_err"] = self.min_top5_err if len(preds) > 0 and len(labels) > 0: calc_binary_stats(preds, labels, stats, self._cfg) self.logger.info(stats) return stats
def log_epoch_stats(self, cur_epoch): """ Log the stats of the current epoch. Args: cur_epoch (int): the number of current epoch. """ eta_sec = self.iter_timer.seconds() * ( self.MAX_EPOCH - (cur_epoch + 1) * self.epoch_iters) eta = str(datetime.timedelta(seconds=int(eta_sec))) mem_usage = misc.gpu_mem_usage() verb_top1_acc = self.num_verb_top1_cor / self.num_samples verb_top5_acc = self.num_verb_top5_cor / self.num_samples noun_top1_acc = self.num_noun_top1_cor / self.num_samples noun_top5_acc = self.num_noun_top5_cor / self.num_samples top1_acc = self.num_top1_cor / self.num_samples top5_acc = self.num_top5_cor / self.num_samples avg_loss_verb = self.loss_verb_total / self.num_samples avg_loss_noun = self.loss_noun_total / self.num_samples avg_loss = self.loss_total / self.num_samples stats = { "_type": "train_epoch", "epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), "time_diff": self.iter_timer.seconds(), "eta": eta, "verb_top1_acc": verb_top1_acc, "verb_top5_acc": verb_top5_acc, "noun_top1_acc": noun_top1_acc, "noun_top5_acc": noun_top5_acc, "top1_acc": top1_acc, "top5_acc": top5_acc, "verb_loss": avg_loss_verb, "noun_loss": avg_loss_noun, "loss": avg_loss, "lr": self.lr, "mem": int(np.ceil(mem_usage)), } log_to_tensorboard(self.tb_writer, stats, False) logging.log_json_stats(stats)
def test(cfg): """ Perform multi-view testing on the pretrained video model. Args: cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py """ # Set up environment. du.init_distributed_training(cfg) # Set random seed from configs. np.random.seed(cfg.RNG_SEED) torch.manual_seed(cfg.RNG_SEED) # Setup logging format. logging.setup_logging(cfg.OUTPUT_DIR) # Print config. logger.info("Test with config:") logger.info(cfg) # Build the video model and print model statistics. model = build_model(cfg) out_str_prefix = "lin" if cfg.MODEL.DETACH_FINAL_FC else "" if du.is_master_proc() and cfg.LOG_MODEL_INFO: misc.log_model_info(model, cfg, use_train_input=False) if (cfg.TASK == "ssl" and cfg.MODEL.MODEL_NAME == "ContrastiveModel" and cfg.CONTRASTIVE.KNN_ON): train_loader = loader.construct_loader(cfg, "train") out_str_prefix = "knn" if hasattr(model, "module"): model.module.init_knn_labels(train_loader) else: model.init_knn_labels(train_loader) cu.load_test_checkpoint(cfg, model) # Create video testing loaders. test_loader = loader.construct_loader(cfg, "test") logger.info("Testing model for {} iterations".format(len(test_loader))) if cfg.DETECTION.ENABLE: assert cfg.NUM_GPUS == cfg.TEST.BATCH_SIZE or cfg.NUM_GPUS == 0 test_meter = AVAMeter(len(test_loader), cfg, mode="test") else: assert ( test_loader.dataset.num_videos % (cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS) == 0) # Create meters for multi-view testing. test_meter = TestMeter( test_loader.dataset.num_videos // (cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS), cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS, cfg.MODEL.NUM_CLASSES if not cfg.TASK == "ssl" else cfg.CONTRASTIVE.NUM_CLASSES_DOWNSTREAM, len(test_loader), cfg.DATA.MULTI_LABEL, cfg.DATA.ENSEMBLE_METHOD, ) # Set up writer for logging to Tensorboard format. if cfg.TENSORBOARD.ENABLE and du.is_master_proc( cfg.NUM_GPUS * cfg.NUM_SHARDS): writer = tb.TensorboardWriter(cfg) else: writer = None # # Perform multi-view test on the entire dataset. test_meter = perform_test(test_loader, model, test_meter, cfg, writer) if writer is not None: writer.close() result_string = ( "_a{}{}{} Top1 Acc: {} Top5 Acc: {} MEM: {:.2f} dataset: {}{}" "".format( out_str_prefix, cfg.TEST.DATASET[0], test_meter.stats["top1_acc"], test_meter.stats["top1_acc"], test_meter.stats["top5_acc"], misc.gpu_mem_usage(), cfg.TEST.DATASET[0], cfg.MODEL.NUM_CLASSES, )) logger.info("testing done: {}".format(result_string)) return result_string
def train(cfg): """ Train a video model for many epochs on train set and evaluate it on val set. Args: cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py """ # Set up environment. du.init_distributed_training(cfg) # Set random seed from configs. np.random.seed(cfg.RNG_SEED) torch.manual_seed(cfg.RNG_SEED) # Setup logging format. logging.setup_logging(cfg.OUTPUT_DIR) # Init multigrid. multigrid = None if cfg.MULTIGRID.LONG_CYCLE or cfg.MULTIGRID.SHORT_CYCLE: multigrid = MultigridSchedule() cfg = multigrid.init_multigrid(cfg) if cfg.MULTIGRID.LONG_CYCLE: cfg, _ = multigrid.update_long_cycle(cfg, cur_epoch=0) # Print config. logger.info("Train with config:") logger.info(pprint.pformat(cfg)) # Build the video model and print model statistics. model = build_model(cfg) if du.is_master_proc() and cfg.LOG_MODEL_INFO: misc.log_model_info(model, cfg, use_train_input=True) # Construct the optimizer. optimizer = optim.construct_optimizer(model, cfg) # Create a GradScaler for mixed precision training scaler = torch.cuda.amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION) # Load a checkpoint to resume training if applicable. if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint(cfg.OUTPUT_DIR): logger.info("Load from last checkpoint.") last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR, task=cfg.TASK) if last_checkpoint is not None: checkpoint_epoch = cu.load_checkpoint( last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer, scaler if cfg.TRAIN.MIXED_PRECISION else None, ) start_epoch = checkpoint_epoch + 1 elif "ssl_eval" in cfg.TASK: last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR, task="ssl") checkpoint_epoch = cu.load_checkpoint( last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer, scaler if cfg.TRAIN.MIXED_PRECISION else None, epoch_reset=True, clear_name_pattern=cfg.TRAIN.CHECKPOINT_CLEAR_NAME_PATTERN, ) start_epoch = checkpoint_epoch + 1 else: start_epoch = 0 elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "": logger.info("Load from given checkpoint file.") checkpoint_epoch = cu.load_checkpoint( cfg.TRAIN.CHECKPOINT_FILE_PATH, model, cfg.NUM_GPUS > 1, optimizer, scaler if cfg.TRAIN.MIXED_PRECISION else None, inflation=cfg.TRAIN.CHECKPOINT_INFLATE, convert_from_caffe2=cfg.TRAIN.CHECKPOINT_TYPE == "caffe2", epoch_reset=cfg.TRAIN.CHECKPOINT_EPOCH_RESET, clear_name_pattern=cfg.TRAIN.CHECKPOINT_CLEAR_NAME_PATTERN, ) start_epoch = checkpoint_epoch + 1 else: start_epoch = 0 # Create the video train and val loaders. train_loader = loader.construct_loader(cfg, "train") val_loader = loader.construct_loader(cfg, "val") precise_bn_loader = (loader.construct_loader( cfg, "train", is_precise_bn=True) if cfg.BN.USE_PRECISE_STATS else None) # if ( # cfg.TASK == "ssl" # and cfg.MODEL.MODEL_NAME == "ContrastiveModel" # and cfg.CONTRASTIVE.KNN_ON # ): # if hasattr(model, "module"): # model.module.init_knn_labels(train_loader) # else: # model.init_knn_labels(train_loader) # Create meters. if cfg.DETECTION.ENABLE: train_meter = AVAMeter(len(train_loader), cfg, mode="train") val_meter = AVAMeter(len(val_loader), cfg, mode="val") else: train_meter = TrainMeter(1e6, cfg) val_meter = ValMeter(1e6, cfg) # set up writer for logging to Tensorboard format. if cfg.TENSORBOARD.ENABLE and du.is_master_proc( cfg.NUM_GPUS * cfg.NUM_SHARDS): writer = tb.TensorboardWriter(cfg) else: writer = None # Perform the training loop. logger.info("Start epoch: {}".format(start_epoch + 1)) epoch_timer = EpochTimer() for cur_epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH): if cur_epoch > 0 and cfg.DATA.LOADER_CHUNK_SIZE > 0: num_chunks = math.ceil(cfg.DATA.LOADER_CHUNK_OVERALL_SIZE / cfg.DATA.LOADER_CHUNK_SIZE) skip_rows = (cur_epoch) % num_chunks * cfg.DATA.LOADER_CHUNK_SIZE logger.info( f"=================+++ num_chunks {num_chunks} skip_rows {skip_rows}" ) cfg.DATA.SKIP_ROWS = skip_rows logger.info(f"|===========| skip_rows {skip_rows}") train_loader = loader.construct_loader(cfg, "train") loader.shuffle_dataset(train_loader, cur_epoch) if cfg.MULTIGRID.LONG_CYCLE: cfg, changed = multigrid.update_long_cycle(cfg, cur_epoch) if changed: ( model, optimizer, train_loader, val_loader, precise_bn_loader, train_meter, val_meter, ) = build_trainer(cfg) # Load checkpoint. if cu.has_checkpoint(cfg.OUTPUT_DIR): last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR, task=cfg.TASK) assert "{:05d}.pyth".format(cur_epoch) in last_checkpoint else: last_checkpoint = cfg.TRAIN.CHECKPOINT_FILE_PATH logger.info("Load from {}".format(last_checkpoint)) cu.load_checkpoint(last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer) # Shuffle the dataset. loader.shuffle_dataset(train_loader, cur_epoch) if hasattr(train_loader.dataset, "_set_epoch_num"): train_loader.dataset._set_epoch_num(cur_epoch) # Train for one epoch. epoch_timer.epoch_tic() train_epoch( train_loader, model, optimizer, scaler, train_meter, cur_epoch, cfg, writer, ) epoch_timer.epoch_toc() logger.info( f"Epoch {cur_epoch} takes {epoch_timer.last_epoch_time():.2f}s. Epochs " f"from {start_epoch} to {cur_epoch} take " f"{epoch_timer.avg_epoch_time():.2f}s in average and " f"{epoch_timer.median_epoch_time():.2f}s in median.") logger.info( f"For epoch {cur_epoch}, each iteraction takes " f"{epoch_timer.last_epoch_time()/len(train_loader):.2f}s in average. " f"From epoch {start_epoch} to {cur_epoch}, each iteraction takes " f"{epoch_timer.avg_epoch_time()/len(train_loader):.2f}s in average." ) is_checkp_epoch = (cu.is_checkpoint_epoch( cfg, cur_epoch, None if multigrid is None else multigrid.schedule, ) or cur_epoch == cfg.SOLVER.MAX_EPOCH - 1) is_eval_epoch = misc.is_eval_epoch( cfg, cur_epoch, None if multigrid is None else multigrid.schedule) # Compute precise BN stats. if ((is_checkp_epoch or is_eval_epoch) and cfg.BN.USE_PRECISE_STATS and len(get_bn_modules(model)) > 0): calculate_and_update_precise_bn( precise_bn_loader, model, min(cfg.BN.NUM_BATCHES_PRECISE, len(precise_bn_loader)), cfg.NUM_GPUS > 0, ) _ = misc.aggregate_sub_bn_stats(model) # Save a checkpoint. if is_checkp_epoch: cu.save_checkpoint( cfg.OUTPUT_DIR, model, optimizer, cur_epoch, cfg, scaler if cfg.TRAIN.MIXED_PRECISION else None, ) # Evaluate the model on validation set. if is_eval_epoch: eval_epoch( val_loader, model, val_meter, cur_epoch, cfg, train_loader, writer, ) if writer is not None: writer.close() result_string = "Top1 Acc: {:.2f} Top5 Acc: {:.2f} MEM: {:.2f}" "".format( 100 - val_meter.min_top1_err, 100 - val_meter.min_top5_err, misc.gpu_mem_usage(), ) logger.info("training done: {}".format(result_string)) return result_string