def test_image_collate(self): batch = self.create_batch() img_collate = ImageCollateFunction() samples, labels, fnames = img_collate(batch) self.assertIsNotNone(img_collate) self.assertEqual(len(samples), len(labels), len(fnames))
def test_image_collate_tuple_input_size(self): batch = self.create_batch() img_collate = ImageCollateFunction(input_size=(32, 32), ) samples, labels, fnames = img_collate(batch) samples0, samples1 = samples self.assertIsNotNone(img_collate) self.assertEqual(len(samples0), len(samples1)) self.assertEqual(len(samples1), len(labels), len(fnames))
return z resnet = torchvision.models.resnet18() backbone = nn.Sequential(*list(resnet.children())[:-1]) model = BarlowTwins(backbone) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True) dataset = LightlyDataset.from_torch_dataset(cifar10) # or create a dataset from a folder containing images or videos: # dataset = LightlyDataset("path/to/folder") collate_fn = ImageCollateFunction(input_size=32) dataloader = torch.utils.data.DataLoader( dataset, batch_size=256, collate_fn=collate_fn, shuffle=True, drop_last=True, num_workers=8, ) criterion = BarlowTwinsLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.06) print("Starting Training") for epoch in range(10):
def _train_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) if 'seed' in cfg.keys(): seed = cfg['seed'] torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if cfg["trainer"]["weights_summary"] == "None": cfg["trainer"]["weights_summary"] = None if torch.cuda.is_available(): device = 'cuda' elif cfg['trainer'] and cfg['trainer']['gpus']: device = 'cpu' cfg['trainer']['gpus'] = 0 else: device = 'cpu' distributed_strategy = None if cfg['trainer']['gpus'] > 1: distributed_strategy = 'ddp' if cfg['loader']['batch_size'] < 64: msg = 'Training a self-supervised model with a small batch size: {}! ' msg = msg.format(cfg['loader']['batch_size']) msg += 'Small batch size may harm embedding quality. ' msg += 'You can specify the batch size via the loader key-word: ' msg += 'loader.batch_size=BSZ' warnings.warn(msg) # determine the number of available cores if cfg['loader']['num_workers'] < 0: cfg['loader']['num_workers'] = cpu_count() state_dict = None checkpoint = cfg['checkpoint'] if cfg['pre_trained'] and not checkpoint: # if checkpoint wasn't specified explicitly and pre_trained is True # try to load the checkpoint from the model zoo checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist! ' msg += 'Model will be trained from scratch.' warnings.warn(msg) elif checkpoint: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint if checkpoint: # load the PyTorch state dictionary and map it to the current device if is_url(checkpoint): state_dict = load_state_dict_from_url( checkpoint, map_location=device)['state_dict'] else: state_dict = torch.load(checkpoint, map_location=device)['state_dict'] # load model resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), nn.AdaptiveAvgPool2d(1), ) model = _SimCLR(features, num_ftrs=cfg['model']['num_ftrs'], out_dim=cfg['model']['out_dim']) if state_dict is not None: load_from_state_dict(model, state_dict) criterion = NTXentLoss(**cfg['criterion']) optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer']) dataset = LightlyDataset(input_dir) cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'], len(dataset)) collate_fn = ImageCollateFunction(**cfg['collate']) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'], collate_fn=collate_fn) encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader) encoder.init_checkpoint_callback(**cfg['checkpoint_callback']) encoder.train_embedding(**cfg['trainer'], strategy=distributed_strategy) print( f'Best model is stored at: {bcolors.OKBLUE}{encoder.checkpoint}{bcolors.ENDC}' ) os.environ[cfg['environment_variable_names'] ['lightly_last_checkpoint_path']] = encoder.checkpoint return encoder.checkpoint
def _train_cli(cfg, is_cli_call=True): data = cfg['data'] download = cfg['download'] root = cfg['root'] if root and is_cli_call: root = fix_input_path(root) input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) if 'seed' in cfg.keys(): seed = cfg['seed'] torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): device = 'cuda' elif cfg['trainer'] and cfg['trainer']['gpus']: device = 'cpu' cfg['trainer']['gpus'] = 0 if cfg['loader']['batch_size'] < 64: msg = 'Training a self-supervised model with a small batch size: {}! ' msg = msg.format(cfg['loader']['batch_size']) msg += 'Small batch size may harm embedding quality. ' msg += 'You can specify the batch size via the loader key-word: ' msg += 'loader.batch_size=BSZ' warnings.warn(msg) state_dict = None checkpoint = cfg['checkpoint'] if cfg['pre_trained'] and not checkpoint: # if checkpoint wasn't specified explicitly and pre_trained is True # try to load the checkpoint from the model zoo checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist! ' msg += 'Model will be trained from scratch.' warnings.warn(msg) elif checkpoint: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint if checkpoint: # load the PyTorch state dictionary and map it to the current device if is_url(checkpoint): state_dict = load_state_dict_from_url( checkpoint, map_location=device )['state_dict'] else: state_dict = torch.load( checkpoint, map_location=device )['state_dict'] # load model model = ResNetSimCLR(**cfg['model']) if state_dict is not None: model.load_from_state_dict(state_dict) criterion = NTXentLoss(**cfg['criterion']) optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer']) dataset = LightlyDataset(root, name=data, train=True, download=download, from_folder=input_dir) cfg['loader']['batch_size'] = min( cfg['loader']['batch_size'], len(dataset) ) collate_fn = ImageCollateFunction(**cfg['collate']) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'], collate_fn=collate_fn) encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader) encoder.init_checkpoint_callback(**cfg['checkpoint_callback']) encoder = encoder.train_embedding(**cfg['trainer']) print('Best model is stored at: %s' % (encoder.checkpoint)) return encoder.checkpoint