예제 #1
0
def train(cfg, saver):
    """
        Train a new dataset with source data
        e.g.: train: dukemtmc with market feat
    """

    source_name = cfg.DATASET.NAME
    target_name = cfg.FEAT.DATASET_NAME

    source_loader, _ = make_train_data_loader_for_extract(cfg, source_name)

    tr = TrainComponent(cfg, 702)
    to_load = {'model': tr.model}
    saver.checkpoint_params = to_load
    saver.load_checkpoint(is_best=True)

    feat, _ = do_extract(cfg, source_loader, tr)
    feat = feat.cpu()
    if cfg.MODEL.DEVICE == 'cuda':
        torch.cuda.empty_cache()
    logger.info(f"Extracting feat is done. {feat.size()}")

    train_loader, num_classes = make_data_with_loader_with_feat_label(
        cfg, source_name, target_name, feat)

    valid = make_multi_valid_data_loader(cfg, [source_name])

    # inference(cfg, train_component.module, valid)

    do_train_with_feat(cfg, train_loader, valid, tr, saver)
예제 #2
0
def test(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    valid = make_multi_valid_data_loader(cfg, dataset_name, verbose=True)

    tr = TrainComponent(cfg)
    saver.checkpoint_params['model'] = tr.model
    saver.load_checkpoint(is_best=True)
    inference(cfg, tr.model, valid)
예제 #3
0
def test(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    valid = make_multi_valid_data_loader(cfg, dataset_name, verbose=True)

    tr = TrainComponent(cfg, 0)
    to_load = {'module': tr.model}
    saver.to_save = to_load
    saver.load_checkpoint(is_best=True)
    inference(cfg, tr.model, valid)
예제 #4
0
def train(cfg, saver):
    """
        Train a new dataset with distillation
        e.g.: train: dukemtmc with market module
    """
    source_tr = TrainComponent(cfg, 0)
    saver.checkpoint_params['model'] = source_tr.model
    saver.load_checkpoint(is_best=True)

    dataset_name = [cfg.DATASET.NAME, cfg.CONTINUATION.DATASET_NAME]

    train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1])

    current_tr = TrainComponent(cfg, num_classes)
    saver.checkpoint_params['model'] = current_tr.model
    saver.load_checkpoint(is_best=True)

    valid = make_multi_valid_data_loader(cfg, dataset_name)

    # inference(cfg, current_tr.module, valid)

    do_continuous_train(cfg, train_loader, valid, source_tr, current_tr, saver)
예제 #5
0
def train(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    if cfg.JOINT.IF_ON:
        for name in cfg.JOINT.DATASET_NAME:
            dataset_name.append(name)
        train_loader, num_classes = make_train_data_loader_with_expand(
            cfg, dataset_name)
    else:
        train_loader, num_classes = make_train_data_loader(
            cfg, dataset_name[0])

    valid = make_multi_valid_data_loader(cfg, dataset_name)

    train_component = TrainComponent(cfg, num_classes)

    do_train(cfg, train_loader, valid, train_component, saver)
예제 #6
0
def train(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    if cfg.JOINT.IF_ON:
        for name in cfg.JOINT.DATASET_NAME:
            dataset_name.append(name)
        train_loader, num_classes = make_train_data_loader_with_expand(
            cfg, dataset_name)
    else:
        train_loader, num_classes = make_train_data_loader(
            cfg, dataset_name[0])

    valid_dict = make_multi_valid_data_loader(cfg, dataset_name)

    train_component = TrainComponent(cfg, num_classes)

    saver.checkpoint_params['model'] = train_component.model

    do_train(cfg, train_loader, valid_dict, train_component, saver)
예제 #7
0
def train(cfg, saver):
    """
        Train a new dataset with distillation
        e.g.: train: dukemtmc with market module
    """
    dataset_name = [cfg.DATASET.NAME, cfg.EBLL.DATASET_NAME]
    source_train_loader, source_num_classes = make_train_data_loader(cfg, dataset_name[0])
    source_valid = make_multi_valid_data_loader(cfg, [dataset_name[0]])

    source_tr = TrainComponent(cfg)
    saver.checkpoint_params['model'] = source_tr.model
    saver.load_checkpoint(is_best=True)

    autoencoder_tr = TrainComponent(cfg, autoencoder=True)
    saver.checkpoint_params['autoencoder'] = source_tr.model

    logger.info("")
    logger.info('*' * 60)
    logger.info("Start training autoencoder")
    logger.info('*' * 60)
    logger.info("")

    train_autoencoder(cfg,
                      source_train_loader,
                      source_valid,
                      source_tr,
                      autoencoder_tr,
                      saver)

    saver.best_result = 0

    train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1])
    ebll_valid = make_multi_valid_data_loader(cfg, [dataset_name[1]])

    current_tr = TrainComponent(cfg, num_classes)
    saver.checkpoint_params['model'] = current_tr.model
    saver.load_checkpoint(is_best=True)
    # print(current_tr)

    logger.info("")
    logger.info('*' * 60)
    logger.info("Start fine tuning current model")
    logger.info('*' * 60)
    logger.info("")

    fine_tune_current_model(cfg,
                            train_loader,
                            ebll_valid,
                            current_tr,
                            saver)

    saver.best_result = 0
    k_s = [1, 5, 10, 15, 20, 30, 50, 100, 150]
    for k in k_s:
        logger.info("")
        logger.info('*' * 60)
        logger.info(f"Start ebll training using {0.001 * k}")
        logger.info('*' * 60)
        logger.info("")

        copy_cfg = copy.deepcopy(cfg)
        copy_cfg["CONTINUATION"]["IF_ON"] = True
        copy_cfg["EBLL"]["AE_LOSS_WEIGHT"] = 0.001 * k
        ebll_tr = TrainComponent(copy_cfg, num_classes)
        ebll_tr.model = copy.deepcopy(current_tr.model)
        saver.checkpoint_params['model'] = ebll_tr.model
        ebll_valid = make_multi_valid_data_loader(cfg, dataset_name)

        ebll_train(cfg,
                   train_loader,
                   ebll_valid,
                   source_tr,
                   ebll_tr,
                   autoencoder_tr,
                   saver)

        if "ae_dist" not in cfg.CONTINUATION.LOSS_TYPE:
            break