def fast_test_p_s(config, base, loaders, current_step, if_test_forget=True): # using Cython test during train # return mAP, Rank-1 base.set_all_model_eval() print(f'****** start perform fast testing! ******') # meters # compute query and gallery features def _cmc_map(_query_features_meter, _gallery_features_meter): query_features = _query_features_meter.get_val() gallery_features = _gallery_features_meter.get_val() distance_matrix = compute_distance_matrix(query_features, gallery_features, config.test_metric) distance_matrix = distance_matrix.data.cpu().numpy() CMC, mAP = fast_evaluate_rank(distance_matrix, query_pids_meter.get_val_numpy(), gallery_pids_meter.get_val_numpy(), query_cids_meter.get_val_numpy(), gallery_cids_meter.get_val_numpy(), max_rank=50, use_metric_cuhk03=False, use_cython=True) return CMC[0] * 100, mAP * 100 results_dict = {} for dataset_name, temp_loaders in loaders.test_loader_dict.items(): query_features_meter, query_pids_meter, query_cids_meter = CatMeter( ), CatMeter(), CatMeter() gallery_features_meter, gallery_pids_meter, gallery_cids_meter = CatMeter( ), CatMeter(), CatMeter() query_metagraph_features_meter, query_metagraph_pids_meter, query_metagraph_cids_meter = CatMeter( ), CatMeter(), CatMeter() gallery_metagraph_features_meter, gallery_metagraph_pids_meter, gallery_metagraph_cids_meter = CatMeter( ), CatMeter(), CatMeter() query_fuse_features_meter, query_fuse_pids_meter, query_fuse_cids_meter = CatMeter( ), CatMeter(), CatMeter() gallery_fuse_features_meter, gallery_fuse_pids_meter, gallery_fuse_cids_meter = CatMeter( ), CatMeter(), CatMeter() print(time_now(), f' {dataset_name} feature start ') with torch.no_grad(): for loader_id, loader in enumerate(temp_loaders): for data in loader: # compute feautres images, pids, cids = data[0:3] images = images.to(base.device) features, featuremaps = base.model_dict['tasknet']( images, current_step) if config.if_test_metagraph: features_metagraph, _ = base.model_dict['metagraph']( features) features_fuse = features + features_metagraph # save as query features if loader_id == 0: query_features_meter.update(features.data) if config.if_test_metagraph: query_fuse_features_meter.update( features_fuse.data) query_metagraph_features_meter.update( features_metagraph.data) query_pids_meter.update(pids) query_cids_meter.update(cids) # save as gallery features elif loader_id == 1: gallery_features_meter.update(features.data) if config.if_test_metagraph: gallery_metagraph_features_meter.update( features_metagraph.data) gallery_fuse_features_meter.update( features_fuse.data) gallery_pids_meter.update(pids) gallery_cids_meter.update(cids) # print(f'Save distance matrix to RegDB_three_stream_dist({current_step}).npy') # # # # np.save(os.path.join(config.feature_save_path, f'query_features_({dataset_name})_({current_step}).pth'), # query_features_meter.get_val_numpy()) # np.save(os.path.join(config.feature_save_path, f'query_pids_({dataset_name})_({current_step}).pth'), # query_pids_meter.get_val_numpy()) # np.save(os.path.join(config.feature_save_path, f'query_cids_({dataset_name})_({current_step}).pth'), # query_cids_meter.get_val_numpy()) # np.save(os.path.join(config.feature_save_path, f'gallery_features_({dataset_name})_({current_step}).pth'), # gallery_features_meter.get_val_numpy()) # np.save(os.path.join(config.feature_save_path, f'gallery_pids_({dataset_name})_({current_step}).pth'), # gallery_pids_meter.get_val_numpy()) # np.save(os.path.join(config.feature_save_path, f'gallery_cids_({dataset_name})_({current_step}).pth'), # gallery_cids_meter.get_val_numpy()) # np.save(os.path.join(config.feature_save_path, f'query_fuse_features_({dataset_name})_({current_step}).pth'), # query_fuse_features_meter.get_val_numpy()) # np.save(os.path.join(config.feature_save_path, f'gallery_fuse_features_({dataset_name})_({current_step}).pth'), # gallery_fuse_features_meter.get_val_numpy()) print(time_now(), f' {dataset_name} feature done') rank1, map = _cmc_map(query_features_meter, gallery_features_meter) results_dict[f'{dataset_name}_tasknet_mAP'], results_dict[ f'{dataset_name}_tasknet_Rank1'] = map, rank1 if config.if_test_metagraph: # rank1, map = _cmc_map(query_metagraph_features_meter, gallery_metagraph_features_meter) # results_dict['metagraph_mAP'], results_dict['metagraph_Rank1'] = map, rank1 rank1, map = _cmc_map(query_fuse_features_meter, gallery_fuse_features_meter) results_dict[f'{dataset_name}_fuse_mAP'], results_dict[ f'{dataset_name}_fuse_Rank1'] = map, rank1 results_str = '' for criterion, value in results_dict.items(): results_str = results_str + f'\n{criterion}: {value}' return results_dict, results_str
def test_continual_neck(config, base, loaders, current_step): base.set_all_model_eval() print(f'****** start perform full testing! ******') # meters query_features_meter, query_pids_meter, query_cids_meter = CatMeter( ), CatMeter(), CatMeter() gallery_features_meter, gallery_pids_meter, gallery_cids_meter = CatMeter( ), CatMeter(), CatMeter() # init dataset if config.test_dataset == 'market': loaders = [loaders.market_query_loader, loaders.market_gallery_loader] elif config.test_dataset == 'duke': loaders = [loaders.duke_query_loader, loaders.duke_gallery_loader] elif config.test_dataset == 'mix': loaders = [loaders.mix_query_loader, loaders.mix_gallery_loader] else: assert 0, 'test dataset error, expect mix/market/duke/, given {}'.format( config.test_dataset) print(time_now(), 'feature start') # compute query and gallery features with torch.no_grad(): for loader_id, loader in enumerate(loaders): for data in loader: # compute feautres images, pids, cids, _ = data images = images.to(base.device) features, _ = base.model_dict['tasknet'](images, current_step) # save as query features if loader_id == 0: query_features_meter.update(features.data) query_pids_meter.update(pids) query_cids_meter.update(cids) # save as gallery features elif loader_id == 1: gallery_features_meter.update(features.data) gallery_pids_meter.update(pids) gallery_cids_meter.update(cids) print(time_now(), 'feature done') # query_features = query_features_meter.get_val_numpy() gallery_features = gallery_features_meter.get_val_numpy() # compute mAP and rank@k mAP, CMC = ReIDEvaluator(dist=config.test_metric, mode=config.test_mode).evaluate( query_features, query_cids_meter.get_val_numpy(), query_pids_meter.get_val_numpy(), gallery_features, gallery_cids_meter.get_val_numpy(), gallery_pids_meter.get_val_numpy()) # compute precision-recall curve thresholds = np.linspace(1.0, 0.0, num=101) pres, recalls, thresholds = PrecisionRecall( dist=config.test_metric, mode=config.test_mode).evaluate(thresholds, query_features, query_cids_meter.get_val_numpy(), query_pids_meter.get_val_numpy(), gallery_features, gallery_cids_meter.get_val_numpy(), gallery_pids_meter.get_val_numpy()) return mAP, CMC[0:150], pres, recalls, thresholds
def main(config): # init loaders and base loaders = IncrementalReIDLoaders(config) base = Base_metagraph_p_s(config, loaders) # init logger logger = Logger(os.path.join(base.output_dirs_dict['logs'], 'log.txt')) logger(config) if config.visualize_train_by_visdom: port = 8097 visdom_dict = { 'feature_maps_fake': VisdomFeatureMapsLogger('image', pad_value=1, nrow=8, port=port, env=config.running_time, opts={'title': f'featuremaps fake'}), 'feature_maps_true': VisdomFeatureMapsLogger('image', pad_value=1, nrow=8, port=port, env=config.running_time, opts={'title': f'featuremaps true'}), 'feature_maps': VisdomFeatureMapsLogger('image', pad_value=1, nrow=8, port=port, env=config.running_time, opts={'title': f'featuremaps'}) } assert config.mode in ['train', 'test', 'visualize'] if config.mode == 'train': # train mode # automatically resume model from the latest one if config.auto_resume_training_from_lastest_steps: start_train_step, start_train_epoch = base.resume_last_model() # continual loop for current_step in range(start_train_step, loaders.total_step): # for current_step in range(2, loaders.total_step): current_total_train_epochs = config.total_continual_train_epochs if current_step > 0 else config.total_train_epochs if current_step > 0: logger(f'save_and_frozen old model in {current_step}') old_model = base.copy_model_and_frozen(model_name='tasknet') old_graph_model = base.copy_model_and_frozen( model_name='metagraph') else: old_model = None old_graph_model = None for current_epoch in range(start_train_epoch, current_total_train_epochs): visdom_result_dict = {} # save model base.save_model(current_step, current_epoch) # train str_lr, dict_lr = base.get_current_learning_rate() logger(str_lr) if current_epoch < config.epoch_start_joint: results = train_p_s_an_epoch( config, base, loaders, current_step, old_model, old_graph_model, current_epoch, output_featuremaps=config.output_featuremaps) if config.output_featuremaps: results_dict, results_str, heatmaps = results if config.output_featuremaps_from_fixed: heatmaps_true, heatmaps_fake = output_featuremaps_from_fixed( base, current_epoch) visdom_dict['feature_maps_fake'].images(heatmaps_fake) visdom_dict['feature_maps_true'].images(heatmaps_true) else: visdom_dict['feature_maps'].images(heatmaps) else: results_dict, results_str = results logger('Time: {}; Step: {}; Epoch: {}; {}'.format( time_now(), current_step, current_epoch, results_str)) if config.test_frequency > 0 and current_epoch % config.test_frequency == 0: rank_map_dict, rank_map_str = fast_test_p_s( config, base, loaders, current_step, if_test_forget=config.if_test_forget) logger( f'Time: {time_now()}; Test Dataset: {config.test_dataset}: {rank_map_str}' ) visdom_result_dict.update(rank_map_dict) if current_epoch == config.total_train_epochs - 1: # test # base.save_model(current_step, config.total_train_epochs) # mAP, CMC, pres, recalls, thresholds = test_continual_neck(config, base, loaders, current_step) rank_map_dict, rank_map_str = fast_test_p_s( config, base, loaders, current_step, if_test_forget=config.if_test_forget) logger( f'Time: {time_now()}; Step: {current_step}; Epoch: {current_epoch} Test Dataset: {config.test_dataset}, {rank_map_str}' ) # plot_prerecall_curve(config, pres, recalls, thresholds, mAP, CMC, 'none', current_step) print(f'Current step {current_step} is finished.') start_train_epoch = 0 visdom_result_dict.update(rank_map_dict) if config.visualize_train_by_visdom: visdom_result_dict.update(results_dict) visdom_result_dict.update(dict_lr) if current_step > 0: global_current_epoch = current_epoch + ( current_step - 1 ) * current_total_train_epochs + config.total_train_epochs else: global_current_epoch = current_epoch for name, value in visdom_result_dict.items(): if name in visdom_dict.keys(): visdom_dict[name].log(global_current_epoch, value, name=str(current_step)) else: visdom_dict[name] = VisdomPlotLogger( 'line', port=port, env=config.running_time, opts={'title': f'train {name}'}) visdom_dict[name].log(global_current_epoch, value, name=str(current_step)) if current_step > 0: del old_model elif config.mode == 'test': # test mode base.resume_from_model(config.resume_test_model) mAP, CMC, pres, recalls, thresholds = test_continual_neck( config, base, loaders, 0) logger('Time: {}; Test Dataset: {}, \nmAP: {} \nRank: {}'.format( time_now(), config.test_dataset, mAP, CMC)) logger( 'Time: {}; Test Dataset: {}, \nprecision: {} \nrecall: {}\nthresholds: {}' .format(time_now(), config.test_dataset, mAP, CMC, pres, recalls, thresholds)) plot_prerecall_curve(config, pres, recalls, thresholds, mAP, CMC, 'none') elif config.mode == 'visualize': # visualization mode base.resume_from_model(config.resume_visualize_model) visualize(config, base, loaders)
def fast_test_continual_neck(config, base, loaders, current_step, if_test_forget=True): # using Cython test during train # return mAP, Rank-1 base.set_all_model_eval() print(f'****** start perform fast testing! ******') # meters query_features_meter, query_pids_meter, query_cids_meter = CatMeter( ), CatMeter(), CatMeter() gallery_features_meter, gallery_pids_meter, gallery_cids_meter = CatMeter( ), CatMeter(), CatMeter() query_metagraph_features_meter, query_metagraph_pids_meter, query_metagraph_cids_meter = CatMeter( ), CatMeter(), CatMeter() gallery_metagraph_features_meter, gallery_metagraph_pids_meter, gallery_metagraph_cids_meter = CatMeter( ), CatMeter(), CatMeter() # init dataset if config.test_dataset == 'market': loaders = [loaders.market_query_loader, loaders.market_gallery_loader] elif config.test_dataset == 'duke': loaders = [loaders.duke_query_loader, loaders.duke_gallery_loader] elif config.test_dataset == 'mix': if if_test_forget: loaders_validation = [ loaders.mix_validation_query_loader, loaders.mix_validation_gallery_loader ] loaders = [loaders.mix_query_loader, loaders.mix_gallery_loader] else: assert 0, 'test dataset error, expect mix/market/duke/, given {}'.format( config.test_dataset) print(time_now(), ' feature start ') # compute query and gallery features def _cmc_map(_query_features_meter, _gallery_features_meter): query_features = _query_features_meter.get_val() gallery_features = _gallery_features_meter.get_val() distance_matrix = compute_distance_matrix(query_features, gallery_features, config.test_metric) distance_matrix = distance_matrix.data.cpu().numpy() CMC, mAP = fast_evaluate_rank(distance_matrix, query_pids_meter.get_val_numpy(), gallery_pids_meter.get_val_numpy(), query_cids_meter.get_val_numpy(), gallery_cids_meter.get_val_numpy(), max_rank=50, use_metric_cuhk03=False, use_cython=True) return CMC[0] * 100, mAP * 100 with torch.no_grad(): for loader_id, loader in enumerate(loaders): for data in loader: # compute feautres images, pids, cids, _ = data images = images.to(base.device) features, featuremaps = base.model_dict['tasknet']( images, current_step) if config.if_test_metagraph: features_metagraph, _ = base.model_dict['metagraph']( featuremaps=featuremaps, label=None, current_step=current_step) # save as query features if loader_id == 0: query_features_meter.update(features.data) if config.if_test_metagraph: query_metagraph_features_meter.update( features_metagraph.data) query_pids_meter.update(pids) query_cids_meter.update(cids) # save as gallery features elif loader_id == 1: gallery_features_meter.update(features.data) if config.if_test_metagraph: gallery_metagraph_features_meter.update( features_metagraph.data) gallery_pids_meter.update(pids) gallery_cids_meter.update(cids) print(time_now(), 'feature done') results_dict = {} rank1, map = _cmc_map(query_features_meter, gallery_features_meter) results_dict['tasknet_mAP'], results_dict['tasknet_Rank1'] = map, rank1 if config.if_test_metagraph: rank1, map = _cmc_map(query_metagraph_features_meter, gallery_metagraph_features_meter) results_dict['metagraph_mAP'], results_dict[ 'metagraph_Rank1'] = map, rank1 if if_test_forget: # meters query_features_meter, query_pids_meter, query_cids_meter = CatMeter( ), CatMeter(), CatMeter() gallery_features_meter, gallery_pids_meter, gallery_cids_meter = CatMeter( ), CatMeter(), CatMeter() query_metagraph_features_meter, query_metagraph_pids_meter, query_metagraph_cids_meter = CatMeter( ), CatMeter(), CatMeter() gallery_metagraph_features_meter, gallery_metagraph_pids_meter, gallery_metagraph_cids_meter = CatMeter( ), CatMeter(), CatMeter() print(time_now(), 'validation feature start') with torch.no_grad(): for loader_id, loader in enumerate(loaders_validation): for data in loader: # compute feautres images, cids = data[0], data[2] if config.use_local_label4validation: pids = data[3] else: pids = data[1] images = images.to(base.device) features, featuremaps = base.model_dict['tasknet']( images, current_step) if config.if_test_metagraph: features_metagraph, _ = base.model_dict['metagraph']( featuremaps=featuremaps, label=None, current_step=current_step) # save as query features if loader_id == 0: query_features_meter.update(features.data) if config.if_test_metagraph: query_metagraph_features_meter.update( features_metagraph.data) query_pids_meter.update(pids) query_cids_meter.update(cids) # save as gallery features elif loader_id == 1: gallery_features_meter.update(features.data) if config.if_test_metagraph: gallery_metagraph_features_meter.update( features_metagraph.data) gallery_pids_meter.update(pids) gallery_cids_meter.update(cids) print(time_now(), 'validation feature done') rank1, map = _cmc_map(query_features_meter, gallery_features_meter) results_dict['tasknet_validation_mAP'], results_dict[ 'tasknet_validation_Rank1'] = map, rank1 if config.if_test_metagraph: rank1, map = _cmc_map(query_metagraph_features_meter, gallery_metagraph_features_meter) results_dict['metagraph_validation_mAP'], results_dict[ 'metagraph_validation_Rank1'] = map, rank1 results_str = '' for criterion, value in results_dict.items(): results_str = results_str + f'\n{criterion}: {value}' return results_dict, results_str