def compute_time_full(model, loss_fun, train_loader, test_loader): """Times model and data loader.""" logger.info("Computing model and loader timings...") # Compute timings test_fw_time = compute_time_eval(model) train_fw_time, train_bw_time = compute_time_train(model, loss_fun) train_fw_bw_time = train_fw_time + train_bw_time train_loader_time = compute_time_loader(train_loader) # Output iter timing iter_times = { "_type": "iter_time", "test_fw_time": test_fw_time, "train_fw_time": train_fw_time, "train_bw_time": train_bw_time, "train_fw_bw_time": train_fw_bw_time, "train_loader_time": train_loader_time, } logger.info(logging.dump_json_stats(iter_times)) # Output epoch timing epoch_times = { "_type": "epoch_time", "test_fw_time": test_fw_time * len(test_loader), "train_fw_time": train_fw_time * len(train_loader), "train_bw_time": train_bw_time * len(train_loader), "train_fw_bw_time": train_fw_bw_time * len(train_loader), "train_loader_time": train_loader_time * len(train_loader), } logger.info(logging.dump_json_stats(epoch_times)) # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1) overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
def time_model(): """Times a model.""" # Setup training/testing environment setup_env() # Construct the model and loss_fun model = setup_model() loss_fun = builders.build_loss_fun().cuda() # Compute precise time logger.info("Computing precise time...") prec_time = net.compute_time_full(model, loss_fun) logger.info(logging.dump_json_stats(prec_time))
def time_model(): """Times a model.""" assert cfg.PREC_TIME.ENABLED, "PREC_TIME.ENABLED must be set." # Setup training/testing environment setup_env() # Construct the model and loss_fun model = setup_model() loss_fun = builders.build_loss_fun().cuda() # Compute precise time logger.info("Computing precise time...") prec_time = net.compute_precise_time(model, loss_fun) logger.info(logging.dump_json_stats(prec_time)) net.reset_bn_stats(model)
def test_model(): """Evaluates the model.""" # Setup logging logging.setup_logging() # Show the config logger.info("Config:\n{}".format(cfg)) # Fix the RNG seeds (see RNG comment in core/config.py for discussion) np.random.seed(cfg.RNG_SEED) torch.manual_seed(cfg.RNG_SEED) # Configure the CUDNN backend torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK # Build the model (before the loaders to speed up debugging) model = builders.build_model() logger.info("Model:\n{}".format(model)) logger.info(logging.dump_json_stats(net.complexity(model))) # Compute precise time if cfg.PREC_TIME.ENABLED: logger.info("Computing precise time...") loss_fun = builders.build_loss_fun() prec_time = net.compute_precise_time(model, loss_fun) logger.info(logging.dump_json_stats(prec_time)) net.reset_bn_stats(model) # Load model weights checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model) logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) # Create data loaders test_loader = loader.construct_test_loader() # Create meters test_meter = meters.TestMeter(len(test_loader)) # Evaluate the model test_epoch(test_loader, model, test_meter, 0)
def train_model(): """Trains the model.""" # Setup training/testing environment setup_env() # Construct the model, loss_fun, and optimizer model = setup_model() loss_fun = builders.build_loss_fun().cuda() optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Compute precise time if start_epoch == 0 and cfg.PREC_TIME.ENABLED: logger.info("Computing precise time...") prec_time = net.compute_precise_time(model, loss_fun) logger.info(logging.dump_json_stats(prec_time)) net.reset_bn_stats(model) # Create data loaders and meters train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: test_epoch(test_loader, model, test_meter, cur_epoch)
def setup_model(): """Sets up a model for training or testing and log the results.""" # Build the model model = builders.build_model() logger.info("Model:\n{}".format(model)) # Log model complexity logger.info(logging.dump_json_stats(net.complexity(model))) # Transfer the model to the current GPU device err_str = "Cannot use more GPU devices than available" assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str cur_device = torch.cuda.current_device() model = model.cuda(device=cur_device) # Use multi-process data parallel model in the multi-gpu setting if cfg.NUM_GPUS > 1: # Make model replica operate on the current device model = torch.nn.parallel.DistributedDataParallel( module=model, device_ids=[cur_device], output_device=cur_device) # Set complexity function to be module's complexity function model.complexity = model.module.complexity return model
def train_model(): """Trains the model.""" # Setup logging logging.setup_logging() # Show the config logger.info("Config:\n{}".format(cfg)) # Fix the RNG seeds (see RNG comment in core/config.py for discussion) np.random.seed(cfg.RNG_SEED) torch.manual_seed(cfg.RNG_SEED) # Configure the CUDNN backend torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK # Build the model (before the loaders to speed up debugging) model = builders.build_model() logger.info("Model:\n{}".format(model)) logger.info(logging.dump_json_stats(net.complexity(model))) # Define the loss function loss_fun = builders.build_loss_fun() # Construct the optimizer optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Compute precise time if start_epoch == 0 and cfg.PREC_TIME.ENABLED: logger.info("Computing precise time...") prec_time = net.compute_precise_time(model, loss_fun) logger.info(logging.dump_json_stats(prec_time)) net.reset_bn_stats(model) # Create data loaders train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() # Create meters train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if checkpoint.is_checkpoint_epoch(cur_epoch): checkpoint_file = checkpoint.save_checkpoint( model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model if is_eval_epoch(cur_epoch): test_epoch(test_loader, model, test_meter, cur_epoch)
def log_epoch_stats(self, cur_epoch): stats = self.get_epoch_stats(cur_epoch) logger.info(logging.dump_json_stats(stats))
def log_iter_stats(self, cur_epoch, cur_iter): if (cur_iter + 1) % cfg.LOG_PERIOD != 0: return stats = self.get_iter_stats(cur_epoch, cur_iter) logger.info(logging.dump_json_stats(stats))