def test_momentum_works(self): model = nn.Sequential( nn.Linear(32, 32), nn.ReLU(), ) model_momentum = copy.deepcopy(model) update_momentum(model, model_momentum, 0.99)
def training_step(self, batch, batch_idx): (x0, x1), _, _ = batch # update momentum update_momentum(self.backbone, self.backbone_momentum, 0.99) update_momentum(self.projection_head, self.projection_head_momentum, 0.99) def step(x0_, x1_): x1_, shuffle = batch_shuffle(x1_, distributed=distributed) x0_ = self.backbone(x0_).flatten(start_dim=1) x0_ = self.projection_head(x0_) x1_ = self.backbone_momentum(x1_).flatten(start_dim=1) x1_ = self.projection_head_momentum(x1_) x1_ = batch_unshuffle(x1_, shuffle, distributed=distributed) return x0_, x1_ # We use a symmetric loss (model trains faster at little compute overhead) # https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb loss_1 = self.criterion(*step(x0, x1)) loss_2 = self.criterion(*step(x1, x0)) loss = 0.5 * (loss_1 + loss_2) self.log('train_loss_ssl', loss) return loss
def training_step(self, batch, batch_idx): update_momentum(self.backbone, self.backbone_momentum, m=0.99) update_momentum(self.projection_head, self.projection_head_momentum, m=0.99) (x_query, x_key), _, _ = batch query = self.forward(x_query) key = self.forward_momentum(x_key) loss = self.criterion(query, key) return loss
def training_step(self, batch, batch_idx): update_momentum(self.backbone, self.backbone_momentum, m=0.99) update_momentum(self.projection_head, self.projection_head_momentum, m=0.99) (x0, x1), _, _ = batch p0 = self.forward(x0) z0 = self.forward_momentum(x0) p1 = self.forward(x1) z1 = self.forward_momentum(x1) loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0)) return loss
def training_step(self, batch, batch_idx): update_momentum(self.student_backbone, self.teacher_backbone, m=0.99) update_momentum(self.student_head, self.teacher_head, m=0.99) views, _, _ = batch views = [view.to(self.device) for view in views] global_views = views[:2] teacher_out = [self.forward_teacher(view) for view in global_views] student_out = [self.forward(view) for view in views] loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch) return loss
def training_step(self, batch, batch_idx): (x_q, x_k), _, _ = batch # update momentum update_momentum(self.backbone, self.backbone_momentum, 0.99) update_momentum(self.projection_head, self.projection_head_momentum, 0.99) # get queries q = self.backbone(x_q).flatten(start_dim=1) q = self.projection_head(q) # get keys k, shuffle = batch_shuffle(x_k) k = self.backbone_momentum(k).flatten(start_dim=1) k = self.projection_head_momentum(k) k = batch_unshuffle(k, shuffle) loss = self.criterion(q, k) self.log("train_loss_ssl", loss) return loss
def training_step(self, batch, batch_idx): (x0, x1), _, _ = batch # update momentum update_momentum(self.backbone, self.backbone_momentum, 0.99) update_momentum(self.projection_head, self.projection_head_momentum, 0.99) def step(x0_, x1_): x0_ = self.backbone(x0_).flatten(start_dim=1) x0_ = self.projection_head(x0_) x0_ = self.prediction_head(x0_) x1_ = self.backbone_momentum(x1_).flatten(start_dim=1) x1_ = self.projection_head_momentum(x1_) return x0_, x1_ p0, z1 = step(x0, x1) p1, z0 = step(x1, x0) loss = self.criterion((z0, p0), (z1, p1)) self.log('train_loss_ssl', loss) return loss
dataloader = torch.utils.data.DataLoader( dataset, batch_size=256, collate_fn=collate_fn, shuffle=True, drop_last=True, num_workers=8, ) criterion = NTXentLoss(memory_bank_size=4096) optimizer = torch.optim.SGD(model.parameters(), lr=0.06) print("Starting Training") for epoch in range(10): total_loss = 0 for (x_query, x_key), _, _ in dataloader: update_momentum(model.backbone, model.backbone_momentum, m=0.99) update_momentum(model.projection_head, model.projection_head_momentum, m=0.99) x_query = x_query.to(device) x_key = x_key.to(device) query = model(x_query) key = model.forward_momentum(x_key) loss = criterion(query, key) total_loss += loss.detach() loss.backward() optimizer.step() optimizer.zero_grad() avg_loss = total_loss / len(dataloader) print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
num_workers=8, ) criterion = DINOLoss( output_dim=2048, warmup_teacher_temp_epochs=5, ) # move loss to correct device because it also contains parameters criterion = criterion.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) print("Starting Training") for epoch in range(10): total_loss = 0 for views, _, _ in dataloader: update_momentum(model.student_backbone, model.teacher_backbone, m=0.99) update_momentum(model.student_head, model.teacher_head, m=0.99) views = [view.to(device) for view in views] global_views = views[:2] teacher_out = [model.forward_teacher(view) for view in global_views] student_out = [model.forward(view) for view in views] loss = criterion(teacher_out, student_out, epoch=epoch) total_loss += loss.detach() loss.backward() optimizer.step() optimizer.zero_grad() avg_loss = total_loss / len(dataloader) print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")