Example #1
0
def main():
    conf = Conf()
    args = parse(conf)
    device = conf.get_device()

    conf.suppress_random(set_determinism=args.set_determinism)
    saver = Saver(conf.log_path, args.exp_name)

    train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \
        get_dataloaders(args.dataset_name, conf.nas_path, device, args)

    num_pids = train_loader.dataset.get_num_pids()

    net = nn.DataParallel(get_model(args, num_pids))
    net = net.to(device)

    saver.write_logs(net.module, vars(args))

    opt = Adam(net.parameters(), lr=1e-4, weight_decay=args.wd)
    milestones = list(
        range(args.first_milestone, args.num_epochs, args.step_milestone))
    scheduler = lr_scheduler.MultiStepLR(opt,
                                         milestones=milestones,
                                         gamma=args.gamma)

    triplet_loss = OnlineTripletLoss('soft', True, reduction='mean').to(device)
    class_loss = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("EXP_NAME: ", args.exp_name)

    for e in range(args.num_epochs):

        if e % args.eval_epoch_interval == 0 and e > 0:
            ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader,
                           galleryimg_loader, DATA_CONFS[args.dataset_name],
                           device)
            ev.eval(saver, e, args.verbose)

        if e % args.save_epoch_interval == 0 and e > 0:
            saver.save_net(net.module, f'chk_{e // args.save_epoch_interval}')

        avm = AvgMeter(['triplet', 'class'])

        for it, (x, y, cams) in enumerate(train_loader):
            net.train()

            x, y = x.to(device), y.to(device)

            opt.zero_grad()
            embeddings, f_class = net(x, return_logits=True)

            triplet_loss_batch = triplet_loss(embeddings, y)
            class_loss_batch = class_loss(f_class, y)
            loss = triplet_loss_batch + class_loss_batch

            avm.add([triplet_loss_batch.item(), class_loss_batch.item()])

            loss.backward()
            opt.step()

        if e % args.print_epoch_interval == 0:
            stats = avm()
            str_ = f"Epoch: {e}"
            for (l, m) in stats:
                str_ += f" - {l} {m:.2f}"
                saver.dump_metric_tb(m, e, 'losses', f"avg_{l}")
            saver.dump_metric_tb(opt.param_groups[0]['lr'], e, 'lr', 'lr')
            print(str_)

        scheduler.step()

    ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader,
                   galleryimg_loader, DATA_CONFS[args.dataset_name], device)
    ev.eval(saver, e, args.verbose)

    saver.save_net(net.module, 'chk_end')
    saver.writer.close()
Example #2
0
    def __call__(self, teacher_net: TriNet, student_net: TriNet):

        opt = Adam(student_net.parameters(), lr=self.lr(self._gen), weight_decay=1e-5)

        milestones = list(range(self.args.first_milestone, self.args.num_epochs,
                                self.args.step_milestone))

        scheduler = lr_scheduler.MultiStepLR(opt, milestones=milestones, gamma=self.args.gamma)

        for e in range(self.args.num_epochs):

            if e % self.args.eval_epoch_interval == 0 and e > 0:
                self.evaluate(student_net)

            avm = AvgMeter(['kl', 'triplet', 'class', 'similarity', 'loss'])

            student_net.student_mode()
            teacher_net.teacher_mode()

            for x, y, cams in self.train_loader:

                x, y = x.to(self.device), y.to(self.device)
                x_ = torch.stack([x[i, torch.randperm(x.shape[1])] for i in range(x.shape[0])])

                x_teacher, x_student = x, x_[:, :self.args.num_student_images]

                with torch.no_grad():
                    teacher_emb, teacher_logits = teacher_net(x_teacher, return_logits=True)

                opt.zero_grad()

                student_emb, student_logits = student_net(x_student, return_logits=True)

                kl_div_batch = self.distill_loss(teacher_logits, student_logits)
                similarity_loss_batch = self.similarity_loss(teacher_emb, student_emb)
                triplet_loss_batch = self.triplet_loss(student_emb, y)
                class_loss_batch = self.class_loss(student_logits, y)

                loss = (triplet_loss_batch + class_loss_batch) + \
                       self.args.lambda_coeff * (similarity_loss_batch) + \
                       self.args.kl_coeff * (kl_div_batch)

                avm.add([kl_div_batch.item(), triplet_loss_batch.item(),
                         class_loss_batch.item(), similarity_loss_batch.item(),
                         loss.item()])

                loss.backward()
                opt.step()

            scheduler.step()

            if self._epoch % self.args.print_epoch_interval == 0:
                stats = avm()
                str_ = f"Epoch: {self._epoch}"
                for (l, m) in stats:
                    str_ += f" - {l} {m:.2f}"
                    self.saver.dump_metric_tb(m, self._epoch, 'losses', f"avg_{l}")
                self.saver.dump_metric_tb(opt.defaults['lr'], self._epoch, 'lr', 'lr')
                print(str_)

            self._epoch += 1

        self._gen += 1

        return student_net