Пример #1
0
    def run_stage(self, stage, number):
        logging.info(f"start {stage}")
        stage_cfg = update_config(self.base_cfg,
                                  Dict(self.base_cfg.stages[stage]))
        weights_path = self.get_stage_weights_path(stage)

        previous_checkpoint = self.get_best_previous_checkpoint(number)
        if previous_checkpoint:
            print(f"start from previous {previous_checkpoint}")
            pipeline = ImageNetLightningPipeline.load_from_checkpoint_params(
                checkpoint_path=previous_checkpoint, hparams=stage_cfg)
        else:
            pipeline = ImageNetLightningPipeline(stage_cfg)

        trainer = object_from_dict(
            stage_cfg.trainer,
            checkpoint_callback=object_from_dict(stage_cfg.checkpoint,
                                                 filepath=weights_path),
            logger=object_from_dict(stage_cfg.logger,
                                    path=self.log_path,
                                    run_name=f"{stage}",
                                    version=self.base_cfg.version),
        )

        trainer.fit(pipeline)
        del pipeline, trainer
Пример #2
0
def main():
    # val_dataset = HakunaDataset(mode="val", path=PATHS["data.path"], long_side=320, crop_size=(192, 256))
    # val_dataloader = DataLoader(val_dataset, num_workers=4, batch_size=8, collate_fn=fast_collate)

    cfg = Dict(Fire(fit))
    set_determenistic(cfg.seed)

    add_dict = {"val_data": {"batch_size": 8}}
    add_dict = Dict(add_dict)

    print(add_dict, "\t")

    cfg = Dict(update_config(cfg, add_dict))

    print("\t")

    print(cfg.data)
    loader = object_from_dict(cfg.val_data)
    batch_size = loader.batch_size
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])

    # for idx, batch in enumerate(loader):
    #     images, targets = batch
    #     images = images.numpy()
    #     targets = targets.numpy()
    #     plt.figure()
    #     for i in range(images.shape[0]):
    #         plt.subplot(2, 4, i + 1)
    #         image = np.transpose(images[i], (1, 2, 0))
    #         plt.title(np.argmax(targets[i]))
    #         plt.imshow(image)

    for images, targets in tqdm(loader, total=len(loader)):
        print(images.shape)
        print(targets.shape)

        img = np.transpose(images.cpu().numpy(), (0, 2, 3, 1))
        labels = targets.cpu().numpy()

        plt.figure(figsize=(25, 35))
        for i in range(batch_size):
            plt.subplot(2, 4, i + 1)
            shw = np.uint8(
                np.clip(255 * (imagenet_mean * img[i] + imagenet_std), 0, 255))
            plt.imshow(shw)
        plt.show()
Пример #3
0
def main():
    cfg = Dict(Fire(fit))
    set_determenistic(cfg.seed)

    add_dict = {"data": {"batch_size": 24}}
    add_dict = Dict(add_dict)

    print(add_dict, "\t")

    cfg = Dict(update_config(cfg, add_dict))

    print("\t")

    print(cfg.data)
    loader = object_from_dict(cfg.data, mode="val")

    batch_size = loader.batch_size
    side = int(np.sqrt(batch_size))
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])

    for images, targets in tqdm(loader, total=len(loader)):
        print(images.shape)
        print(targets.shape)

        img = np.transpose(images.cpu().numpy(), (0, 2, 3, 1))
        labels = targets.cpu().numpy()

        plt.figure(figsize=(25, 35))
        for i in range(batch_size):
            plt.subplot(side, side, i + 1)
            shw = np.uint8(
                np.clip(255 * (imagenet_mean * img[i] + imagenet_std), 0, 255))
            plt.imshow(shw)

        plt.show()

        break
Пример #4
0
        return torch.utils.data.distributed.DistributedSampler  # for pytorch_lightning

    def dataset(self):  # for pytorch_lightning
        return None


if __name__ == "__main__":
    cfg = Dict(Fire(fit))
    set_determenistic(cfg.seed)

    add_dict = {"data": {"batch_size": 25}}
    add_dict = Dict(add_dict)

    print(add_dict, "\t")

    cfg = Dict(update_config(cfg, add_dict))

    print("\t")

    print(cfg)
    loader = object_from_dict(cfg.data, mode="val")

    batch_size = loader.batch_size
    side = int(np.sqrt(batch_size))
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])

    for images, targets in tqdm(loader, total=len(loader)):
        print(images.shape)
        print(targets.shape)