예제 #1
0
def do_continuous_train(cfg, train_loader, valid_dict,
                        source_tr_comp: TrainComponent,
                        current_tr_comp: TrainComponent, saver):
    # tb_log = TensorBoardXLog(cfg, saver.save_dir)

    trainer = create_supervised_trainer(source_tr_comp.model,
                                        current_tr_comp.model,
                                        current_tr_comp.optimizer,
                                        current_tr_comp.loss,
                                        device=cfg.MODEL.DEVICE,
                                        apex=cfg.APEX.IF_ON)

    run(cfg, train_loader, valid_dict, current_tr_comp, saver, trainer)
예제 #2
0
def fine_tune_current_model(cfg, train_loader, valid_dict, tr_comp, saver):
    for param in tr_comp.model.base.parameters():
        param.requires_grad = False
    trainer = create_supervised_trainer(tr_comp.model,
                                        tr_comp.optimizer,
                                        tr_comp.loss,
                                        device=cfg.MODEL.DEVICE,
                                        apex=cfg.APEX.IF_ON)

    evaler = Eval(valid_dict, cfg.MODEL.DEVICE)
    evaler.get_valid_eval_map(cfg, tr_comp.model)
    copy_cfg = copy.deepcopy(cfg)
    copy_cfg["TRAIN"]["MAX_EPOCHS"] = 60
    run(copy_cfg, train_loader, tr_comp, saver, trainer, evaler)
예제 #3
0
def train_autoencoder(cfg, train_loader, valid_dict,
                      source_tr_comp: TrainComponent,
                      current_tr_comp: TrainComponent, saver):
    trainer = create_autoencoder_trainer(source_tr_comp.model,
                                         current_tr_comp.model,
                                         current_tr_comp.optimizer,
                                         current_tr_comp.loss,
                                         device=cfg.MODEL.DEVICE,
                                         apex=cfg.APEX.IF_ON)

    evaler = Eval(valid_dict, cfg.MODEL.DEVICE)
    evaler.get_valid_eval_map_autoencoder(cfg, source_tr_comp.model,
                                          current_tr_comp.model)
    copy_cfg = copy.deepcopy(cfg)
    copy_cfg["TRAIN"]["MAX_EPOCHS"] = 90
    run(copy_cfg, train_loader, current_tr_comp, saver, trainer, evaler)
예제 #4
0
def ebll_train(cfg, train_loader, valid_dict, source_tr_comp, current_tr_comp,
               autoencoder_tr, saver):
    for param in current_tr_comp.model.base.parameters():
        param.requires_grad = True

    trainer = create_ebll_trainer(source_tr_comp.model,
                                  autoencoder_tr.model,
                                  current_tr_comp.model,
                                  current_tr_comp.optimizer,
                                  current_tr_comp.loss,
                                  apex=cfg.APEX.IF_ON,
                                  device=cfg.MODEL.DEVICE)

    evaler = Eval(valid_dict, cfg.MODEL.DEVICE)
    evaler.get_valid_eval_map_ebll(cfg, source_tr_comp.model,
                                   current_tr_comp.model)
    run(cfg, train_loader, current_tr_comp, saver, trainer, evaler)
예제 #5
0
파일: train.py 프로젝트: linxin98/diff
    base_model.eval()
    if diff_attention_model is not None:
        diff_attention_model = diff_attention_model.to(device)
        diff_attention_model.eval()
    logger.info('Get model.')

    # Get data loaders.
    train_loader, query_loader, gallery_loader = loader.get_data_loaders(config, base_model=base_model, device=device)
    logger.info('Get data loaders.')

    # Get loss.
    loss, center_loss = loss.get_loss(config, device)
    logger.info('Get loss.')

    # Get optimizer.
    base_optimizer, base_scheduler, diff_optimizer, diff_scheduler, center_optimizer, center_scheduler = optimizer.get_optimizer(
        config, base_model, diff_attention_model, center_loss=center_loss)
    logger.info('Get optimizer.')

    # Get trainer.
    trainer = trainer.get_trainer(config, base_model, diff_attention_model, loss, device, logger, query_loader,
                                  gallery_loader, base_optimizer=base_optimizer, base_scheduler=base_scheduler,
                                  diff_optimizer=diff_optimizer, diff_scheduler=diff_scheduler,
                                  center_optimizer=center_optimizer, center_scheduler=center_scheduler)
    logger.info('Get trainer.')

    # Do train.
    logger.info('Start training.')
    trainer.run(train_loader, max_epochs=config['trainer'].getint('epochs'))
    logger.info('Finish training.')