def test_net(lfb=None): """ Test a model. For AVA, we can either test on a center crop or multiple crops. For EPIC-Kitchens, we test on a center crop for simplicity. For Charades, we follow prior work (e.g. non-local net) and perform 3-spatial-shifts * 10-clip testing. """ if cfg.DATASET in ['ava', 'avabox']: for threshold in cfg.AVA.DETECTION_SCORE_THRESH_EVAL: cfg.AVA.DETECTION_SCORE_THRESH = threshold if cfg.AVA.TEST_MULTI_CROP: cfg.LFB.WRITE_LFB = False cfg.LFB.LOAD_LFB = False for flip in [False, True]: cfg.AVA.FORCE_TEST_FLIP = flip for scale in cfg.AVA.TEST_MULTI_CROP_SCALES: cfg.TEST.SCALE = scale cfg.TEST.CROP_SIZE = min(256, scale) lfb = None for shift in range(3): out_name = 'detections_%s.csv' % \ get_test_name(shift) if os.path.isfile(out_name): logger.info("%s already exists." % out_name) continue if cfg.LFB.ENABLED and lfb is None: lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=False) test_one_crop(lfb=lfb, suffix='_final_test', shift=shift) metrics.combine_ava_multi_crops() else: test_one_crop(lfb=lfb, suffix='_final_test') else: if cfg.DATASET == 'charades': cfg.CHARADES.NUM_TEST_CLIPS = cfg.CHARADES.NUM_TEST_CLIPS_FINAL_EVAL test_one_crop(lfb=lfb, suffix='_final_test')
def train(opts): """Train a model.""" workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) logging.getLogger(__name__) # Generate seed. misc.generate_random_seed(opts) # Create checkpoint dir. checkpoint_dir = checkpoints.create_and_get_checkpoint_directory() logger.info('Checkpoint directory created: {}'.format(checkpoint_dir)) # Create tensorborad logger tb_writer = SummaryWriter(os.path.join(cfg.CHECKPOINT.DIR, 'tb')) # Setting training-time-specific configurations. cfg.AVA.FULL_EVAL = cfg.AVA.FULL_EVAL_DURING_TRAINING cfg.AVA.DETECTION_SCORE_THRESH = cfg.AVA.DETECTION_SCORE_THRESH_TRAIN cfg.CHARADES.NUM_TEST_CLIPS = cfg.CHARADES.NUM_TEST_CLIPS_DURING_TRAINING test_lfb, train_lfb = None, None if cfg.LFB.ENABLED: test_lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=False) train_lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=True) # Build test_model. # We build test_model first, so that we don't overwrite init. test_model, test_timer, test_meter = create_wrapper( is_train=False, lfb=test_lfb, ) total_test_iters = misc.get_total_test_iters(test_model) logger.info('Test iters: {}'.format(total_test_iters)) # Build train_model. train_model, train_timer, train_meter = create_wrapper( is_train=True, lfb=train_lfb, ) # Bould BN auxilary model. if cfg.TRAIN.COMPUTE_PRECISE_BN: bn_aux = bn_helper.BatchNormHelper() bn_aux.create_bn_aux_model(node_id=opts.node_id) # Load checkpoint or pre-trained weight. # See checkpoints.load_model_from_params_file for more details. start_model_iter = 0 if cfg.CHECKPOINT.RESUME or cfg.TRAIN.PARAMS_FILE: start_model_iter = checkpoints.load_model_from_params_file(train_model) logger.info("------------- Training model... -------------") train_meter.reset() last_checkpoint = checkpoints.get_checkpoint_resume_file() for curr_iter in range(start_model_iter, cfg.SOLVER.MAX_ITER): train_model.UpdateWorkspaceLr(curr_iter) train_timer.tic() # SGD step. workspace.RunNet(train_model.net.Proto().name) train_timer.toc() if curr_iter == start_model_iter: misc.print_net(train_model) os.system('nvidia-smi') misc.show_flops_params(train_model) misc.check_nan_losses() # Checkpoint. if (curr_iter + 1) % cfg.CHECKPOINT.CHECKPOINT_PERIOD == 0 \ or curr_iter + 1 == cfg.SOLVER.MAX_ITER: if cfg.TRAIN.COMPUTE_PRECISE_BN: bn_aux.compute_and_update_bn_stats(curr_iter) last_checkpoint = os.path.join( checkpoint_dir, 'c2_model_iter{}.pkl'.format(curr_iter + 1)) checkpoints.save_model_params(model=train_model, params_file=last_checkpoint, model_iter=curr_iter) train_meter.calculate_and_log_all_metrics_train(curr_iter, train_timer, suffix='_train', tb_writer=tb_writer) # Evaluation. if (curr_iter + 1) % cfg.TRAIN.EVAL_PERIOD == 0: if cfg.TRAIN.COMPUTE_PRECISE_BN: bn_aux.compute_and_update_bn_stats(curr_iter) test_meter.reset() logger.info("=> Testing model") for test_iter in range(0, total_test_iters): test_timer.tic() workspace.RunNet(test_model.net.Proto().name) test_timer.toc() test_meter.calculate_and_log_all_metrics_test(test_iter, test_timer, total_test_iters, suffix='_test') test_meter.finalize_metrics(name='iter%d' % (curr_iter + 1)) test_meter.compute_and_log_best() test_meter.log_final_metrics(curr_iter) tb_writer.add_scalar('Test/mini_MAP', test_meter.full_map, curr_iter + 1) # Finalize and reset train_meter after test. train_meter.finalize_metrics(is_train=True) json_stats = metrics.get_json_stats_dict(train_meter, test_meter, curr_iter) misc.log_json_stats(json_stats) train_meter.reset() train_model.shutdown_data_loader() test_model.shutdown_data_loader() if cfg.TRAIN.TEST_AFTER_TRAIN: cfg.TEST.PARAMS_FILE = last_checkpoint test_net(test_lfb)
def test_one_crop(lfb=None, suffix='', shift=None): """Test one crop.""" workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) np.random.seed(cfg.RNG_SEED) cfg.AVA.FULL_EVAL = True if lfb is None and cfg.LFB.ENABLED: print_cfg() lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=False) print_cfg() workspace.ResetWorkspace() logger.info("Done ResetWorkspace...") timer = Timer() logger.warning('Testing started...') # for monitoring cluster jobs if shift is None: shift = cfg.TEST.CROP_SHIFT test_model = model_builder_video.ModelBuilder(train=False, use_cudnn=True, cudnn_exhaustive_search=True, split=cfg.TEST.DATA_TYPE) test_model.build_model(lfb=lfb, suffix=suffix, shift=shift) if cfg.PROF_DAG: test_model.net.Proto().type = 'prof_dag' else: test_model.net.Proto().type = 'dag' workspace.RunNetOnce(test_model.param_init_net) workspace.CreateNet(test_model.net) misc.save_net_proto(test_model.net) misc.save_net_proto(test_model.param_init_net) total_test_net_iters = misc.get_total_test_iters(test_model) test_model.start_data_loader() test_meter = metrics.MetricsCalculator( model=test_model, split=cfg.TEST.DATA_TYPE, video_idx_to_name=test_model.input_db._video_idx_to_name, total_num_boxes=(test_model.input_db._num_boxes_used if cfg.DATASET in ['ava', 'avabox'] else None)) if cfg.TEST.PARAMS_FILE: checkpoints.load_model_from_params_file_for_test( test_model, cfg.TEST.PARAMS_FILE) else: raise Exception('No params files specified for testing model.') begin_time = time.time() for test_iter in range(total_test_net_iters): timer.tic() workspace.RunNet(test_model.net.Proto().name) timer.toc() if test_iter == 0: misc.print_net(test_model) os.system('nvidia-smi') misc.show_flops_params(test_model) test_meter.calculate_and_log_all_metrics_test(test_iter, timer, total_test_net_iters, suffix) logger.info('TTTTTTTIME: {}'.format(time.time() - begin_time)) test_meter.finalize_metrics(name=get_test_name(shift)) test_meter.log_final_metrics(test_iter, total_test_net_iters) test_model.shutdown_data_loader()