示例#1
0
def train(cfg, saver):
    """
        Train a new dataset with distillation
        e.g.: train: dukemtmc with market module
    """
    source_tr = TrainComponent(cfg, 0)
    to_load = {'module': source_tr.model}
    saver.to_save = to_load
    saver.load_checkpoint(is_best=True)

    dataset_name = [cfg.DATASET.NAME, cfg.CONTINUATION.DATASET_NAME]

    train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1])

    current_tr = TrainComponent(cfg, num_classes)
    to_load = {'module': current_tr.model}
    saver.to_save = to_load
    saver.load_checkpoint(is_best=True)

    valid = make_multi_valid_data_loader(cfg, dataset_name)

    # inference(cfg, current_tr.module, valid)

    do_continuous_train(cfg, train_loader, valid, source_tr, current_tr, saver)


if __name__ == '__main__':
    cfg, saver = main(
        ["CONTINUATION.IF_ON", True, "TEST.IF_RE_RANKING", False])
    train(cfg, saver)
示例#2
0
    source_name = cfg.DATASET.NAME
    target_name = cfg.FEAT.DATASET_NAME

    source_loader, _ = make_train_data_loader_for_extract(cfg, source_name)

    tr = TrainComponent(cfg, 702)
    to_load = {'model': tr.model}
    saver.checkpoint_params = to_load
    saver.load_checkpoint(is_best=True)

    feat, _ = do_extract(cfg, source_loader, tr)
    feat = feat.cpu()
    if cfg.MODEL.DEVICE == 'cuda':
        torch.cuda.empty_cache()
    logger.info(f"Extracting feat is done. {feat.size()}")

    train_loader, num_classes = make_data_with_loader_with_feat_label(
        cfg, source_name, target_name, feat)

    valid = make_multi_valid_data_loader(cfg, [source_name])

    # inference(cfg, train_component.module, valid)

    do_train_with_feat(cfg, train_loader, valid, tr, saver)


if __name__ == '__main__':
    cfg, saver = main(["FEAT.IF_ON", True])
    train(cfg, saver)
示例#3
0
sys.path.append('.')
sys.path.append('..')

from data import make_train_data_loader, make_train_data_loader_with_expand, make_multi_valid_data_loader
from tools.component import main, TrainComponent
from engine.trainer import do_train


def train(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    if cfg.JOINT.IF_ON:
        for name in cfg.JOINT.DATASET_NAME:
            dataset_name.append(name)
        train_loader, num_classes = make_train_data_loader_with_expand(
            cfg, dataset_name)
    else:
        train_loader, num_classes = make_train_data_loader(
            cfg, dataset_name[0])

    valid = make_multi_valid_data_loader(cfg, dataset_name)

    train_component = TrainComponent(cfg, num_classes)

    do_train(cfg, train_loader, valid, train_component, saver)


if __name__ == '__main__':
    cfg, saver = main()
    train(cfg, saver)
示例#4
0
import sys

sys.path.append('.')
sys.path.append('..')

from data import make_multi_valid_data_loader
from engine.inference import inference
from tools.component import main, TrainComponent


def test(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    valid = make_multi_valid_data_loader(cfg, dataset_name, verbose=True)

    tr = TrainComponent(cfg, 0)
    to_load = {'module': tr.model}
    saver.to_save = to_load
    saver.load_checkpoint(is_best=True)
    inference(cfg, tr.model, valid)


if __name__ == '__main__':
    cfg, saver = main(["TEST.IF_ON", True])
    test(cfg, saver)
示例#5
0
    for k in k_s:
        logger.info("")
        logger.info('*' * 60)
        logger.info(f"Start ebll training using {0.001 * k}")
        logger.info('*' * 60)
        logger.info("")

        copy_cfg = copy.deepcopy(cfg)
        copy_cfg["CONTINUATION"]["IF_ON"] = True
        copy_cfg["EBLL"]["AE_LOSS_WEIGHT"] = 0.001 * k
        ebll_tr = TrainComponent(copy_cfg, num_classes)
        ebll_tr.model = copy.deepcopy(current_tr.model)
        saver.checkpoint_params['model'] = ebll_tr.model
        ebll_valid = make_multi_valid_data_loader(cfg, dataset_name)

        ebll_train(cfg,
                   train_loader,
                   ebll_valid,
                   source_tr,
                   ebll_tr,
                   autoencoder_tr,
                   saver)

        if "ae_dist" not in cfg.CONTINUATION.LOSS_TYPE:
            break


if __name__ == '__main__':
    cfg, saver = main(["EBLL.IF_ON", True, "TEST.IF_RE_RANKING", False])
    train(cfg, saver)