def simple_simclr_example(): from pl_bolts.models.self_supervised import SimCLR from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform # Load ResNet50 pretrained using SimCLR on ImageNet. weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) #train_dataset = MyDataset(transforms=SimCLRTrainDataTransform()) #val_dataset = MyDataset(transforms=SimCLREvalDataTransform()) train_dataset = torchvision.datasets.CIFAR10("", train=True, download=True, transform=SimCLRTrainDataTransform()) val_dataset = torchvision.datasets.CIFAR10("", train=False, download=True, transform=SimCLREvalDataTransform()) # SimCLR needs a lot of compute! model = SimCLR(gpus=2, num_samples=len(train_dataset), batch_size=32, dataset="cifar10") trainer = pl.Trainer(gpus=2, accelerator="ddp") trainer.fit( model, torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12), torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12), ) #-------------------- simclr_resnet50 = simclr.encoder simclr_resnet50.eval() #my_dataset = SomeDataset() my_dataset = val_dataset for batch in my_dataset: x, y = batch out = simclr_resnet50(x)
def simclr_example(): from pl_bolts.models.self_supervised import SimCLR from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = SimCLRTrainDataTransform(input_height=32) dm.val_transforms = SimCLREvalDataTransform(input_height=32) # Model. model = SimCLR(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10") # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp") trainer.fit(model, datamodule=dm) #-------------------- # CIFAR-10 pretrained model: weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" # ImageNet pretrained model: #weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) simclr.freeze()
def load_simclr_imagenet( path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" ): simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False) model_config = {'model': simclr.encoder, 'emb_size': 2048} return model_config
def test_simclr(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimCLR(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10') trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule)
def load_simclr_imagenet( path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" ): simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False) # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(simclr.encoder.children())[:-2]) return backbone, 2048
def cli_main(): # pragma: no-cover pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False) # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) dm.test_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled args.num_samples = dm.num_labeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) # finetune tuner = SSLFineTuner(backbone, in_features=2048 * 2 * 2, num_classes=dm.num_classes, hidden_dim=None) trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def test_simclr(tmpdir): seed_everything() datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimCLR(batch_size=2, num_samples=datamodule.num_samples) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0
def test_simclr(tmpdir): reset_seed() datamodule = CIFAR10DataModule(tmpdir, num_workers=0) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimCLR(data_dir=tmpdir, batch_size=2, datamodule=datamodule, online_ft=True) trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) loss = trainer.callback_metrics['loss'] assert loss > 0
def __init__( self, weight_path='https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt' ): super().__init__() backbone = deepcopy( SimCLR.load_from_checkpoint(weight_path, strict=False).encoder) backbone.fc = nn.Identity() self.encoder = backbone self.freeze() # freeze last block of resnet18 self.inplanes = self.encoder.inplanes
def train_self_supervised(): logger = TensorBoardLogger('runs', name='SimCLR_libri_speech') # 8, 224, 8 worked well # 16, 224, 4 as well batch_size = 16 input_height = 224 num_workers = 4 train_dataset = LibrispeechSpectrogramDataset( transform=SimCLRTrainDataTransform(input_height=input_height, gaussian_blur=False), train=True) val_dataset = LibrispeechSpectrogramDataset( transform=SimCLREvalDataTransform(input_height=input_height, gaussian_blur=False), train=False) train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers) test_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers) model = SimCLR(gpus=1, num_samples=len(train_dataset), batch_size=batch_size, dataset='librispeech') checkpoint_callback = ModelCheckpoint( monitor="val_loss", dirpath=r'D:\Users\lVavrek\research\data', filename="self-supervised-librispeech-{epoch:02d}-{val_loss:.2f}", save_top_k=3, mode="min", ) early_stopping = EarlyStopping(monitor="val_loss") trainer = Trainer(gpus=1, callbacks=[checkpoint_callback, early_stopping], logger=logger) trainer.fit(model, train_loader, test_loader)
def __init__(self): super().__init__() # init a pretrained resnet weight_path = r'D:\Users\lVavrek\research\data\sim-clr-backups\01112021-self-supervised-librispeech-epoch=19-val_loss=1.52.ckpt' simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) backbone = simclr.encoder # extract last layer num_filters = backbone.fc.in_features layers = list(backbone.children())[:-1] self.feature_extractor = torch.nn.Sequential(*layers) # use the pretrained model to classify destination problem (PD or healthy, 2 classes) num_target_classes = 2 self.classifier = torch.nn.Linear(num_filters, num_target_classes) self.sigmoid = torch.nn.Sigmoid() self.celoss = torch.nn.CrossEntropyLoss()
def __init__( self, pre, weight_path='https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt' ): super().__init__() self.pre = pre self.weight_path = weight_path if pre: self.encoder = deepcopy( SimCLR.load_from_checkpoint(weight_path, strict=False).encoder) else: self.encoder = torchvision.models.resnet18(pretrained=False) self.freeze() self.avgpool = self.encoder.avgpool numft = self.encoder.fc.in_features self.fc = nn.Sequential(nn.Linear(in_features=numft, out_features=256), nn.ReLU(), nn.Linear(in_features=256, out_features=100), nn.ReLU(), nn.Linear(in_features=100, out_features=20))
def cli_main(): parser = ArgumentParser() parser.add_argument("--DATA_PATH", type=str, help="path to folders with images") parser.add_argument("--MODEL_PATH", default=None , type=str, help="path to model checkpoint") parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL") parser.add_argument("--image_size", default=256, type=int, help="image size for SSL") parser.add_argument("--num_workers", default=1, type=int, help="number of CPU cores to use for data processing") parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument("--hidden_dims", default=128, type=int, help="hidden dimensions in classification layer added onto model for finetuning") parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model") parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model") parser.add_argument("--patience", default=-1, type=int, help="automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping.") parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data") parser.add_argument("--withold_train_percent", default=0, type=float, help="decimal from 0-1 representing how much of the training data to withold during finetuning") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument("--eval", default=True, type=bool, help="Eval Mode will train and evaluate the finetuned model's performance") parser.add_argument("--pretrain_encoder", default=False, type=bool, help="initialize resnet encoder with pretrained imagenet weights. Ignored if MODEL_PATH is specified.") parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving") parser.add_argument("--fix_backbone", default=True, type=bool, help="Fix backbone during finetuning") args = parser.parse_args() URL = args.DATA_PATH batch_size = args.batch_size image_size = args.image_size num_workers = args.num_workers embedding_size = args.image_embedding_size hidden_dims = args.hidden_dims epochs = args.epochs lr = args.lr patience = args.patience val_split = args.val_split withold_train_percent = args.withold_train_percent version = args.version model_checkpoint = args.MODEL_PATH gpus = args.gpus eval_model = args.eval version = args.version pretrain= args.pretrain_encoder fix_backbone = args.fix_backbone train_transform = SimCLRFinetuneTransform(256, eval_transform=False) val_transform = SimCLRFinetuneTransform(256, eval_transform=True) dm = ImageDataModule(URL, train_transform = train_transform, val_transform = val_transform, val_split = val_split) dm.setup() #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model = SimCLR(arch = 'resnet18', batch_size = batch_size, num_samples = dm.num_samples, gpus = gpus, dataset = 'None', max_epochs = epochs, learning_rate = lr) # model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim = 512, hidden_dim = 256, output_dim = embedding_size) #overrides if model_checkpoint is not None: model.load_state_dict(torch.load(model_checkpoint)) print('Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights') else: if pretrain: print('Using imagenet weights instead of a pretrained SSL model') else: print('Using random initialization of encoder') print('Finetuning to classify ', dm.num_classes, ' Classes') tuner = SSLFineTuner( backbone, in_features=512, num_classes=dm.num_classes, epochs=epochs, hidden_dim=None, dropout=0, learning_rate=0.3, weight_decay=1e-6, nesterov=False, scheduler_type='cosine', gamma=0.1, final_lr=0., fix_backbone = fix_backbone )
if os.path.exists(os.path.join(args.data_dir, 'train')): train_set = ImageFolderDataset(os.path.join(args.data_dir, 'train')) test_set = ImageFolderDataset(os.path.join(args.data_dir, 'test')) else: files = utils.recursive_folder_image_paths(args.data_dir) random.seed(19) random.shuffle(files) train_files = files[:int(train_test_ratio * len(files))] test_files = files[int(train_test_ratio * len(files)):] train_set = ImageFilesDataset(train_files) test_set = ImageFilesDataset(test_files) train_loader = DataLoader(train_set, batch_size=12, shuffle=True, num_workers=4) test_loader = DataLoader(test_set, batch_size=12, shuffle=False, num_workers=4) model = SimCLR(gpus=1, num_samples=(len(train_set) + len(test_set)), batch_size=12, dataset=train_loader) trainer = pl.Trainer(gpus=1) trainer.fit(model, train_loader, test_loader) model.freeze()
def main(): parser = ArgumentParser() parser.add_argument("--data_dir", type=str, required=True, help="path to the folder of images") parser.add_argument("--log_dir", type=str, required=True, help="output training logging dir") parser.add_argument("--learning_rate", type=float, required=True, default=1e-3, help="learning rate") parser.add_argument( "--input_height", type=int, required=True, help="height of image input to SimCLR", ) parser.add_argument("--batch_size", type=int, default=1024, required=True) parser.add_argument("--gpus", type=int, default=0, required=True, help="Number of GPUs") parser.add_argument("--num_workers", type=int, default=0, required=True, help="Number of dataloader workers") parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") args = parser.parse_args() dm = SneakerDataModule(image_folder=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_transforms = SimCLRTrainDataTransform(args.input_height) dm.val_transforms = SimCLREvalDataTransform(args.input_height) model = SimCLR( num_samples=dm.num_samples, batch_size=dm.batch_size, learning_rate=args.learning_rate, max_epochs=args.max_epochs, gpus=args.gpus, dataset="sneakers", ) model_checkpoint_callback = ModelCheckpoint( monitor="val_loss", save_last=True, save_top_k=-1, period=10, filename='{epoch}-{val_loss:.2f}-{step}') # TODO set the logger folder # Warning message is "Missing logger folder: /lightning_logs" trainer = pl.Trainer( default_root_dir=args.log_dir, callbacks=[model_checkpoint_callback], # checkpoint_callback=True, # configures a default checkpointing callback max_epochs=args.max_epochs, gpus=args.gpus, accelerator='ddp' if args.gpus > 1 else None, enable_pl_optimizer=True if args.gpus > 1 else False, ) trainer.fit(model, dm)
def cli_main(): parser = ArgumentParser() parser.add_argument("--DATA_PATH", type=str, help="path to folders with images") parser.add_argument("--MODEL_PATH", default=None, type=str, help="path to model checkpoint") parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL") parser.add_argument("--image_size", default=256, type=int, help="image size for SSL") parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument( "--hidden_dims", default=128, type=int, help= "hidden dimensions in classification layer added onto model for finetuning" ) parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model") parser.add_argument("--lr", default=0.3, type=float, help="learning rate for training model") parser.add_argument( "--patience", default=-1, type=int, help= "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping." ) parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data") parser.add_argument( "--withold_train_percent", default=0, type=float, help= "decimal from 0-1 representing how much of the training data to withold during finetuning" ) parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument( "--eval", default=True, type=bool, help= "Eval Mode will train and evaluate the finetuned model's performance") parser.add_argument( "--pretrain_encoder", default=False, type=bool, help= "initialize resnet encoder with pretrained imagenet weights. Ignored if MODEL_PATH is specified." ) parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving") parser.add_argument("--fix_backbone", default=True, type=bool, help="Fix backbone during finetuning") parser.add_argument("--num_workers", default=0, type=int, help="number of workers to use to fetch data") args = parser.parse_args() URL = args.DATA_PATH batch_size = args.batch_size image_size = args.image_size embedding_size = args.image_embedding_size hidden_dims = args.hidden_dims epochs = args.epochs lr = args.lr patience = args.patience val_split = args.val_split withold_train_percent = args.withold_train_percent version = args.version model_checkpoint = args.MODEL_PATH gpus = args.gpus eval_model = args.eval version = args.version pretrain = args.pretrain_encoder fix_backbone = args.fix_backbone num_workers = args.num_workers dm = FolderDataset2(URL, val_split=val_split, train_transform=SimCLRFinetuneTransform(image_size), val_transform=SimCLRFinetuneTransform(image_size)) dm.setup() model = SimCLR(arch='resnet18', batch_size=batch_size, num_samples=dm.num_samples, gpus=1, dataset='None', max_epochs=100, learning_rate=lr) # model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=128) #overrides model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) if model_checkpoint is not None: model.load_state_dict(torch.load(model_checkpoint)) print( 'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights' ) else: if pretrain: print('Using imagenet weights instead of a pretrained SSL model') else: print('Using random initialization of encoder') print('Finetuning to classify ', dm.num_classes, ' Classes') tuner = SSLFineTuner(model, in_features=512, num_classes=dm.num_classes, epochs=epochs, hidden_dim=hidden_dims, dropout=0, learning_rate=lr, weight_decay=1e-6, nesterov=False, scheduler_type='cosine', gamma=0.1, final_lr=0., fix_backbone=True) trainer = pl.Trainer(gpus=gpus, num_nodes=1, precision=16, max_epochs=epochs, distributed_backend='ddp', sync_batchnorm=False) trainer.fit(tuner, dm) Path(f"./models/Finetune/SIMCLR_Finetune_{version}").mkdir(parents=True, exist_ok=True) if eval_model: print('Evaluating Model...') save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/validationMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) if dm.val_dataloader() is not None: eval_finetune(tuner, 'validation', dm.val_dataloader(), save_path) save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/trainingMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) eval_finetune(tuner, 'training', dm.train_dataloader(), save_path) print('Saving model...') torch.save( tuner.state_dict(), f"./models/Finetune/SIMCLR_Finetune_{version}/SIMCLR_FINETUNE_{version}.pt" )
def cli_main(): parser = ArgumentParser() parser.add_argument("--DATA_PATH", type=str, help="path to folders with images") parser.add_argument("--MODEL_PATH", default=None, type=str, help="path to model checkpoint") parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL") parser.add_argument("--image_size", default=256, type=int, help="image size for SSL") parser.add_argument( "--image_type", default="tif", type=str, help= "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)" ) parser.add_argument("--num_workers", default=1, type=int, help="number of CPU cores to use for data processing") parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model") parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model") parser.add_argument( "--patience", default=-1, type=int, help= "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping." ) parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data") parser.add_argument( "--pretrain_encoder", default=False, type=bool, help= "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint." ) parser.add_argument( "--withold_train_percent", default=0, type=float, help= "decimal from 0-1 representing how much of the training data to withold during SSL training" ) parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") args = parser.parse_args() URL = args.DATA_PATH batch_size = args.batch_size image_size = args.image_size image_type = args.image_type num_workers = args.num_workers embedding_size = args.image_embedding_size epochs = args.epochs lr = args.lr patience = args.patience val_split = args.val_split pretrain = args.pretrain_encoder withold_train_percent = args.withold_train_percent version = args.version model_checkpoint = args.MODEL_PATH gpus = args.gpus # #testing # batch_size = 128 # image_type = 'tif' # image_size = 256 # num_workers = 4 # URL ='/content/UCMerced_LandUse/Images' # embedding_size = 128 # epochs = 2 # lr = 1e-3 # patience = 1 # val_split = 0.2 # pretrain = False # withold_train_percent = 0.2 # version = "1" # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0.pt' # gpus = 1 # #gets dataset. We can't combine since validation data has different transform needed train_dataset = FolderDataset( URL, validation=False, val_split=val_split, withold_train_percent=withold_train_percent, transform=SimCLRTrainDataTransform(image_size), image_type=image_type) data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True) print('Training Data Loaded...') val_dataset = FolderDataset(URL, validation=True, val_split=val_split, transform=SimCLREvalDataTransform(image_size), image_type=image_type) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True) print('Validation Data Loaded...') num_samples = len(train_dataset) #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model = SimCLR(arch='resnet18', batch_size=batch_size, num_samples=num_samples, gpus=gpus, dataset='None', max_epochs=epochs, learning_rate=lr) # model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=embedding_size) #overrides if patience > 0: cb = EarlyStopping('val_loss', patience=patience) trainer = Trainer(gpus=gpus, max_epochs=epochs, callbacks=[cb], progress_bar_refresh_rate=5) else: trainer = Trainer(gpus=gpus, max_epochs=epochs, progress_bar_refresh_rate=5) if model_checkpoint is not None: model.load_state_dict(torch.load(model_checkpoint)) print( 'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights' ) model.cuda() print('Model Initialized') trainer.fit(model, data_loader, val_loader) Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
parser.add_argument('--data_dir', required=True, type=str, help='path to image data directory') parser.add_argument('--save_dir', required=True, type=str, help='path to image data directory') parser.add_argument('--image_index', type=int, default=42) parser.add_argument('--n_images', type=int, default=20) parser.add_argument('--rgb', type=bool, default=True) args = parser.parse_args() image_paths = utils.recursive_folder_image_paths(args.data_dir) model = SimCLR.load_from_checkpoint(checkpoint_path=args.model_dir, strict=False) model_enc = model.encoder model_enc.eval() transform = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor()]) y = np.empty((len(image_paths), 2048), float) for i, p in enumerate(tqdm(image_paths)): image = Image.open(p) if args.rgb: image = image.convert('RGB') image = transform(image).unsqueeze_(0) y_hat = model_enc(image) y_hat = y_hat[0].detach().numpy().reshape(1, -1)
def get_self_supervised_model(run_params): import pl_bolts from pl_bolts.models.self_supervised import SimCLR from pl_bolts.models.self_supervised.simclr import ( SimCLRTrainDataTransform, SimCLREvalDataTransform, ) from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping checkpoints_dir = os.path.join(run_params["PATH_PREFIX"], "checkpoints") checkpoint_resume = os.path.join(checkpoints_dir, run_params["MODEL_SAVE_NAME"] + ".ckpt") dataset = SelfSupervisedDataset( final_df, validation=False, transform=SimCLRTrainDataTransform( min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])), prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/", ) val_dataset = SelfSupervisedDataset( final_df, validation=True, transform=SimCLREvalDataTransform( min(run_params["RESIZE"], run_params["RANDOM_RESIZE_CROP"])), prefix=run_params["RAW_PREPROCESS_FOLDER"] + "/", ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"], num_workers=0) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"], num_workers=0) num_samples = len(dataset) # #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model_self_sup = SimCLR( gpus=1, arch="resnet50", dataset="", max_epochs=run_params["SELF_SUPERVISED_EPOCHS"], warmup_epochs=run_params["SELF_SUPERVISED_WARMUP_EPOCHS"], batch_size=run_params["SELF_SUPERVISED_BATCH_SIZE"], num_samples=num_samples, ) if run_params["SELF_SUPERVISED_TRAIN"]: logger = TensorBoardLogger( os.path.join(run_params["PATH_PREFIX"], "tb_logs", "simCLR"), name=run_params["MODEL_SAVE_NAME"], ) early_stopping = EarlyStopping("val_loss", patience=5) if os.path.exists(checkpoint_resume): trainer = Trainer( gpus=1, max_epochs=run_params["SELF_SUPERVISED_EPOCHS"], logger=logger, auto_scale_batch_size=True, resume_from_checkpoint=checkpoint_resume, callbacks=[early_stopping], ) else: checkpoint_callback = ModelCheckpoint( monitor="val_loss", dirpath=checkpoints_dir, filename=run_params["MODEL_SAVE_NAME"], save_top_k=1, mode="min", ) trainer = Trainer( gpus=1, max_epochs=run_params["SELF_SUPERVISED_EPOCHS"], logger=logger, auto_scale_batch_size=True, callbacks=[checkpoint_callback, early_stopping], ) trainer.fit(model_self_sup, data_loader, val_loader) model_self_sup = model_self_sup.load_from_checkpoint(checkpoint_resume) elif os.path.exists(checkpoint_resume): model_self_sup.load_from_checkpoint(checkpoint_resume) else: print( f"Not checkpoint found, so it could not load model from it\n{checkpoint_resume}" ) return model_self_sup
def cli_main(): parser = ArgumentParser() parser.add_argument("--DATA_PATH", type=str, help="path to folders with images") parser.add_argument("--MODEL_PATH", default=None, type=str, help="path to model checkpoint") parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL") parser.add_argument("--image_size", default=256, type=int, help="image size for SSL") parser.add_argument( "--image_type", default="tif", type=str, help= "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)" ) parser.add_argument("--num_workers", default=1, type=int, help="number of CPU cores to use for data processing") parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument( "--hidden_dims", default=128, type=int, help= "hidden dimensions in classification layer added onto model for finetuning" ) parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model") parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model") parser.add_argument( "--patience", default=-1, type=int, help= "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping." ) parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data") parser.add_argument( "--withold_train_percent", default=0, type=float, help= "decimal from 0-1 representing how much of the training data to withold during finetuning" ) parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument( "--eval", default=True, type=bool, help= "Eval Mode will train and evaluate the finetuned model's performance") parser.add_argument("--imagenet_weights", default=False, type=bool, help="Use weights from a non-SSL") parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving") args = parser.parse_args() DATA_PATH = args.DATA_PATH batch_size = args.batch_size image_size = args.image_size image_type = args.image_type num_workers = args.num_workers embedding_size = args.image_embedding_size hidden_dims = args.hidden_dims epochs = args.epochs lr = args.lr patience = args.patience val_split = args.val_split withold_train_percent = args.withold_train_percent version = args.version model_checkpoint = args.MODEL_PATH gpus = args.gpus eval_model = args.eval version = args.version imagenet_weights = args.imagenet_weights # #testing # batch_size = 128 # image_type = 'tif' # image_size = 256 # num_workers = 4 # DATA_PATH ='/content/UCMerced_LandUse/Images' # embedding_size = 128 # epochs = 15 # hidden_dims = 128 # lr = 1e-3 # patience = 1 # val_split = 0.2 # withold_train_percent = 0.2 # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0/SIMCLR_SSL_0.pt' # gpus = 1 # eval_model = True # version = "0" #gets dataset. We can't combine since validation data has different transform needed finetune_dataset = FolderDataset( DATA_PATH, validation=False, val_split=val_split, withold_train_percent=withold_train_percent, transform=SimCLRFinetuneTransform(image_size), image_type=image_type) finetune_loader = torch.utils.data.DataLoader(finetune_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True) print('Training Data Loaded...') finetune_val_dataset = FolderDataset( DATA_PATH, validation=True, val_split=val_split, transform=SimCLRFinetuneTransform(image_size), image_type=image_type) finetune_val_loader = torch.utils.data.DataLoader(finetune_val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True) print('Validation Data Loaded...') num_samples = len(finetune_dataset) model = SimCLR(arch='resnet18', batch_size=batch_size, num_samples=num_samples, gpus=gpus, dataset='None', max_epochs=epochs, learning_rate=lr) # model.encoder = resnet18(pretrained=imagenet_weights, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=embedding_size) #overrides if model_checkpoint is not None: model.load_state_dict(torch.load(model_checkpoint)) print( 'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights' ) else: if imagenet_weights: print('Using imagenet weights instead of a pretrained SSL model') else: print('Using random initialization of encoder') num_classes = len(set(finetune_dataset.labels)) print('Finetuning to classify ', num_classes, ' Classes') tuner = SSLFineTuner(model, in_features=512, num_classes=num_classes, hidden_dim=hidden_dims, learning_rate=lr) if patience > 0: cb = EarlyStopping('val_loss', patience=patience) trainer = Trainer(gpus=gpus, max_epochs=epochs, callbacks=[cb], progress_bar_refresh_rate=5) else: trainer = Trainer(gpus=gpus, max_epochs=epochs, progress_bar_refresh_rate=5) tuner.cuda() trainer.fit(tuner, train_dataloader=finetune_loader, val_dataloaders=finetune_val_loader) Path(f"./models/Finetune/SIMCLR_Finetune_{version}").mkdir(parents=True, exist_ok=True) if eval_model: print('Evaluating Model...') save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/trainingMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) eval_finetune(tuner, 'training', finetune_loader, save_path) save_path = f"./models/Finetune/SIMCLR_Finetune_{version}/Evaluation/validationMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) eval_finetune(tuner, 'validation', finetune_val_loader, save_path) print('Saving model...') torch.save( tuner.state_dict(), f"./models/Finetune/SIMCLR_Finetune_{version}/SIMCLR_FINETUNE_{version}.pt" )
def cli_main(): parser = ArgumentParser() parser.add_argument("--DATA_PATH", type=str, help="path to folders with images") parser.add_argument("--MODEL_PATH", default=None, type=str, help="path to model checkpoint") parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL") parser.add_argument("--image_size", default=256, type=int, help="image size for SSL") parser.add_argument("--num_workers", default=1, type=int, help="number of CPU cores to use for data processing") parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model") parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model") parser.add_argument( "--patience", default=-1, type=int, help= "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping." ) parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data") parser.add_argument( "--pretrain_encoder", default=False, type=bool, help= "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint." ) parser.add_argument( "--withold_train_percent", default=0, type=float, help= "decimal from 0-1 representing how much of the training data to withold during SSL training" ) parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument("--num_workers", default=0, type=int, help="number of workers to use to fetch data") args = parser.parse_args() URL = args.DATA_PATH batch_size = args.batch_size image_size = args.image_size num_workers = args.num_workers embedding_size = args.image_embedding_size epochs = args.epochs lr = args.lr patience = args.patience val_split = args.val_split pretrain = args.pretrain_encoder withold_train_percent = args.withold_train_percent version = args.version model_checkpoint = args.MODEL_PATH gpus = args.gpus num_workers = args.num_workers train_transform = SimCLRTrainDataTransform(256) val_transform = SimCLREvalDataTransform(256) dm = ImageDataModule(URL, train_transform=train_transform, val_transform=val_transform, val_split=val_split, num_workers=num_workers) dm.setup() #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model = SimCLR(arch='resnet18', batch_size=batch_size, num_samples=dm.num_samples, gpus=gpus, dataset='None', max_epochs=epochs, learning_rate=lr) # model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=embedding_size) #overrides if patience > 0: cb = EarlyStopping('val_loss', patience=patience) trainer = Trainer(gpus=gpus, max_epochs=epochs, callbacks=[cb], progress_bar_refresh_rate=5) else: trainer = Trainer(gpus=gpus, max_epochs=epochs, progress_bar_refresh_rate=5) if model_checkpoint is not None: model.load_state_dict(torch.load(model_checkpoint)) print( 'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights' ) model.cuda() print('Model Initialized') trainer.fit(model, dm) Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
def cli_main(): parser = ArgumentParser() parser.add_argument( "--MODEL_PATH", type=str, help="path to .pt file containing SSL-trained SimCLR Resnet18 Model") parser.add_argument( "--DATA_PATH", type=str, help= "path to data. If folder already contains validation data only, set val_split to 0" ) parser.add_argument( "--val_split", default=0.2, type=float, help="amount of data to use for validation as a decimal") parser.add_argument( "--image_type", default="tif", type=str, help= "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)" ) parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument("--image_size", default=128, type=int, help="height of square image to pass through model") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument("--rank", default=50, type=int, help="number of neighbors to search for") parser.add_argument( "--filter_same_group", default=False, type=bool, help="custom arg for hurricane data to filter same hurricanes out") args = parser.parse_args() MODEL_PATH = args.MODEL_PATH DATA_PATH = args.DATA_PATH image_size = args.image_size image_type = args.image_type embedding_size = args.image_embedding_size val_split = args.val_split gpus = args.gpus rank_to = args.rank filter_hur = args.filter_same_group #testing # MODEL_PATH = '/content/models/SSL/SIMCLR_SSL_0.pt' # DATA_PATH = '/content/UCMerced_LandUse/Images' # image_size = 128 # image_type = 'tif' # embedding_size = 128 # val_split = 0.2 # gpus = 1 # #gets dataset. We can't combine since validation data has different transform needed train_dataset = FolderDataset( DATA_PATH, validation=False, val_split=val_split, transform=SimCLRTrainDataTransform(image_size), image_type=image_type) print('Training Data Loaded...') val_dataset = FolderDataset(DATA_PATH, validation=True, val_split=val_split, transform=SimCLREvalDataTransform(image_size), image_type=image_type) print('Validation Data Loaded...') #load model num_samples = len(train_dataset) #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model = SimCLR(arch='resnet18', batch_size=1, num_samples=num_samples, gpus=gpus, dataset='None') # model.encoder = resnet18(pretrained=False, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=embedding_size) #overrides model.load_state_dict(torch.load(MODEL_PATH)) model.cuda() print('Successfully loaded your model for evaluation.') #running eval on validation data save_path = f"{MODEL_PATH[:-3]}/Evaluation/validationMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) eval_embeddings(model, val_dataset, save_path, rank_to, filter_hur) print('Validation Data Evaluation Complete.') #running eval on training data save_path = f"{MODEL_PATH[:-3]}/Evaluation/trainingMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) eval_embeddings(model, train_dataset, save_path, rank_to, filter_hur) print('Training Data Evaluation Complete.') print(f'Please check {MODEL_PATH[:-3]}/Evaluation/ for your results')
def cli_main(): parser = ArgumentParser() parser.add_argument( "--MODEL_PATH", type=str, help="path to .pt file containing SSL-trained SimCLR Resnet18 Model") parser.add_argument( "--DATA_PATH", type=str, help= "path to data. If folder already contains validation data only, set val_split to 0" ) parser.add_argument( "--val_split", default=0.2, type=float, help="amount of data to use for validation as a decimal") parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument("--image_size", default=128, type=int, help="height of square image to pass through model") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument("--rank", default=50, type=int, help="number of neighbors to search for") parser.add_argument("--batch_size", default=128, type=int, help="batch size for Evaluation") parser.add_argument( "--pretrain_encoder", default=False, type=bool, help= "initialize resnet encoder with pretrained imagenet weights. Will be ignored if MODEL_PATH is specified." ) args = parser.parse_args() MODEL_PATH = args.MODEL_PATH URL = args.DATA_PATH image_size = args.image_size embedding_size = args.image_embedding_size val_split = args.val_split gpus = args.gpus rank_to = args.rank batch_size = args.batch_size pretrain = args.pretrain_encoder train_transform = SimCLRTrainDataTransform(256) val_transform = SimCLREvalDataTransform(256) dm = ImageDataModule(URL, train_transform=train_transform, val_transform=val_transform, val_split=val_split) dm.setup() #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model = SimCLR(arch='resnet18', batch_size=batch_size, num_samples=dm.num_samples, gpus=gpus, dataset='None') # model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=embedding_size) #overrides model.cuda() print('Successfully loaded your model for evaluation.') #running eval on validation data save_path = f"{MODEL_PATH[:-3]}/Evaluation/validationMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) if dm.val_dataloader() is not None: eval_embeddings(model, dm.val_dataloader(), save_path, rank_to) print('Validation Data Evaluation Complete.') #running eval on training data save_path = f"{MODEL_PATH[:-3]}/Evaluation/trainingMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) eval_embeddings(model, dm.train_dataloader(), save_path, rank_to) print('Training Data Evaluation Complete.') print(f'Please check {MODEL_PATH[:-3]}/Evaluation/ for your results')