Ejemplo n.º 1
0
def trace_model_from_checkpoint(logdir, method_name):
    config_path = logdir / "configs/_config.json"
    checkpoint_path = logdir / "checkpoints/best.pth"
    print("Load config")
    config: Dict[str, dict] = safitty.load(config_path)

    # Get expdir name
    config_expdir = Path(config["args"]["expdir"])
    # We will use copy of expdir from logs for reproducibility
    expdir_from_logs = Path(logdir) / "code" / config_expdir.name

    print("Import experiment and runner from logdir")
    ExperimentType, RunnerType = \
        import_experiment_and_runner(expdir_from_logs)
    experiment: Experiment = ExperimentType(config)

    print("Load model state from checkpoints/best.pth")
    model = experiment.get_model(next(iter(experiment.stages)))
    checkpoint = UtilsFactory.load_checkpoint(checkpoint_path)
    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    print("Tracing")
    traced = trace_model(model, experiment, RunnerType, method_name)

    print("Done")
    return traced
Ejemplo n.º 2
0
def report_by_dir(folder):
    checkpoint = f"{folder}/best.pth"
    checkpoint = UtilsFactory.load_checkpoint(checkpoint)
    exp_name = folder.rsplit("/", 1)[-1]
    row = {"exp_name": exp_name, "epoch": checkpoint["epoch"]}
    row.update(checkpoint["valid_metrics"])
    return row
Ejemplo n.º 3
0
 def _preprocess_model_for_stage(self, stage: str, model: _Model):
     stage_index = self.stages.index(stage)
     if stage_index > 0:
         checkpoint_path = \
             f"{self.logdir}/checkpoints/best.pth"
         checkpoint = UtilsFactory.load_checkpoint(checkpoint_path)
         UtilsFactory.unpack_checkpoint(checkpoint, model=model)
     return model
Ejemplo n.º 4
0
 def load_actor_weights(self):
     if self.resume is not None:
         checkpoint = UtilsFactory.load_checkpoint(self.resume)
         weights = checkpoint[f"actor_state_dict"]
         self.actor.load_state_dict(weights)
     elif self.redis_server is not None:
         weights = deserialize(
             self.redis_server.get(f"{self.redis_prefix}_actor_weights"))
         weights = {k: self.to_tensor(v) for k, v in weights.items()}
         self.actor.load_state_dict(weights)
     else:
         raise NotImplementedError
     self.actor.eval()
Ejemplo n.º 5
0
    def load_checkpoint(self, filepath, load_optimizer=True):
        checkpoint = UtilsFactory.load_checkpoint(filepath)
        for key in ["actor", "critic"]:
            value_l = getattr(self, key, None)
            if value_l is not None:
                value_r = checkpoint[f"{key}_state_dict"]
                value_l.load_state_dict(value_r)

            if load_optimizer:
                for key2 in ["optimizer", "scheduler"]:
                    key2 = f"{key}_{key2}"
                    value_l = getattr(self, key2, None)
                    if value_l is not None:
                        value_r = checkpoint[f"{key2}_state_dict"]
                        value_l.load_state_dict(value_r)
Ejemplo n.º 6
0
    def load_checkpoint(*, filename, state):
        if os.path.isfile(filename):
            print("=> loading checkpoint \"{}\"".format(filename))
            checkpoint = UtilsFactory.load_checkpoint(filename)

            state.epoch = checkpoint["epoch"]

            UtilsFactory.unpack_checkpoint(checkpoint,
                                           model=state.model,
                                           criterion=state.criterion,
                                           optimizer=state.optimizer,
                                           scheduler=state.scheduler)

            print("loaded checkpoint \"{}\" (epoch {})".format(
                filename, checkpoint["epoch"]))
        else:
            raise Exception("no checkpoint found at \"{}\"".format(filename))
Ejemplo n.º 7
0
    def load_checkpoint(self,
                        *,
                        filepath: str = None,
                        db_server: DBSpec = None):
        if filepath is not None:
            checkpoint = UtilsFactory.load_checkpoint(filepath)
            weights = checkpoint[f"{self._sampler_weight_mode}_state_dict"]
            self.agent.load_state_dict(weights)
        elif db_server is not None:
            weights = db_server.load_weights(prefix=self._sampler_weight_mode)
            weights = {k: self._to_tensor(v) for k, v in weights.items()}
            self.agent.load_state_dict(weights)
        else:
            raise NotImplementedError

        self.agent.to(self._device)
        self.agent.eval()
Ejemplo n.º 8
0
    def load_checkpoint(self, filepath, load_optimizer=True):
        super().load_checkpoint(filepath, load_optimizer)

        checkpoint = UtilsFactory.load_checkpoint(filepath)
        key = "critics"
        for i in range(len(self.critics)):
            value_l = getattr(self, key, None)
            value_l = value_l[i] if value_l is not None else None
            if value_l is not None:
                value_r = checkpoint[f"{key}{i}_state_dict"]
                value_l.load_state_dict(value_r)
            if load_optimizer:
                for key2 in ["optimizer", "scheduler"]:
                    key2 = f"{key}_{key2}"
                    value_l = getattr(self, key2, None)
                    if value_l is not None:
                        value_r = checkpoint[f"{key2}_state_dict"]
                        value_l.load_state_dict(value_r)
Ejemplo n.º 9
0
    def _get_optimizer(self, *, model_params, **params):
        key_value_flag = params.pop("_key_value", False)

        if key_value_flag:
            optimizer = {}
            for key, params_ in params.items():
                optimizer[key] = self._get_optimizer(model_params=model_params,
                                                     **params_)
        else:
            load_from_previous_stage = \
                params.pop("load_from_previous_stage", False)
            optimizer = OPTIMIZERS.get_from_params(**params,
                                                   params=model_params)

            if load_from_previous_stage:
                checkpoint_path = \
                    f"{self.logdir}/checkpoints/best.pth"
                checkpoint = UtilsFactory.load_checkpoint(checkpoint_path)
                UtilsFactory.unpack_checkpoint(checkpoint, optimizer=optimizer)
                for key, value in params.items():
                    for pg in optimizer.param_groups:
                        pg[key] = value

        return optimizer
Ejemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model', type=str, default='unet', help='')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default=None,
                        required=True,
                        help='Data dir')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        required=True,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=16,
                        help='Batch size for inference')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(os.path.dirname(checkpoint_file))
    out_dir = os.path.join(run_dir, 'evaluation')
    os.makedirs(out_dir, exist_ok=True)

    model = get_model(args.model)

    checkpoint = UtilsFactory.load_checkpoint(checkpoint_file)
    checkpoint_epoch = checkpoint['epoch']
    print('Loaded model weights from', args.checkpoint)
    print('Epoch   :', checkpoint_epoch)
    print('Metrics (Train):', 'IoU:',
          checkpoint['epoch_metrics']['train']['jaccard'], 'Acc:',
          checkpoint['epoch_metrics']['train']['accuracy'])
    print('Metrics (Valid):', 'IoU:',
          checkpoint['epoch_metrics']['valid']['jaccard'], 'Acc:',
          checkpoint['epoch_metrics']['valid']['accuracy'])

    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    model = model.cuda().eval()

    train_images = find_in_dir(os.path.join(data_dir, 'train', 'images'))
    for fname in tqdm(train_images, total=len(train_images)):
        image = read_rgb_image(fname)
        mask = predict(model,
                       image,
                       tta=args.tta,
                       image_size=(512, 512),
                       batch_size=args.batch_size,
                       activation='sigmoid')
        mask = (mask * 255).astype(np.uint8)
        name = os.path.join(out_dir, os.path.basename(fname))
        cv2.imwrite(name, mask)
Ejemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str, default="unet", help="")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=None,
                        required=True,
                        help="Data dir")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        required=True,
        help="Checkpoint filename to use as initial model weights",
    )
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=16,
                        help="Batch size for inference")
    parser.add_argument("-tta",
                        "--tta",
                        default=None,
                        type=str,
                        help="Type of TTA to use [fliplr, d4]")
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(os.path.dirname(checkpoint_file))
    out_dir = os.path.join(run_dir, "evaluation")
    os.makedirs(out_dir, exist_ok=True)

    model = get_model(args.model)

    checkpoint = UtilsFactory.load_checkpoint(checkpoint_file)
    checkpoint_epoch = checkpoint["epoch"]
    print("Loaded model weights from", args.checkpoint)
    print("Epoch   :", checkpoint_epoch)
    print(
        "Metrics (Train):",
        "IoU:",
        checkpoint["epoch_metrics"]["train"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["train"]["accuracy"],
    )
    print(
        "Metrics (Valid):",
        "IoU:",
        checkpoint["epoch_metrics"]["valid"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["valid"]["accuracy"],
    )

    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    model = model.cuda().eval()

    train_images = find_in_dir(os.path.join(data_dir, "train", "images"))
    for fname in tqdm(train_images, total=len(train_images)):
        image = read_inria_rgb(fname)
        mask = predict(model,
                       image,
                       tta=args.tta,
                       image_size=(512, 512),
                       batch_size=args.batch_size,
                       activation="sigmoid")
        mask = (mask * 255).astype(np.uint8)
        name = os.path.join(out_dir, os.path.basename(fname))
        cv2.imwrite(name, mask)