def run_stage(self, stage, number): logging.info(f"start {stage}") stage_cfg = update_config(self.base_cfg, Dict(self.base_cfg.stages[stage])) weights_path = self.get_stage_weights_path(stage) previous_checkpoint = self.get_best_previous_checkpoint(number) if previous_checkpoint: print(f"start from previous {previous_checkpoint}") pipeline = ImageNetLightningPipeline.load_from_checkpoint_params( checkpoint_path=previous_checkpoint, hparams=stage_cfg) else: pipeline = ImageNetLightningPipeline(stage_cfg) trainer = object_from_dict( stage_cfg.trainer, checkpoint_callback=object_from_dict(stage_cfg.checkpoint, filepath=weights_path), logger=object_from_dict(stage_cfg.logger, path=self.log_path, run_name=f"{stage}", version=self.base_cfg.version), ) trainer.fit(pipeline) del pipeline, trainer
def main(): # val_dataset = HakunaDataset(mode="val", path=PATHS["data.path"], long_side=320, crop_size=(192, 256)) # val_dataloader = DataLoader(val_dataset, num_workers=4, batch_size=8, collate_fn=fast_collate) cfg = Dict(Fire(fit)) set_determenistic(cfg.seed) add_dict = {"val_data": {"batch_size": 8}} add_dict = Dict(add_dict) print(add_dict, "\t") cfg = Dict(update_config(cfg, add_dict)) print("\t") print(cfg.data) loader = object_from_dict(cfg.val_data) batch_size = loader.batch_size imagenet_mean = np.array([0.485, 0.456, 0.406]) imagenet_std = np.array([0.229, 0.224, 0.225]) # for idx, batch in enumerate(loader): # images, targets = batch # images = images.numpy() # targets = targets.numpy() # plt.figure() # for i in range(images.shape[0]): # plt.subplot(2, 4, i + 1) # image = np.transpose(images[i], (1, 2, 0)) # plt.title(np.argmax(targets[i])) # plt.imshow(image) for images, targets in tqdm(loader, total=len(loader)): print(images.shape) print(targets.shape) img = np.transpose(images.cpu().numpy(), (0, 2, 3, 1)) labels = targets.cpu().numpy() plt.figure(figsize=(25, 35)) for i in range(batch_size): plt.subplot(2, 4, i + 1) shw = np.uint8( np.clip(255 * (imagenet_mean * img[i] + imagenet_std), 0, 255)) plt.imshow(shw) plt.show()
def main(): cfg = Dict(Fire(fit)) set_determenistic(cfg.seed) add_dict = {"data": {"batch_size": 24}} add_dict = Dict(add_dict) print(add_dict, "\t") cfg = Dict(update_config(cfg, add_dict)) print("\t") print(cfg.data) loader = object_from_dict(cfg.data, mode="val") batch_size = loader.batch_size side = int(np.sqrt(batch_size)) imagenet_mean = np.array([0.485, 0.456, 0.406]) imagenet_std = np.array([0.229, 0.224, 0.225]) for images, targets in tqdm(loader, total=len(loader)): print(images.shape) print(targets.shape) img = np.transpose(images.cpu().numpy(), (0, 2, 3, 1)) labels = targets.cpu().numpy() plt.figure(figsize=(25, 35)) for i in range(batch_size): plt.subplot(side, side, i + 1) shw = np.uint8( np.clip(255 * (imagenet_mean * img[i] + imagenet_std), 0, 255)) plt.imshow(shw) plt.show() break
return torch.utils.data.distributed.DistributedSampler # for pytorch_lightning def dataset(self): # for pytorch_lightning return None if __name__ == "__main__": cfg = Dict(Fire(fit)) set_determenistic(cfg.seed) add_dict = {"data": {"batch_size": 25}} add_dict = Dict(add_dict) print(add_dict, "\t") cfg = Dict(update_config(cfg, add_dict)) print("\t") print(cfg) loader = object_from_dict(cfg.data, mode="val") batch_size = loader.batch_size side = int(np.sqrt(batch_size)) imagenet_mean = np.array([0.485, 0.456, 0.406]) imagenet_std = np.array([0.229, 0.224, 0.225]) for images, targets in tqdm(loader, total=len(loader)): print(images.shape) print(targets.shape)