def train(cfg): logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL) logger.info(cfg) model = build_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) criterion = build_loss(cfg) optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) train_loader = build_data(cfg, is_train=True) val_loader = build_data(cfg, is_train=False) logger.info(train_loader.dataset) logger.info(val_loader.dataset) arguments = dict() arguments["iteration"] = 0 checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR) do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, criterion, checkpointer, device, checkpoint_period, arguments, logger)
def train(cfg): logger = setup_logger(name="Train", level=cfg.LOGGER.LEVEL) logger.info(cfg) model = build_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if len(os.environ["CUDA_VISIBLE_DEVICES"]) > 1: model = torch.nn.DataParallel(model) criterion = build_loss(cfg) optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) train_loader = build_data(cfg, is_train=True) val_loader = build_data(cfg, is_train=False) logger.info(train_loader.dataset) for x in val_loader: logger.info(x.dataset) arguments = dict() arguments["iteration"] = 0 checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD ckp_save_path = os.path.join(cfg.SAVE_DIR, cfg.NAME) os.makedirs(ckp_save_path, exist_ok=True) checkpointer = Checkpointer(model, optimizer, scheduler, ckp_save_path) tb_save_path = os.path.join(cfg.TB_SAVE_DIR, cfg.NAME) os.makedirs(tb_save_path, exist_ok=True) writer = SummaryWriter(tb_save_path) do_train( cfg, model, train_loader, val_loader, optimizer, scheduler, criterion, checkpointer, writer, device, checkpoint_period, arguments, logger, )
def train(cfg): logger = setup_logger(name="Train", level=cfg.LOGGER.LEVEL) logger.info(cfg) model = build_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) criterion = build_loss(cfg) optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) train_loader = build_data(cfg, cfg.DATA.TRAIN_IMG_SOURCE, is_train=True) query_loader = build_data(cfg, cfg.DATA.TEST_QUERY_IMG_SOURCE, is_train=False) logger.info(train_loader.dataset) logger.info(query_loader.dataset) gallery_loader = None if cfg.DATA.TEST_GALLERY_IMG_SOURCE: gallery_loader = build_data(cfg, cfg.DATA.TEST_GALLERY_IMG_SOURCE, is_train=False) logger.info(gallery_loader.dataset) arguments = dict() arguments["iteration"] = 0 checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR) do_train( cfg, model, train_loader, query_loader, optimizer, scheduler, criterion, checkpointer, device, checkpoint_period, arguments, logger, gallery_loader=gallery_loader, )
def train(cfg): logger = setup_logger(name="Train", level=cfg.LOGGER.LEVEL) logger.info(cfg) train_loader = build_data(cfg, is_train=True) num_classes = max(set([int(i) for i in train_loader.dataset.label_list])) + 1 cfg.num_classes = num_classes criterion = build_loss(cfg.LOSSES.NAME, num_classes, cfg) train_loader = build_data(cfg, is_train=True) model = build_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if isinstance(criterion, tuple): criterion, optimizer_center = criterion criterion = criterion.cuda() scheduler_center = build_lr_scheduler(cfg, optimizer_center) else: optimizer_center = None scheduler_center = None optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) val_loader = build_data(cfg, is_train=False) trainVal_loader = build_trainVal_data(cfg, val_loader[0].dataset) if cfg.LOSSES.NAME_XBM_LOSS != 'same': criterion_xbm = build_loss(cfg.LOSSES.NAME_XBM_LOSS, num_classes, cfg) else: criterion_xbm = None logger.info(train_loader.dataset) logger.info(trainVal_loader.dataset) for x in val_loader: logger.info(x.dataset) arguments = dict() arguments["iteration"] = 0 checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD ckp_save_path = os.path.join(cfg.SAVE_DIR, cfg.NAME) os.makedirs(ckp_save_path, exist_ok=True) checkpointer = Checkpointer(model, optimizer, scheduler, ckp_save_path) tb_save_path = os.path.join(cfg.TB_SAVE_DIR, cfg.NAME) os.makedirs(tb_save_path, exist_ok=True) writer = SummaryWriter(tb_save_path) do_train( cfg, model, train_loader, trainVal_loader, val_loader, optimizer, optimizer_center, scheduler, scheduler_center, criterion, criterion_xbm, checkpointer, writer, device, arguments, logger, )