Exemplo n.º 1
0
def train(cfg):
    logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL)
    logger.info(cfg)
    model = build_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    criterion = build_loss(cfg)

    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    train_loader = build_data(cfg, is_train=True)
    val_loader = build_data(cfg, is_train=False)

    logger.info(train_loader.dataset)
    logger.info(val_loader.dataset)

    arguments = dict()
    arguments["iteration"] = 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR)

    do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             criterion, checkpointer, device, checkpoint_period, arguments,
             logger)
Exemplo n.º 2
0
def train(cfg):
    logger = setup_logger(name="Train", level=cfg.LOGGER.LEVEL)
    logger.info(cfg)
    model = build_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if len(os.environ["CUDA_VISIBLE_DEVICES"]) > 1:
        model = torch.nn.DataParallel(model)

    criterion = build_loss(cfg)

    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    train_loader = build_data(cfg, is_train=True)
    val_loader = build_data(cfg, is_train=False)

    logger.info(train_loader.dataset)
    for x in val_loader:
        logger.info(x.dataset)

    arguments = dict()
    arguments["iteration"] = 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    ckp_save_path = os.path.join(cfg.SAVE_DIR, cfg.NAME)

    os.makedirs(ckp_save_path, exist_ok=True)
    checkpointer = Checkpointer(model, optimizer, scheduler, ckp_save_path)

    tb_save_path = os.path.join(cfg.TB_SAVE_DIR, cfg.NAME)
    os.makedirs(tb_save_path, exist_ok=True)
    writer = SummaryWriter(tb_save_path)

    do_train(
        cfg,
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        criterion,
        checkpointer,
        writer,
        device,
        checkpoint_period,
        arguments,
        logger,
    )
Exemplo n.º 3
0
def train(cfg):
    logger = setup_logger(name="Train", level=cfg.LOGGER.LEVEL)
    logger.info(cfg)
    model = build_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    criterion = build_loss(cfg)

    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    train_loader = build_data(cfg, cfg.DATA.TRAIN_IMG_SOURCE, is_train=True)
    query_loader = build_data(cfg,
                              cfg.DATA.TEST_QUERY_IMG_SOURCE,
                              is_train=False)
    logger.info(train_loader.dataset)
    logger.info(query_loader.dataset)
    gallery_loader = None
    if cfg.DATA.TEST_GALLERY_IMG_SOURCE:
        gallery_loader = build_data(cfg,
                                    cfg.DATA.TEST_GALLERY_IMG_SOURCE,
                                    is_train=False)
        logger.info(gallery_loader.dataset)

    arguments = dict()
    arguments["iteration"] = 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR)

    do_train(
        cfg,
        model,
        train_loader,
        query_loader,
        optimizer,
        scheduler,
        criterion,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        logger,
        gallery_loader=gallery_loader,
    )
def train(cfg):
    logger = setup_logger(name="Train", level=cfg.LOGGER.LEVEL)
    logger.info(cfg)
    train_loader = build_data(cfg, is_train=True)
    num_classes = max(set([int(i)
                           for i in train_loader.dataset.label_list])) + 1
    cfg.num_classes = num_classes
    criterion = build_loss(cfg.LOSSES.NAME, num_classes, cfg)
    train_loader = build_data(cfg, is_train=True)
    model = build_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if isinstance(criterion, tuple):
        criterion, optimizer_center = criterion
        criterion = criterion.cuda()
        scheduler_center = build_lr_scheduler(cfg, optimizer_center)
    else:
        optimizer_center = None
        scheduler_center = None
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    val_loader = build_data(cfg, is_train=False)

    trainVal_loader = build_trainVal_data(cfg, val_loader[0].dataset)

    if cfg.LOSSES.NAME_XBM_LOSS != 'same':
        criterion_xbm = build_loss(cfg.LOSSES.NAME_XBM_LOSS, num_classes, cfg)
    else:
        criterion_xbm = None

    logger.info(train_loader.dataset)
    logger.info(trainVal_loader.dataset)
    for x in val_loader:
        logger.info(x.dataset)

    arguments = dict()
    arguments["iteration"] = 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    ckp_save_path = os.path.join(cfg.SAVE_DIR, cfg.NAME)
    os.makedirs(ckp_save_path, exist_ok=True)
    checkpointer = Checkpointer(model, optimizer, scheduler, ckp_save_path)

    tb_save_path = os.path.join(cfg.TB_SAVE_DIR, cfg.NAME)
    os.makedirs(tb_save_path, exist_ok=True)
    writer = SummaryWriter(tb_save_path)

    do_train(
        cfg,
        model,
        train_loader,
        trainVal_loader,
        val_loader,
        optimizer,
        optimizer_center,
        scheduler,
        scheduler_center,
        criterion,
        criterion_xbm,
        checkpointer,
        writer,
        device,
        arguments,
        logger,
    )