def main(): conf = Conf() args = parse(conf) device = conf.get_device() conf.suppress_random(set_determinism=args.set_determinism) saver = Saver(conf.log_path, args.exp_name) train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ get_dataloaders(args.dataset_name, conf.nas_path, device, args) num_pids = train_loader.dataset.get_num_pids() net = nn.DataParallel(get_model(args, num_pids)) net = net.to(device) saver.write_logs(net.module, vars(args)) opt = Adam(net.parameters(), lr=1e-4, weight_decay=args.wd) milestones = list( range(args.first_milestone, args.num_epochs, args.step_milestone)) scheduler = lr_scheduler.MultiStepLR(opt, milestones=milestones, gamma=args.gamma) triplet_loss = OnlineTripletLoss('soft', True, reduction='mean').to(device) class_loss = nn.CrossEntropyLoss(reduction='mean').to(device) print("EXP_NAME: ", args.exp_name) for e in range(args.num_epochs): if e % args.eval_epoch_interval == 0 and e > 0: ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, DATA_CONFS[args.dataset_name], device) ev.eval(saver, e, args.verbose) if e % args.save_epoch_interval == 0 and e > 0: saver.save_net(net.module, f'chk_{e // args.save_epoch_interval}') avm = AvgMeter(['triplet', 'class']) for it, (x, y, cams) in enumerate(train_loader): net.train() x, y = x.to(device), y.to(device) opt.zero_grad() embeddings, f_class = net(x, return_logits=True) triplet_loss_batch = triplet_loss(embeddings, y) class_loss_batch = class_loss(f_class, y) loss = triplet_loss_batch + class_loss_batch avm.add([triplet_loss_batch.item(), class_loss_batch.item()]) loss.backward() opt.step() if e % args.print_epoch_interval == 0: stats = avm() str_ = f"Epoch: {e}" for (l, m) in stats: str_ += f" - {l} {m:.2f}" saver.dump_metric_tb(m, e, 'losses', f"avg_{l}") saver.dump_metric_tb(opt.param_groups[0]['lr'], e, 'lr', 'lr') print(str_) scheduler.step() ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, DATA_CONFS[args.dataset_name], device) ev.eval(saver, e, args.verbose) saver.save_net(net.module, 'chk_end') saver.writer.close()
def __call__(self, teacher_net: TriNet, student_net: TriNet): opt = Adam(student_net.parameters(), lr=self.lr(self._gen), weight_decay=1e-5) milestones = list(range(self.args.first_milestone, self.args.num_epochs, self.args.step_milestone)) scheduler = lr_scheduler.MultiStepLR(opt, milestones=milestones, gamma=self.args.gamma) for e in range(self.args.num_epochs): if e % self.args.eval_epoch_interval == 0 and e > 0: self.evaluate(student_net) avm = AvgMeter(['kl', 'triplet', 'class', 'similarity', 'loss']) student_net.student_mode() teacher_net.teacher_mode() for x, y, cams in self.train_loader: x, y = x.to(self.device), y.to(self.device) x_ = torch.stack([x[i, torch.randperm(x.shape[1])] for i in range(x.shape[0])]) x_teacher, x_student = x, x_[:, :self.args.num_student_images] with torch.no_grad(): teacher_emb, teacher_logits = teacher_net(x_teacher, return_logits=True) opt.zero_grad() student_emb, student_logits = student_net(x_student, return_logits=True) kl_div_batch = self.distill_loss(teacher_logits, student_logits) similarity_loss_batch = self.similarity_loss(teacher_emb, student_emb) triplet_loss_batch = self.triplet_loss(student_emb, y) class_loss_batch = self.class_loss(student_logits, y) loss = (triplet_loss_batch + class_loss_batch) + \ self.args.lambda_coeff * (similarity_loss_batch) + \ self.args.kl_coeff * (kl_div_batch) avm.add([kl_div_batch.item(), triplet_loss_batch.item(), class_loss_batch.item(), similarity_loss_batch.item(), loss.item()]) loss.backward() opt.step() scheduler.step() if self._epoch % self.args.print_epoch_interval == 0: stats = avm() str_ = f"Epoch: {self._epoch}" for (l, m) in stats: str_ += f" - {l} {m:.2f}" self.saver.dump_metric_tb(m, self._epoch, 'losses', f"avg_{l}") self.saver.dump_metric_tb(opt.defaults['lr'], self._epoch, 'lr', 'lr') print(str_) self._epoch += 1 self._gen += 1 return student_net