Exemple #1
0
def test(args):
    """Run model testing."""
    test_args = args.test_args
    logger_args = args.logger_args

    # Load the model at ckpt_path.
    ckpt_path = test_args.ckpt_path
    ckpt_save_dir = Path(ckpt_path).parent

    # Get model args from checkpoint and add them to
    # command-line specified model args.
    model_args, data_args, optim_args, logger_args\
        = ModelSaver.get_args(cl_logger_args=logger_args,
                              ckpt_save_dir=ckpt_save_dir)

    model, ckpt_info = ModelSaver.load_model(ckpt_path=ckpt_path,
                                             gpu_ids=args.gpu_ids,
                                             model_args=model_args,
                                             is_training=False)

    # Get logger.
    logger = Logger(logger_args=logger_args,
                    data_args=data_args,
                    optim_args=optim_args,
                    test_args=test_args)

    # Instantiate the Predictor class for obtaining model predictions.
    predictor = Predictor(model=model, device=args.device)

    phase = test_args.phase
    is_test = False
    if phase == 'test':
        is_test = True
        phase = 'valid'  # Run valid first to get threshold

    print(f"======================{phase}=======================")
    # Get phase loader object.
    loader = get_loader(phase=phase,
                        data_args=data_args,
                        is_training=False,
                        logger=logger)
    # Obtain model predictions.
    predictions, groundtruth = predictor.predict(loader)

    # Instantiate the evaluator class for evaluating models.
    evaluator = Evaluator(logger=logger, tune_threshold=True)

    # Get model metrics and curves on the phase dataset.
    metrics = evaluator.evaluate(groundtruth, predictions)

    # Log metrics to stdout and file.
    logger.log_stdout(f"Writing metrics to {logger.metrics_path}.")
    logger.log_metrics(metrics, phase=phase)

    # Evaluate dense to get back thresholds
    dense_loader = get_loader(phase=phase,
                              data_args=data_args,
                              is_training=False,
                              logger=logger)
    dense_predictions, dense_groundtruth = predictor.predict(dense_loader)
    dense_metrics = evaluator.dense_evaluate(dense_groundtruth,
                                             dense_predictions)

    # Log metrics to stdout and file.
    logger.log_stdout(f"Writing metrics to {logger.metrics_path}.")
    logger.log_metrics(dense_metrics, phase=phase)

    if is_test:
        phase = 'test'
        threshold = metrics['threshold']
        print(f"======================{phase}=======================")

        # Get phase loader object.
        loader = get_loader(phase=phase,
                            data_args=data_args,
                            is_training=False,
                            test_args=test_args,
                            logger=logger)
        # Obtain model predictions.
        predictions, groundtruth = predictor.predict(loader)

        # Instantiate the evaluator class for evaluating models.
        evaluator = Evaluator(logger=logger,
                              threshold=threshold,
                              tune_threshold=False)

        # Get model metrics and curves on the phase dataset.
        metrics = evaluator.evaluate(groundtruth, predictions)
        # Log metrics to stdout and file.
        logger.log_stdout(f"Writing metrics to {logger.metrics_path}.")
        logger.log_metrics(metrics, phase=phase)

        # Dense test
        phase = 'dense_test'
        dense_loader = get_loader(phase=phase,
                                  data_args=data_args,
                                  is_training=False,
                                  test_args=test_args,
                                  logger=logger)
        threshold_dense = dense_metrics["threshold_dense"]
        threshold_tunef1_dense = dense_metrics["threshold_tunef1_dense"]
        dense_predictions, dense_groundtruth = predictor.predict(dense_loader)
        dense_metrics = evaluator.dense_evaluate(
            dense_groundtruth,
            dense_predictions,
            threshold=threshold_dense,
            threshold_tunef1=threshold_tunef1_dense)
        logger.log_stdout(f"Writing metrics to {logger.metrics_path}.")
        logger.log_metrics(dense_metrics, phase=phase)
Exemple #2
0
def train(args):
    """Run model training."""

    # Get nested namespaces.
    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args

    # Get logger.
    logger = Logger(logger_args)

    if model_args.ckpt_path:
        # CL-specified args are used to load the model, rather than the
        # ones saved to args.json.
        model_args.pretrained = False
        ckpt_path = model_args.ckpt_path
        assert False
        model, ckpt_info = ModelSaver.load_model(ckpt_path=ckpt_path,
                                                 gpu_ids=args.gpu_ids,
                                                 model_args=model_args,
                                                 is_training=True)
        optim_args.start_epoch = ckpt_info['epoch'] + 1
    else:
        # If no ckpt_path is provided, instantiate a new randomly
        # initialized model.
        model_fn = models.__dict__[model_args.model]
        model = model_fn(model_args)
        model = nn.DataParallel(model, args.gpu_ids)
    # Put model on gpu or cpu and put into training mode.
    model = model.to(args.device)
    model.train()

    # Get train and valid loader objects.
    train_loader = get_loader(phase="train",
                              data_args=data_args,
                              is_training=True,
                              logger=logger)
    valid_loader = get_loader(phase="valid",
                              data_args=data_args,
                              is_training=False,
                              logger=logger)
    dense_valid_loader = get_loader(phase="dense_valid",
                                    data_args=data_args,
                                    is_training=False,
                                    logger=logger)

    # Instantiate the predictor class for obtaining model predictions.
    predictor = Predictor(model, args.device)

    # Instantiate the evaluator class for evaluating models.
    # By default, get best performance on validation set.
    evaluator = Evaluator(logger=logger, tune_threshold=True)

    # Instantiate the saver class for saving model checkpoints.
    saver = ModelSaver(save_dir=logger_args.save_dir,
                       iters_per_save=logger_args.iters_per_save,
                       max_ckpts=logger_args.max_ckpts,
                       metric_name=optim_args.metric_name,
                       maximize_metric=optim_args.maximize_metric,
                       keep_topk=True,
                       logger=logger)

    # Instantiate the optimizer class for guiding model training.
    optimizer = Optimizer(parameters=model.parameters(),
                          optim_args=optim_args,
                          batch_size=data_args.batch_size,
                          iters_per_print=logger_args.iters_per_print,
                          iters_per_visual=logger_args.iters_per_visual,
                          iters_per_eval=logger_args.iters_per_eval,
                          dataset_len=len(train_loader.dataset),
                          logger=logger)
    if model_args.ckpt_path:
        # Load the same optimizer as used in the original training.
        optimizer.load_optimizer(ckpt_path=model_args.ckpt_path,
                                 gpu_ids=args.gpu_ids)

    loss_fn = evaluator.get_loss_fn(loss_fn_name=optim_args.loss_fn)

    # Run training
    while not optimizer.is_finished_training():
        optimizer.start_epoch()

        for inputs, targets in train_loader:
            optimizer.start_iter()

            if optimizer.global_step % optimizer.iters_per_eval == 0:
                # Only evaluate every iters_per_eval examples.
                predictions, groundtruth = predictor.predict(valid_loader)
                metrics = evaluator.evaluate(groundtruth, predictions)

                # Evaluate on dense dataset
                dense_predictions, dense_groundtruth = predictor.predict(
                    dense_valid_loader)
                dense_metrics = evaluator.dense_evaluate(
                    dense_groundtruth, dense_predictions)
                # Merge the metrics dicts together
                metrics = {**metrics, **dense_metrics}

                # Log metrics to stdout.
                logger.log_metrics(metrics, phase='valid')

                # Log to tb
                logger.log_scalars(metrics,
                                   optimizer.global_step,
                                   phase='valid')

                if optimizer.global_step % logger_args.iters_per_save == 0:
                    # Only save every iters_per_save examples directly
                    # after evaluation.
                    saver.save(iteration=optimizer.global_step,
                               epoch=optimizer.epoch,
                               model=model,
                               optimizer=optimizer,
                               device=args.device,
                               metric_val=metrics[optim_args.metric_name])

                # Step learning rate scheduler.
                optimizer.step_scheduler(metrics[optim_args.metric_name])

            with torch.set_grad_enabled(True):

                # Run the minibatch through the model.
                logits = model(inputs.to(args.device))

                # Compute the minibatch loss.
                loss = loss_fn(logits, targets.to(args.device))

                # Log the data from this iteration.
                optimizer.log_iter(inputs, logits, targets, loss)

                # Perform a backward pass.
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            optimizer.end_iter()

        optimizer.end_epoch(metrics)

    # Save the most recent model.
    saver.save(iteration=optimizer.global_step,
               epoch=optimizer.epoch,
               model=model,
               optimizer=optimizer,
               device=args.device,
               metric_val=metrics[optim_args.metric_name])