Ejemplo n.º 1
0
def compute_time_full(model, loss_fun, train_loader, test_loader):
    """Times model and data loader."""
    logger.info("Computing model and loader timings...")
    # Compute timings
    test_fw_time = compute_time_eval(model)
    train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
    train_fw_bw_time = train_fw_time + train_bw_time
    train_loader_time = compute_time_loader(train_loader)
    # Output iter timing
    iter_times = {
        "test_fw_time": test_fw_time,
        "train_fw_time": train_fw_time,
        "train_bw_time": train_bw_time,
        "train_fw_bw_time": train_fw_bw_time,
        "train_loader_time": train_loader_time,
    }
    logger.info(logging.dump_log_data(iter_times, "iter_times"))
    # Output epoch timing
    epoch_times = {
        "test_fw_time": test_fw_time * len(test_loader),
        "train_fw_time": train_fw_time * len(train_loader),
        "train_bw_time": train_bw_time * len(train_loader),
        "train_fw_bw_time": train_fw_bw_time * len(train_loader),
        "train_loader_time": train_loader_time * len(train_loader),
    }
    logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
    # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
    overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
    logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
Ejemplo n.º 2
0
def compute_time_model(model, loss_fun):
    """Times model."""
    logger.info("Computing model timings only...")
    # Compute timings
    test_fw_time = compute_time_eval(model)
    train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
    train_fw_bw_time = train_fw_time + train_bw_time
    # Output iter timing
    iter_times = {
        "test_fw_time": test_fw_time,
        "train_fw_time": train_fw_time,
        "train_bw_time": train_bw_time,
        "train_fw_bw_time": train_fw_bw_time,
    }
    logger.info(logging.dump_log_data(iter_times, "iter_times"))
Ejemplo n.º 3
0
def setup_env():
    """Sets up environment for training or testing."""
    if dist.is_master_proc():
        # Ensure that the output dir exists
        os.makedirs(cfg.OUT_DIR, exist_ok=True)
        # Save the config
        config.dump_cfg()
    # Setup logging
    logging.setup_logging()
    # Log the config as both human readable and as a json
    logger.info("Config:\n{}".format(cfg))
    logger.info(logging.dump_log_data(cfg, "cfg"))
    # Fix the RNG seeds (see RNG comment in core/config.py for discussion)
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)
    # Configure the CUDNN backend
    torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
Ejemplo n.º 4
0
def setup_model():
    """Sets up a model for training or testing and log the results."""
    # Build the model
    model = builders.build_model()
    logger.info("Model:\n{}".format(model))
    # Log model complexity
    logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
    # Transfer the model to the current GPU device
    err_str = "Cannot use more GPU devices than available"
    assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
    cur_device = torch.cuda.current_device()
    model = model.cuda(device=cur_device)
    # Use multi-process data parallel model in the multi-gpu setting
    if cfg.NUM_GPUS > 1:
        # Make model replica operate on the current device
        model = torch.nn.parallel.DistributedDataParallel(
            module=model, device_ids=[cur_device], output_device=cur_device)
        # Set complexity function to be module's complexity function
        model.complexity = model.module.complexity
    return model
Ejemplo n.º 5
0
 def log_epoch_stats(self, cur_epoch):
     stats = self.get_epoch_stats(cur_epoch)
     logger.info(logging.dump_log_data(stats, "test_epoch"))
Ejemplo n.º 6
0
 def log_iter_stats(self, cur_epoch, cur_iter):
     if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
         return
     stats = self.get_iter_stats(cur_epoch, cur_iter)
     logger.info(logging.dump_log_data(stats, "test_iter"))