Ejemplo n.º 1
0
def eval_epoch(val_loader, model, epoch, cfg):
    '''Evaluate the model on the val set.

    Args:
      val_loader (loader): data loader to provide validation data.
      model (model): model to evaluate the performance.
      epoch (int): number of the current epoch of training.
      cfg (CfgNode): configs. Details can be found in config/defaults.py
    '''
    if is_master_proc():
        log.info('Testing..')

    model.eval()
    test_loss = 0.0
    correct = total = 0.0
    for batch_idx, (inputs, labels) in enumerate(val_loader):
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels, reduction='mean')

        # Gather all predictions across all devices.
        if cfg.NUM_GPUS > 1:
            loss = all_reduce([loss])[0]
            outputs, labels = all_gather([outputs, labels])

        # Accuracy.
        batch_correct = topks_correct(outputs, labels, (1, ))[0]
        correct += batch_correct.item()
        total += labels.size(0)

        if is_master_proc():
            test_loss += loss.item()
            test_acc = correct / total
            log.info('Loss: %.3f | Acc: %.3f' % (test_loss /
                                                 (batch_idx + 1), test_acc))
Ejemplo n.º 2
0
def train_epoch(train_loader, model, optimizer, epoch, cfg):
    '''Epoch training.

    Args:
      train_loader (DataLoader): training data loader.
      model (model): the video model to train.
      optimizer (optim): the optimizer to perform optimization on the model's parameters.
      epoch (int): current epoch of training.
      cfg (CfgNode): configs. Details can be found in config/defaults.py
    '''
    if is_master_proc():
        log.info('Epoch: %d' % epoch)

    model.train()
    num_batches = len(train_loader)
    train_loss = 0.0
    correct = total = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda()

        # Update lr.
        lr = get_epoch_lr(cfg, epoch + float(batch_idx) / num_batches)
        set_lr(optimizer, lr)

        # Forward.
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels, reduction='mean')

        # Backward.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Gather all predictions across all devices.
        if cfg.NUM_GPUS > 1:
            loss = all_reduce([loss])[0]
            outputs, labels = all_gather([outputs, labels])

        # Accuracy.
        batch_correct = topks_correct(outputs, labels, (1, ))[0]
        correct += batch_correct.item()
        total += labels.size(0)

        if is_master_proc():
            train_loss += loss.item()
            train_acc = correct / total
            log.info('Loss: %.3f | Acc: %.3f | LR: %.3f' %
                     (train_loss / (batch_idx + 1), train_acc, lr))
Ejemplo n.º 3
0
def save_checkpoint(model, optimizer, epoch, cfg):
    '''Save checkpoint.

    Args:
      save_path (str): save file path.
      model (model): model to save.
      optimizer (optim): optimizer to save.
      epoch (int): current epoch index.
      cfg (CfgNode): configs to save.
    '''
    # Only save on master process.
    if not dist.is_master_proc(cfg.NUM_GPUS):
        return

    state_dict = model.module.state_dict(
    ) if cfg.NUM_GPUS > 1 else model.state_dict()
    for k, v in state_dict.items():
        state_dict[k] = v.cpu()

    checkpoint = {
        'epoch': epoch,
        'model': state_dict,
        'optimizer': optimizer.state_dict(),
        'cfg': cfg.dump(),
    }

    if not os.path.isdir(cfg.TRAIN.CHECKPOINT_DIR):
        os.mkdir(cfg.TRAIN.CHECKPOINT_DIR)
    save_path = os.path.join(cfg.TRAIN.CHECKPOINT_DIR, '%d.pth' % epoch)
    torch.save(checkpoint, save_path)
Ejemplo n.º 4
0
def save_checkpoint(path_to_job, model, optimizer, epoch, gs, cfg):
    """
    Save a checkpoint.
    Args:
        model (model): model to save the weight to the checkpoint.
        optimizer (optim): optimizer to save the historical state.
        epoch (int): current number of epoch of the model.
        cfg (CfgNode): configs to save.
    """
    # Save checkpoints only from the master process.
    if not du.is_master_proc(cfg.NUM_GPUS):
        return
    # Ensure that the checkpoint dir exists.
    os.makedirs(get_checkpoint_dir(path_to_job), exist_ok=True)
    # Omit the DDP wrapper in the multi-gpu setting.
    sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
    # Record the state.
    checkpoint = {
        "epoch": epoch,
        "global_step": gs,
        "model_state": sd,
        "optimizer_state": optimizer.state_dict(),
        "cfg": cfg.dump(),
    }
    # Write the checkpoint.
    path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1)
    torch.save(checkpoint, path_to_checkpoint)
    return path_to_checkpoint
Ejemplo n.º 5
0
def test(cfg):
    """
    Perform multi-view testing on the trained video model.
    Args:
        cfg (CfgNode): configs. Details can be found in
        config.py
    """
    # Set random seed from configs.
    if cfg.RNG_SEED != -1:
        random.seed(cfg.RNG_SEED)
        np.random.seed(cfg.RNG_SEED)
        torch.manual_seed(cfg.RNG_SEED)
        torch.cuda.manual_seed_all(cfg.RNG_SEED)

    # Setup logging format.
    logging.setup_logging(cfg.NUM_GPUS)

    # Print config.
    logger.info("Test with config:")
    logger.info(pprint.pformat(cfg))

    # Model for testing
    model = build_model(cfg)
    # Print model statistics.
    if du.is_master_proc(cfg.NUM_GPUS):
        misc.log_model_info(model, cfg, use_train_input=False)

    if cfg.TEST.CHECKPOINT_FILE_PATH:
        if os.path.isfile(cfg.TEST.CHECKPOINT_FILE_PATH):
            logger.info("=> loading checkpoint '{}'".format(
                cfg.TEST.CHECKPOINT_FILE_PATH))
            ms = model.module if cfg.NUM_GPUS > 1 else model
            # Load the checkpoint on CPU to avoid GPU mem spike.
            checkpoint = torch.load(cfg.TEST.CHECKPOINT_FILE_PATH,
                                    map_location='cpu')
            ms.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                cfg.TEST.CHECKPOINT_FILE_PATH, checkpoint['epoch']))
    else:
        logger.info("Test with random initialization for debugging")

    # Create video testing loaders
    test_loader = loader.construct_loader(cfg, "test")
    logger.info("Testing model for {} iterations".format(len(test_loader)))

    # Create meters for multi-view testing.
    test_meter = TestMeter(
        cfg.TEST.DATASET_SIZE,
        cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS,
        cfg.MODEL.NUM_CLASSES,
        len(test_loader),
        cfg.DATA.MULTI_LABEL,
        cfg.DATA.ENSEMBLE_METHOD,
        cfg.LOG_PERIOD,
    )

    cudnn.benchmark = True

    # # Perform multi-view test on the entire dataset.
    perform_test(test_loader, model, test_meter, cfg)
Ejemplo n.º 6
0
def save_checkpoint(path_to_job, model, optimizer, epoch, cfg):
    """
    Save a checkpoint.
    Args:
        model (model): model to save the weight to the checkpoint.
        optimizer (optim): optimizer to save the historical state.
        epoch (int): current number of epoch of the model.
        cfg (CfgNode): configs to save.
    """
    # Save checkpoints only from the master process.
    if not du.is_master_proc(cfg.NUM_GPUS * cfg.NUM_SHARDS):
        return
    # Ensure that the checkpoint dir exists.
    PathManager.mkdirs(get_checkpoint_dir(path_to_job))
    # Omit the DDP wrapper in the multi-gpu setting.
    sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
    normalized_sd = sub_to_normal_bn(sd)

    # Record the state.
    checkpoint = {
        "epoch": epoch,
        "model_state": normalized_sd,
        "optimizer_state": optimizer.state_dict(),
        "cfg": cfg.dump(),
    }
    # Write the checkpoint.
    path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1)
    with PathManager.open(path_to_checkpoint, "wb") as f:
        torch.save(checkpoint, f)
    return path_to_checkpoint
Ejemplo n.º 7
0
def setup_logging(output_dir=None):
    """
    Sets up the logging for multiple processes. Only enable the logging for the
    master process, and suppress logging for the non-master processes.
    """
    # Set up logging format.
    _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s"

    if du.is_master_proc():
        # Enable logging for the master process.
        logging.root.handlers = []
        logging.basicConfig(level=logging.INFO,
                            format=_FORMAT,
                            stream=sys.stdout)
    else:
        # Suppress logging for non-master processes.
        _suppress_print()

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    logger.propagate = False
    plain_formatter = logging.Formatter(
        "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
        datefmt="%m/%d %H:%M:%S",
    )

    if du.is_master_proc():
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(plain_formatter)
        logger.addHandler(ch)

    if output_dir is not None and du.is_master_proc(du.get_world_size()):
        filename = os.path.join(output_dir, "stdout.log")
        fh = logging.StreamHandler(_cached_log_stream(filename))
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)
Ejemplo n.º 8
0
def make_checkpoint_dir(path_to_job):
    """
    Creates the checkpoint directory (if not present already).
    Args:
        path_to_job (string): the path to the folder of the current job.
    """
    checkpoint_dir = os.path.join(path_to_job, "checkpoints")
    # Create the checkpoint dir from the master process
    if du.is_master_proc() and not os.path.exists(checkpoint_dir):
        try:
            os.makedirs(checkpoint_dir)
        except Exception:
            pass
    return checkpoint_dir
Ejemplo n.º 9
0
def build_trainer(cfg):
    """
    Build training model and its associated tools, including optimizer,
    dataloaders and meters.
    Args:
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    Returns:
        model (nn.Module): training model.
        optimizer (Optimizer): optimizer.
        train_loader (DataLoader): training data loader.
        val_loader (DataLoader): validatoin data loader.
        precise_bn_loader (DataLoader): training data loader for computing
            precise BN.
        train_meter (TrainMeter): tool for measuring training stats.
        val_meter (ValMeter): tool for measuring validation stats.
    """
    # Build the video model and print model statistics.
    model = build_model(cfg)
    if du.is_master_proc() and cfg.LOG_MODEL_INFO:
        misc.log_model_info(model, cfg, is_train=True)

    # Construct the optimizer.
    optimizer = optim.construct_optimizer(model, cfg)

    # Create the video train and val loaders.
    train_loader = loader.construct_loader(cfg, "train")
    val_loader = loader.construct_loader(cfg, "val")
    precise_bn_loader = loader.construct_loader(
        cfg, "train", is_precise_bn=True
    )
    # Create meters.
    train_meter = TrainMeter(len(train_loader), cfg)
    val_meter = ValMeter(len(val_loader), cfg)

    return (
        model,
        optimizer,
        train_loader,
        val_loader,
        precise_bn_loader,
        train_meter,
        val_meter,
    )
Ejemplo n.º 10
0
def test(cfg):
    """
    Test a model
    """
    logging.setup_logging(logger, cfg)

    logger.info("Test with config")
    logger.info(pprint.pformat(cfg))

    model = model_builder.build_model(cfg)
    if du.is_master_proc():
        misc.log_model_info(model)

    if cfg.TEST.CHECKPOINT_FILE_PATH != "":
        logger.info("Load from given checkpoint file.")
        gs, checkpoint_epoch = cu.load_checkpoint(
            cfg.TEST.CHECKPOINT_FILE_PATH,
            model,
            cfg.NUM_GPUS > 1,
            optimizer=None,
            inflation=False,
            convert_from_caffe2=False)
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint(cfg.OUTPUT_DIR):
        logger.info("Load from last checkpoint.")
        last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR)
        gs, checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model,
                                                  cfg.NUM_GPUS > 1, None)
        start_epoch = checkpoint_epoch + 1

    # Create the video train and val loaders.
    test_loader = loader.construct_loader(cfg, "test")

    test_meter = TestMeter(cfg)

    if cfg.TEST.AUGMENT_TEST:
        evaluate_with_augmentation(test_loader, model, test_meter, cfg)
    else:
        evaluate(test_loader, model, test_meter, cfg)
Ejemplo n.º 11
0
def test(cfg):
    """
    Perform multi-view testing on the pretrained video model.
    Args:
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Set up environment.
    du.init_distributed_training(cfg)
    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    # Setup logging format.
    logging.setup_logging(cfg.OUTPUT_DIR)

    # Print config.
    logger.info("Test with config:")
    logger.info(cfg)

    # Build the video model and print model statistics.
    model = build_model(cfg)
    if du.is_master_proc() and cfg.LOG_MODEL_INFO:
        misc.log_model_info(model, cfg, is_train=False)

    # Load a checkpoint to test if applicable.
    if cfg.TEST.CHECKPOINT_FILE_PATH != "":
        cu.load_checkpoint(
            cfg.TEST.CHECKPOINT_FILE_PATH,
            model,
            cfg.NUM_GPUS > 1,
            None,
            inflation=False,
            convert_from_caffe2=cfg.TEST.CHECKPOINT_TYPE == "caffe2",
        )
    elif cu.has_checkpoint(cfg.OUTPUT_DIR):
        last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR)
        cu.load_checkpoint(last_checkpoint, model, cfg.NUM_GPUS > 1)
    elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "":
        # If no checkpoint found in TEST.CHECKPOINT_FILE_PATH or in the current
        # checkpoint folder, try to load checkpint from
        # TRAIN.CHECKPOINT_FILE_PATH and test it.
        cu.load_checkpoint(
            cfg.TRAIN.CHECKPOINT_FILE_PATH,
            model,
            cfg.NUM_GPUS > 1,
            None,
            inflation=False,
            convert_from_caffe2=cfg.TRAIN.CHECKPOINT_TYPE == "caffe2",
        )
    else:
        # raise NotImplementedError("Unknown way to load checkpoint.")
        logger.info("Testing with random initialization. Only for debugging.")

    # Create video testing loaders.
    test_loader = loader.construct_loader(cfg, "test")
    logger.info("Testing model for {} iterations".format(len(test_loader)))

    if cfg.DETECTION.ENABLE:
        assert cfg.NUM_GPUS == cfg.TEST.BATCH_SIZE
        test_meter = AVAMeter(len(test_loader), cfg, mode="test")
    else:
        assert (
            len(test_loader.dataset) %
            (cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS) == 0)
        # Create meters for multi-view testing.
        test_meter = TestMeter(
            len(test_loader.dataset) //
            (cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS),
            cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS,
            cfg.MODEL.NUM_CLASSES,
            len(test_loader),
            cfg.DATA.MULTI_LABEL,
            cfg.DATA.ENSEMBLE_METHOD,
        )

    # Set up writer for logging to Tensorboard format.
    if cfg.TENSORBOARD.ENABLE and du.is_master_proc(
            cfg.NUM_GPUS * cfg.NUM_SHARDS):
        writer = tb.TensorboardWriter(cfg)
    else:
        writer = None

    # # Perform multi-view test on the entire dataset.
    perform_test(test_loader, model, test_meter, cfg, writer)
    if writer is not None:
        writer.close()
Ejemplo n.º 12
0
def train(cfg):
    """
    Train a video model for many epochs on train set and evaluate it on val set.
    Args:
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Setup logging format.
    logging.setup_logging(logger, cfg)

    # Print config.
    logger.info("Train with config:")
    logger.info(pprint.pformat(cfg))

    # Build the video model and print model statistics.
    model = model_builder.build_model(cfg)
    if du.is_master_proc():
        misc.log_model_info(model)

    # Construct the optimizer.
    optimizer = optim.construct_optimizer(model, cfg)

    # Record global step
    gs = 0

    # Load a checkpoint to resume training if applicable.
    if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint(cfg.OUTPUT_DIR):
        logger.info("Load from last checkpoint.")
        last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR)
        gs, checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model,
                                                  cfg.NUM_GPUS > 1, optimizer)
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "":
        logger.info("Load from given checkpoint file.")
        if cfg.TRAIN.LOAD_PART_OF_CHECKPOINT:
            gs, checkpoint_epoch = cu.load_part_of_checkpoint(
                cfg.TRAIN.CHECKPOINT_FILE_PATH,
                model,
                cfg.NUM_GPUS > 1,
                optimizer=None)
        else:
            gs, checkpoint_epoch = cu.load_checkpoint(
                cfg.TRAIN.CHECKPOINT_FILE_PATH,
                model,
                cfg.NUM_GPUS > 1,
                optimizer=None,
                inflation=False,
                convert_from_caffe2=False)
        start_epoch = checkpoint_epoch + 1
    else:
        gs = 0
        start_epoch = 0

    # Create the video train and val loaders.
    train_loader = loader.construct_loader(cfg, "train")
    val_loader = loader.construct_loader(cfg, "val")

    # Create meters.
    train_meter = TrainMeter(len(train_loader), cfg)
    val_meter = ValMeter(cfg)

    # Perform the training loop.
    logger.info("Start epoch: {} gs {}".format(start_epoch + 1, gs + 1))

    for cur_epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH):
        # Shuffle the dataset.
        loader.shuffle_dataset(train_loader, cur_epoch)

        # Evaluate the model on validation set.
        if misc.is_eval_epoch(cfg, cur_epoch):
            if cfg.TRAIN.USE_CENTER_VALIDATION:
                validation_epoch_center(val_loader, model, val_meter,
                                        cur_epoch, cfg)
            else:
                validation_epoch(val_loader, model, val_meter, cur_epoch, cfg)
        # Train for one epoch.
        gs = train_epoch(train_loader, model, optimizer, train_meter,
                         cur_epoch, gs, cfg)

        # Compute precise BN stats.
        # if cfg.BN.USE_PRECISE_STATS and len(get_bn_modules(model)) > 0:
        #     calculate_and_update_precise_bn(
        #         train_loader, model, cfg.BN.NUM_BATCHES_PRECISE
        #     )
        # Save a checkpoint.
        if cu.is_checkpoint_epoch(cur_epoch, cfg.TRAIN.CHECKPOINT_PERIOD):
            cu.save_checkpoint(cfg.OUTPUT_DIR, model, optimizer, cur_epoch, gs,
                               cfg)
Ejemplo n.º 13
0
def train(cfg):
    """
    Train a video model for many epochs on train set and evaluate it on val set.
    Args:
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Set up environment.
    du.init_distributed_training(cfg)
    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    # Setup logging format.
    logging.setup_logging(cfg.OUTPUT_DIR)

    # Init multigrid.
    multigrid = None
    if cfg.MULTIGRID.LONG_CYCLE or cfg.MULTIGRID.SHORT_CYCLE:
        multigrid = MultigridSchedule()
        cfg = multigrid.init_multigrid(cfg)
        if cfg.MULTIGRID.LONG_CYCLE:
            cfg, _ = multigrid.update_long_cycle(cfg, cur_epoch=0)
    # Print config.
    logger.info("Train with config:")
    logger.info(pprint.pformat(cfg))

    # Build the video model and print model statistics.
    model = build_model(cfg)
    # model = x3d.MyModel()
    if du.is_master_proc() and cfg.LOG_MODEL_INFO:
        misc.log_model_info(model, cfg, is_train=True)

    # Construct the optimizer.
    optimizer = optim.construct_optimizer(model, cfg)

    # Load a checkpoint to resume training if applicable.
    if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint(cfg.OUTPUT_DIR):
        last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR)
        logger.info("Load from last checkpoint, {}.".format(last_checkpoint))
        checkpoint_epoch = cu.load_checkpoint(
            last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer
        )
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "":
        logger.info("Load from given checkpoint file.")
        checkpoint_epoch = cu.load_checkpoint(
            cfg.TRAIN.CHECKPOINT_FILE_PATH,
            model,
            cfg.NUM_GPUS > 1,
            optimizer,
            inflation=cfg.TRAIN.CHECKPOINT_INFLATE,
            convert_from_caffe2=cfg.TRAIN.CHECKPOINT_TYPE == "caffe2",
        )
        start_epoch = checkpoint_epoch + 1
    else:
        start_epoch = 0

    # Create the video train and val loaders.
    train_loader = loader.construct_loader(cfg, "train")
    val_loader = loader.construct_loader(cfg, "val")
    precise_bn_loader = loader.construct_loader(
        cfg, "train", is_precise_bn=True
    )

    # Create meters.
    if cfg.DETECTION.ENABLE:
        train_meter = AVAMeter(len(train_loader), cfg, mode="train")
        val_meter = AVAMeter(len(val_loader), cfg, mode="val")
    else:
        train_meter = TrainMeter(len(train_loader), cfg)
        val_meter = ValMeter(len(val_loader), cfg)

    # set up writer for logging to Tensorboard format.
    if cfg.TENSORBOARD.ENABLE and du.is_master_proc(
        cfg.NUM_GPUS * cfg.NUM_SHARDS
    ):
        writer = tb.TensorboardWriter(cfg)
    else:
        writer = None

    # Perform the training loop.
    logger.info("Start epoch: {}".format(start_epoch + 1))

    for cur_epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH):
        if cfg.MULTIGRID.LONG_CYCLE:
            cfg, changed = multigrid.update_long_cycle(cfg, cur_epoch)
            if changed:
                (
                    model,
                    optimizer,
                    train_loader,
                    val_loader,
                    precise_bn_loader,
                    train_meter,
                    val_meter,
                ) = build_trainer(cfg)

                # Load checkpoint.
                if cu.has_checkpoint(cfg.OUTPUT_DIR):
                    last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR)
                    assert "{:05d}.pyth".format(cur_epoch) in last_checkpoint
                else:
                    last_checkpoint = cfg.TRAIN.CHECKPOINT_FILE_PATH
                logger.info("Load from {}".format(last_checkpoint))
                cu.load_checkpoint(
                    last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer
                )

        # Shuffle the dataset.
        loader.shuffle_dataset(train_loader, cur_epoch)
        # Train for one epoch.
        train_epoch(
            train_loader, model, optimizer, train_meter, cur_epoch, cfg, writer
        )

        # Compute precise BN stats.
        if cfg.BN.USE_PRECISE_STATS and len(get_bn_modules(model)) > 0:
            calculate_and_update_precise_bn(
                precise_bn_loader,
                model,
                min(cfg.BN.NUM_BATCHES_PRECISE, len(precise_bn_loader)),
            )
        _ = misc.aggregate_sub_bn_stats(model)

        # Save a checkpoint.
        if cu.is_checkpoint_epoch(
            cfg, cur_epoch, None if multigrid is None else multigrid.schedule
        ):
            cu.save_checkpoint(cfg.OUTPUT_DIR, model,
                               optimizer, cur_epoch, cfg)
        # Evaluate the model on validation set.
        if misc.is_eval_epoch(
            cfg, cur_epoch, None if multigrid is None else multigrid.schedule
        ):
            eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer)

    if writer is not None:
        writer.close()
Ejemplo n.º 14
0
def train(cfg):
    """
    Train function.
    Args:
        cfg (CfgNode) : configs. Details can be found in
            config.py
    """
    # Set random seed from configs.
    if cfg.RNG_SEED != -1:
        random.seed(cfg.RNG_SEED)
        np.random.seed(cfg.RNG_SEED)
        torch.manual_seed(cfg.RNG_SEED)
        torch.cuda.manual_seed_all(cfg.RNG_SEED)

    # Setup logging format.
    logging.setup_logging(cfg.NUM_GPUS, os.path.join(cfg.LOG_DIR, "log.txt"))

    # Print config.
    logger.info("Train with config:")
    logger.info(pprint.pformat(cfg))

    # Model for training.
    model = build_model(cfg)
    # Construct te optimizer.
    optimizer = optim.construct_optimizer(model, cfg)

    # Print model statistics.
    if du.is_master_proc(cfg.NUM_GPUS):
        misc.log_model_info(model, cfg, use_train_input=True)

    # Create dataloaders.
    train_loader = loader.construct_loader(cfg, 'train')
    val_loader = loader.construct_loader(cfg, 'val')

    if cfg.SOLVER.MAX_EPOCH != -1:
        max_epoch = cfg.SOLVER.MAX_EPOCH * cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS
        num_steps = max_epoch * len(train_loader)
        cfg.SOLVER.NUM_STEPS = cfg.SOLVER.MAX_EPOCH * len(train_loader)
        cfg.SOLVER.WARMUP_PROPORTION = cfg.SOLVER.WARMUP_EPOCHS / cfg.SOLVER.MAX_EPOCH
    else:
        num_steps = cfg.SOLVER.NUM_STEPS * cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS
        max_epoch = math.ceil(num_steps / len(train_loader))
        cfg.SOLVER.MAX_EPOCH = cfg.SOLVER.NUM_STEPS / len(train_loader)
        cfg.SOLVER.WARMUP_EPOCHS = cfg.SOLVER.MAX_EPOCH * cfg.SOLVER.WARMUP_PROPORTION

    start_epoch = 0
    global_step = 0
    if cfg.TRAIN.CHECKPOINT_FILE_PATH:
        if os.path.isfile(cfg.TRAIN.CHECKPOINT_FILE_PATH):
            logger.info(
                "=> loading checkpoint '{}'".format(
                    cfg.TRAIN.CHECKPOINT_FILE_PATH
                )
            )
            ms = model.module if cfg.NUM_GPUS > 1 else model
            # Load the checkpoint on CPU to avoid GPU mem spike.
            checkpoint = torch.load(
                cfg.TRAIN.CHECKPOINT_FILE_PATH, map_location='cpu'
            )
            start_epoch = checkpoint['epoch']
            ms.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            global_step = checkpoint['epoch'] * len(train_loader)
            logger.info(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    cfg.TRAIN.CHECKPOINT_FILE_PATH,
                    checkpoint['epoch']
                )
            )
    else:
        logger.info("Training with random initialization.")

    # Create meters.
    train_meter = TrainMeter(
        len(train_loader),
        num_steps,
        max_epoch,
        cfg
    )
    val_meter = ValMeter(
        len(val_loader),
        max_epoch,
        cfg
    )

    # Perform the training loop.
    logger.info("Start epoch: {}".format(start_epoch+1))

    cudnn.benchmark = True

    best_epoch, best_top1_err, top5_err, best_map = 0, 100.0, 100.0, 0.0

    for cur_epoch in range(start_epoch, max_epoch):
        is_best_epoch = False
        # Shuffle the dataset.
        # loader.shuffle_dataset(train_loader, cur_epoch)
        # Pretrain for one epoch.
        global_step = train_epoch(
            train_loader,
            model,
            optimizer,
            train_meter,
            cur_epoch,
            global_step,
            num_steps,
            cfg
        )

        if cfg.BN.USE_PRECISE_STATS and len(get_bn_modules(model)) > 0:
            calculate_and_update_precise_bn(
                train_loader, model, cfg.BN.NUM_BATCHES_PRECISE
            )

        if misc.is_eval_epoch(cfg, cur_epoch, max_epoch):
            stats = eval_epoch(val_loader, model, val_meter, cur_epoch, cfg)
            if cfg.DATA.MULTI_LABEL:
                if best_map < float(stats["map"]):
                    best_epoch = cur_epoch + 1
                    best_map = float(stats["map"])
                    is_best_epoch = True
                logger.info(
                    "BEST: epoch: {}, best_map: {:.2f}".format(
                        best_epoch, best_map,
                    )
                )
            else:
                if best_top1_err > float(stats["top1_err"]):
                    best_epoch = cur_epoch + 1
                    best_top1_err = float(stats["top1_err"])
                    top5_err = float(stats["top5_err"])
                    is_best_epoch = True
                logger.info(
                    "BEST: epoch: {}, best_top1_err: {:.2f}, top5_err: {:.2f}".format(
                        best_epoch, best_top1_err, top5_err
                    )
                )

        sd = \
            model.module.state_dict() if cfg.NUM_GPUS > 1 else \
            model.state_dict()

        ckpt = {
            'epoch': cur_epoch + 1,
            'model_arch': cfg.MODEL.DOWNSTREAM_ARCH,
            'state_dict': sd,
            'optimizer': optimizer.state_dict(),
        }

        if (cur_epoch + 1) % cfg.SAVE_EVERY_EPOCH == 0 and du.get_rank() == 0:
            sd = \
                model.module.state_dict() if cfg.NUM_GPUS > 1 else \
                model.state_dict()
            save_checkpoint(
                ckpt,
                filename=os.path.join(cfg.SAVE_DIR, f'epoch{cur_epoch+1}.pyth')
            )

        if is_best_epoch and du.get_rank() == 0:
            save_checkpoint(
                ckpt,
                filename=os.path.join(cfg.SAVE_DIR, f"epoch_best.pyth")
            )