def train_resnet_by_byol(args, train_loader): # define cnn(resnet) and byol resnet = resnet50(pretrained=True).to(args.device) learner = BYOL(resnet, image_size = args.height, hidden_layer = 'avgpool') opt = torch.optim.Adam(learner.parameters(), lr=args.lr) # train resnet via BYOL print("BYOL training start -- ") for epoch in range(args.epoch): loss_history=[] for data, _ in train_loader: images = data.to(args.device) loss = learner(images) loss_float = loss.detach().cpu().numpy().tolist() loss_history.append(loss_float) opt.zero_grad() loss.backward() opt.step() learner.update_moving_average() # update moving average of target encoder if epoch % 5 == 0: print(f"EPOCH: {epoch} / loss: {sum(loss_history)/len(train_loader)}") return resnet
def get_learner(config): resnet = torchvision.models.resnet18(pretrained=True) learner = BYOL(resnet, image_size=32, hidden_layer='avgpool') opt = torch.optim.Adam(learner.parameters(), lr=config.lr) learner = learner.cuda() return learner
def run_train(dataloader): if torch.cuda.is_available(): dev = "cuda:0" else: dev = "cpu" device = torch.device(dev) model = AlexNet().to(device) learner = BYOL( model, image_size=32, ) opt = torch.optim.Adam(learner.parameters(), lr=3e-4) for epoch in range(EPOCH): for step, (images, _) in tqdm( enumerate(dataloader), total=int(60000 / BATCH_SIZE), desc="Epoch {d}: Training on batch".format(d=epoch)): images = images.to(device) loss = learner(images) opt.zero_grad() loss.backward() opt.step() learner.update_moving_average( ) # update moving average of target encoder wandb.log({"train_loss": loss.cpu(), "step": step, "epoch": epoch}) # save your improved network torch.save(model.state_dict(), './improved-net.pt')
class BYOLHandler: """Encapsulates different utility methods for working with BYOL model.""" def __init__(self, device: Union[str, torch.device] = "cpu", load_from_path: str = None, augmentations: nn.Sequential = None) -> None: """Initialise model and learner setup.""" self.device = device self.model = models.wide_resnet101_2(pretrained=False).to(self.device) self.timestamp = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") if load_from_path is not None: print("Loading model...") state_dict = torch.load(load_from_path) self.model.load_state_dict(state_dict) if augmentations is None: self.learner = BYOL(self.model, image_size=64, hidden_layer="avgpool") if augmentations is not None: self.learner = BYOL(self.model, image_size=64, hidden_layer="avgpool", augment_fn=augmentations) self.opt = torch.optim.Adam(self.learner.parameters(), lr=0.0001, betas=(0.9, 0.999)) self.loss_history: list[float] = [] def train(self, dataset: torch.utils.data.Dataset, epochs: int = 1, use_tqdm: bool = False) -> None: """Train model on dataset for specified number of epochs.""" for i in range(epochs): dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True) if use_tqdm: dataloader = tqdm(dataloader) for images in dataloader: device_images = images.to(self.device) loss = self.learner(device_images) self.opt.zero_grad() loss.backward() self.opt.step() self.learner.update_moving_average() self.loss_history.append(loss.detach().item()) del device_images torch.cuda.empty_cache() print("Epochs performed:", i + 1) def save(self) -> None: """Save model.""" if not os.path.exists("outputs"): os.mkdir("outputs") dirpath = os.path.join("outputs", self.timestamp) if not os.path.exists(dirpath): os.mkdir(dirpath) save_path = os.path.join(dirpath, "model.pt") torch.save(self.model.state_dict(), save_path) # Plot loss history: fig, ax = plt.subplots() ax.plot(self.loss_history) fig.tight_layout() fig.savefig(os.path.join(dirpath, "loss_v_batch.png"), dpi=300) plt.close() def infer(self, dataset: torch.utils.data.Dataset, use_tqdm: bool = False) -> np.ndarray: """Use model to infer embeddings of provided dataset.""" dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0) embeddings_list: list[np.ndarray] = [] with torch.no_grad(): self.model.eval() if use_tqdm: dataloader = tqdm(dataloader) for data in dataloader: device_data = data.to(self.device) projection, embedding = self.learner(device_data, return_embedding=True) np_embedding = embedding.detach().cpu().numpy() np_embedding = np.reshape(np_embedding, (data.shape[0], -1)) embeddings_list.append(np_embedding) del device_data del projection del embedding torch.cuda.empty_cache() self.model.train() return np.concatenate(embeddings_list, axis=0)
import torch from byol_pytorch import BYOL from torchvision import models resnet = models.resnet50(pretrained=True) learner = BYOL(resnet, image_size=256, hidden_layer='avgpool') opt = torch.optim.Adam(learner.parameters(), lr=3e-4) def sample_unlabelled_images(): return torch.randn(20, 3, 256, 256) for _ in range(100): images = sample_unlabelled_images() loss = learner(images) opt.zero_grad() loss.backward() opt.step() learner.update_moving_average() # update moving average of target encoder # save your improved network torch.save(resnet.state_dict(), './improved-net.pt')
elif args.depth == 101: model = models.resnet101(pretrained=False).cuda() else: assert ("Not supported Depth") learner = BYOL( model, image_size=args.image_size, hidden_layer='avgpool', projection_size=256, projection_hidden_size=4096, moving_average_decay=0.99, use_momentum=False # turn off momentum in the target encoder ) opt = torch.optim.Adam(learner.parameters(), lr=args.lr) ds = ImagesDataset(args.image_folder, args.image_size) trainloader = DataLoader(ds, batch_size=args.batch_size, num_workers=NUM_WORKERS, shuffle=True) losses = AverageMeter() for epoch in range(args.epoch): bar = Bar('Processing', max=len(trainloader)) for batch_idx, inputs in enumerate(trainloader): loss = learner(inputs.cuda()) losses.update(loss.data.item(), inputs.size(0)) opt.zero_grad() loss.backward() opt.step() # plot progress