Beispiel #1
0
def train(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    if cfg.JOINT.IF_ON:
        for name in cfg.JOINT.DATASET_NAME:
            dataset_name.append(name)
        train_loader, num_classes = make_train_data_loader_with_expand(
            cfg, dataset_name)
    else:
        train_loader, num_classes = make_train_data_loader(
            cfg, dataset_name[0])

    valid = make_multi_valid_data_loader(cfg, dataset_name)

    train_component = TrainComponent(cfg, num_classes)

    do_train(cfg, train_loader, valid, train_component, saver)
Beispiel #2
0
def train(cfg, saver):
    dataset_name = [cfg.DATASET.NAME]
    if cfg.JOINT.IF_ON:
        for name in cfg.JOINT.DATASET_NAME:
            dataset_name.append(name)
        train_loader, num_classes = make_train_data_loader_with_expand(
            cfg, dataset_name)
    else:
        train_loader, num_classes = make_train_data_loader(
            cfg, dataset_name[0])

    valid_dict = make_multi_valid_data_loader(cfg, dataset_name)

    train_component = TrainComponent(cfg, num_classes)

    saver.checkpoint_params['model'] = train_component.model

    do_train(cfg, train_loader, valid_dict, train_component, saver)
def train(cfg, saver):
    """
        Train a new dataset with distillation
        e.g.: train: dukemtmc with market module
    """
    source_tr = TrainComponent(cfg, 0)
    saver.checkpoint_params['model'] = source_tr.model
    saver.load_checkpoint(is_best=True)

    dataset_name = [cfg.DATASET.NAME, cfg.CONTINUATION.DATASET_NAME]

    train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1])

    current_tr = TrainComponent(cfg, num_classes)
    saver.checkpoint_params['model'] = current_tr.model
    saver.load_checkpoint(is_best=True)

    valid = make_multi_valid_data_loader(cfg, dataset_name)

    # inference(cfg, current_tr.module, valid)

    do_continuous_train(cfg, train_loader, valid, source_tr, current_tr, saver)
Beispiel #4
0
def train(cfg, saver):
    """
        Train a new dataset with distillation
        e.g.: train: dukemtmc with market module
    """
    dataset_name = [cfg.DATASET.NAME, cfg.EBLL.DATASET_NAME]
    source_train_loader, source_num_classes = make_train_data_loader(
        cfg, dataset_name[0])
    source_valid = make_multi_valid_data_loader(cfg, [dataset_name[0]])

    source_tr = TrainComponent(cfg)
    saver.checkpoint_params['model'] = source_tr.model
    saver.load_checkpoint(is_best=True)

    autoencoder_tr = TrainComponent(cfg, autoencoder=True)
    saver.checkpoint_params['autoencoder'] = source_tr.model

    logger.info("")
    logger.info('*' * 60)
    logger.info("Start training autoencoder")
    logger.info('*' * 60)
    logger.info("")

    train_autoencoder(cfg, source_train_loader, source_valid, source_tr,
                      autoencoder_tr, saver)

    saver.best_result = 0

    train_loader, num_classes = make_train_data_loader(cfg, dataset_name[1])
    ebll_valid = make_multi_valid_data_loader(cfg, [dataset_name[1]])

    current_tr = TrainComponent(cfg, num_classes)
    saver.checkpoint_params['model'] = current_tr.model
    saver.load_checkpoint(is_best=True)
    # print(current_tr)

    logger.info("")
    logger.info('*' * 60)
    logger.info("Start fine tuning current model")
    logger.info('*' * 60)
    logger.info("")

    fine_tune_current_model(cfg, train_loader, ebll_valid, current_tr, saver)

    saver.best_result = 0
    k_s = [1, 5, 10, 15, 20, 30, 50, 100, 150]
    for k in k_s:
        logger.info("")
        logger.info('*' * 60)
        logger.info(f"Start ebll training using {0.001 * k}")
        logger.info('*' * 60)
        logger.info("")

        copy_cfg = copy.deepcopy(cfg)
        copy_cfg["CONTINUATION"]["IF_ON"] = True
        copy_cfg["EBLL"]["AE_LOSS_WEIGHT"] = 0.001 * k
        ebll_tr = TrainComponent(copy_cfg, num_classes)
        ebll_tr.model = copy.deepcopy(current_tr.model)
        saver.checkpoint_params['model'] = ebll_tr.model
        ebll_valid = make_multi_valid_data_loader(cfg, dataset_name)

        ebll_train(cfg, train_loader, ebll_valid, source_tr, ebll_tr,
                   autoencoder_tr, saver)

        if "ae_dist" not in cfg.CONTINUATION.LOSS_TYPE:
            break
Beispiel #5
0
    @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD))
    def log_validation_results(engine):
        logger.info(f"Valid - Epoch: {engine.state.epoch}")
        eval_multi_dataset(cfg, valid_dict, tr_comp)

    trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS)


if __name__ == '__main__':
    # 配置文件
    cfg, saver = main()

    # 数据集
    dataset_name = [cfg.DATASET.NAME]
    if cfg.JOINT.IF_ON:
        for name in cfg.JOINT.DATASET_NAME:
            dataset_name.append(name)
        train_loader, num_classes = make_train_data_loader_with_expand(cfg, dataset_name)
    else:
        train_loader, num_classes = make_train_data_loader(cfg, dataset_name[0])

    # 测试数据集
    valid_dict = make_multi_valid_data_loader(cfg, dataset_name)

    # 训练组件
    tr_comp = TrainComponent(cfg, num_classes)
    trainer = create_supervised_trainer(tr_comp)

    run(cfg, train_loader, tr_comp, saver, trainer, valid_dict)