def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimCLR.add_model_specific_args(parser) args = parser.parse_args() # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) model = SimCLR(**args.__dict__) # finetune in real-time def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y online_eval = SSLOnlineEvaluator(z_dim=2048 * 2 * 2, num_classes=dm.num_classes) online_eval.to_device = to_device trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_eval]) trainer.fit(model, dm)
def test_simclr(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimCLR(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10') trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule)
def cli_main(): # pragma: no-cover from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser.add_argument('--dataset', type=str, help='stl10, cifar10', default='cifar10') parser.add_argument('--ckpt_path', type=str, help='path to ckpt') parser.add_argument('--data_dir', type=str, help='path to ckpt', default=os.getcwd()) args = parser.parse_args() # load the backbone backbone = SimCLR.load_from_checkpoint(args.ckpt_path, strict=False) # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) dm.test_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_labeled dm.val_dataloader = dm.val_dataloader_labeled args.num_samples = dm.num_labeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) # finetune tuner = SSLFineTuner(backbone, in_features=2048 * 2 * 2, num_classes=dm.num_classes, hidden_dim=None) trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True) trainer.fit(tuner, dm) trainer.test(datamodule=dm)
def test_simclr_transforms(img_size): pl.seed_everything(0) (c, h, w) = img_size x = torch.rand(c, h, w) x = transforms.ToPILImage(mode='RGB')(x) transform = SimCLREvalDataTransform(input_height=h) transform(x) transform = SimCLRTrainDataTransform(input_height=h) transform(x)
def test_simclr(tmpdir): seed_everything() datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimCLR(batch_size=2, num_samples=datamodule.num_samples) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0
def test_simsiam(tmpdir, datadir): seed_everything() datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimSiam(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10') trainer = pl.Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) < 0
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform parser = ArgumentParser() # model args parser = SimCLR.add_model_specific_args(parser) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False args.first_conv = True args.input_height = dm.size()[-1] normalization = stl10_normalization() args.gaussian_blur = True args.jitter_strength = 1. elif args.dataset == 'cifar10': val_split = 5000 if args.num_nodes * args.gpus * args.batch_size > val_split: val_split = args.num_nodes * args.gpus * args.batch_size dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 normalization = cifar10_normalization() args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == 'imagenet': args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.gaussian_blur = True args.jitter_strength = 1. args.batch_size = 64 args.num_nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = 'sgd' args.lars_wrapper = True args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SimCLRTrainDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) dm.val_transforms = SimCLREvalDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) model = SimCLR(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator(drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset) model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator ] if args.online_ft else [model_checkpoint] trainer = pl.Trainer.from_argparse_args( args, sync_batchnorm=True if args.gpus > 1 else False, callbacks=callbacks, ) trainer.fit(model, datamodule=dm)
def cli_main(): parser = ArgumentParser() parser.add_argument( "--MODEL_PATH", type=str, help="path to .pt file containing SSL-trained SimCLR Resnet18 Model") parser.add_argument( "--DATA_PATH", type=str, help= "path to data. If folder already contains validation data only, set val_split to 0" ) parser.add_argument( "--val_split", default=0.2, type=float, help="amount of data to use for validation as a decimal") parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument("--image_size", default=128, type=int, help="height of square image to pass through model") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument("--rank", default=50, type=int, help="number of neighbors to search for") parser.add_argument("--batch_size", default=128, type=int, help="batch size for Evaluation") parser.add_argument( "--pretrain_encoder", default=False, type=bool, help= "initialize resnet encoder with pretrained imagenet weights. Will be ignored if MODEL_PATH is specified." ) args = parser.parse_args() MODEL_PATH = args.MODEL_PATH URL = args.DATA_PATH image_size = args.image_size embedding_size = args.image_embedding_size val_split = args.val_split gpus = args.gpus rank_to = args.rank batch_size = args.batch_size pretrain = args.pretrain_encoder train_transform = SimCLRTrainDataTransform(256) val_transform = SimCLREvalDataTransform(256) dm = ImageDataModule(URL, train_transform=train_transform, val_transform=val_transform, val_split=val_split) dm.setup() #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model = SimCLR(arch='resnet18', batch_size=batch_size, num_samples=dm.num_samples, gpus=gpus, dataset='None') # model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=embedding_size) #overrides model.cuda() print('Successfully loaded your model for evaluation.') #running eval on validation data save_path = f"{MODEL_PATH[:-3]}/Evaluation/validationMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) if dm.val_dataloader() is not None: eval_embeddings(model, dm.val_dataloader(), save_path, rank_to) print('Validation Data Evaluation Complete.') #running eval on training data save_path = f"{MODEL_PATH[:-3]}/Evaluation/trainingMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) eval_embeddings(model, dm.train_dataloader(), save_path, rank_to) print('Training Data Evaluation Complete.') print(f'Please check {MODEL_PATH[:-3]}/Evaluation/ for your results')
import pytorch_lightning as pl from pl_bolts.models.self_supervised import BYOL from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.simclr.transforms import ( SimCLREvalDataTransform, SimCLRTrainDataTransform) # model model = BYOL(num_classes=10) # data dm = CIFAR10DataModule(num_workers=4) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) trainer = pl.Trainer() trainer.fit(model, dm)
def cli_main(): parser = ArgumentParser() parser.add_argument("--DATA_PATH", type=str, help="path to folders with images") parser.add_argument("--MODEL_PATH", default=None, type=str, help="path to model checkpoint") parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL") parser.add_argument("--image_size", default=256, type=int, help="image size for SSL") parser.add_argument("--num_workers", default=1, type=int, help="number of CPU cores to use for data processing") parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model") parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model") parser.add_argument( "--patience", default=-1, type=int, help= "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping." ) parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data") parser.add_argument( "--pretrain_encoder", default=False, type=bool, help= "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint." ) parser.add_argument( "--withold_train_percent", default=0, type=float, help= "decimal from 0-1 representing how much of the training data to withold during SSL training" ) parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument("--num_workers", default=0, type=int, help="number of workers to use to fetch data") args = parser.parse_args() URL = args.DATA_PATH batch_size = args.batch_size image_size = args.image_size num_workers = args.num_workers embedding_size = args.image_embedding_size epochs = args.epochs lr = args.lr patience = args.patience val_split = args.val_split pretrain = args.pretrain_encoder withold_train_percent = args.withold_train_percent version = args.version model_checkpoint = args.MODEL_PATH gpus = args.gpus num_workers = args.num_workers train_transform = SimCLRTrainDataTransform(256) val_transform = SimCLREvalDataTransform(256) dm = ImageDataModule(URL, train_transform=train_transform, val_transform=val_transform, val_split=val_split, num_workers=num_workers) dm.setup() #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model = SimCLR(arch='resnet18', batch_size=batch_size, num_samples=dm.num_samples, gpus=gpus, dataset='None', max_epochs=epochs, learning_rate=lr) # model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=embedding_size) #overrides if patience > 0: cb = EarlyStopping('val_loss', patience=patience) trainer = Trainer(gpus=gpus, max_epochs=epochs, callbacks=[cb], progress_bar_refresh_rate=5) else: trainer = Trainer(gpus=gpus, max_epochs=epochs, progress_bar_refresh_rate=5) if model_checkpoint is not None: model.load_state_dict(torch.load(model_checkpoint)) print( 'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights' ) model.cuda() print('Model Initialized') trainer.fit(model, dm) Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
def cli_main(): parser = ArgumentParser() parser.add_argument("--DATA_PATH", type=str, help="path to folders with images") parser.add_argument("--MODEL_PATH", default=None, type=str, help="path to model checkpoint") parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL") parser.add_argument("--image_size", default=256, type=int, help="image size for SSL") parser.add_argument( "--image_type", default="tif", type=str, help= "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)" ) parser.add_argument("--num_workers", default=1, type=int, help="number of CPU cores to use for data processing") parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train model") parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model") parser.add_argument( "--patience", default=-1, type=int, help= "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping." ) parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data") parser.add_argument( "--pretrain_encoder", default=False, type=bool, help= "initialize resnet encoder with pretrained imagenet weights. Cannot be true if passing previous SSL model checkpoint." ) parser.add_argument( "--withold_train_percent", default=0, type=float, help= "decimal from 0-1 representing how much of the training data to withold during SSL training" ) parser.add_argument("--version", default="0", type=str, help="version to name checkpoint for saving") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") args = parser.parse_args() URL = args.DATA_PATH batch_size = args.batch_size image_size = args.image_size image_type = args.image_type num_workers = args.num_workers embedding_size = args.image_embedding_size epochs = args.epochs lr = args.lr patience = args.patience val_split = args.val_split pretrain = args.pretrain_encoder withold_train_percent = args.withold_train_percent version = args.version model_checkpoint = args.MODEL_PATH gpus = args.gpus # #testing # batch_size = 128 # image_type = 'tif' # image_size = 256 # num_workers = 4 # URL ='/content/UCMerced_LandUse/Images' # embedding_size = 128 # epochs = 2 # lr = 1e-3 # patience = 1 # val_split = 0.2 # pretrain = False # withold_train_percent = 0.2 # version = "1" # model_checkpoint = '/content/models/SSL/SIMCLR_SSL_0.pt' # gpus = 1 # #gets dataset. We can't combine since validation data has different transform needed train_dataset = FolderDataset( URL, validation=False, val_split=val_split, withold_train_percent=withold_train_percent, transform=SimCLRTrainDataTransform(image_size), image_type=image_type) data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True) print('Training Data Loaded...') val_dataset = FolderDataset(URL, validation=True, val_split=val_split, transform=SimCLREvalDataTransform(image_size), image_type=image_type) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True) print('Validation Data Loaded...') num_samples = len(train_dataset) #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model = SimCLR(arch='resnet18', batch_size=batch_size, num_samples=num_samples, gpus=gpus, dataset='None', max_epochs=epochs, learning_rate=lr) # model.encoder = resnet18(pretrained=pretrain, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=embedding_size) #overrides if patience > 0: cb = EarlyStopping('val_loss', patience=patience) trainer = Trainer(gpus=gpus, max_epochs=epochs, callbacks=[cb], progress_bar_refresh_rate=5) else: trainer = Trainer(gpus=gpus, max_epochs=epochs, progress_bar_refresh_rate=5) if model_checkpoint is not None: model.load_state_dict(torch.load(model_checkpoint)) print( 'Successfully loaded your checkpoint. Keep in mind that this does not preserve the previous trainer states, only the model weights' ) model.cuda() print('Model Initialized') trainer.fit(model, data_loader, val_loader) Path(f"./models/SSL/SIMCLR_SSL_{version}").mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), f"./models/SSL/SIMCLR_SSL_{version}/SIMCLR_SSL_{version}.pt")
def main(): parser = ArgumentParser() parser.add_argument("--data_dir", type=str, required=True, help="path to the folder of images") parser.add_argument("--log_dir", type=str, required=True, help="output training logging dir") parser.add_argument("--learning_rate", type=float, required=True, default=1e-3, help="learning rate") parser.add_argument( "--input_height", type=int, required=True, help="height of image input to SimCLR", ) parser.add_argument("--batch_size", type=int, default=1024, required=True) parser.add_argument("--gpus", type=int, default=0, required=True, help="Number of GPUs") parser.add_argument("--num_workers", type=int, default=0, required=True, help="Number of dataloader workers") parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") args = parser.parse_args() dm = SneakerDataModule(image_folder=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_transforms = SimCLRTrainDataTransform(args.input_height) dm.val_transforms = SimCLREvalDataTransform(args.input_height) model = SimCLR( num_samples=dm.num_samples, batch_size=dm.batch_size, learning_rate=args.learning_rate, max_epochs=args.max_epochs, gpus=args.gpus, dataset="sneakers", ) model_checkpoint_callback = ModelCheckpoint( monitor="val_loss", save_last=True, save_top_k=-1, period=10, filename='{epoch}-{val_loss:.2f}-{step}') # TODO set the logger folder # Warning message is "Missing logger folder: /lightning_logs" trainer = pl.Trainer( default_root_dir=args.log_dir, callbacks=[model_checkpoint_callback], # checkpoint_callback=True, # configures a default checkpointing callback max_epochs=args.max_epochs, gpus=args.gpus, accelerator='ddp' if args.gpus > 1 else None, enable_pl_optimizer=True if args.gpus > 1 else False, ) trainer.fit(model, dm)
def cli_main(): parser = ArgumentParser() parser.add_argument( "--MODEL_PATH", type=str, help="path to .pt file containing SSL-trained SimCLR Resnet18 Model") parser.add_argument( "--DATA_PATH", type=str, help= "path to data. If folder already contains validation data only, set val_split to 0" ) parser.add_argument( "--val_split", default=0.2, type=float, help="amount of data to use for validation as a decimal") parser.add_argument( "--image_type", default="tif", type=str, help= "extension of image for PIL to open and parse - i.e. jpeg, gif, tif, etc. Only put the extension name, not the dot (.)" ) parser.add_argument("--image_embedding_size", default=128, type=int, help="size of image representation of SIMCLR") parser.add_argument("--image_size", default=128, type=int, help="height of square image to pass through model") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument("--rank", default=50, type=int, help="number of neighbors to search for") parser.add_argument( "--filter_same_group", default=False, type=bool, help="custom arg for hurricane data to filter same hurricanes out") args = parser.parse_args() MODEL_PATH = args.MODEL_PATH DATA_PATH = args.DATA_PATH image_size = args.image_size image_type = args.image_type embedding_size = args.image_embedding_size val_split = args.val_split gpus = args.gpus rank_to = args.rank filter_hur = args.filter_same_group #testing # MODEL_PATH = '/content/models/SSL/SIMCLR_SSL_0.pt' # DATA_PATH = '/content/UCMerced_LandUse/Images' # image_size = 128 # image_type = 'tif' # embedding_size = 128 # val_split = 0.2 # gpus = 1 # #gets dataset. We can't combine since validation data has different transform needed train_dataset = FolderDataset( DATA_PATH, validation=False, val_split=val_split, transform=SimCLRTrainDataTransform(image_size), image_type=image_type) print('Training Data Loaded...') val_dataset = FolderDataset(DATA_PATH, validation=True, val_split=val_split, transform=SimCLREvalDataTransform(image_size), image_type=image_type) print('Validation Data Loaded...') #load model num_samples = len(train_dataset) #init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate model = SimCLR(arch='resnet18', batch_size=1, num_samples=num_samples, gpus=gpus, dataset='None') # model.encoder = resnet18(pretrained=False, first_conv=model.first_conv, maxpool1=model.maxpool1, return_all_feature_maps=False) model.projection = Projection(input_dim=512, hidden_dim=256, output_dim=embedding_size) #overrides model.load_state_dict(torch.load(MODEL_PATH)) model.cuda() print('Successfully loaded your model for evaluation.') #running eval on validation data save_path = f"{MODEL_PATH[:-3]}/Evaluation/validationMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) eval_embeddings(model, val_dataset, save_path, rank_to, filter_hur) print('Validation Data Evaluation Complete.') #running eval on training data save_path = f"{MODEL_PATH[:-3]}/Evaluation/trainingMetrics" Path(save_path).mkdir(parents=True, exist_ok=True) eval_embeddings(model, train_dataset, save_path, rank_to, filter_hur) print('Training Data Evaluation Complete.') print(f'Please check {MODEL_PATH[:-3]}/Evaluation/ for your results')