class SelfSupervisedLearner(pl.LightningModule): def __init__(self, net, **kwargs): super().__init__() self.learner = BYOL(net, **kwargs) def forward(self, images): return self.learner(images) def training_step(self, images, _): loss = self.forward(images) return {'loss': loss} def training_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() self.logger.experiment.log_metric('train_loss', avg_loss) return {'train_loss': avg_loss} def validation_step(self, images, _): loss = self.forward(images) return {'val_loss': loss} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() self.logger.experiment.log_metric('val_loss', avg_loss) return {'val_loss': avg_loss} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=LR) def on_before_zero_grad(self, _): self.learner.update_moving_average()
class SelfSupervisedLearner(pl.LightningModule): def __init__(self, net, dataset: Dataset, batch_size: int = 32, **kwargs): super().__init__() self.batch_size = batch_size self.dataset = dataset self.learner = BYOL(net, **kwargs) def forward(self, images): return self.learner(images) def train_dataloader(self): return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS, shuffle=True, pin_memory=True) def training_step(self, images, _): loss = self.forward(images) return {'loss': loss} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=LR) def on_before_zero_grad(self, _): if self.learner.use_momentum: self.learner.update_moving_average()
def train_resnet_by_byol(args, train_loader): # define cnn(resnet) and byol resnet = resnet50(pretrained=True).to(args.device) learner = BYOL(resnet, image_size = args.height, hidden_layer = 'avgpool') opt = torch.optim.Adam(learner.parameters(), lr=args.lr) # train resnet via BYOL print("BYOL training start -- ") for epoch in range(args.epoch): loss_history=[] for data, _ in train_loader: images = data.to(args.device) loss = learner(images) loss_float = loss.detach().cpu().numpy().tolist() loss_history.append(loss_float) opt.zero_grad() loss.backward() opt.step() learner.update_moving_average() # update moving average of target encoder if epoch % 5 == 0: print(f"EPOCH: {epoch} / loss: {sum(loss_history)/len(train_loader)}") return resnet
class SelfSupervisedLearner(pl.LightningModule): def __init__(self, net, **kwargs): super().__init__() self.learner = BYOL(net, **kwargs) def forward(self, images): return self.learner(images) def training_step(self, images, _): loss = self.forward(images) return {'loss': loss} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=LR) def on_before_zero_grad(self, _): self.learner.update_moving_average()
class SelfSupervisedLearner(pl.LightningModule): def __init__(self, net, **kwargs): super().__init__() self.learner = BYOL(net, **kwargs) def forward(self, images): return self.learner(images) def training_step(self, images, _): loss = self.forward(images) self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return {'loss': loss} def configure_optimizers(self): return torch.optim.Adam(self.parameters()) def on_before_zero_grad(self, _): if self.learner.use_momentum: self.learner.update_moving_average()
def run_train(dataloader): if torch.cuda.is_available(): dev = "cuda:0" else: dev = "cpu" device = torch.device(dev) model = AlexNet().to(device) learner = BYOL( model, image_size=32, ) opt = torch.optim.Adam(learner.parameters(), lr=3e-4) for epoch in range(EPOCH): for step, (images, _) in tqdm( enumerate(dataloader), total=int(60000 / BATCH_SIZE), desc="Epoch {d}: Training on batch".format(d=epoch)): images = images.to(device) loss = learner(images) opt.zero_grad() loss.backward() opt.step() learner.update_moving_average( ) # update moving average of target encoder wandb.log({"train_loss": loss.cpu(), "step": step, "epoch": epoch}) # save your improved network torch.save(model.state_dict(), './improved-net.pt')
class BYOLHandler: """Encapsulates different utility methods for working with BYOL model.""" def __init__(self, device: Union[str, torch.device] = "cpu", load_from_path: str = None, augmentations: nn.Sequential = None) -> None: """Initialise model and learner setup.""" self.device = device self.model = models.wide_resnet101_2(pretrained=False).to(self.device) self.timestamp = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") if load_from_path is not None: print("Loading model...") state_dict = torch.load(load_from_path) self.model.load_state_dict(state_dict) if augmentations is None: self.learner = BYOL(self.model, image_size=64, hidden_layer="avgpool") if augmentations is not None: self.learner = BYOL(self.model, image_size=64, hidden_layer="avgpool", augment_fn=augmentations) self.opt = torch.optim.Adam(self.learner.parameters(), lr=0.0001, betas=(0.9, 0.999)) self.loss_history: list[float] = [] def train(self, dataset: torch.utils.data.Dataset, epochs: int = 1, use_tqdm: bool = False) -> None: """Train model on dataset for specified number of epochs.""" for i in range(epochs): dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True) if use_tqdm: dataloader = tqdm(dataloader) for images in dataloader: device_images = images.to(self.device) loss = self.learner(device_images) self.opt.zero_grad() loss.backward() self.opt.step() self.learner.update_moving_average() self.loss_history.append(loss.detach().item()) del device_images torch.cuda.empty_cache() print("Epochs performed:", i + 1) def save(self) -> None: """Save model.""" if not os.path.exists("outputs"): os.mkdir("outputs") dirpath = os.path.join("outputs", self.timestamp) if not os.path.exists(dirpath): os.mkdir(dirpath) save_path = os.path.join(dirpath, "model.pt") torch.save(self.model.state_dict(), save_path) # Plot loss history: fig, ax = plt.subplots() ax.plot(self.loss_history) fig.tight_layout() fig.savefig(os.path.join(dirpath, "loss_v_batch.png"), dpi=300) plt.close() def infer(self, dataset: torch.utils.data.Dataset, use_tqdm: bool = False) -> np.ndarray: """Use model to infer embeddings of provided dataset.""" dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0) embeddings_list: list[np.ndarray] = [] with torch.no_grad(): self.model.eval() if use_tqdm: dataloader = tqdm(dataloader) for data in dataloader: device_data = data.to(self.device) projection, embedding = self.learner(device_data, return_embedding=True) np_embedding = embedding.detach().cpu().numpy() np_embedding = np.reshape(np_embedding, (data.shape[0], -1)) embeddings_list.append(np_embedding) del device_data del projection del embedding torch.cuda.empty_cache() self.model.train() return np.concatenate(embeddings_list, axis=0)
import torch from byol_pytorch import BYOL from torchvision import models resnet = models.resnet50(pretrained=True) learner = BYOL(resnet, image_size=256, hidden_layer='avgpool') opt = torch.optim.Adam(learner.parameters(), lr=3e-4) def sample_unlabelled_images(): return torch.randn(20, 3, 256, 256) for _ in range(100): images = sample_unlabelled_images() loss = learner(images) opt.zero_grad() loss.backward() opt.step() learner.update_moving_average() # update moving average of target encoder # save your improved network torch.save(resnet.state_dict(), './improved-net.pt')
class SelfSupervisedLearner(pl.LightningModule): def __init__( self, net, train_dataset: Dataset, valid_dataset: Dataset, epochs: int, learning_rate: float, batch_size: int = 32, num_gpus: int = 1, **kwargs): super().__init__() self.batch_size = batch_size self.train_dataset = train_dataset self.valid_dataset = valid_dataset self.learning_rate = learning_rate self.epochs = epochs self.learner = BYOL(net, **kwargs) self.num_gpus = num_gpus def forward(self, images): return self.learner(images) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS, shuffle=True, pin_memory=True, drop_last=True ) def get_progress_bar_dict(self): # don't show the experiment version number items = super().get_progress_bar_dict() items.pop("v_num", None) return items def val_dataloader(self): return DataLoader( self.valid_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS, shuffle=False, pin_memory=True, drop_last=False ) def validation_step(self, batch, batch_idx): loss = self.forward(batch) self.log('val_loss', loss, sync_dist=True) def training_step(self, images, _): loss = self.forward(images) # opt = self.optimizers() # self.manual_backward(loss, opt) # opt.step() # opt.zero_grad() self.log("loss", loss, sync_dist=True) # print(loss) return loss def configure_optimizers(self): layer_groups = [ [self.learner.online_encoder.net], [ self.learner.online_encoder.projector, self.learner.online_predictor ] ] steps_per_epochs = math.floor( len(self.train_dataset) / self.batch_size / self.num_gpus ) print("Steps per epochs:", steps_per_epochs) n_steps = steps_per_epochs * self.epochs lr_durations = [ int(n_steps*0.05), int(np.ceil(n_steps*0.95)) + 1 ] break_points = [0] + list(np.cumsum(lr_durations))[:-1] optimizer = RAdam([ { "params": chain.from_iterable([x.parameters() for x in layer_groups[0]]), "lr": self.learning_rate * 0.2 }, { "params": chain.from_iterable([x.parameters() for x in layer_groups[1]]), "lr": self.learning_rate } ]) scheduler = { 'scheduler': MultiStageScheduler( [ LinearLR(optimizer, 0.01, lr_durations[0]), CosineAnnealingLR(optimizer, lr_durations[1]) ], start_at_epochs=break_points ), # 'scheduler': CosineAnnealingLR(optimizer, n_steps, eta_min=1e-8), 'interval': 'step', 'frequency': 1, 'strict': True, } print(optimizer) return { 'optimizer': optimizer, 'lr_scheduler': scheduler } def on_before_zero_grad(self, _): if self.learner.use_momentum: self.learner.update_moving_average()