def train(cfg, saver): """ Train a new dataset with source data e.g.: train: dukemtmc with market feat """ source_name = cfg.DATASET.NAME target_name = cfg.FEAT.DATASET_NAME source_loader, _ = make_train_data_loader_for_extract(cfg, source_name) tr = TrainComponent(cfg, 702) to_load = {'model': tr.model} saver.checkpoint_params = to_load saver.load_checkpoint(is_best=True) feat, _ = do_extract(cfg, source_loader, tr) feat = feat.cpu() if cfg.MODEL.DEVICE == 'cuda': torch.cuda.empty_cache() logger.info(f"Extracting feat is done. {feat.size()}") train_loader, num_classes = make_data_with_loader_with_feat_label( cfg, source_name, target_name, feat) valid = make_multi_valid_data_loader(cfg, [source_name]) # inference(cfg, train_component.module, valid) do_train_with_feat(cfg, train_loader, valid, tr, saver)
def test(cfg, saver): dataset_name = [cfg.DATASET.NAME] valid = make_multi_valid_data_loader(cfg, dataset_name, verbose=True) tr = TrainComponent(cfg) saver.checkpoint_params['model'] = tr.model saver.load_checkpoint(is_best=True) inference(cfg, tr.model, valid)
def test(cfg, saver): dataset_name = [cfg.DATASET.NAME] valid = make_multi_valid_data_loader(cfg, dataset_name, verbose=True) tr = TrainComponent(cfg, 0) to_load = {'module': tr.model} saver.to_save = to_load saver.load_checkpoint(is_best=True) inference(cfg, tr.model, valid)
def train(cfg, saver): """ Train a new dataset with distillation e.g.: train: dukemtmc with market module """ source_tr = TrainComponent(cfg, 0) saver.checkpoint_params['model'] = source_tr.model saver.load_checkpoint(is_best=True) dataset_name = [cfg.DATASET.NAME, cfg.CONTINUATION.DATASET_NAME] train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1]) current_tr = TrainComponent(cfg, num_classes) saver.checkpoint_params['model'] = current_tr.model saver.load_checkpoint(is_best=True) valid = make_multi_valid_data_loader(cfg, dataset_name) # inference(cfg, current_tr.module, valid) do_continuous_train(cfg, train_loader, valid, source_tr, current_tr, saver)
def train(cfg, saver): dataset_name = [cfg.DATASET.NAME] if cfg.JOINT.IF_ON: for name in cfg.JOINT.DATASET_NAME: dataset_name.append(name) train_loader, num_classes = make_train_data_loader_with_expand( cfg, dataset_name) else: train_loader, num_classes = make_train_data_loader( cfg, dataset_name[0]) valid = make_multi_valid_data_loader(cfg, dataset_name) train_component = TrainComponent(cfg, num_classes) do_train(cfg, train_loader, valid, train_component, saver)
def train(cfg, saver): dataset_name = [cfg.DATASET.NAME] if cfg.JOINT.IF_ON: for name in cfg.JOINT.DATASET_NAME: dataset_name.append(name) train_loader, num_classes = make_train_data_loader_with_expand( cfg, dataset_name) else: train_loader, num_classes = make_train_data_loader( cfg, dataset_name[0]) valid_dict = make_multi_valid_data_loader(cfg, dataset_name) train_component = TrainComponent(cfg, num_classes) saver.checkpoint_params['model'] = train_component.model do_train(cfg, train_loader, valid_dict, train_component, saver)
def train(cfg, saver): """ Train a new dataset with distillation e.g.: train: dukemtmc with market module """ dataset_name = [cfg.DATASET.NAME, cfg.EBLL.DATASET_NAME] source_train_loader, source_num_classes = make_train_data_loader(cfg, dataset_name[0]) source_valid = make_multi_valid_data_loader(cfg, [dataset_name[0]]) source_tr = TrainComponent(cfg) saver.checkpoint_params['model'] = source_tr.model saver.load_checkpoint(is_best=True) autoencoder_tr = TrainComponent(cfg, autoencoder=True) saver.checkpoint_params['autoencoder'] = source_tr.model logger.info("") logger.info('*' * 60) logger.info("Start training autoencoder") logger.info('*' * 60) logger.info("") train_autoencoder(cfg, source_train_loader, source_valid, source_tr, autoencoder_tr, saver) saver.best_result = 0 train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1]) ebll_valid = make_multi_valid_data_loader(cfg, [dataset_name[1]]) current_tr = TrainComponent(cfg, num_classes) saver.checkpoint_params['model'] = current_tr.model saver.load_checkpoint(is_best=True) # print(current_tr) logger.info("") logger.info('*' * 60) logger.info("Start fine tuning current model") logger.info('*' * 60) logger.info("") fine_tune_current_model(cfg, train_loader, ebll_valid, current_tr, saver) saver.best_result = 0 k_s = [1, 5, 10, 15, 20, 30, 50, 100, 150] for k in k_s: logger.info("") logger.info('*' * 60) logger.info(f"Start ebll training using {0.001 * k}") logger.info('*' * 60) logger.info("") copy_cfg = copy.deepcopy(cfg) copy_cfg["CONTINUATION"]["IF_ON"] = True copy_cfg["EBLL"]["AE_LOSS_WEIGHT"] = 0.001 * k ebll_tr = TrainComponent(copy_cfg, num_classes) ebll_tr.model = copy.deepcopy(current_tr.model) saver.checkpoint_params['model'] = ebll_tr.model ebll_valid = make_multi_valid_data_loader(cfg, dataset_name) ebll_train(cfg, train_loader, ebll_valid, source_tr, ebll_tr, autoencoder_tr, saver) if "ae_dist" not in cfg.CONTINUATION.LOSS_TYPE: break