Пример #1
0
 def test_momentum_works(self):
     model = nn.Sequential(
         nn.Linear(32, 32),
         nn.ReLU(),
     )
     model_momentum = copy.deepcopy(model)
     update_momentum(model, model_momentum, 0.99)
Пример #2
0
    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch

        # update momentum
        update_momentum(self.backbone, self.backbone_momentum, 0.99)
        update_momentum(self.projection_head, self.projection_head_momentum,
                        0.99)

        def step(x0_, x1_):
            x1_, shuffle = batch_shuffle(x1_, distributed=distributed)
            x0_ = self.backbone(x0_).flatten(start_dim=1)
            x0_ = self.projection_head(x0_)

            x1_ = self.backbone_momentum(x1_).flatten(start_dim=1)
            x1_ = self.projection_head_momentum(x1_)
            x1_ = batch_unshuffle(x1_, shuffle, distributed=distributed)
            return x0_, x1_

        # We use a symmetric loss (model trains faster at little compute overhead)
        # https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb
        loss_1 = self.criterion(*step(x0, x1))
        loss_2 = self.criterion(*step(x1, x0))

        loss = 0.5 * (loss_1 + loss_2)
        self.log('train_loss_ssl', loss)
        return loss
Пример #3
0
 def training_step(self, batch, batch_idx):
     update_momentum(self.backbone, self.backbone_momentum, m=0.99)
     update_momentum(self.projection_head,
                     self.projection_head_momentum,
                     m=0.99)
     (x_query, x_key), _, _ = batch
     query = self.forward(x_query)
     key = self.forward_momentum(x_key)
     loss = self.criterion(query, key)
     return loss
Пример #4
0
 def training_step(self, batch, batch_idx):
     update_momentum(self.backbone, self.backbone_momentum, m=0.99)
     update_momentum(self.projection_head, self.projection_head_momentum, m=0.99)
     (x0, x1), _, _ = batch
     p0 = self.forward(x0)
     z0 = self.forward_momentum(x0)
     p1 = self.forward(x1)
     z1 = self.forward_momentum(x1)
     loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0))
     return loss
Пример #5
0
 def training_step(self, batch, batch_idx):
     update_momentum(self.student_backbone, self.teacher_backbone, m=0.99)
     update_momentum(self.student_head, self.teacher_head, m=0.99)
     views, _, _ = batch
     views = [view.to(self.device) for view in views]
     global_views = views[:2]
     teacher_out = [self.forward_teacher(view) for view in global_views]
     student_out = [self.forward(view) for view in views]
     loss = self.criterion(teacher_out,
                           student_out,
                           epoch=self.current_epoch)
     return loss
Пример #6
0
    def training_step(self, batch, batch_idx):
        (x_q, x_k), _, _ = batch

        # update momentum
        update_momentum(self.backbone, self.backbone_momentum, 0.99)
        update_momentum(self.projection_head, self.projection_head_momentum,
                        0.99)

        # get queries
        q = self.backbone(x_q).flatten(start_dim=1)
        q = self.projection_head(q)

        # get keys
        k, shuffle = batch_shuffle(x_k)
        k = self.backbone_momentum(k).flatten(start_dim=1)
        k = self.projection_head_momentum(k)
        k = batch_unshuffle(k, shuffle)

        loss = self.criterion(q, k)
        self.log("train_loss_ssl", loss)
        return loss
Пример #7
0
    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch

        # update momentum
        update_momentum(self.backbone, self.backbone_momentum, 0.99)
        update_momentum(self.projection_head, self.projection_head_momentum, 0.99)

        def step(x0_, x1_):
            x0_ = self.backbone(x0_).flatten(start_dim=1)
            x0_ = self.projection_head(x0_)
            x0_ = self.prediction_head(x0_)

            x1_ = self.backbone_momentum(x1_).flatten(start_dim=1)
            x1_ = self.projection_head_momentum(x1_)
            return x0_, x1_

        p0, z1 = step(x0, x1)
        p1, z0 = step(x1, x0)
        
        loss = self.criterion((z0, p0), (z1, p1))
        self.log('train_loss_ssl', loss)
        return loss
Пример #8
0
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = NTXentLoss(memory_bank_size=4096)
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for (x_query, x_key), _, _ in dataloader:
        update_momentum(model.backbone, model.backbone_momentum, m=0.99)
        update_momentum(model.projection_head, model.projection_head_momentum, m=0.99)
        x_query = x_query.to(device)
        x_key = x_key.to(device)
        query = model(x_query)
        key = model.forward_momentum(x_key)
        loss = criterion(query, key)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
Пример #9
0
    num_workers=8,
)

criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp_epochs=5,
)
# move loss to correct device because it also contains parameters
criterion = criterion.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for views, _, _ in dataloader:
        update_momentum(model.student_backbone, model.teacher_backbone, m=0.99)
        update_momentum(model.student_head, model.teacher_head, m=0.99)
        views = [view.to(device) for view in views]
        global_views = views[:2]
        teacher_out = [model.forward_teacher(view) for view in global_views]
        student_out = [model.forward(view) for view in views]
        loss = criterion(teacher_out, student_out, epoch=epoch)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")