Ejemplo n.º 1
0
def main(args):
    model = models.__dict__[args.arch](pretrained=True)
    model = model.eval()
    model, device = UtilsFactory.prepare_model(model)

    labels = json.loads(open(args.labels).read())

    i2k = Images2Keywords(model, args.n_keywords, labels)

    images_df = pd.read_csv(args.in_csv)
    images_df = images_df.reset_index().drop("index", axis=1)
    images_df = list(images_df.to_dict("index").values())

    open_fn = ImageReader(input_key=args.img_col,
                          output_key="image",
                          datapath=args.datapath)

    dataloader = UtilsFactory.create_loader(images_df,
                                            open_fn,
                                            batch_size=args.batch_size,
                                            workers=args.n_workers,
                                            dict_transform=dict_transformer)

    keywords = []
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for batch in dataloader:
            keywords_batch = i2k(batch["image"].to(device))
            keywords += keywords_batch

    input_csv = pd.read_csv(args.in_csv)
    input_csv[args.keywords_col] = keywords
    input_csv.to_csv(args.out_csv, index=False)
Ejemplo n.º 2
0
def trace_model_from_checkpoint(logdir, method_name):
    config_path = logdir / "configs/_config.json"
    checkpoint_path = logdir / "checkpoints/best.pth"
    print("Load config")
    config: Dict[str, dict] = safitty.load(config_path)

    # Get expdir name
    config_expdir = Path(config["args"]["expdir"])
    # We will use copy of expdir from logs for reproducibility
    expdir_from_logs = Path(logdir) / "code" / config_expdir.name

    print("Import experiment and runner from logdir")
    ExperimentType, RunnerType = \
        import_experiment_and_runner(expdir_from_logs)
    experiment: Experiment = ExperimentType(config)

    print("Load model state from checkpoints/best.pth")
    model = experiment.get_model(next(iter(experiment.stages)))
    checkpoint = UtilsFactory.load_checkpoint(checkpoint_path)
    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    print("Tracing")
    traced = trace_model(model, experiment, RunnerType, method_name)

    print("Done")
    return traced
Ejemplo n.º 3
0
def main(args, _=None):
    global IMG_SIZE

    IMG_SIZE = (args.img_size, args.img_size)

    model = ResnetEncoder(arch=args.arch, pooling=args.pooling)
    model = model.eval()
    model, device = UtilsFactory.prepare_model(model)

    images_df = pd.read_csv(args.in_csv)
    images_df = images_df.reset_index().drop("index", axis=1)
    images_df = list(images_df.to_dict("index").values())

    open_fn = ImageReader(
        input_key=args.img_col, output_key="image", datapath=args.datapath
    )

    dataloader = UtilsFactory.create_loader(
        images_df,
        open_fn,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        dict_transform=dict_transformer
    )

    features = []
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for batch in dataloader:
            features_ = model(batch["image"].to(device))
            features_ = features_.cpu().detach().numpy()
            features.append(features_)

    features = np.concatenate(features, axis=0)
    np.save(args.out_npy, features)
Ejemplo n.º 4
0
 def _preprocess_model_for_stage(self, stage: str, model: _Model):
     stage_index = self.stages.index(stage)
     if stage_index > 0:
         checkpoint_path = \
             f"{self.logdir}/checkpoints/best.pth"
         checkpoint = UtilsFactory.load_checkpoint(checkpoint_path)
         UtilsFactory.unpack_checkpoint(checkpoint, model=model)
     return model
Ejemplo n.º 5
0
def main(args):
    images_df = pd.read_csv(args.in_csv)
    images_df = images_df.reset_index().drop("index", axis=1)
    images_df = list(images_df.to_dict("index").values())

    if args.fasttext_model is not None:
        encode_fn = create_fasttext_encode_fn(args.fasttext_model,
                                              normalize=args.normalize)
    elif args.w2v_model is not None:
        encode_fn = create_gensim_encode_fn(args.w2v_model,
                                            sep=args.txt_sep,
                                            normalize=args.normalize)
    else:
        raise NotImplementedError

    open_fn = LambdaReader(input_key=args.txt_col,
                           output_key="txt",
                           encode_fn=encode_fn)

    dataloader = UtilsFactory.create_loader(images_df,
                                            open_fn,
                                            batch_size=args.batch_size,
                                            workers=args.n_workers)

    features = []
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    for batch in dataloader:
        features_ = batch["txt"]
        features.append(features_)

    features = np.concatenate(features, axis=0)
    np.save(args.out_npy, features)
Ejemplo n.º 6
0
def report_by_dir(folder):
    checkpoint = f"{folder}/best.pth"
    checkpoint = UtilsFactory.load_checkpoint(checkpoint)
    exp_name = folder.rsplit("/", 1)[-1]
    row = {"exp_name": exp_name, "epoch": checkpoint["epoch"]}
    row.update(checkpoint["valid_metrics"])
    return row
Ejemplo n.º 7
0
    def save_checkpoint(self,
                        logdir,
                        checkpoint,
                        is_best,
                        save_n_best=5,
                        main_metric="loss",
                        minimize_metric=True):
        suffix = f"{checkpoint['stage']}.{checkpoint['epoch']}"
        filepath = UtilsFactory.save_checkpoint(
            logdir=f"{logdir}/checkpoints/",
            checkpoint=checkpoint,
            suffix=suffix,
            is_best=is_best,
            is_last=True)

        checkpoint_metric = checkpoint["valid_metrics"].get(main_metric, None)
        checkpoint_metric = checkpoint_metric or checkpoint.get("epoch", -1)
        self.top_best_metrics.append((filepath, checkpoint_metric))
        self.top_best_metrics = sorted(self.top_best_metrics,
                                       key=lambda x: x[1],
                                       reverse=not minimize_metric)
        if len(self.top_best_metrics) > save_n_best:
            last_item = self.top_best_metrics.pop(-1)
            last_filepath = last_item[0]
            os.remove(last_filepath)
Ejemplo n.º 8
0
 def _init(self):
     """
     Inner method for children's classes for model specific initialization.
     As baseline, checks device support and puts model on it.
     :return:
     """
     self.model, self.device = UtilsFactory.prepare_model(self.model)
Ejemplo n.º 9
0
    def load_checkpoint(*, filename, state):
        if os.path.isfile(filename):
            print("=> loading checkpoint \"{}\"".format(filename))
            checkpoint = UtilsFactory.load_checkpoint(filename)

            state.epoch = checkpoint["epoch"]

            UtilsFactory.unpack_checkpoint(checkpoint,
                                           model=state.model,
                                           criterion=state.criterion,
                                           optimizer=state.optimizer,
                                           scheduler=state.scheduler)

            print("loaded checkpoint \"{}\" (epoch {})".format(
                filename, checkpoint["epoch"]))
        else:
            raise Exception("no checkpoint found at \"{}\"".format(filename))
Ejemplo n.º 10
0
 def save(self):
     if self.epoch % self.save_period == 0:
         checkpoint = self.algorithm.prepare_checkpoint()
         checkpoint["epoch"] = self.epoch
         filename = UtilsFactory.save_checkpoint(logdir=self.logdir,
                                                 checkpoint=checkpoint,
                                                 suffix=str(self.epoch))
         print("Checkpoint saved to: %s" % filename)
Ejemplo n.º 11
0
    def create_loaders(self, train_df, val_df):
        train_loader = UtilsFactory.create_loader(train_df,
                                                  open_fn=self.get_input_pair,
                                                  batch_size=self.batch_size,
                                                  num_workers=self.num_workers,
                                                  shuffle=True)

        valid_loader = UtilsFactory.create_loader(val_df,
                                                  open_fn=self.get_input_pair,
                                                  batch_size=self.batch_size,
                                                  num_workers=self.num_workers,
                                                  shuffle=True)

        loaders = collections.OrderedDict()
        loaders['train'] = train_loader
        loaders['valid'] = valid_loader

        return loaders
Ejemplo n.º 12
0
def load_model(network, model_weights_path, channels, neighbours):
    device = 'gpu' if torch.cuda.is_available() else 'cpu'
    model = get_model(network)
    model.encoder.conv1 = nn.Conv2d(
        count_channels(channels)*neighbours, 64, kernel_size=(7, 7),
        stride=(2, 2), padding=(3, 3), bias=False
    )
    model, device = UtilsFactory.prepare_model(model)
    model.load_state_dict(torch.load(model_weights_path, map_location=torch.device(device)))
    return model, device
Ejemplo n.º 13
0
    def create_test_loaders(self, test_df):
        test_loader = UtilsFactory.create_loader(test_df,
                                                 open_fn=self.get_input_pair,
                                                 batch_size=self.batch_size,
                                                 num_workers=self.num_workers,
                                                 shuffle=True)

        loaders = collections.OrderedDict()
        loaders['test'] = test_df
        return loaders
Ejemplo n.º 14
0
 def _save_checkpoint(self):
     if self.epoch % self.save_period == 0:
         checkpoint = self.algorithm.pack_checkpoint()
         checkpoint["epoch"] = self.epoch
         filename = UtilsFactory.save_checkpoint(
             logdir=self.logdir,
             checkpoint=checkpoint,
             suffix=str(self.epoch)
         )
         print(f"Checkpoint saved to: {filename}")
Ejemplo n.º 15
0
def predict(data_path, model_weights_path, network, test_df_path, save_path,
            size, channels, neighbours, classification_head):
    model = get_model(network, classification_head)
    model.encoder.conv1 = nn.Conv2d(count_channels(channels) * neighbours,
                                    64,
                                    kernel_size=(7, 7),
                                    stride=(2, 2),
                                    padding=(3, 3),
                                    bias=False)

    model, device = UtilsFactory.prepare_model(model)

    if classification_head:
        model.load_state_dict(torch.load(model_weights_path))
    else:
        checkpoint = torch.load(model_weights_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])

    test_df = pd.read_csv(test_df_path)

    predictions_path = os.path.join(save_path, "predictions")

    if not os.path.exists(predictions_path):
        os.makedirs(predictions_path, exist_ok=True)
        print("Prediction directory created.")

    for _, image_info in tqdm(test_df.iterrows()):
        filename = '_'.join([image_info['name'], image_info['position']])
        image_path = get_filepath(data_path,
                                  image_info['dataset_folder'],
                                  'images',
                                  filename,
                                  file_type='tiff')

        image_tensor = filter_by_channels(read_tensor(image_path), channels,
                                          neighbours)
        if image_tensor.ndim == 2:
            image_tensor = np.expand_dims(image_tensor, -1)

        image = transforms.ToTensor()(image_tensor)
        if classification_head:
            prediction, label = model.predict(
                image.view(1,
                           count_channels(channels) * neighbours, size,
                           size).to(device, dtype=torch.float))
        else:
            prediction = model.predict(
                image.view(1,
                           count_channels(channels) * neighbours, size,
                           size).to(device, dtype=torch.float))

        result = prediction.view(size, size).detach().cpu().numpy()

        cv.imwrite(get_filepath(predictions_path, filename, file_type='png'),
                   result * 255)
Ejemplo n.º 16
0
def train(args):
    set_random_seed(42)
    model = get_model(args.network)
    print('Loading model')
    model.encoder.conv1 = nn.Conv2d(
        count_channels(args.channels), 64, kernel_size=(7, 7),
        stride=(2, 2), padding=(3, 3), bias=False)
    model, device = UtilsFactory.prepare_model(model)

    train_df = pd.read_csv(args.train_df).to_dict('records')
    val_df = pd.read_csv(args.val_df).to_dict('records')

    ds = Dataset(args.channels, args.dataset_path, args.image_size, args.batch_size, args.num_workers)
    loaders = ds.create_loaders(train_df, val_df)
    print(loaders['train'].dataset.data)

    criterion = BCE_Dice_Loss(bce_weight=0.2)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[10, 20, 40], gamma=0.3
    )

    save_path = os.path.join(
        args.logdir,
        '_'.join([args.network, *args.channels])
    )

    # model runner
    runner = SupervisedRunner()

    # model training
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        callbacks=[
            DiceCallback()
        ],
        logdir=save_path,
        num_epochs=args.epochs,
        verbose=True
    )

    infer_loader = collections.OrderedDict([('infer', loaders['valid'])])
    runner.infer(
        model=model,
        loaders=infer_loader,
        callbacks=[
            CheckpointCallback(resume=f'{save_path}/checkpoints/best.pth'),
            InferCallback()
        ],
    )
Ejemplo n.º 17
0
    def _prepare_model(self, stage: str = None):
        """
        Inner method for children's classes for model specific initialization.
        As baseline, checks device support and puts model on it.
        :return:
        """

        if stage is not None:
            self.model = self.experiment.get_model(stage)

        self.model, self.device = \
            UtilsFactory.prepare_model(self.model)
Ejemplo n.º 18
0
def load_model(network, model_weights_path, channels):
    model = get_model(network)
    model.encoder.conv1 = torch.nn.Conv2d(count_channels(channels),
                                          64,
                                          kernel_size=(7, 7),
                                          stride=(2, 2),
                                          padding=(3, 3),
                                          bias=False)
    checkpoint = torch.load(model_weights_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model, device = UtilsFactory.prepare_model(model.eval())

    return model, device
Ejemplo n.º 19
0
 def load_actor_weights(self):
     if self.resume is not None:
         checkpoint = UtilsFactory.load_checkpoint(self.resume)
         weights = checkpoint[f"actor_state_dict"]
         self.actor.load_state_dict(weights)
     elif self.redis_server is not None:
         weights = deserialize(
             self.redis_server.get(f"{self.redis_prefix}_actor_weights"))
         weights = {k: self.to_tensor(v) for k, v in weights.items()}
         self.actor.load_state_dict(weights)
     else:
         raise NotImplementedError
     self.actor.eval()
Ejemplo n.º 20
0
def train(args):
    set_random_seed(42)
    for fold in range(args.folds):
        model = get_model(args.network)

        print("Loading model")
        model, device = UtilsFactory.prepare_model(model)
        train_df = pd.read_csv(
            os.path.join(args.dataset_path,
                         f'train{fold}.csv')).to_dict('records')
        val_df = pd.read_csv(os.path.join(args.dataset_path,
                                          f'val{fold}.csv')).to_dict('records')

        ds = Dataset(args.channels, args.dataset_path, args.image_size,
                     args.batch_size, args.num_workers)
        loaders = ds.create_loaders(train_df, val_df)

        criterion = BCE_Dice_Loss(bce_weight=0.2)

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[10, 20, 40], gamma=0.3)

        # model runner
        runner = SupervisedRunner()

        save_path = os.path.join(args.logdir, f'fold{fold}')

        # model training
        runner.train(model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     loaders=loaders,
                     callbacks=[DiceCallback()],
                     logdir=save_path,
                     num_epochs=args.epochs,
                     verbose=True)

        infer_loader = collections.OrderedDict([("infer", loaders["valid"])])
        runner.infer(
            model=model,
            loaders=infer_loader,
            callbacks=[
                CheckpointCallback(resume=f'{save_path}/checkpoints/best.pth'),
                InferCallback()
            ],
        )

        print(f'Fold {fold} ended')
Ejemplo n.º 21
0
 def save_checkpoint(
     self,
     logdir,
     checkpoint,
     is_best,
 ):
     suffix = f"{checkpoint['stage']}.iter.{self.count}"
     filepath = UtilsFactory.save_checkpoint(
         logdir=f"{logdir}/checkpoints/",
         checkpoint=checkpoint,
         suffix=suffix,
         is_best=is_best,
         is_last=True)
     print(f"\nSaved checkpoint at {filepath}")
Ejemplo n.º 22
0
    def __init__(
        self,
        agent: Union[ActorSpec, CriticSpec],
        env: EnvironmentSpec,
        db_server: DBSpec = None,
        exploration_handler: ExplorationHandler = None,
        logdir: str = None,
        id: int = 0,
        mode: str = "infer",  # train/valid/infer
        weights_sync_period: int = 1,
        weights_sync_mode: str = None,
        seeds: List = None,
        trajectory_limit: int = None,
        force_store: bool = False,
        gc_period: int = 10,
    ):
        self._device = UtilsFactory.get_device()
        self._sampler_id = id

        self._infer = mode == "infer"
        self.seeds = seeds
        self._seeder = Seeder(
            init_seed=42 + id,
            max_seed=len(seeds) if seeds is not None else None)

        # logging
        self._prepare_logger(logdir, mode)
        self._sample_flag = mp.Value(c_bool, False)

        # environment, model, exploration & action handlers
        self.env = env
        self.agent = agent
        self.exploration_handler = exploration_handler
        self.trajectory_index = 0
        self.trajectory_sampler = TrajectorySampler(
            env=self.env,
            agent=self.agent,
            device=self._device,
            deterministic=self._infer,
            sample_flag=self._sample_flag)

        # synchronization configuration
        self.db_server = db_server
        self._weights_sync_period = weights_sync_period
        self._weights_sync_mode = weights_sync_mode
        self._trajectory_limit = trajectory_limit or np.iinfo(np.int32).max
        self._force_store = force_store
        self._gc_period = gc_period
        self._db_loop_thread = None
Ejemplo n.º 23
0
    def load_checkpoint(self, filepath, load_optimizer=True):
        checkpoint = UtilsFactory.load_checkpoint(filepath)
        for key in ["actor", "critic"]:
            value_l = getattr(self, key, None)
            if value_l is not None:
                value_r = checkpoint[f"{key}_state_dict"]
                value_l.load_state_dict(value_r)

            if load_optimizer:
                for key2 in ["optimizer", "scheduler"]:
                    key2 = f"{key}_{key2}"
                    value_l = getattr(self, key2, None)
                    if value_l is not None:
                        value_r = checkpoint[f"{key2}_state_dict"]
                        value_l.load_state_dict(value_r)
Ejemplo n.º 24
0
    def _get_optimizer(self, *, model_params, **params):
        key_value_flag = params.pop("_key_value", False)

        if key_value_flag:
            optimizer = {}
            for key, params_ in params.items():
                optimizer[key] = self._get_optimizer(model_params=model_params,
                                                     **params_)
        else:
            load_from_previous_stage = \
                params.pop("load_from_previous_stage", False)
            optimizer = OPTIMIZERS.get_from_params(**params,
                                                   params=model_params)

            if load_from_previous_stage:
                checkpoint_path = \
                    f"{self.logdir}/checkpoints/best.pth"
                checkpoint = UtilsFactory.load_checkpoint(checkpoint_path)
                UtilsFactory.unpack_checkpoint(checkpoint, optimizer=optimizer)
                for key, value in params.items():
                    for pg in optimizer.param_groups:
                        pg[key] = value

        return optimizer
Ejemplo n.º 25
0
    def create_loaders(self, train_df, val_df):
        labels = [(x["mask_pxl"] == 0) * 1 for x in train_df]
        sampler = BalanceClassSampler(labels, mode="upsampling")
        train_loader = UtilsFactory.create_loader(train_df,
                                                  open_fn=self.get_input_pair,
                                                  batch_size=self.batch_size,
                                                  num_workers=self.num_workers,
                                                  shuffle=sampler is None,
                                                  sampler=sampler)

        labels = [(x["mask_pxl"] == 0) * 1 for x in val_df]
        sampler = BalanceClassSampler(labels, mode="upsampling")
        valid_loader = UtilsFactory.create_loader(val_df,
                                                  open_fn=self.get_input_pair,
                                                  batch_size=self.batch_size,
                                                  num_workers=self.num_workers,
                                                  shuffle=sampler is None,
                                                  sampler=sampler)

        loaders = collections.OrderedDict()
        loaders['train'] = train_loader
        loaders['valid'] = valid_loader

        return loaders
Ejemplo n.º 26
0
    def __init__(
        self,
        agent: Union[ActorSpec, CriticSpec],
        env: EnvironmentSpec,
        db_server: DBSpec = None,
        exploration_handler: ExplorationHandler = None,
        logdir: str = None,
        id: int = 0,
        mode: str = "infer",
        buffer_size: int = int(1e4),
        weights_sync_period: int = 1,
        seeds: List = None,
        episode_limit: int = None,
        force_store: bool = False,
        gc_period: int = 10,
    ):
        self._device = UtilsFactory.prepare_device()
        self._seed = 42 + id
        self._sampler_id = id

        self._infer = mode == "infer"
        self.seeds = seeds

        # logging
        self._prepare_logger(logdir, mode)

        # environment, model, exploration & action handlers
        self.env = env
        self.agent = agent
        self.exploration_handler = exploration_handler
        self.episode_index = 0
        self.episode_runner = EpisodeRunner(
            env=self.env,
            agent=self.agent,
            device=self._device,
            capacity=buffer_size,
            deterministic=self._infer
        )

        # synchronization configuration
        self.db_server = db_server
        self.weights_sync_period = weights_sync_period
        self.episode_limit = episode_limit or _BIG_NUM
        self._force_store = force_store
        self._sampler_weight_mode = \
            "critic" if env.discrete_actions else "actor"
        self._gc_period = gc_period
Ejemplo n.º 27
0
def get_trainer_components(
    *,
    agent,
    loss_params=None,
    optimizer_params=None,
    scheduler_params=None,
    grad_clip_params=None
):
    # criterion
    loss_params = _copy_params(loss_params)
    criterion = CRITERIONS.get_from_params(**loss_params)
    if criterion is not None \
            and torch.cuda.is_available():
        criterion = criterion.cuda()

    # optimizer
    agent_params = UtilsFactory.get_optimizable_params(
        agent.parameters())
    optimizer_params = _copy_params(optimizer_params)
    optimizer = OPTIMIZERS.get_from_params(
        **optimizer_params,
        params=agent_params
    )

    # scheduler
    scheduler_params = _copy_params(scheduler_params)
    scheduler = SCHEDULERS.get_from_params(
        **scheduler_params,
        optimizer=optimizer
    )

    # grad clipping
    grad_clip_params = _copy_params(grad_clip_params)
    grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params)

    result = {
        "loss_params": loss_params,
        "criterion": criterion,
        "optimizer_params": optimizer_params,
        "optimizer": optimizer,
        "scheduler_params": scheduler_params,
        "scheduler": scheduler,
        "grad_clip_params": grad_clip_params,
        "grad_clip_fn": grad_clip_fn
    }

    return result
Ejemplo n.º 28
0
    def load_checkpoint(self,
                        *,
                        filepath: str = None,
                        db_server: DBSpec = None):
        if filepath is not None:
            checkpoint = UtilsFactory.load_checkpoint(filepath)
            weights = checkpoint[f"{self._sampler_weight_mode}_state_dict"]
            self.agent.load_state_dict(weights)
        elif db_server is not None:
            weights = db_server.load_weights(prefix=self._sampler_weight_mode)
            weights = {k: self._to_tensor(v) for k, v in weights.items()}
            self.agent.load_state_dict(weights)
        else:
            raise NotImplementedError

        self.agent.to(self._device)
        self.agent.eval()
Ejemplo n.º 29
0
def train(args):
    set_random_seed(42)
    model = get_model('fpn50_season')

    print("Loading model")
    model, device = UtilsFactory.prepare_model(model)

    train_df = pd.read_csv(args.train_df).to_dict('records')
    val_df = pd.read_csv(args.val_df).to_dict('records')

    ds = SeasonDataset(args.channels, args.dataset_path, args.image_size,
                       args.batch_size, args.num_workers)
    loaders = ds.create_loaders(train_df, val_df)

    criterion = BCE_Dice_Loss(bce_weight=0.2)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[10, 20, 40],
                                                     gamma=0.3)

    best_valid_dice = -1
    best_epoch = -1
    best_accuracy = -1

    for epoch in range(args.epochs):
        segmentation_weight = 0.8

        train_iter(loaders['train'], model, device, criterion, optimizer,
                   segmentation_weight)
        dice_mean, valid_accuracy = valid_iter(loaders['valid'], model, device,
                                               criterion, segmentation_weight)

        if dice_mean > best_valid_dice:
            best_valid_dice = dice_mean
            best_epoch = epoch
            best_accuracy = valid_accuracy
            os.makedirs(f'{args.logdir}/weights', exist_ok=True)
            torch.save(model.state_dict(),
                       f'{args.logdir}/weights/epoch{epoch}.pth')

        scheduler.step()
        print("Epoch {0} ended".format(epoch))

    print("Best epoch: ", best_epoch, "with dice ", best_valid_dice,
          "and season prediction accuracy", best_accuracy)
Ejemplo n.º 30
0
    def load_checkpoint(self, filepath, load_optimizer=True):
        super().load_checkpoint(filepath, load_optimizer)

        checkpoint = UtilsFactory.load_checkpoint(filepath)
        key = "critics"
        for i in range(len(self.critics)):
            value_l = getattr(self, key, None)
            value_l = value_l[i] if value_l is not None else None
            if value_l is not None:
                value_r = checkpoint[f"{key}{i}_state_dict"]
                value_l.load_state_dict(value_r)
            if load_optimizer:
                for key2 in ["optimizer", "scheduler"]:
                    key2 = f"{key}_{key2}"
                    value_l = getattr(self, key2, None)
                    if value_l is not None:
                        value_r = checkpoint[f"{key2}_state_dict"]
                        value_l.load_state_dict(value_r)
Ejemplo n.º 31
0
    def prepare_loaders(
            *,
            mode: str,
            stage: str = None,
            n_workers: int = None,
            batch_size: int = None,
            datapath=None,
            in_csv=None,
            in_csv_train=None, in_csv_valid=None, in_csv_infer=None,
            train_folds=None, valid_folds=None,
            tag2class=None, class_column=None, tag_column=None,
            folds_seed=42, n_folds=5):
        loaders = collections.OrderedDict()

        df, df_train, df_valid, df_infer = parse_in_csvs(
            in_csv=in_csv,
            in_csv_train=in_csv_train, in_csv_valid=in_csv_valid,
            in_csv_infer=in_csv_infer,
            train_folds=train_folds, valid_folds=valid_folds,
            tag2class=tag2class,
            class_column=class_column, tag_column=tag_column,
            folds_seed=folds_seed, n_folds=n_folds)

        open_fn = [
            ImageReader(
                row_key="filepath", dict_key="image",
                datapath=datapath),
            ScalarReader(
                row_key="class", dict_key="targets",
                default_value=-1, dtype=np.int64)
        ]
        open_fn = ReaderCompose(readers=open_fn)

        if len(df_train) > 0:
            labels = [x["class"] for x in df_train]
            sampler = BalanceClassSampler(labels, mode="upsampling")

            train_loader = UtilsFactory.create_loader(
                data_source=df_train,
                open_fn=open_fn,
                dict_transform=DataSource.prepare_transforms(
                    mode="train", stage=stage),
                dataset_cache_prob=-1,
                batch_size=batch_size,
                workers=n_workers,
                shuffle=sampler is None,
                sampler=sampler)

            print("Train samples", len(train_loader) * batch_size)
            print("Train batches", len(train_loader))
            loaders["train"] = train_loader

        if len(df_valid) > 0:
            sampler = None

            valid_loader = UtilsFactory.create_loader(
                data_source=df_valid,
                open_fn=open_fn,
                dict_transform=DataSource.prepare_transforms(
                    mode="valid", stage=stage),
                dataset_cache_prob=-1,
                batch_size=batch_size,
                workers=n_workers,
                shuffle=False,
                sampler=sampler)

            print("Valid samples", len(valid_loader) * batch_size)
            print("Valid batches", len(valid_loader))
            loaders["valid"] = valid_loader

        if len(df_infer) > 0:
            infer_loader = UtilsFactory.create_loader(
                data_source=df_infer,
                open_fn=open_fn,
                dict_transform=DataSource.prepare_transforms(
                    mode="infer", stage=None),
                dataset_cache_prob=-1,
                batch_size=batch_size,
                workers=n_workers,
                shuffle=False,
                sampler=None)

            print("Infer samples", len(infer_loader) * batch_size)
            print("Infer batches", len(infer_loader))
            loaders["infer"] = infer_loader

        return loaders