Ejemplo n.º 1
0
def trace_model_from_checkpoint(logdir, method_name):
    config_path = logdir / "configs/_config.json"
    checkpoint_path = logdir / "checkpoints/best.pth"
    print("Load config")
    config: Dict[str, dict] = safitty.load(config_path)

    # Get expdir name
    config_expdir = Path(config["args"]["expdir"])
    # We will use copy of expdir from logs for reproducibility
    expdir_from_logs = Path(logdir) / "code" / config_expdir.name

    print("Import experiment and runner from logdir")
    ExperimentType, RunnerType = \
        import_experiment_and_runner(expdir_from_logs)
    experiment: Experiment = ExperimentType(config)

    print("Load model state from checkpoints/best.pth")
    model = experiment.get_model(next(iter(experiment.stages)))
    checkpoint = UtilsFactory.load_checkpoint(checkpoint_path)
    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    print("Tracing")
    traced = trace_model(model, experiment, RunnerType, method_name)

    print("Done")
    return traced
Ejemplo n.º 2
0
 def _preprocess_model_for_stage(self, stage: str, model: _Model):
     stage_index = self.stages.index(stage)
     if stage_index > 0:
         checkpoint_path = \
             f"{self.logdir}/checkpoints/best.pth"
         checkpoint = UtilsFactory.load_checkpoint(checkpoint_path)
         UtilsFactory.unpack_checkpoint(checkpoint, model=model)
     return model
Ejemplo n.º 3
0
    def load_checkpoint(*, filename, state):
        if os.path.isfile(filename):
            print("=> loading checkpoint \"{}\"".format(filename))
            checkpoint = UtilsFactory.load_checkpoint(filename)

            state.epoch = checkpoint["epoch"]

            UtilsFactory.unpack_checkpoint(checkpoint,
                                           model=state.model,
                                           criterion=state.criterion,
                                           optimizer=state.optimizer,
                                           scheduler=state.scheduler)

            print("loaded checkpoint \"{}\" (epoch {})".format(
                filename, checkpoint["epoch"]))
        else:
            raise Exception("no checkpoint found at \"{}\"".format(filename))
Ejemplo n.º 4
0
    def _get_optimizer(self, *, model_params, **params):
        key_value_flag = params.pop("_key_value", False)

        if key_value_flag:
            optimizer = {}
            for key, params_ in params.items():
                optimizer[key] = self._get_optimizer(model_params=model_params,
                                                     **params_)
        else:
            load_from_previous_stage = \
                params.pop("load_from_previous_stage", False)
            optimizer = OPTIMIZERS.get_from_params(**params,
                                                   params=model_params)

            if load_from_previous_stage:
                checkpoint_path = \
                    f"{self.logdir}/checkpoints/best.pth"
                checkpoint = UtilsFactory.load_checkpoint(checkpoint_path)
                UtilsFactory.unpack_checkpoint(checkpoint, optimizer=optimizer)
                for key, value in params.items():
                    for pg in optimizer.param_groups:
                        pg[key] = value

        return optimizer
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model', type=str, default='unet', help='')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default=None,
                        required=True,
                        help='Data dir')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        required=True,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=16,
                        help='Batch size for inference')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(os.path.dirname(checkpoint_file))
    out_dir = os.path.join(run_dir, 'evaluation')
    os.makedirs(out_dir, exist_ok=True)

    model = get_model(args.model)

    checkpoint = UtilsFactory.load_checkpoint(checkpoint_file)
    checkpoint_epoch = checkpoint['epoch']
    print('Loaded model weights from', args.checkpoint)
    print('Epoch   :', checkpoint_epoch)
    print('Metrics (Train):', 'IoU:',
          checkpoint['epoch_metrics']['train']['jaccard'], 'Acc:',
          checkpoint['epoch_metrics']['train']['accuracy'])
    print('Metrics (Valid):', 'IoU:',
          checkpoint['epoch_metrics']['valid']['jaccard'], 'Acc:',
          checkpoint['epoch_metrics']['valid']['accuracy'])

    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    model = model.cuda().eval()

    train_images = find_in_dir(os.path.join(data_dir, 'train', 'images'))
    for fname in tqdm(train_images, total=len(train_images)):
        image = read_rgb_image(fname)
        mask = predict(model,
                       image,
                       tta=args.tta,
                       image_size=(512, 512),
                       batch_size=args.batch_size,
                       activation='sigmoid')
        mask = (mask * 255).astype(np.uint8)
        name = os.path.join(out_dir, os.path.basename(fname))
        cv2.imwrite(name, mask)
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str, default="unet", help="")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=None,
                        required=True,
                        help="Data dir")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        required=True,
        help="Checkpoint filename to use as initial model weights",
    )
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=16,
                        help="Batch size for inference")
    parser.add_argument("-tta",
                        "--tta",
                        default=None,
                        type=str,
                        help="Type of TTA to use [fliplr, d4]")
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(os.path.dirname(checkpoint_file))
    out_dir = os.path.join(run_dir, "evaluation")
    os.makedirs(out_dir, exist_ok=True)

    model = get_model(args.model)

    checkpoint = UtilsFactory.load_checkpoint(checkpoint_file)
    checkpoint_epoch = checkpoint["epoch"]
    print("Loaded model weights from", args.checkpoint)
    print("Epoch   :", checkpoint_epoch)
    print(
        "Metrics (Train):",
        "IoU:",
        checkpoint["epoch_metrics"]["train"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["train"]["accuracy"],
    )
    print(
        "Metrics (Valid):",
        "IoU:",
        checkpoint["epoch_metrics"]["valid"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["valid"]["accuracy"],
    )

    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    model = model.cuda().eval()

    train_images = find_in_dir(os.path.join(data_dir, "train", "images"))
    for fname in tqdm(train_images, total=len(train_images)):
        image = read_inria_rgb(fname)
        mask = predict(model,
                       image,
                       tta=args.tta,
                       image_size=(512, 512),
                       batch_size=args.batch_size,
                       activation="sigmoid")
        mask = (mask * 255).astype(np.uint8)
        name = os.path.join(out_dir, os.path.basename(fname))
        cv2.imwrite(name, mask)