Beispiel #1
0
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss = self.forward(images)
        return {'loss': loss}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.logger.experiment.log_metric('train_loss', avg_loss)
        return {'train_loss': avg_loss}

    def validation_step(self, images, _):
        loss = self.forward(images)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.logger.experiment.log_metric('val_loss', avg_loss)
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        self.learner.update_moving_average()
Beispiel #2
0
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, dataset: Dataset, batch_size: int = 32, **kwargs):
        super().__init__()
        self.batch_size = batch_size
        self.dataset = dataset
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def train_dataloader(self):
        return DataLoader(self.dataset,
                          batch_size=self.batch_size,
                          num_workers=NUM_WORKERS,
                          shuffle=True,
                          pin_memory=True)

    def training_step(self, images, _):
        loss = self.forward(images)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()
def train_resnet_by_byol(args, train_loader):

    # define cnn(resnet) and byol
    resnet = resnet50(pretrained=True).to(args.device)
    learner = BYOL(resnet, image_size = args.height, hidden_layer = 'avgpool')
    opt = torch.optim.Adam(learner.parameters(), lr=args.lr)

    # train resnet via BYOL
    print("BYOL training start -- ")
    for epoch in range(args.epoch):
        loss_history=[]
        for data, _ in train_loader:
            images = data.to(args.device)
            loss = learner(images)
            loss_float = loss.detach().cpu().numpy().tolist()
            loss_history.append(loss_float)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            learner.update_moving_average() # update moving average of target encoder
            
        if epoch % 5 == 0:
            print(f"EPOCH: {epoch} / loss: {sum(loss_history)/len(train_loader)}")

    return resnet
Beispiel #4
0
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss = self.forward(images)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        self.learner.update_moving_average()
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss = self.forward(images)
        self.log('loss',
                 loss,
                 on_step=True,
                 on_epoch=True,
                 prog_bar=True,
                 logger=True)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()
Beispiel #6
0
def run_train(dataloader):

    if torch.cuda.is_available():
        dev = "cuda:0"
    else:
        dev = "cpu"

    device = torch.device(dev)

    model = AlexNet().to(device)

    learner = BYOL(
        model,
        image_size=32,
    )

    opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

    for epoch in range(EPOCH):

        for step, (images, _) in tqdm(
                enumerate(dataloader),
                total=int(60000 / BATCH_SIZE),
                desc="Epoch {d}: Training on batch".format(d=epoch)):
            images = images.to(device)
            loss = learner(images)
            opt.zero_grad()
            loss.backward()
            opt.step()
            learner.update_moving_average(
            )  # update moving average of target encoder

            wandb.log({"train_loss": loss.cpu(), "step": step, "epoch": epoch})

    # save your improved network
    torch.save(model.state_dict(), './improved-net.pt')
Beispiel #7
0
class BYOLHandler:
    """Encapsulates different utility methods for working with BYOL model."""
    def __init__(self,
                 device: Union[str, torch.device] = "cpu",
                 load_from_path: str = None,
                 augmentations: nn.Sequential = None) -> None:
        """Initialise model and learner setup."""
        self.device = device
        self.model = models.wide_resnet101_2(pretrained=False).to(self.device)
        self.timestamp = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")

        if load_from_path is not None:
            print("Loading model...")
            state_dict = torch.load(load_from_path)
            self.model.load_state_dict(state_dict)

        if augmentations is None:
            self.learner = BYOL(self.model,
                                image_size=64,
                                hidden_layer="avgpool")

        if augmentations is not None:
            self.learner = BYOL(self.model,
                                image_size=64,
                                hidden_layer="avgpool",
                                augment_fn=augmentations)

        self.opt = torch.optim.Adam(self.learner.parameters(),
                                    lr=0.0001,
                                    betas=(0.9, 0.999))
        self.loss_history: list[float] = []

    def train(self,
              dataset: torch.utils.data.Dataset,
              epochs: int = 1,
              use_tqdm: bool = False) -> None:
        """Train model on dataset for specified number of epochs."""
        for i in range(epochs):
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=64,
                                                     shuffle=True)
            if use_tqdm:
                dataloader = tqdm(dataloader)
            for images in dataloader:
                device_images = images.to(self.device)
                loss = self.learner(device_images)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                self.learner.update_moving_average()
                self.loss_history.append(loss.detach().item())
                del device_images
                torch.cuda.empty_cache()
            print("Epochs performed:", i + 1)

    def save(self) -> None:
        """Save model."""
        if not os.path.exists("outputs"):
            os.mkdir("outputs")
        dirpath = os.path.join("outputs", self.timestamp)
        if not os.path.exists(dirpath):
            os.mkdir(dirpath)

        save_path = os.path.join(dirpath, "model.pt")
        torch.save(self.model.state_dict(), save_path)

        # Plot loss history:
        fig, ax = plt.subplots()
        ax.plot(self.loss_history)
        fig.tight_layout()
        fig.savefig(os.path.join(dirpath, "loss_v_batch.png"), dpi=300)
        plt.close()

    def infer(self,
              dataset: torch.utils.data.Dataset,
              use_tqdm: bool = False) -> np.ndarray:
        """Use model to infer embeddings of provided dataset."""
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=512,
                                                 shuffle=False,
                                                 num_workers=0)
        embeddings_list: list[np.ndarray] = []
        with torch.no_grad():
            self.model.eval()
            if use_tqdm:
                dataloader = tqdm(dataloader)
            for data in dataloader:
                device_data = data.to(self.device)
                projection, embedding = self.learner(device_data,
                                                     return_embedding=True)
                np_embedding = embedding.detach().cpu().numpy()
                np_embedding = np.reshape(np_embedding, (data.shape[0], -1))
                embeddings_list.append(np_embedding)
                del device_data
                del projection
                del embedding
                torch.cuda.empty_cache()
            self.model.train()

        return np.concatenate(embeddings_list, axis=0)
Beispiel #8
0
import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(resnet, image_size=256, hidden_layer='avgpool')

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)


def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)


for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average()  # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')
Beispiel #9
0
class SelfSupervisedLearner(pl.LightningModule):
    def __init__(
            self, net, train_dataset: Dataset, valid_dataset: Dataset,
            epochs: int, learning_rate: float,
            batch_size: int = 32, num_gpus: int = 1, **kwargs):
        super().__init__()
        self.batch_size = batch_size
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.learner = BYOL(net, **kwargs)
        self.num_gpus = num_gpus

    def forward(self, images):
        return self.learner(images)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size,
            num_workers=NUM_WORKERS, shuffle=True, pin_memory=True,
            drop_last=True
        )

    def get_progress_bar_dict(self):
        # don't show the experiment version number
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        return items

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset, batch_size=self.batch_size,
            num_workers=NUM_WORKERS, shuffle=False, pin_memory=True,
            drop_last=False
        )

    def validation_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log('val_loss', loss, sync_dist=True)

    def training_step(self, images, _):
        loss = self.forward(images)
        # opt = self.optimizers()
        # self.manual_backward(loss, opt)
        # opt.step()
        # opt.zero_grad()
        self.log("loss", loss, sync_dist=True)
        # print(loss)
        return loss

    def configure_optimizers(self):
        layer_groups = [
            [self.learner.online_encoder.net],
            [
                self.learner.online_encoder.projector,
                self.learner.online_predictor
            ]
        ]
        steps_per_epochs = math.floor(
            len(self.train_dataset) / self.batch_size / self.num_gpus
        )
        print("Steps per epochs:", steps_per_epochs)
        n_steps = steps_per_epochs * self.epochs
        lr_durations = [
            int(n_steps*0.05),
            int(np.ceil(n_steps*0.95)) + 1
        ]
        break_points = [0] + list(np.cumsum(lr_durations))[:-1]
        optimizer = RAdam([
            {
                "params": chain.from_iterable([x.parameters() for x in layer_groups[0]]),
                "lr": self.learning_rate * 0.2
            },
            {
                "params": chain.from_iterable([x.parameters() for x in layer_groups[1]]),
                "lr": self.learning_rate
            }
        ])
        scheduler = {
            'scheduler': MultiStageScheduler(
                [
                    LinearLR(optimizer, 0.01, lr_durations[0]),
                    CosineAnnealingLR(optimizer, lr_durations[1])
                ],
                start_at_epochs=break_points
            ),
            # 'scheduler': CosineAnnealingLR(optimizer, n_steps, eta_min=1e-8),
            'interval': 'step',
            'frequency': 1,
            'strict': True,
        }
        print(optimizer)
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler
        }

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()