示例#1
0
def train_resnet_by_byol(args, train_loader):

    # define cnn(resnet) and byol
    resnet = resnet50(pretrained=True).to(args.device)
    learner = BYOL(resnet, image_size = args.height, hidden_layer = 'avgpool')
    opt = torch.optim.Adam(learner.parameters(), lr=args.lr)

    # train resnet via BYOL
    print("BYOL training start -- ")
    for epoch in range(args.epoch):
        loss_history=[]
        for data, _ in train_loader:
            images = data.to(args.device)
            loss = learner(images)
            loss_float = loss.detach().cpu().numpy().tolist()
            loss_history.append(loss_float)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            learner.update_moving_average() # update moving average of target encoder
            
        if epoch % 5 == 0:
            print(f"EPOCH: {epoch} / loss: {sum(loss_history)/len(train_loader)}")

    return resnet
示例#2
0
def get_learner(config):
    resnet = torchvision.models.resnet18(pretrained=True)
    learner = BYOL(resnet, image_size=32, hidden_layer='avgpool')

    opt = torch.optim.Adam(learner.parameters(), lr=config.lr)

    learner = learner.cuda()
    return learner
示例#3
0
def run_train(dataloader):

    if torch.cuda.is_available():
        dev = "cuda:0"
    else:
        dev = "cpu"

    device = torch.device(dev)

    model = AlexNet().to(device)

    learner = BYOL(
        model,
        image_size=32,
    )

    opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

    for epoch in range(EPOCH):

        for step, (images, _) in tqdm(
                enumerate(dataloader),
                total=int(60000 / BATCH_SIZE),
                desc="Epoch {d}: Training on batch".format(d=epoch)):
            images = images.to(device)
            loss = learner(images)
            opt.zero_grad()
            loss.backward()
            opt.step()
            learner.update_moving_average(
            )  # update moving average of target encoder

            wandb.log({"train_loss": loss.cpu(), "step": step, "epoch": epoch})

    # save your improved network
    torch.save(model.state_dict(), './improved-net.pt')
示例#4
0
class BYOLHandler:
    """Encapsulates different utility methods for working with BYOL model."""
    def __init__(self,
                 device: Union[str, torch.device] = "cpu",
                 load_from_path: str = None,
                 augmentations: nn.Sequential = None) -> None:
        """Initialise model and learner setup."""
        self.device = device
        self.model = models.wide_resnet101_2(pretrained=False).to(self.device)
        self.timestamp = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")

        if load_from_path is not None:
            print("Loading model...")
            state_dict = torch.load(load_from_path)
            self.model.load_state_dict(state_dict)

        if augmentations is None:
            self.learner = BYOL(self.model,
                                image_size=64,
                                hidden_layer="avgpool")

        if augmentations is not None:
            self.learner = BYOL(self.model,
                                image_size=64,
                                hidden_layer="avgpool",
                                augment_fn=augmentations)

        self.opt = torch.optim.Adam(self.learner.parameters(),
                                    lr=0.0001,
                                    betas=(0.9, 0.999))
        self.loss_history: list[float] = []

    def train(self,
              dataset: torch.utils.data.Dataset,
              epochs: int = 1,
              use_tqdm: bool = False) -> None:
        """Train model on dataset for specified number of epochs."""
        for i in range(epochs):
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=64,
                                                     shuffle=True)
            if use_tqdm:
                dataloader = tqdm(dataloader)
            for images in dataloader:
                device_images = images.to(self.device)
                loss = self.learner(device_images)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                self.learner.update_moving_average()
                self.loss_history.append(loss.detach().item())
                del device_images
                torch.cuda.empty_cache()
            print("Epochs performed:", i + 1)

    def save(self) -> None:
        """Save model."""
        if not os.path.exists("outputs"):
            os.mkdir("outputs")
        dirpath = os.path.join("outputs", self.timestamp)
        if not os.path.exists(dirpath):
            os.mkdir(dirpath)

        save_path = os.path.join(dirpath, "model.pt")
        torch.save(self.model.state_dict(), save_path)

        # Plot loss history:
        fig, ax = plt.subplots()
        ax.plot(self.loss_history)
        fig.tight_layout()
        fig.savefig(os.path.join(dirpath, "loss_v_batch.png"), dpi=300)
        plt.close()

    def infer(self,
              dataset: torch.utils.data.Dataset,
              use_tqdm: bool = False) -> np.ndarray:
        """Use model to infer embeddings of provided dataset."""
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=512,
                                                 shuffle=False,
                                                 num_workers=0)
        embeddings_list: list[np.ndarray] = []
        with torch.no_grad():
            self.model.eval()
            if use_tqdm:
                dataloader = tqdm(dataloader)
            for data in dataloader:
                device_data = data.to(self.device)
                projection, embedding = self.learner(device_data,
                                                     return_embedding=True)
                np_embedding = embedding.detach().cpu().numpy()
                np_embedding = np.reshape(np_embedding, (data.shape[0], -1))
                embeddings_list.append(np_embedding)
                del device_data
                del projection
                del embedding
                torch.cuda.empty_cache()
            self.model.train()

        return np.concatenate(embeddings_list, axis=0)
示例#5
0
import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(resnet, image_size=256, hidden_layer='avgpool')

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)


def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)


for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average()  # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')
示例#6
0
    elif args.depth == 101:
        model = models.resnet101(pretrained=False).cuda()
    else:
        assert ("Not supported Depth")

learner = BYOL(
    model,
    image_size=args.image_size,
    hidden_layer='avgpool',
    projection_size=256,
    projection_hidden_size=4096,
    moving_average_decay=0.99,
    use_momentum=False  # turn off momentum in the target encoder
)

opt = torch.optim.Adam(learner.parameters(), lr=args.lr)
ds = ImagesDataset(args.image_folder, args.image_size)
trainloader = DataLoader(ds, batch_size=args.batch_size, num_workers=NUM_WORKERS, shuffle=True)

losses = AverageMeter()

for epoch in range(args.epoch):
    bar = Bar('Processing', max=len(trainloader))
    for batch_idx, inputs in enumerate(trainloader):
        loss = learner(inputs.cuda())
        losses.update(loss.data.item(), inputs.size(0))

        opt.zero_grad()
        loss.backward()
        opt.step()
        # plot progress