def train(cfg, saver): """ Train a new dataset with distillation e.g.: train: dukemtmc with market module """ source_tr = TrainComponent(cfg, 0) to_load = {'module': source_tr.model} saver.to_save = to_load saver.load_checkpoint(is_best=True) dataset_name = [cfg.DATASET.NAME, cfg.CONTINUATION.DATASET_NAME] train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1]) current_tr = TrainComponent(cfg, num_classes) to_load = {'module': current_tr.model} saver.to_save = to_load saver.load_checkpoint(is_best=True) valid = make_multi_valid_data_loader(cfg, dataset_name) # inference(cfg, current_tr.module, valid) do_continuous_train(cfg, train_loader, valid, source_tr, current_tr, saver) if __name__ == '__main__': cfg, saver = main( ["CONTINUATION.IF_ON", True, "TEST.IF_RE_RANKING", False]) train(cfg, saver)
source_name = cfg.DATASET.NAME target_name = cfg.FEAT.DATASET_NAME source_loader, _ = make_train_data_loader_for_extract(cfg, source_name) tr = TrainComponent(cfg, 702) to_load = {'model': tr.model} saver.checkpoint_params = to_load saver.load_checkpoint(is_best=True) feat, _ = do_extract(cfg, source_loader, tr) feat = feat.cpu() if cfg.MODEL.DEVICE == 'cuda': torch.cuda.empty_cache() logger.info(f"Extracting feat is done. {feat.size()}") train_loader, num_classes = make_data_with_loader_with_feat_label( cfg, source_name, target_name, feat) valid = make_multi_valid_data_loader(cfg, [source_name]) # inference(cfg, train_component.module, valid) do_train_with_feat(cfg, train_loader, valid, tr, saver) if __name__ == '__main__': cfg, saver = main(["FEAT.IF_ON", True]) train(cfg, saver)
sys.path.append('.') sys.path.append('..') from data import make_train_data_loader, make_train_data_loader_with_expand, make_multi_valid_data_loader from tools.component import main, TrainComponent from engine.trainer import do_train def train(cfg, saver): dataset_name = [cfg.DATASET.NAME] if cfg.JOINT.IF_ON: for name in cfg.JOINT.DATASET_NAME: dataset_name.append(name) train_loader, num_classes = make_train_data_loader_with_expand( cfg, dataset_name) else: train_loader, num_classes = make_train_data_loader( cfg, dataset_name[0]) valid = make_multi_valid_data_loader(cfg, dataset_name) train_component = TrainComponent(cfg, num_classes) do_train(cfg, train_loader, valid, train_component, saver) if __name__ == '__main__': cfg, saver = main() train(cfg, saver)
import sys sys.path.append('.') sys.path.append('..') from data import make_multi_valid_data_loader from engine.inference import inference from tools.component import main, TrainComponent def test(cfg, saver): dataset_name = [cfg.DATASET.NAME] valid = make_multi_valid_data_loader(cfg, dataset_name, verbose=True) tr = TrainComponent(cfg, 0) to_load = {'module': tr.model} saver.to_save = to_load saver.load_checkpoint(is_best=True) inference(cfg, tr.model, valid) if __name__ == '__main__': cfg, saver = main(["TEST.IF_ON", True]) test(cfg, saver)
for k in k_s: logger.info("") logger.info('*' * 60) logger.info(f"Start ebll training using {0.001 * k}") logger.info('*' * 60) logger.info("") copy_cfg = copy.deepcopy(cfg) copy_cfg["CONTINUATION"]["IF_ON"] = True copy_cfg["EBLL"]["AE_LOSS_WEIGHT"] = 0.001 * k ebll_tr = TrainComponent(copy_cfg, num_classes) ebll_tr.model = copy.deepcopy(current_tr.model) saver.checkpoint_params['model'] = ebll_tr.model ebll_valid = make_multi_valid_data_loader(cfg, dataset_name) ebll_train(cfg, train_loader, ebll_valid, source_tr, ebll_tr, autoencoder_tr, saver) if "ae_dist" not in cfg.CONTINUATION.LOSS_TYPE: break if __name__ == '__main__': cfg, saver = main(["EBLL.IF_ON", True, "TEST.IF_RE_RANKING", False]) train(cfg, saver)