def train(cfg, saver):
    """
        Train a new dataset with source data
        e.g.: train: dukemtmc with market feat
    """

    source_name = cfg.DATASET.NAME
    target_name = cfg.FEAT.DATASET_NAME

    source_loader, _ = make_train_data_loader_for_extract(cfg, source_name)

    tr = TrainComponent(cfg, 702)
    to_load = {'model': tr.model}
    saver.checkpoint_params = to_load
    saver.load_checkpoint(is_best=True)

    feat, _ = do_extract(cfg, source_loader, tr)
    feat = feat.cpu()
    if cfg.MODEL.DEVICE == 'cuda':
        torch.cuda.empty_cache()
    logger.info(f"Extracting feat is done. {feat.size()}")

    train_loader, num_classes = make_data_with_loader_with_feat_label(
        cfg, source_name, target_name, feat)

    valid = make_multi_valid_data_loader(cfg, [source_name])

    # inference(cfg, train_component.module, valid)

    do_train_with_feat(cfg, train_loader, valid, tr, saver)
Exemple #2
0
def test(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    valid = make_multi_valid_data_loader(cfg, dataset_name, verbose=True)

    tr = TrainComponent(cfg)
    saver.checkpoint_params['model'] = tr.model
    saver.load_checkpoint(is_best=True)
    inference(cfg, tr.model, valid)
Exemple #3
0
def test(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    valid = make_multi_valid_data_loader(cfg, dataset_name, verbose=True)

    tr = TrainComponent(cfg, 0)
    to_load = {'module': tr.model}
    saver.to_save = to_load
    saver.load_checkpoint(is_best=True)
    inference(cfg, tr.model, valid)
Exemple #4
0
def train(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    if cfg.JOINT.IF_ON:
        for name in cfg.JOINT.DATASET_NAME:
            dataset_name.append(name)
        train_loader, num_classes = make_train_data_loader_with_expand(
            cfg, dataset_name)
    else:
        train_loader, num_classes = make_train_data_loader(
            cfg, dataset_name[0])

    valid = make_multi_valid_data_loader(cfg, dataset_name)

    train_component = TrainComponent(cfg, num_classes)

    do_train(cfg, train_loader, valid, train_component, saver)
Exemple #5
0
def train(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    if cfg.JOINT.IF_ON:
        for name in cfg.JOINT.DATASET_NAME:
            dataset_name.append(name)
        train_loader, num_classes = make_train_data_loader_with_expand(
            cfg, dataset_name)
    else:
        train_loader, num_classes = make_train_data_loader(
            cfg, dataset_name[0])

    valid_dict = make_multi_valid_data_loader(cfg, dataset_name)

    train_component = TrainComponent(cfg, num_classes)

    saver.checkpoint_params['model'] = train_component.model

    do_train(cfg, train_loader, valid_dict, train_component, saver)
def train(cfg, saver):
    """
        Train a new dataset with distillation
        e.g.: train: dukemtmc with market module
    """
    source_tr = TrainComponent(cfg, 0)
    saver.checkpoint_params['model'] = source_tr.model
    saver.load_checkpoint(is_best=True)

    dataset_name = [cfg.DATASET.NAME, cfg.CONTINUATION.DATASET_NAME]

    train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1])

    current_tr = TrainComponent(cfg, num_classes)
    saver.checkpoint_params['model'] = current_tr.model
    saver.load_checkpoint(is_best=True)

    valid = make_multi_valid_data_loader(cfg, dataset_name)

    # inference(cfg, current_tr.module, valid)

    do_continuous_train(cfg, train_loader, valid, source_tr, current_tr, saver)
Exemple #7
0
    logger.info("mAP: {:.2%}".format(mAP))
    for r in [1, 5, 10]:
        logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(r, cmc[r - 1]))
    sum_result = (mAP + cmc[0]) / 2
    logger.info(f'sum_result: {sum_result:.2%}')
    cost_time = time.time() - time_start
    logger.info(f"cost_time: {cost_time:.4f} s")
    logger.info('-' * 60)


def eval_multi_dataset(cfg, valid_dict, tr_comp: TrainComponent):
    for name, (dataloader, n_q) in valid_dict.items():
        evaler = create_supervised_evaluator(tr_comp)
        eval_one_dataset(cfg, name, dataloader, n_q, evaler)


if __name__ == '__main__':
    cfg, saver = main(["TEST.IF_ON", True])
    valid_dict = make_multi_valid_data_loader(cfg,
                                              cfg.TEST.DATASET_NAMES,
                                              verbose=True)
    checkpoint = torch.load(
        '../run/direct/market1501/resnet50/experiment-34/model/train_checkpoint_37200.pt'
    )
    checkpoint['model'].pop('classifier.weight')
    tr_comp = TrainComponent(cfg, 0)
    to_load = tr_comp.state_dict()
    saver.load_objects(to_load={'model': to_load['model']},
                       checkpoint={'model': checkpoint['model']})
    eval_multi_dataset(cfg, valid_dict, tr_comp)
Exemple #8
0
def train(cfg, saver):
    """
        Train a new dataset with distillation
        e.g.: train: dukemtmc with market module
    """
    dataset_name = [cfg.DATASET.NAME, cfg.EBLL.DATASET_NAME]
    source_train_loader, source_num_classes = make_train_data_loader(
        cfg, dataset_name[0])
    source_valid = make_multi_valid_data_loader(cfg, [dataset_name[0]])

    source_tr = TrainComponent(cfg)
    saver.checkpoint_params['model'] = source_tr.model
    saver.load_checkpoint(is_best=True)

    autoencoder_tr = TrainComponent(cfg, autoencoder=True)
    saver.checkpoint_params['autoencoder'] = source_tr.model

    logger.info("")
    logger.info('*' * 60)
    logger.info("Start training autoencoder")
    logger.info('*' * 60)
    logger.info("")

    train_autoencoder(cfg, source_train_loader, source_valid, source_tr,
                      autoencoder_tr, saver)

    saver.best_result = 0

    train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1])
    ebll_valid = make_multi_valid_data_loader(cfg, [dataset_name[1]])

    current_tr = TrainComponent(cfg, num_classes)
    saver.checkpoint_params['model'] = current_tr.model
    saver.load_checkpoint(is_best=True)
    # print(current_tr)

    logger.info("")
    logger.info('*' * 60)
    logger.info("Start fine tuning current model")
    logger.info('*' * 60)
    logger.info("")

    fine_tune_current_model(cfg, train_loader, ebll_valid, current_tr, saver)

    saver.best_result = 0
    k_s = [1, 5, 10, 15, 20, 30, 50, 100, 150]
    for k in k_s:
        logger.info("")
        logger.info('*' * 60)
        logger.info(f"Start ebll training using {0.001 * k}")
        logger.info('*' * 60)
        logger.info("")

        copy_cfg = copy.deepcopy(cfg)
        copy_cfg["CONTINUATION"]["IF_ON"] = True
        copy_cfg["EBLL"]["AE_LOSS_WEIGHT"] = 0.001 * k
        ebll_tr = TrainComponent(copy_cfg, num_classes)
        ebll_tr.model = copy.deepcopy(current_tr.model)
        saver.checkpoint_params['model'] = ebll_tr.model
        ebll_valid = make_multi_valid_data_loader(cfg, dataset_name)

        ebll_train(cfg, train_loader, ebll_valid, source_tr, ebll_tr,
                   autoencoder_tr, saver)

        if "ae_dist" not in cfg.CONTINUATION.LOSS_TYPE:
            break
Exemple #9
0
    @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD))
    def log_validation_results(engine):
        logger.info(f"Valid - Epoch: {engine.state.epoch}")
        eval_multi_dataset(cfg, valid_dict, tr_comp)

    trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS)


if __name__ == '__main__':
    # 配置文件
    cfg, saver = main()

    # 数据集
    dataset_name = [cfg.DATASET.NAME]
    if cfg.JOINT.IF_ON:
        for name in cfg.JOINT.DATASET_NAME:
            dataset_name.append(name)
        train_loader, num_classes = make_train_data_loader_with_expand(cfg, dataset_name)
    else:
        train_loader, num_classes = make_train_data_loader(cfg, dataset_name[0])

    # 测试数据集
    valid_dict = make_multi_valid_data_loader(cfg, dataset_name)

    # 训练组件
    tr_comp = TrainComponent(cfg, num_classes)
    trainer = create_supervised_trainer(tr_comp)

    run(cfg, train_loader, tr_comp, saver, trainer, valid_dict)