コード例 #1
0
ファイル: checkpoint.py プロジェクト: kdubovikov/catalyst
    def load_checkpoint(*, filename, state: _State):
        if os.path.isfile(filename):
            print(f"=> loading checkpoint {filename}")
            checkpoint = utils.load_checkpoint(filename)

            if not state.stage.startswith("infer"):
                state.epoch = checkpoint["epoch"]
                state.stage_epoch = checkpoint["stage_epoch"]
                state.stage = checkpoint["stage"]

            utils.unpack_checkpoint(
                checkpoint,
                model=state.model,
                criterion=state.criterion,
                optimizer=state.optimizer,
                scheduler=state.scheduler
            )

            print(
                f"loaded checkpoint {filename} "
                f"(epoch {checkpoint['epoch']}, "
                f"stage_epoch {checkpoint['stage_epoch']}, "
                f"stage {checkpoint['stage']})"
            )
        else:
            raise Exception(f"No checkpoint found at {filename}")
コード例 #2
0
def trace_model_from_checkpoint(logdir, logger, method_name='forward'):
    config_path = f'{logdir}/configs/_config.json'
    checkpoint_path = f'{logdir}/checkpoints/best.pth'
    logger.info('Load config')
    config = safitty.load(config_path)
    if 'distributed_params' in config:
        del config['distributed_params']

    # Get expdir name
    # noinspection SpellCheckingInspection,PyTypeChecker
    # We will use copy of expdir from logs for reproducibility
    expdir_name = os.path.basename(config['args']['expdir'])
    expdir_from_logs = os.path.abspath(join(logdir, '../', expdir_name))

    logger.info('Import experiment and runner from logdir')
    ExperimentType, RunnerType = \
        import_experiment_and_runner(Path(expdir_from_logs))
    experiment: Experiment = ExperimentType(config)

    logger.info('Load model state from checkpoints/best.pth')
    model = experiment.get_model(next(iter(experiment.stages)))
    checkpoint = utils.load_checkpoint(checkpoint_path)
    utils.unpack_checkpoint(checkpoint, model=model)

    logger.info('Tracing')
    traced = trace_model(model, experiment, RunnerType, method_name)

    logger.info('Done')
    return traced
コード例 #3
0
ファイル: sampler.py プロジェクト: olgaiv39/catalyst
    def load_checkpoint(self,
                        *,
                        filepath: str = None,
                        db_server: DBSpec = None) -> bool:
        if filepath is not None:
            checkpoint = utils.load_checkpoint(filepath)
        elif db_server is not None:
            current_epoch = db_server.epoch
            checkpoint = db_server.get_checkpoint()
            if not db_server.training_enabled \
                    and db_server.epoch == current_epoch:
                return False

            while checkpoint is None or db_server.epoch <= current_epoch:
                time.sleep(3.0)
                checkpoint = db_server.get_checkpoint()

                if not db_server.training_enabled \
                        and db_server.epoch == current_epoch:
                    return False
        else:
            return False

        self.checkpoint = checkpoint
        weights = self.checkpoint[f"{self._weights_sync_mode}_state_dict"]
        weights = {
            k: utils.any2device(v, device=self._device)
            for k, v in weights.items()
        }
        self.agent.load_state_dict(weights)
        self.agent.to(self._device)
        self.agent.eval()

        return True
コード例 #4
0
def _load_checkpoint(*, filename, state: State):
    if os.path.isfile(filename):
        print(f"=> loading checkpoint {filename}")
        checkpoint = utils.load_checkpoint(filename)

        if not state.stage_name.startswith("infer"):
            state.stage_name = checkpoint["stage_name"]
            state.epoch = checkpoint["epoch"]
            state.global_epoch = checkpoint["global_epoch"]
            # @TODO: should we also load,
            # checkpoint_data, main_metric, minimize_metric, valid_loader ?
            # epoch_metrics, valid_metrics ?

        utils.unpack_checkpoint(
            checkpoint,
            model=state.model,
            criterion=state.criterion,
            optimizer=state.optimizer,
            scheduler=state.scheduler,
        )

        print(
            f"loaded checkpoint {filename} "
            f"(global epoch {checkpoint['global_epoch']}, "
            f"epoch {checkpoint['epoch']}, "
            f"stage {checkpoint['stage_name']})"
        )
    else:
        raise Exception(f"No checkpoint found at {filename}")
コード例 #5
0
def trace_model_from_checkpoint(logdir,
                                logger,
                                method_name='forward',
                                file='best'):
    config_path = f'{logdir}/configs/_config.json'
    checkpoint_path = f'{logdir}/checkpoints/{file}.pth'
    logger.info('Load config')
    config = safitty.load(config_path)
    if 'distributed_params' in config:
        del config['distributed_params']

    # Get expdir name
    # noinspection SpellCheckingInspection,PyTypeChecker
    # We will use copy of expdir from logs for reproducibility
    expdir_name = config['args']['expdir']
    logger.info(f'expdir_name from args: {expdir_name}')

    sys.path.insert(0, os.path.abspath(join(logdir, '../')))

    expdir_from_logs = os.path.abspath(join(logdir, '../', expdir_name))

    logger.info(f'expdir_from_logs: {expdir_from_logs}')
    logger.info('Import experiment and runner from logdir')

    ExperimentType, RunnerType = \
        import_experiment_and_runner(Path(expdir_from_logs))
    experiment: Experiment = ExperimentType(config)

    logger.info(f'Load model state from checkpoints/{file}.pth')
    model = experiment.get_model(next(iter(experiment.stages)))
    checkpoint = utils.load_checkpoint(checkpoint_path)
    utils.unpack_checkpoint(checkpoint, model=model)

    device = 'cpu'
    stage = list(experiment.stages)[0]
    loader = 0
    mode = 'eval'
    requires_grad = False
    opt_level = None

    runner: RunnerType = RunnerType()
    runner.model, runner.device = model, device

    batch = experiment.get_native_batch(stage, loader)

    logger.info('Tracing')
    traced = trace_model(
        model,
        runner,
        batch,
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
        opt_level=opt_level,
        device=device,
    )

    logger.info('Done')
    return traced
コード例 #6
0
ファイル: quantization.py プロジェクト: skaruk/catalyst
def quantize_model_from_checkpoint(
    logdir: Path,
    checkpoint_name: str,
    stage: str = None,
    qconfig_spec: Optional[Union[Set, Dict]] = None,
    dtype: Optional[torch.dtype] = torch.qint8,
    backend: str = None,
) -> Model:
    """
    Quantize model using created experiment and runner.

    Args:
        logdir (Union[str, Path]): Path to Catalyst logdir with model
        checkpoint_name (str): Name of model checkpoint to use
        stage (str): experiment's stage name
        qconfig_spec: torch.quantization.quantize_dynamic
                parameter, you can define layers to be quantize
        dtype: type of the model parameters, default int8
        backend: defines backend for quantization

    Returns:
        Quantized model
    """
    if backend is not None:
        torch.backends.quantized.engine = backend

    config_path = logdir / "configs" / "_config.json"
    checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth"
    logging.info("Load config")
    config: Dict[str, dict] = load_config(config_path)

    # Get expdir name
    config_expdir = Path(config["args"]["expdir"])
    # We will use copy of expdir from logs for reproducibility
    expdir = Path(logdir) / "code" / config_expdir.name

    logger.info("Import experiment and runner from logdir")
    experiment: ConfigExperiment = None
    experiment, _, _ = prepare_config_api_components(expdir=expdir,
                                                     config=config)

    logger.info(f"Load model state from checkpoints/{checkpoint_name}.pth")
    if stage is None:
        stage = list(experiment.stages)[0]

    model = experiment.get_model(stage)
    checkpoint = load_checkpoint(checkpoint_path)
    unpack_checkpoint(checkpoint, model=model)

    logger.info("Quantization is running...")
    quantized_model = quantization.quantize_dynamic(
        model.cpu(),
        qconfig_spec=qconfig_spec,
        dtype=dtype,
    )

    logger.info("Done")
    return quantized_model
コード例 #7
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_environment(config, args.logdir, args.configs)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841
        if args.logdir is not None:
            dump_code(args.expdir, args.logdir)

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True
    else:
        # @TODO: add registry for algorithms, trainers, samplers
        raise NotImplementedError()

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch
    )

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    if args.resume is not None:
        checkpoint = utils.load_checkpoint(filepath=args.resume)
        checkpoint = utils.any2device(checkpoint, utils.get_device())
        algorithm.unpack_checkpoint(
            checkpoint=checkpoint,
            with_optimizer=False
        )

    monitoring_params = config.get("monitoring_params", None)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        logdir=args.logdir,
        monitoring_params=monitoring_params,
        **config["trainer"],
    )

    trainer.run()
コード例 #8
0
ファイル: runner.py プロジェクト: kiminh/dnn.cool
    def best(self) -> nn.Module:
        model = self.model
        checkpoint_path = self.project_dir / self.default_logdir / 'checkpoints' / 'best_full.pth'
        ckpt = load_checkpoint(checkpoint_path)
        unpack_checkpoint(ckpt, model)

        thresholds_path = self.project_dir / self.default_logdir / 'tuned_params.pkl'
        if not thresholds_path.exists():
            return model
        tuned_params = torch.load(thresholds_path)
        self.task_flow.get_decoder().load_tuned(tuned_params)
        return model
コード例 #9
0
 def _load(
     self,
     runner: "IRunner",
     resume_logpath: Any = None,
     resume_model: str = None,
     resume_runner: str = None,
 ):
     if resume_logpath is not None:
         runner.engine.wait_for_everyone()
         if self.mode == "model":
             try:
                 unwrapped_model = runner.engine.unwrap_model(runner.model)
                 unwrapped_model.load_state_dict(
                     load_checkpoint(resume_logpath))
             except BaseException:
                 checkpoint = load_checkpoint(resume_logpath)
                 unpack_checkpoint(checkpoint=checkpoint,
                                   model=runner.model)
         else:
             checkpoint = load_checkpoint(resume_logpath)
             unpack_checkpoint(checkpoint=checkpoint, model=runner.model)
     if resume_runner is not None:
         runner.engine.wait_for_everyone()
         checkpoint = load_checkpoint(resume_runner)
         unpack_checkpoint(
             checkpoint=checkpoint,
             model=runner.model,
             criterion=runner.criterion,
             optimizer=runner.optimizer,
             scheduler=runner.scheduler,
         )
         runner.epoch_step = checkpoint["epoch_step"]
         runner.batch_step = checkpoint["batch_step"]
         runner.sample_step = checkpoint["sample_step"]
     if resume_model is not None:
         runner.engine.wait_for_everyone()
         unwrapped_model = runner.engine.unwrap_model(runner.model)
         unwrapped_model.load_state_dict(load_checkpoint(resume_model))
コード例 #10
0
ファイル: inference.py プロジェクト: Ramstein/Retinopathy2
def model_fn(model_dir):
    model_path = path.join(model_dir,
                           checkpoint_fname)  # '/opt/ml/model/model.pth'

    # already available in this method torch.load(model_path, map_location=lambda storage, loc: storage)
    checkpoint = load_checkpoint(model_path)
    params = checkpoint['checkpoint_data']['cmd_args']

    model_name = 'seresnext50d_gap'

    if model_name is None:
        model_name = params['model']

    coarse_grading = params.get('coarse', False)

    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    CLASS_NAMES = get_class_names(coarse_grading=coarse_grading)
    num_classes = len(CLASS_NAMES)
    model = get_model(model_name, pretrained=False, num_classes=num_classes)
    unpack_checkpoint(checkpoint, model=model)
    report_checkpoint(checkpoint)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model = model.eval()

    if apply_softmax:
        model = nn.Sequential(model, ApplySoftmaxToLogits())

    if tta == 'flip' or tta == 'fliplr':
        model = FlipLRMultiheadTTA(model)

    if tta == 'flip4':
        model = Flip4MultiheadTTA(model)

    if tta == 'fliplr_ms':
        model = MultiscaleFlipLRMultiheadTTA(model)

    with torch.no_grad():
        if torch.cuda.is_available():
            model = model.cuda()
            if torch.cuda.device_count() > 1:
                model = nn.DataParallel(
                    model,
                    device_ids=[id for id in range(torch.cuda.device_count())])

    return model
コード例 #11
0
    def load_checkpoint(*, filename, state: _State):
        if os.path.isfile(filename):
            print(f"=> loading checkpoint {filename}")
            checkpoint = utils.load_checkpoint(filename)

            state.epoch = checkpoint["epoch"]

            utils.unpack_checkpoint(checkpoint,
                                    model=state.model,
                                    criterion=state.criterion,
                                    optimizer=state.optimizer,
                                    scheduler=state.scheduler)

            print(
                f"loaded checkpoint {filename} (epoch {checkpoint['epoch']})")
        else:
            raise Exception(f"No checkpoint found at {filename}")
コード例 #12
0
ファイル: sampler.py プロジェクト: zkid18/catalyst
    def _load_checkpoint(self, *, filepath: str = None, db_server: "IRLDatabase" = None):
        if filepath is not None:
            checkpoint = utils.load_checkpoint(filepath)
        elif db_server is not None:
            checkpoint = db_server.get_checkpoint()
            while checkpoint is None:
                time.sleep(3.0)
                checkpoint = db_server.get_checkpoint()
        else:
            raise NotImplementedError("No checkpoint found")

        # self.checkpoint = checkpoint
        # weights = utils.any2device(checkpoint[self._weights_key], device=self.device)
        # weights = {k: utils.any2device(v, device=self.device) for k, v in weights.items()}
        weights = checkpoint[self._weights_key]
        self.actor.load_state_dict(weights)
        self.actor.to(self.device)
        self.actor.eval()
コード例 #13
0
ファイル: sampler.py プロジェクト: lightforever/catalyst
    def load_checkpoint(
        self, *, filepath: str = None, db_server: DBSpec = None
    ):
        if filepath is not None:
            checkpoint = utils.load_checkpoint(filepath)
        elif db_server is not None:
            checkpoint = db_server.load_checkpoint()
            while checkpoint is None:
                time.sleep(3.0)
                checkpoint = db_server.load_checkpoint()
        else:
            raise NotImplementedError

        self.checkpoint = checkpoint
        weights = self.checkpoint[f"{self._weights_sync_mode}_state_dict"]
        weights = {
            k: utils.any2device(v, device=self._device)
            for k, v in weights.items()
        }
        self.agent.load_state_dict(weights)
        self.agent.to(self._device)
        self.agent.eval()
コード例 #14
0
def trace_model_from_checkpoint(
    logdir: Path,
    method_name: str,
    checkpoint_name: str,
    mode: str = "eval",
    requires_grad: bool = False,
):
    config_path = logdir / "configs" / "_config.json"
    checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth"
    print("Load config")
    config: Dict[str, dict] = safitty.load(config_path)

    # Get expdir name
    config_expdir = safitty.get(config, "args", "expdir", apply=Path)
    # We will use copy of expdir from logs for reproducibility
    expdir = Path(logdir) / "code" / config_expdir.name

    print("Import experiment and runner from logdir")
    ExperimentType, RunnerType = import_experiment_and_runner(expdir)
    experiment: Experiment = ExperimentType(config)

    print(f"Load model state from checkpoints/{checkpoint_name}.pth")
    model = experiment.get_model(next(iter(experiment.stages)))
    checkpoint = utils.load_checkpoint(checkpoint_path)
    utils.unpack_checkpoint(checkpoint, model=model)

    print("Tracing")
    traced = trace_model(
        model,
        experiment,
        RunnerType,
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
    )

    print("Done")
    return traced
コード例 #15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--balance', action='store_true')
    parser.add_argument('--balance-datasets', action='store_true')
    parser.add_argument('--swa', action='store_true')
    parser.add_argument('--show', action='store_true')
    parser.add_argument('--use-idrid', action='store_true')
    parser.add_argument('--use-messidor', action='store_true')
    parser.add_argument('--use-aptos2015', action='store_true')
    parser.add_argument('--use-aptos2019', action='store_true')
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--coarse', action='store_true')
    parser.add_argument('-acc',
                        '--accumulation-steps',
                        type=int,
                        default=1,
                        help='Number of batches to process')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='resnet18_gap',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-f',
                        '--fold',
                        action='append',
                        type=int,
                        default=None)
    parser.add_argument('-ft', '--fine-tune', default=0, type=int)
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('--criterion-reg',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-ord',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-cls',
                        type=str,
                        default=['ce'],
                        nargs='+',
                        help='Criterion')
    parser.add_argument('-l1',
                        type=float,
                        default=0,
                        help='L1 regularization loss')
    parser.add_argument('-l2',
                        type=float,
                        default=0,
                        help='L2 regularization loss')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument('-p',
                        '--preprocessing',
                        default=None,
                        help='Preprocessing method')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=multiprocessing.cpu_count(),
                        type=int,
                        help='Num workers')
    parser.add_argument('-a',
                        '--augmentations',
                        default='medium',
                        type=str,
                        help='')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    parser.add_argument('-t', '--transfer', default=None, type=str, help='')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('-s',
                        '--scheduler',
                        default='multistep',
                        type=str,
                        help='')
    parser.add_argument('--size',
                        default=512,
                        type=int,
                        help='Image size for training & inference')
    parser.add_argument('-wd',
                        '--weight-decay',
                        default=0,
                        type=float,
                        help='L2 weight decay')
    parser.add_argument('-wds',
                        '--weight-decay-step',
                        default=None,
                        type=float,
                        help='L2 weight decay step to add after each epoch')
    parser.add_argument('-d',
                        '--dropout',
                        default=0.0,
                        type=float,
                        help='Dropout before head layer')
    parser.add_argument(
        '--warmup',
        default=0,
        type=int,
        help=
        'Number of warmup epochs with 0.1 of the initial LR and frozed encoder'
    )
    parser.add_argument('-x',
                        '--experiment',
                        default=None,
                        type=str,
                        help='Dropout before head layer')

    args = parser.parse_args()

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    l1 = args.l1
    l2 = args.l2
    early_stopping = args.early_stopping
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (args.size, args.size)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    fine_tune = args.fine_tune
    criterion_reg_name = args.criterion_reg
    criterion_cls_name = args.criterion_cls
    criterion_ord_name = args.criterion_ord
    folds = args.fold
    mixup = args.mixup
    balance = args.balance
    balance_datasets = args.balance_datasets
    use_swa = args.swa
    show_batches = args.show
    scheduler_name = args.scheduler
    verbose = args.verbose
    weight_decay = args.weight_decay
    use_idrid = args.use_idrid
    use_messidor = args.use_messidor
    use_aptos2015 = args.use_aptos2015
    use_aptos2019 = args.use_aptos2019
    warmup = args.warmup
    dropout = args.dropout
    use_unsupervised = False
    experiment = args.experiment
    preprocessing = args.preprocessing
    weight_decay_step = args.weight_decay_step
    coarse_grading = args.coarse
    class_names = get_class_names(coarse_grading)

    assert use_aptos2015 or use_aptos2019 or use_idrid or use_messidor

    current_time = datetime.now().strftime('%b%d_%H_%M')
    random_name = get_random_name()

    if folds is None or len(folds) == 0:
        folds = [None]

    for fold in folds:
        torch.cuda.empty_cache()
        checkpoint_prefix = f'{model_name}_{args.size}_{augmentations}'

        if preprocessing is not None:
            checkpoint_prefix += f'_{preprocessing}'
        if use_aptos2019:
            checkpoint_prefix += '_aptos2019'
        if use_aptos2015:
            checkpoint_prefix += '_aptos2015'
        if use_messidor:
            checkpoint_prefix += '_messidor'
        if use_idrid:
            checkpoint_prefix += '_idrid'
        if coarse_grading:
            checkpoint_prefix += '_coarse'

        if fold is not None:
            checkpoint_prefix += f'_fold{fold}'

        checkpoint_prefix += f'_{random_name}'

        if experiment is not None:
            checkpoint_prefix = experiment

        directory_prefix = f'{current_time}/{checkpoint_prefix}'
        log_dir = os.path.join('runs', directory_prefix)
        os.makedirs(log_dir, exist_ok=False)

        config_fname = os.path.join(log_dir, f'{checkpoint_prefix}.json')
        with open(config_fname, 'w') as f:
            train_session_args = vars(args)
            f.write(json.dumps(train_session_args, indent=2))

        set_manual_seed(args.seed)
        num_classes = len(class_names)
        model = get_model(model_name, num_classes=num_classes,
                          dropout=dropout).cuda()

        if args.transfer:
            transfer_checkpoint = fs.auto_file(args.transfer)
            print("Transfering weights from model checkpoint",
                  transfer_checkpoint)
            checkpoint = load_checkpoint(transfer_checkpoint)
            pretrained_dict = checkpoint['model_state_dict']

            for name, value in pretrained_dict.items():
                try:
                    model.load_state_dict(collections.OrderedDict([(name,
                                                                    value)]),
                                          strict=False)
                except Exception as e:
                    print(e)

            report_checkpoint(checkpoint)

        if args.checkpoint:
            checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        train_ds, valid_ds, train_sizes = get_datasets(
            data_dir=data_dir,
            use_aptos2019=use_aptos2019,
            use_aptos2015=use_aptos2015,
            use_idrid=use_idrid,
            use_messidor=use_messidor,
            use_unsupervised=False,
            coarse_grading=coarse_grading,
            image_size=image_size,
            augmentation=augmentations,
            preprocessing=preprocessing,
            target_dtype=int,
            fold=fold,
            folds=4)

        train_loader, valid_loader = get_dataloaders(
            train_ds,
            valid_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            train_sizes=train_sizes,
            balance=balance,
            balance_datasets=balance_datasets,
            balance_unlabeled=False)

        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        print('Datasets         :', data_dir)
        print('  Train size     :', len(train_loader),
              len(train_loader.dataset))
        print('  Valid size     :', len(valid_loader),
              len(valid_loader.dataset))
        print('  Aptos 2019     :', use_aptos2019)
        print('  Aptos 2015     :', use_aptos2015)
        print('  IDRID          :', use_idrid)
        print('  Messidor       :', use_messidor)
        print('Train session    :', directory_prefix)
        print('  FP16 mode      :', fp16)
        print('  Fast mode      :', fast)
        print('  Mixup          :', mixup)
        print('  Balance cls.   :', balance)
        print('  Balance ds.    :', balance_datasets)
        print('  Warmup epoch   :', warmup)
        print('  Train epochs   :', num_epochs)
        print('  Fine-tune ephs :', fine_tune)
        print('  Workers        :', num_workers)
        print('  Fold           :', fold)
        print('  Log dir        :', log_dir)
        print('  Augmentations  :', augmentations)
        print('Model            :', model_name)
        print('  Parameters     :', count_parameters(model))
        print('  Image size     :', image_size)
        print('  Dropout        :', dropout)
        print('  Classes        :', class_names, num_classes)
        print('Optimizer        :', optimizer_name)
        print('  Learning rate  :', learning_rate)
        print('  Batch size     :', batch_size)
        print('  Criterion (cls):', criterion_cls_name)
        print('  Criterion (reg):', criterion_reg_name)
        print('  Criterion (ord):', criterion_ord_name)
        print('  Scheduler      :', scheduler_name)
        print('  Weight decay   :', weight_decay, weight_decay_step)
        print('  L1 reg.        :', l1)
        print('  L2 reg.        :', l2)
        print('  Early stopping :', early_stopping)

        # model training
        callbacks = []
        criterions = {}

        main_metric = 'cls/kappa'
        if criterion_reg_name is not None:
            cb, crits = get_reg_callbacks(criterion_reg_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_ord_name is not None:
            cb, crits = get_ord_callbacks(criterion_ord_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_cls_name is not None:
            cb, crits = get_cls_callbacks(criterion_cls_name,
                                          num_classes=num_classes,
                                          num_epochs=num_epochs,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if l1 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l1,
                                         end_wd=l1,
                                         schedule=None,
                                         prefix='l1',
                                         p=1)
            ]

        if l2 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l2,
                                         end_wd=l2,
                                         schedule=None,
                                         prefix='l2',
                                         p=2)
            ]

        callbacks += [CustomOptimizerCallback()]

        runner = SupervisedRunner(input_key='image')

        # Pretrain/warmup
        if warmup:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer('Adam',
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate * 0.1)

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=None,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'warmup'),
                         num_epochs=warmup,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer

        # Main train
        if num_epochs:
            set_trainable(model.encoder, True, False)

            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate,
                                      weight_decay=weight_decay)

            if use_swa:
                from torchcontrib.optim import SWA
                optimizer = SWA(optimizer,
                                swa_start=len(train_loader),
                                swa_freq=512)

            scheduler = get_scheduler(scheduler_name,
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=num_epochs,
                                      batches_in_epoch=len(train_loader))

            # Additional callbacks that specific to main stage only added here to copy of callbacks
            main_stage_callbacks = callbacks
            if early_stopping:
                es_callback = EarlyStoppingCallback(early_stopping,
                                                    min_delta=1e-4,
                                                    metric=main_metric,
                                                    minimize=False)
                main_stage_callbacks = callbacks + [es_callback]

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=main_stage_callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'main'),
                         num_epochs=num_epochs,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer, scheduler

            best_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)

            # Restoring best model from checkpoint
            checkpoint = load_checkpoint(best_checkpoint)
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        # Stage 3 - Fine tuning
        if fine_tune:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate)
            scheduler = get_scheduler('multistep',
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=fine_tune,
                                      batches_in_epoch=len(train_loader))

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'finetune'),
                         num_epochs=fine_tune,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            best_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)
コード例 #16
0
def main():
    parser = argparse.ArgumentParser()

    ###########################################################################################
    # Distributed-training related stuff
    parser.add_argument("--local_rank", type=int, default=0)
    ###########################################################################################

    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument(
        "-dd",
        "--data-dir",
        type=str,
        help="Data directory for INRIA sattelite dataset",
        default=os.environ.get("INRIA_DATA_DIR"),
    )
    parser.add_argument(
        "-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset"
    )
    parser.add_argument("-m", "--model", type=str, default="b6_unet32_s2", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")
    parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion")
    parser.add_argument(
        "-l2",
        "--criterion2",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 2 mask",
    )
    parser.add_argument(
        "-l4",
        "--criterion4",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 4 mask",
    )
    parser.add_argument(
        "-l8",
        "--criterion8",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 8 mask",
    )
    parser.add_argument(
        "-l16",
        "--criterion16",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 16 mask",
    )

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="hard", type=str, help="")
    parser.add_argument("-tm", "--train-mode", default="random", type=str, help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()

    args.is_master = args.local_rank == 0
    args.distributed = False
    fp16 = args.fp16

    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.world_size = int(os.environ["WORLD_SIZE"])
        # args.world_size = torch.distributed.get_world_size()

        print("Initializing init_process_group", args.local_rank)

        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        print("Initialized init_process_group", args.local_rank)

    is_master = args.is_master | (not args.distributed)

    if args.distributed:
        distributed_params = {"rank": args.local_rank, "syncbn": True}
        if args.fp16:
            distributed_params["amp"] = True
    else:
        if args.fp16:
            distributed_params = {}
            distributed_params["amp"] = True
        else:
            distributed_params = False

    set_manual_seed(args.seed + args.local_rank)
    catalyst.utils.set_global_seed(args.seed + args.local_rank)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    data_dir = args.data_dir
    if data_dir is None:
        raise ValueError("--data-dir must be set")

    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    criterions2 = args.criterion2
    criterions4 = args.criterion4
    criterions8 = args.criterion8
    criterions16 = args.criterion16

    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    custom_model_kwargs = {"full_size_mask": False}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if any([criterions2, criterions4, criterions8, criterions16]):
        custom_model_kwargs["need_supervision_masks"] = True
        print("Enabling supervision masks")

    model: nn.Module = get_model(model_name, num_classes=16, **custom_model_kwargs).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    main_metric = "jaccard"

    current_time = datetime.now().strftime("%y%m%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    default_callbacks = [
        JaccardMetricPerImage(
            input_key=INPUT_MASK_KEY,
            output_key=OUTPUT_MASK_KEY,
            prefix="jaccard",
            inputs_to_labels=depth2mask,
            outputs_to_labels=decode_depth_mask,
        ),
    ]

    if is_master:

        default_callbacks += [
            BestMetricCheckpointCallback(target_metric="jaccard", target_metric_minimize=False),
            HyperParametersCallback(
                hparam_dict={
                    "model": model_name,
                    "scheduler": scheduler_name,
                    "optimizer": optimizer_name,
                    "augmentations": augmentations,
                    "size": args.size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "dropout": None if dropout is None else float(dropout),
                }
            ),
        ]

        if show:
            visualize_inria_predictions = partial(
                draw_inria_predictions,
                image_key=INPUT_IMAGE_KEY,
                image_id_key=INPUT_IMAGE_ID_KEY,
                targets_key=INPUT_MASK_KEY,
                outputs_key=OUTPUT_MASK_KEY,
                inputs_to_labels=depth2mask,
                outputs_to_labels=decode_depth_mask,
                max_images=16,
            )
            default_callbacks += [
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False),
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True),
            ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        buildings_only=(train_mode == "tiles"),
        fast=fast,
        need_weight_mask=need_weight_mask,
        make_mask_target_fn=mask_to_ce_target,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds), "samples")

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(
                data_dir, include_masks=False, augmentation=None, image_size=image_size
            )

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir, include_masks=True, augmentation=augmentations, image_size=image_size
            )

            if args.distributed:
                label_sampler = DistributedSampler(unlabeled_label, args.world_size, args.local_rank, shuffle=False)
            else:
                label_sampler = None

            loaders["infer"] = DataLoader(
                unlabeled_label,
                batch_size=batch_size // 2,
                num_workers=num_workers,
                pin_memory=True,
                sampler=label_sampler,
                drop_last=False,
            )

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="infer",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label), "samples")

        valid_sampler = None
        if args.distributed:
            if train_sampler is not None:
                train_sampler = DistributedSamplerWrapper(
                    train_sampler, args.world_size, args.local_rank, shuffle=True
                )
            else:
                train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True)
            valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False)

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(
            valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler
        )

        loss_callbacks, loss_criterions = get_criterions(
            criterions, criterions2, criterions4, criterions8, criterions16
        )
        callbacks += loss_callbacks

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        log_dir = os.path.join("runs", checkpoint_prefix)

        if is_master:
            os.makedirs(log_dir, exist_ok=False)
            config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
            with open(config_fname, "w") as f:
                train_session_args = vars(args)
                f.write(json.dumps(train_session_args, indent=2))

            print("Train session    :", checkpoint_prefix)
            print("  FP16 mode      :", fp16)
            print("  Fast mode      :", args.fast)
            print("  Train mode     :", train_mode)
            print("  Epochs         :", num_epochs)
            print("  Workers        :", num_workers)
            print("  Data dir       :", data_dir)
            print("  Log dir        :", log_dir)
            print("  Augmentations  :", augmentations)
            print("  Train size     :", "batches", len(loaders["train"]), "dataset", len(train_ds))
            print("  Valid size     :", "batches", len(loaders["valid"]), "dataset", len(valid_ds))
            print("Model            :", model_name)
            print("  Parameters     :", count_parameters(model))
            print("  Image size     :", image_size)
            print("Optimizer        :", optimizer_name)
            print("  Learning rate  :", learning_rate)
            print("  Batch size     :", batch_size)
            print("  Criterion      :", criterions)
            print("  Use weight mask:", need_weight_mask)
            if args.distributed:
                print("Distributed")
                print("  World size     :", args.world_size)
                print("  Local rank     :", args.local_rank)
                print("  Is master      :", args.is_master)

        # model training
        runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda")
        runner.train(
            fp16=distributed_params,
            model=model,
            criterion=loss_criterions,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        if is_master:
            best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")

            model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")
            clean_checkpoint(best_checkpoint, model_checkpoint)

            unpack_checkpoint(torch.load(model_checkpoint), model=model)

            mask = predict(
                model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size
            )
            mask = ((mask > 0) * 255).astype(np.uint8)
            name = os.path.join(log_dir, "sample_color.jpg")
            cv2.imwrite(name, mask)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory for INRIA sattelite dataset')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='cls_resnet18',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-fe', '--freeze-encoder', action='store_true')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('-l',
                        '--criterion',
                        type=str,
                        default='bce',
                        help='Criterion')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=multiprocessing.cpu_count(),
                        type=int,
                        help='Num workers')
    parser.add_argument('-a',
                        '--augmentations',
                        default='hard',
                        type=str,
                        help='')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    parser.add_argument('-tm',
                        '--train-mode',
                        default='random',
                        type=str,
                        help='')
    parser.add_argument('-rm',
                        '--run-mode',
                        default='fit_predict',
                        type=str,
                        help='')
    parser.add_argument('--transfer', default=None, type=str, help='')
    parser.add_argument('--fp16', action='store_true')

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    early_stopping = args.early_stopping
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    run_mode = args.run_mode
    log_dir = None
    fp16 = args.fp16
    freeze_encoder = args.freeze_encoder

    run_train = run_mode == 'fit_predict' or run_mode == 'fit'
    run_predict = run_mode == 'fit_predict' or run_mode == 'predict'

    model = maybe_cuda(get_model(model_name, num_classes=1))

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint['model_state_dict']

        for name, value in pretrained_dict.items():
            try:
                model.load_state_dict(collections.OrderedDict([(name, value)]),
                                      strict=False)
            except Exception as e:
                print(e)

    checkpoint = None
    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        checkpoint_epoch = checkpoint['epoch']
        print('Loaded model weights from:', args.checkpoint)
        print('Epoch                    :', checkpoint_epoch)
        print('Metrics (Train):', 'f1  :',
              checkpoint['epoch_metrics']['train']['f1_score'], 'loss:',
              checkpoint['epoch_metrics']['train']['loss'])
        print('Metrics (Valid):', 'f1  :',
              checkpoint['epoch_metrics']['valid']['f1_score'], 'loss:',
              checkpoint['epoch_metrics']['valid']['loss'])

        log_dir = os.path.dirname(
            os.path.dirname(fs.auto_file(args.checkpoint)))

    if run_train:

        if freeze_encoder:
            set_trainable(model.encoder, trainable=False, freeze_bn=True)

        criterion = get_loss(args.criterion)
        parameters = get_optimizable_parameters(model)
        optimizer = get_optimizer(optimizer_name, parameters, learning_rate)

        if checkpoint is not None:
            try:
                unpack_checkpoint(checkpoint, optimizer=optimizer)
                print('Restored optimizer state from checkpoint')
            except Exception as e:
                print('Failed to restore optimizer state from checkpoint', e)

        train_loader, valid_loader = get_dataloaders(
            data_dir=data_dir,
            batch_size=batch_size,
            num_workers=num_workers,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast)

        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        current_time = datetime.now().strftime('%b%d_%H_%M')
        prefix = f'adversarial/{args.model}/{current_time}_{args.criterion}'

        if fp16:
            prefix += '_fp16'

        if fast:
            prefix += '_fast'

        log_dir = os.path.join('runs', prefix)
        os.makedirs(log_dir, exist_ok=False)

        scheduler = MultiStepLR(optimizer,
                                milestones=[10, 30, 50, 70, 90],
                                gamma=0.5)

        print('Train session    :', prefix)
        print('\tFP16 mode      :', fp16)
        print('\tFast mode      :', args.fast)
        print('\tTrain mode     :', train_mode)
        print('\tEpochs         :', num_epochs)
        print('\tEarly stopping :', early_stopping)
        print('\tWorkers        :', num_workers)
        print('\tData dir       :', data_dir)
        print('\tLog dir        :', log_dir)
        print('\tAugmentations  :', augmentations)
        print('\tTrain size     :', len(train_loader),
              len(train_loader.dataset))
        print('\tValid size     :', len(valid_loader),
              len(valid_loader.dataset))
        print('Model            :', model_name)
        print('\tParameters     :', count_parameters(model))
        print('\tImage size     :', image_size)
        print('\tFreeze encoder :', freeze_encoder)
        print('Optimizer        :', optimizer_name)
        print('\tLearning rate  :', learning_rate)
        print('\tBatch size     :', batch_size)
        print('\tCriterion      :', args.criterion)

        # model training
        visualization_fn = partial(draw_classification_predictions,
                                   class_names=['Train', 'Test'])

        callbacks = [
            F1ScoreCallback(),
            AUCCallback(),
            ShowPolarBatchesCallback(visualization_fn,
                                     metric='f1_score',
                                     minimize=False),
        ]

        if early_stopping:
            callbacks += [
                EarlyStoppingCallback(early_stopping,
                                      metric='auc',
                                      minimize=False)
            ]

        runner = SupervisedRunner(input_key='image')
        runner.train(fp16=fp16,
                     model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     callbacks=callbacks,
                     loaders=loaders,
                     logdir=log_dir,
                     num_epochs=num_epochs,
                     verbose=True,
                     main_metric='auc',
                     minimize_metric=False,
                     state_kwargs={"cmd_args": vars(args)})

    if run_predict and not fast:
        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = load_checkpoint(
            fs.auto_file('best.pth', where=log_dir))
        unpack_checkpoint(best_checkpoint, model=model)

        model.eval()
        torch.no_grad()

        train_csv = pd.read_csv(os.path.join(data_dir, 'train.csv'))
        train_csv['id_code'] = train_csv['id_code'].apply(
            lambda x: os.path.join(data_dir, 'train_images', f'{x}.png'))
        test_ds = RetinopathyDataset(train_csv['id_code'],
                                     None,
                                     get_test_aug(image_size),
                                     target_as_array=True)
        test_dl = DataLoader(test_ds,
                             batch_size,
                             pin_memory=True,
                             num_workers=num_workers)

        test_ids = []
        test_preds = []

        for batch in tqdm(test_dl, desc='Inference'):
            input = batch['image'].cuda()
            outputs = model(input)
            predictions = to_numpy(outputs['logits'].sigmoid().squeeze(1))
            test_ids.extend(batch['image_id'])
            test_preds.extend(predictions)

        df = pd.DataFrame.from_dict({
            'id_code': test_ids,
            'is_test': test_preds
        })
        df.to_csv(os.path.join(log_dir, 'test_in_train.csv'), index=None)
コード例 #18
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        required=True,
                        help="Data directory for INRIA sattelite dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument(
        "--disaster-type-loss",
        type=str,
        default=None,  # [["ce", 1.0]],
        action="append",
        nargs="+",
        help="Criterion for classifying disaster type",
    )
    parser.add_argument(
        "--damage-type-loss",
        type=str,
        default=None,  # [["bce", 1.0]],
        action="append",
        nargs="+",
        help=
        "Criterion for classifying presence of building with particular damage type",
    )

    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("--mask4",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 4")
    parser.add_argument("--mask8",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 8")
    parser.add_argument("--mask16",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 16")
    parser.add_argument("--mask32",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 32")
    parser.add_argument("--embedding", type=str, default=None)

    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="safe",
                        type=str,
                        help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("--fold", default=0, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=0.0,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("-pl", "--pseudolabeling", type=str, required=True)
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--only-buildings", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")
    parser.add_argument("--crops",
                        action="store_true",
                        help="Train on random crops")
    parser.add_argument("--post-transform", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    segmentation_losses = args.criterion
    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    only_buildings = args.only_buildings
    freeze_bn = args.freeze_bn
    train_on_crops = args.crops
    enable_post_image_transform = args.post_transform
    disaster_type_loss = args.disaster_type_loss
    train_batch_size = args.batch_size
    embedding_criterion = args.embedding
    damage_type_loss = args.damage_type_loss
    pseudolabels_dir = args.pseudolabeling

    # Compute batch size for validaion
    if train_on_crops:
        valid_batch_size = max(1,
                               (train_batch_size *
                                (image_size[0] * image_size[1])) // (1024**2))
    else:
        valid_batch_size = train_batch_size

    run_train = num_epochs > 0

    model: nn.Module = get_model(model_name, dropout=dropout).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    if freeze_bn:
        torch_utils.freeze_bn(model)
        print("Freezing bn params")

    runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None)
    main_metric = "weighted_f1"
    cmd_args = vars(args)

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_{args.size}_fold{fold}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if pseudolabels_dir:
        checkpoint_prefix += "_pseudo"

    if train_on_crops:
        checkpoint_prefix += "_crops"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        CompetitionMetricCallback(input_key=INPUT_MASK_KEY,
                                  output_key=OUTPUT_MASK_KEY,
                                  prefix="weighted_f1"),
        ConfusionMatrixCallback(
            input_key=INPUT_MASK_KEY,
            output_key=OUTPUT_MASK_KEY,
            class_names=[
                "land", "no_damage", "minor_damage", "major_damage",
                "destroyed"
            ],
            ignore_index=UNLABELED_SAMPLE,
        ),
    ]

    if show:
        default_callbacks += [
            ShowPolarBatchesCallback(draw_predictions,
                                     metric=main_metric + "_batch",
                                     minimize=False)
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        fast=fast,
        fold=fold,
        balance=balance,
        only_buildings=only_buildings,
        train_on_crops=train_on_crops,
        crops_multiplication_factor=1,
        enable_post_image_transform=enable_post_image_transform,
    )

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        unlabeled_train = get_pseudolabeling_dataset(
            data_dir,
            include_masks=True,
            image_size=image_size,
            augmentation="medium_nmd",
            train_on_crops=train_on_crops,
            enable_post_image_transform=enable_post_image_transform,
            pseudolabels_dir=pseudolabels_dir,
        )

        train_ds = train_ds + unlabeled_train

        print("Using online pseudolabeling with ", len(unlabeled_train),
              "samples")

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=True,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=valid_batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True)

        # Create losses
        for criterion in segmentation_losses:
            if isinstance(criterion, (list, tuple)) and len(criterion) == 2:
                loss_name, loss_weight = criterion
            else:
                loss_name, loss_weight = criterion[0], 1.0

            cd, criterion, criterion_name = get_criterion_callback(
                loss_name,
                prefix="segmentation",
                input_key=INPUT_MASK_KEY,
                output_key=OUTPUT_MASK_KEY,
                loss_weight=float(loss_weight),
            )
            criterions_dict.update(cd)
            callbacks.append(criterion)
            losses.append(criterion_name)
            print(INPUT_MASK_KEY, "Using loss", loss_name, loss_weight)

        if args.mask4 is not None:
            for criterion in args.mask4:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask4",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_4_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_4_KEY, "Using loss", loss_name, loss_weight)

        if args.mask8 is not None:
            for criterion in args.mask8:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask8",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_8_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_8_KEY, "Using loss", loss_name, loss_weight)

        if args.mask16 is not None:
            for criterion in args.mask16:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask16",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_16_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_16_KEY, "Using loss", loss_name, loss_weight)

        if args.mask32 is not None:
            for criterion in args.mask32:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask32",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_32_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_32_KEY, "Using loss", loss_name, loss_weight)

        if disaster_type_loss is not None:
            callbacks += [
                ConfusionMatrixCallback(
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    class_names=DISASTER_TYPES,
                    ignore_index=UNKNOWN_DISASTER_TYPE_CLASS,
                    prefix=f"{DISASTER_TYPE_KEY}/confusion_matrix",
                ),
                AccuracyCallback(
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    prefix=f"{DISASTER_TYPE_KEY}/accuracy",
                    activation="Softmax",
                ),
            ]

            for criterion in disaster_type_loss:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix=DISASTER_TYPE_KEY,
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    loss_weight=float(loss_weight),
                    ignore_index=UNKNOWN_DISASTER_TYPE_CLASS,
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(DISASTER_TYPE_KEY, "Using loss", loss_name, loss_weight)

        if damage_type_loss is not None:
            callbacks += [
                # MultilabelConfusionMatrixCallback(
                #     input_key=DAMAGE_TYPE_KEY,
                #     output_key=DAMAGE_TYPE_KEY,
                #     class_names=DAMAGE_TYPES,
                #     prefix=f"{DAMAGE_TYPE_KEY}/confusion_matrix",
                # ),
                AccuracyCallback(
                    input_key=DAMAGE_TYPE_KEY,
                    output_key=DAMAGE_TYPE_KEY,
                    prefix=f"{DAMAGE_TYPE_KEY}/accuracy",
                    activation="Sigmoid",
                    threshold=0.5,
                )
            ]

            for criterion in damage_type_loss:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix=DAMAGE_TYPE_KEY,
                    input_key=DAMAGE_TYPE_KEY,
                    output_key=DAMAGE_TYPE_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(DAMAGE_TYPE_KEY, "Using loss", loss_name, loss_weight)

        if embedding_criterion is not None:
            cd, criterion, criterion_name = get_criterion_callback(
                embedding_criterion,
                prefix="embedding",
                input_key=INPUT_MASK_KEY,
                output_key=OUTPUT_EMBEDDING_KEY,
                loss_weight=1.0,
            )
            criterions_dict.update(cd)
            callbacks.append(criterion)
            losses.append(criterion_name)
            print(OUTPUT_EMBEDDING_KEY, "Using loss", embedding_criterion)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("Data             ")
        print("  Augmentations  :", augmentations)
        print("  Train size     :", len(loaders["train"]), len(train_ds))
        print("  Valid size     :", len(loaders["valid"]), len(valid_ds))
        print("  Image size     :", image_size)
        print("  Train on crops :", train_on_crops)
        print("  Balance        :", balance)
        print("  Buildings only :", only_buildings)
        print("  Post transform :", enable_post_image_transform)
        print("  Pseudolabels   :", pseudolabels_dir)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("  Criterion      :", segmentation_losses)
        print("  Damage type    :", damage_type_loss)
        print("  Disaster type  :", disaster_type_loss)
        print(" Embedding      :", embedding_criterion)

        # model training
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "opl"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": cmd_args},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                        f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        del optimizer, loaders
コード例 #19
0
def main(args, _=None):
    """Run the ``catalyst-data text2embeddings`` script."""
    batch_size = args.batch_size
    num_workers = args.num_workers
    max_length = args.max_length
    pooling_groups = args.pooling.split(",")
    bert_level = args.bert_level

    if bert_level is not None:
        assert (args.output_hidden_states
                ), "You need hidden states output for level specification"

    utils.set_global_seed(args.seed)
    utils.prepare_cudnn(args.deterministic, args.benchmark)

    if getattr(args, "in_huggingface", False):
        model_config = BertConfig.from_pretrained(args.in_huggingface)
        model_config.output_hidden_states = args.output_hidden_states
        model = BertModel.from_pretrained(args.in_huggingface,
                                          config=model_config)
        tokenizer = BertTokenizer.from_pretrained(args.in_huggingface)
    else:
        model_config = BertConfig.from_pretrained(args.in_config)
        model_config.output_hidden_states = args.output_hidden_states
        model = BertModel(config=model_config)
        tokenizer = BertTokenizer.from_pretrained(args.in_vocab)
    if getattr(args, "in_model", None) is not None:
        checkpoint = utils.load_checkpoint(args.in_model)
        checkpoint = {"model_state_dict": checkpoint}
        utils.unpack_checkpoint(checkpoint=checkpoint, model=model)

    model = model.eval()
    model, _, _, _, device = utils.process_components(model=model)

    df = pd.read_csv(args.in_csv)
    df = df.dropna(subset=[args.txt_col])
    df.to_csv(f"{args.out_prefix}.df.csv", index=False)
    df = df.reset_index().drop("index", axis=1)
    df = list(df.to_dict("index").values())
    num_samples = len(df)

    open_fn = LambdaReader(
        input_key=args.txt_col,
        output_key=None,
        lambda_fn=partial(
            tokenize_text,
            strip=args.strip,
            lowercase=args.lowercase,
            remove_punctuation=args.remove_punctuation,
        ),
        tokenizer=tokenizer,
        max_length=max_length,
    )

    dataloader = utils.get_loader(
        df,
        open_fn,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    features = {}
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for idx, batch_input in enumerate(dataloader):
            batch_input = utils.any2device(batch_input, device)
            batch_output = model(**batch_input)
            mask = (batch_input["attention_mask"].unsqueeze(-1)
                    if args.mask_for_max_length else None)

            if utils.check_ddp_wrapped(model):
                # using several gpu
                hidden_size = model.module.config.hidden_size
                hidden_states = model.module.config.output_hidden_states

            else:
                # using cpu or one gpu
                hidden_size = model.config.hidden_size
                hidden_states = model.config.output_hidden_states

            batch_features = process_bert_output(
                bert_output=batch_output,
                hidden_size=hidden_size,
                output_hidden_states=hidden_states,
                pooling_groups=pooling_groups,
                mask=mask,
            )

            # create storage based on network output
            if idx == 0:
                for layer_name, layer_value in batch_features.items():
                    if bert_level is not None and bert_level != layer_name:
                        continue
                    layer_name = (layer_name if isinstance(layer_name, str)
                                  else f"{layer_name:02d}")
                    _, embedding_size = layer_value.shape
                    features[layer_name] = np.memmap(
                        f"{args.out_prefix}.{layer_name}.npy",
                        dtype=np.float32,
                        mode="w+",
                        shape=(num_samples, embedding_size),
                    )

            indices = np.arange(idx * batch_size,
                                min((idx + 1) * batch_size, num_samples))
            for layer_name2, layer_value2 in batch_features.items():
                if bert_level is not None and bert_level != layer_name2:
                    continue
                layer_name2 = (layer_name2 if isinstance(layer_name2, str) else
                               f"{layer_name2:02d}")
                features[layer_name2][indices] = _detach(layer_value2)

    if args.force_save:
        for key, mmap in features.items():
            mmap.flush()
            np.save(f"{args.out_prefix}.{key}.force.npy",
                    mmap,
                    allow_pickle=False)
コード例 #20
0
            EarlyStoppingCallback(patience=10, min_delta=0.001),
            CriterionCallback(),
        ]
    elif args.task == "classification":
        callbacks = [
            AUCCallback(class_names=["Fish", "Flower", "Gravel", "Sugar"],
                        num_classes=4),
            EarlyStoppingCallback(patience=10, min_delta=0.001),
            CriterionCallback(),
        ]

    if args.gradient_accumulation:
        callbacks.append(
            OptimizerCallback(accumulation_steps=args.gradient_accumulation))

    checkpoint = utils.load_checkpoint(f"{logdir}/checkpoints/best.pth")
    model.cuda()
    utils.unpack_checkpoint(checkpoint, model=model)
    #
    #
    runner = SupervisedRunner()
    if args.train:
        print("Training")
        runner.train(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            main_metric="dice",
            minimize_metric=False,
            scheduler=scheduler,
            loaders=loaders,
コード例 #21
0
def run_stage_training(model: Union[TimmRgbModel,
                                    YCrCbModel], config: StageConfig,
                       exp_config: ExperimenetConfig, experiment_dir: str):
    # Preparing model
    freeze_model(model, freeze_bn=config.freeze_bn)

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=exp_config.data_dir,
        image_size=config.image_size,
        augmentation=config.augmentations,
        balance=config.balance,
        fast=config.fast,
        fold=exp_config.fold,
        features=model.required_features,
        obliterate_p=config.obliterate_p,
    )

    criterions_dict, loss_callbacks = get_criterions(
        modification_flag=config.modification_flag_loss,
        modification_type=config.modification_type_loss,
        embedding_loss=config.embedding_loss,
        feature_maps_loss=config.feature_maps_loss,
        num_epochs=config.epochs,
        mixup=config.mixup,
        cutmix=config.cutmix,
        tsa=config.tsa,
    )

    callbacks = loss_callbacks + [
        OptimizerCallback(accumulation_steps=config.accumulation_steps,
                          decouple_weight_decay=False),
        HyperParametersCallback(
            hparam_dict={
                "model": exp_config.model_name,
                "scheduler": config.schedule,
                "optimizer": config.optimizer,
                "augmentations": config.augmentations,
                "size": config.image_size[0],
                "weight_decay": config.weight_decay,
            }),
    ]

    if config.show:
        callbacks += [
            ShowPolarBatchesCallback(draw_predictions,
                                     metric="loss",
                                     minimize=True)
        ]

    loaders = collections.OrderedDict()
    loaders["train"] = DataLoader(
        train_ds,
        batch_size=config.train_batch_size,
        num_workers=exp_config.num_workers,
        pin_memory=True,
        drop_last=True,
        shuffle=train_sampler is None,
        sampler=train_sampler,
    )

    loaders["valid"] = DataLoader(valid_ds,
                                  batch_size=config.valid_batch_size,
                                  num_workers=exp_config.num_workers,
                                  pin_memory=True)

    print("Stage            :", config.stage_name)
    print("  FP16 mode      :", config.fp16)
    print("  Fast mode      :", config.fast)
    print("  Epochs         :", config.epochs)
    print("  Workers        :", exp_config.num_workers)
    print("  Data dir       :", exp_config.data_dir)
    print("  Experiment dir :", experiment_dir)
    print("Data              ")
    print("  Augmentations  :", config.augmentations)
    print("  Obliterate (%) :", config.obliterate_p)
    print("  Negative images:", config.negative_image_dir)
    print("  Train size     :", len(loaders["train"]), "batches",
          len(train_ds), "samples")
    print("  Valid size     :", len(loaders["valid"]), "batches",
          len(valid_ds), "samples")
    print("  Image size     :", config.image_size)
    print("  Balance        :", config.balance)
    print("  Mixup          :", config.mixup)
    print("  CutMix         :", config.cutmix)
    print("  TSA            :", config.tsa)
    print("Model            :", exp_config.model_name)
    print("  Parameters     :", count_parameters(model))
    print("  Dropout        :", exp_config.dropout)
    print("Optimizer        :", config.optimizer)
    print("  Learning rate  :", config.learning_rate)
    print("  Weight decay   :", config.weight_decay)
    print("  Scheduler      :", config.schedule)
    print("  Batch sizes    :", config.train_batch_size,
          config.valid_batch_size)
    print("Losses            ")
    print("  Flag           :", config.modification_flag_loss)
    print("  Type           :", config.modification_type_loss)
    print("  Embedding      :", config.embedding_loss)
    print("  Feature maps   :", config.feature_maps_loss)

    optimizer = get_optimizer(
        config.optimizer,
        get_optimizable_parameters(model),
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
    )
    scheduler = get_scheduler(
        config.schedule,
        optimizer,
        lr=config.learning_rate,
        num_epochs=config.epochs,
        batches_in_epoch=len(loaders["train"]),
    )
    if isinstance(scheduler, CyclicLR):
        callbacks += [SchedulerCallback(mode="batch")]

    # model training
    runner = SupervisedRunner(input_key=model.required_features,
                              output_key=None)
    runner.train(
        fp16=config.fp16,
        model=model,
        criterion=criterions_dict,
        optimizer=optimizer,
        scheduler=scheduler,
        callbacks=callbacks,
        loaders=loaders,
        logdir=os.path.join(experiment_dir, config.stage_name),
        num_epochs=config.epochs,
        verbose=config.verbose,
        main_metric=config.main_metric,
        minimize_metric=config.main_metric_minimize,
        checkpoint_data={"config": config},
    )

    del optimizer, loaders, callbacks, runner

    best_checkpoint = os.path.join(experiment_dir, config.stage_name,
                                   "checkpoints", "best.pth")
    model_checkpoint = os.path.join(experiment_dir,
                                    f"{exp_config.checkpoint_prefix}.pth")
    clean_checkpoint(best_checkpoint, model_checkpoint)

    # Restore state of best model
    if config.restore_best:
        unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

    # Some memory cleanup
    torch.cuda.empty_cache()
    gc.collect()
コード例 #22
0
    #     criterion = lovasz_softmax()
    elif args.loss == 'BCEMulticlassDiceLoss':
        criterion = BCEMulticlassDiceLoss()
    elif args.loss == 'MulticlassDiceMetricCallback':
        criterion = MulticlassDiceMetricCallback()
    elif args.loss == 'BCE':
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = smp.utils.losses.BCEDiceLoss(eps=1.)

    if args.make_prediction:
        print('MAKING PREDICTIONS')
        model1 = get_model(model_type=args.segm_type, encoder=args.encoder, encoder_weights=args.encoder_weights,
                          activation=None, task=args.task)
        loaders['test'] = test_loader
        checkpoint = utils.load_checkpoint('/home/dex/Desktop/ml/cloud artgor/logs/weights/lb 6582/best.pth')
        model1.cuda()
        utils.unpack_checkpoint(checkpoint, model=model1)

        model2 = get_model(model_type=args.segm_type, encoder=args.encoder, encoder_weights=args.encoder_weights,
                          activation=None, task=args.task)
        checkpoint = utils.load_checkpoint('/home/dex/Desktop/ml/cloud artgor/logs/weights/lb 6582/cont fitted/best.pth')
        model2.cuda()
        utils.unpack_checkpoint(checkpoint, model=model2)

        model = Model([model1, model2])
        runner = SupervisedRunner(model=model)

        with open(f'{logdir}/class_params.json', 'r') as f:
            class_params = json.load(f)
        print('prediction postprocess params', class_params)
コード例 #23
0
ファイル: trace.py プロジェクト: saswat0/catalyst
def trace_model_from_checkpoint(
    logdir: Path,
    method_name: str,
    checkpoint_name: str,
    stage: str = None,
    loader: Union[str, int] = None,
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    device: Device = "cpu",
):
    """
    Traces model using created experiment and runner.

    Args:
        logdir (Union[str, Path]): Path to Catalyst logdir with model
        checkpoint_name (str): Name of model checkpoint to use
        stage (str): experiment's stage name
        loader (Union[str, int]): experiment's loader name or its index
        method_name (str): Model's method name that will be
            used as entrypoint during tracing
        mode (str): Mode for model to trace (``train`` or ``eval``)
        requires_grad (bool): Flag to use grads
        opt_level (str): AMP FP16 init level
        device (str): Torch device

    Returns:
        the traced model
    """
    config_path = logdir / "configs" / "_config.json"
    checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth"
    print("Load config")
    config: Dict[str, dict] = load_config(config_path)
    runner_params = config.get("runner_params", {}) or {}

    # Get expdir name
    config_expdir = Path(config["args"]["expdir"])
    # We will use copy of expdir from logs for reproducibility
    expdir = Path(logdir) / "code" / config_expdir.name

    print("Import experiment and runner from logdir")
    ExperimentType, RunnerType = import_experiment_and_runner(expdir)
    experiment: ConfigExperiment = ExperimentType(config)

    print(f"Load model state from checkpoints/{checkpoint_name}.pth")
    if stage is None:
        stage = list(experiment.stages)[0]

    model = experiment.get_model(stage)
    checkpoint = load_checkpoint(checkpoint_path)
    unpack_checkpoint(checkpoint, model=model)

    runner: RunnerType = RunnerType(**runner_params)
    runner.model, runner.device = model, device

    if loader is None:
        loader = 0
    batch = get_native_batch_from_loaders(
        loaders=experiment.get_loaders(stage), loader=loader)

    # function to run prediction on batch
    def predict_fn(model, inputs, **kwargs):
        _model = runner.model
        runner.model = model
        result = runner.predict_batch(inputs, **kwargs)
        runner.model = _model
        return result

    print("Tracing")
    traced_model = trace_model(
        model=model,
        predict_fn=predict_fn,
        batch=batch,
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
        opt_level=opt_level,
        device=device,
    )

    print("Done")
    return traced_model
コード例 #24
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str, default="unet", help="")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=None,
                        required=True,
                        help="Data dir")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        required=True,
        help="Checkpoint filename to use as initial model weights",
    )
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=16,
                        help="Batch size for inference")
    parser.add_argument("-tta",
                        "--tta",
                        default=None,
                        type=str,
                        help="Type of TTA to use [fliplr, d4]")
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(os.path.dirname(checkpoint_file))
    out_dir = os.path.join(run_dir, "submit")
    os.makedirs(out_dir, exist_ok=True)

    checkpoint = load_checkpoint(checkpoint_file)
    checkpoint_epoch = checkpoint["epoch"]
    print("Loaded model weights from", args.checkpoint)
    print("Epoch   :", checkpoint_epoch)
    print(
        "Metrics (Train):",
        "IoU:",
        checkpoint["epoch_metrics"]["train"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["train"]["accuracy"],
    )
    print(
        "Metrics (Valid):",
        "IoU:",
        checkpoint["epoch_metrics"]["valid"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["valid"]["accuracy"],
    )

    model = get_model(args.model)
    unpack_checkpoint(checkpoint, model=model)
    # threshold = checkpoint["epoch_metrics"]["valid"].get("optimized_jaccard/threshold", 0.5)
    threshold = 0.5
    print("Using threshold", threshold)

    model = nn.Sequential(PickModelOutput(model, OUTPUT_MASK_KEY),
                          nn.Sigmoid())

    if args.tta == "fliplr":
        model = TTAWrapper(model, fliplr_image2mask)
    elif args.tta == "flipscale":
        model = TTAWrapper(model, fliplr_image2mask)
        model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128])
    elif args.tta == "d4":
        model = TTAWrapper(model, d4_image2mask)
    else:
        pass

    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    model = model.eval()

    mask = predict(model,
                   read_inria_image("sample_color.jpg"),
                   image_size=(512, 512),
                   batch_size=args.batch_size)
    mask = ((mask > threshold) * 255).astype(np.uint8)
    name = os.path.join(run_dir, "sample_color.jpg")
    cv2.imwrite(name, mask)

    test_predictions_dir = os.path.join(out_dir, "test_predictions")
    test_predictions_dir_compressed = os.path.join(
        out_dir, "test_predictions_compressed")

    if args.tta is not None:
        test_predictions_dir += f"_{args.tta}"
        test_predictions_dir_compressed += f"_{args.tta}"

    os.makedirs(test_predictions_dir, exist_ok=True)
    os.makedirs(test_predictions_dir_compressed, exist_ok=True)

    test_images = find_in_dir(os.path.join(data_dir, "test", "images"))
    for fname in tqdm(test_images, total=len(test_images)):
        image = read_inria_image(fname)
        mask = predict(model,
                       image,
                       image_size=(512, 512),
                       batch_size=args.batch_size)
        mask = ((mask > threshold) * 255).astype(np.uint8)
        name = os.path.join(test_predictions_dir, os.path.basename(fname))
        cv2.imwrite(name, mask)

        name_compressed = os.path.join(test_predictions_dir_compressed,
                                       os.path.basename(fname))
        command = (
            "gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 "
            + name + " " + name_compressed)
        subprocess.call(command, shell=True)
コード例 #25
0
ファイル: trace.py プロジェクト: saswat0/catalyst
def trace_model_from_runner(
    runner: IRunner,
    checkpoint_name: str = None,
    method_name: str = "forward",
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    device: Device = "cpu",
) -> ScriptModule:
    """
    Traces model using created experiment and runner.

    Args:
        runner (Runner): Current runner.
        checkpoint_name (str): Name of model checkpoint to use, if None
            traces current model from runner
        method_name (str): Model's method name that will be
            used as entrypoint during tracing
        mode (str): Mode for model to trace (``train`` or ``eval``)
        requires_grad (bool): Flag to use grads
        opt_level (str): AMP FP16 init level
        device (str): Torch device

    Returns:
        (ScriptModule): Traced model
    """
    logdir = runner.logdir
    model = get_nn_from_ddp_module(runner.model)

    if checkpoint_name is not None:
        dumped_checkpoint = pack_checkpoint(model=model)
        checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth"
        checkpoint = load_checkpoint(filepath=checkpoint_path)
        unpack_checkpoint(checkpoint=checkpoint, model=model)

    # getting input names of args for method since we don't have Runner
    # and we don't know input_key to preprocess batch for method call
    fn = getattr(model, method_name)
    method_argnames = _get_input_argnames(fn=fn, exclude=["self"])

    batch = {}
    for name in method_argnames:
        # TODO: We don't know input_keys without runner
        assert name in runner.input, (
            "Input batch should contain the same keys as input argument "
            "names of `forward` function to be traced correctly")
        batch[name] = runner.input[name]

    batch = any2device(batch, device)

    # Dumping previous runner of the model, we will need it to restore
    _device, _is_training, _requires_grad = (
        runner.device,
        model.training,
        get_requires_grad(model),
    )

    model.to(device)

    # Function to run prediction on batch
    def predict_fn(model: Model, inputs, **kwargs):
        return model(**inputs, **kwargs)

    traced_model = trace_model(
        model=model,
        predict_fn=predict_fn,
        batch=batch,
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
        opt_level=opt_level,
        device=device,
    )

    if checkpoint_name is not None:
        unpack_checkpoint(checkpoint=dumped_checkpoint, model=model)

    # Restore previous runner of the model
    getattr(model, "train" if _is_training else "eval")()
    set_requires_grad(model, _requires_grad)
    model.to(_device)

    return traced_model
コード例 #26
0
def main():
    global args, logger
    args = get_parser()
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')

    test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split=args.split,
                                data_root=args.data_root,
                                data_list=args.test_list,
                                transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = [line.rstrip('\n') for line in open(args.names_path)]

    if not args.has_prediction:
        print('arch: ', args.arch)
        if args.arch == 'psp':
            from model.pspnet import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           pretrained=False)
        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           compact=args.compact,
                           shrink_factor=args.shrink_factor,
                           mask_h=args.mask_h,
                           mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor,
                           psa_softmax=args.psa_softmax,
                           pretrained=False)
        elif 'of' in args.arch:
            from catalyst_cityscapes.model.mymodel import ofPSPNet
            print('ofPSPNet')
            model = ofPSPNet(encoder_name=str(args.layers),
                             classes=args.classes,
                             zoom_factor=args.zoom_factor,
                             pretrained=False)

        elif 'smp' in args.arch:
            from catalyst_cityscapes.model.mymodel import smpPSPNet
            print('smpPSPNet')
            model = smpPSPNet(encoder_name='resnet%d' % args.layers,
                              classes=args.classes)

        # logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.model_path):
            logger.info("=> loading checkpoint '{}'".format(args.model_path))
            if 'of' in args.arch or 'smp' in args.arch:
                from catalyst import utils
                checkpoint = utils.load_checkpoint(args.model_path)
                print('checkpoint keys', list(checkpoint))
                utils.unpack_checkpoint(checkpoint, model=model)
            else:
                checkpoint = torch.load(args.model_path)
                model.load_state_dict(checkpoint['state_dict'], strict=False)

            logger.info("=> loaded checkpoint '{}'".format(args.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.model_path))
        test(test_loader,
             test_data.data_list,
             model,
             args.classes,
             mean,
             std,
             args.base_size,
             args.test_h,
             args.test_w,
             args.scales,
             gray_folder,
             color_folder,
             colors,
             is_med=args.get('is_med', False))
    if args.split != 'test':
        cal_acc(test_data.data_list,
                gray_folder,
                args.classes,
                names,
                is_med=args.get('is_med', False),
                label_mapping=args.get('label_mapping', None))
コード例 #27
0
ファイル: catalyst.py プロジェクト: xang1234/mlcomp
    def _checkpoint_fix_config(self, experiment):
        resume = self.resume
        if not resume:
            return

        checkpoint_dir = join(experiment.logdir, 'checkpoints')
        os.makedirs(checkpoint_dir, exist_ok=True)

        file = 'last_full.pth' if resume.get('load_last') else 'best_full.pth'

        path = join(checkpoint_dir, file)
        computer = socket.gethostname()
        if computer != resume['master_computer']:
            master_computer = self.computer_provider.by_name(
                resume['master_computer'])
            path_from = join(master_computer.root_folder,
                             str(resume['master_task_id']), 'log',
                             'checkpoints', file)
            self.info(f'copying checkpoint from: computer = '
                      f'{resume["master_computer"]} path_from={path_from} '
                      f'path_to={path}')

            success = copy_remote(session=self.session,
                                  computer_from=resume['master_computer'],
                                  path_from=path_from,
                                  path_to=path)

            if not success:
                self.error(f'copying from '
                           f'{resume["master_computer"]}/'
                           f'{path_from} failed')
            else:
                self.info('checkpoint copied successfully')

        elif self.task.id != resume['master_task_id']:
            path = join(TASK_FOLDER, str(resume['master_task_id']), 'log',
                        'checkpoints', file)
            self.info(f'master_task_id!=task.id, using checkpoint'
                      f' from task_id = {resume["master_task_id"]}')

        if not os.path.exists(path):
            self.info(f'no checkpoint at {path}')
            return

        ckpt = load_checkpoint(path)
        stages_config = experiment.stages_config
        for k, v in list(stages_config.items()):
            if k == ckpt['stage']:
                stage_epoch = ckpt['checkpoint_data']['stage_epoch'] + 1

                # if it is the last epoch in the stage
                if stage_epoch == v['state_params']['num_epochs'] \
                        or resume.get('load_best'):
                    del stages_config[k]
                    break

                self.checkpoint_stage_epoch = stage_epoch
                v['state_params']['num_epochs'] -= stage_epoch
                break
            del stages_config[k]

        stage = experiment.stages_config[experiment.stages[0]]
        for k, v in stage['callbacks_params'].items():
            if v.get('callback') == 'CheckpointCallback':
                v['resume'] = path

        self.info(f'found checkpoint at {path}')
コード例 #28
0
    else:
        criterion = smp.utils.losses.BCEDiceLoss(eps=1.)

    if args.multigpu:
        model = nn.DataParallel(model)

    if args.task == 'segmentation':
        callbacks = [DiceCallback(), EarlyStoppingCallback(patience=10, min_delta=0.001), CriterionCallback()]
    elif args.task == 'classification':
        callbacks = [AUCCallback(class_names=['Fish', 'Flower', 'Gravel', 'Sugar'], num_classes=4),
                     EarlyStoppingCallback(patience=10, min_delta=0.001), CriterionCallback()]

    if args.gradient_accumulation:
        callbacks.append(OptimizerCallback(accumulation_steps=args.gradient_accumulation))

    checkpoint = utils.load_checkpoint(f'{logdir}/checkpoints/best.pth')
    model.cuda()
    utils.unpack_checkpoint(checkpoint, model=model)
    #
    #
    runner = SupervisedRunner()
    if args.train:
        print('Training')
        runner.train(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            main_metric='dice',
            minimize_metric=False,
            scheduler=scheduler,
            loaders=loaders,
コード例 #29
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration")
    parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--cache", action="store_true")
    parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-m", "--model", type=str, default="resnet34", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64")
    parser.add_argument(
        "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64"
    )
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    parser.add_argument(
        "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement"
    )
    parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs")
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")

    parser.add_argument(
        "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument(
        "--modification-type-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--mixup", action="store_true")
    parser.add_argument("--cutmix", action="store_true")
    parser.add_argument("--tsa", action="store_true")
    parser.add_argument("--fold", default=None, type=int)
    parser.add_argument("-s", "--scheduler", default=None, type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument(
        "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    assert (
        args.modification_flag_loss or args.modification_type_loss or args.embedding_loss
    ), "At least one of losses must be set"

    modification_flag_loss = args.modification_flag_loss
    modification_type_loss = args.modification_type_loss
    embedding_loss = args.embedding_loss
    feature_maps_loss = args.feature_maps_loss
    mask_loss = args.mask_loss
    bits_loss = args.bits_loss

    freeze_encoder = args.freeze_encoder
    data_dir = args.data_dir
    cache = args.cache
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name: str = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    freeze_bn = args.freeze_bn
    train_batch_size = args.batch_size
    mixup = args.mixup
    cutmix = args.cutmix
    tsa = args.tsa
    fine_tune = args.fine_tune
    obliterate_p = args.obliterate
    negative_image_dir = args.negative_image_dir
    warmup_batch_size = args.warmup_batch_size or args.batch_size

    # Compute batch size for validation
    valid_batch_size = train_batch_size
    run_train = num_epochs > 0

    custom_model_kwargs = {}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if embedding_loss is not None:
        custom_model_kwargs["need_embedding"] = True

    model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda()
    required_features = model.required_features

    if mask_loss is not None:
        required_features.append(INPUT_TRUE_MODIFICATION_MASK)

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transferring weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    if freeze_bn:
        from pytorch_toolbelt.optimization.functional import freeze_model

        freeze_model(model, freeze_bn=True)
        print("Freezing bn params")

    main_metric = "loss"
    main_metric_minimize = True

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if mixup:
        checkpoint_prefix += "_mixup"

    if cutmix:
        checkpoint_prefix += "_cutmix"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = []

    if show:
        default_callbacks += [ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True)]

    # Pretrain/warmup
    if warmup:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=0,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            feature_maps_loss=feature_maps_loss,
            num_epochs=warmup,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=warmup_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True)

        if freeze_encoder:
            from pytorch_toolbelt.optimization.functional import freeze_model

            freeze_model(model.encoder, freeze_parameters=True, freeze_bn=None)

        optimizer = get_optimizer(
            "Ranger", get_optimizable_parameters(model), weight_decay=weight_decay, learning_rate=3e-4
        )
        scheduler = None

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout, "(Non-default)" if dropout is not None else "")
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "warmup"),
            num_epochs=warmup,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_warmup.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        # Restore state of best model
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if run_train:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        if negative_image_dir:
            negatives_ds = get_negatives_ds(
                negative_image_dir, fold=fold, features=required_features, max_images=16536
            )
            train_ds = train_ds + negatives_ds
            train_sampler = None  # TODO: Add proper support of sampler
            print("Adding", len(negatives_ds), "negative samples to training set")

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=num_epochs,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")

        # Restore state of best model
        clean_checkpoint(best_checkpoint, model_checkpoint)
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if fine_tune:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation="light",
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=fine_tune,
            mixup=False,
            cutmix=False,
            tsa=False,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            "SGD", get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            "cos", optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "finetune"),
            num_epochs=fine_tune,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        best_checkpoint = os.path.join(log_dir, "finetune", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_finetune.pth")

        clean_checkpoint(best_checkpoint, model_checkpoint)
        unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        del optimizer, loaders, runner, callbacks
コード例 #30
0
ファイル: trace.py プロジェクト: smivv/catalyst
def trace_model_from_checkpoint(
    logdir: Path,
    method_name: str,
    checkpoint_name: str,
    stage: str = None,
    loader: Union[str, int] = None,
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    device: Device = "cpu",
):
    """
    Traces model using created experiment and runner.

    Args:
        logdir (Union[str, Path]): Path to Catalyst logdir with model
        checkpoint_name (str): Name of model checkpoint to use
        stage (str): experiment's stage name
        loader (Union[str, int]): experiment's loader name or its index
        method_name (str): Model's method name that will be
            used as entrypoint during tracing
        mode (str): Mode for model to trace (``train`` or ``eval``)
        requires_grad (bool): Flag to use grads
        opt_level (str): AMP FP16 init level
        device (str): Torch device

    Returns:
        the traced model
    """
    config_path = logdir / "configs" / "_config.json"
    checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth"
    print("Load config")
    config: Dict[str, dict] = safitty.load(config_path)

    # Get expdir name
    config_expdir = safitty.get(config, "args", "expdir", apply=Path)
    # We will use copy of expdir from logs for reproducibility
    expdir = Path(logdir) / "code" / config_expdir.name

    print("Import experiment and runner from logdir")
    ExperimentType, RunnerType = import_experiment_and_runner(expdir)
    experiment: Experiment = ExperimentType(config)

    print(f"Load model state from checkpoints/{checkpoint_name}.pth")
    if stage is None:
        stage = list(experiment.stages)[0]

    model = experiment.get_model(stage)
    checkpoint = utils.load_checkpoint(checkpoint_path)
    utils.unpack_checkpoint(checkpoint, model=model)

    runner: RunnerType = RunnerType()
    runner.model, runner.device = model, device

    if loader is None:
        loader = 0
    batch = experiment.get_native_batch(stage, loader)

    print("Tracing")
    traced = trace_model(
        model,
        runner,
        batch,
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
        opt_level=opt_level,
        device=device,
    )

    print("Done")
    return traced