示例#1
0
文件: eval.py 项目: zyg11/VKD
    def eval(self,
             saver: Saver,
             iteration: int,
             verbose: bool,
             do_tb: bool = True):

        if self.perform_x2v:
            cmc_scores_i2v, mAP_i2v = self.evaluate_i2v(verbose=verbose)
            if do_tb:
                saver.dump_metric_tb(mAP_i2v, iteration, 'i2v', f'mAP')
                self.tb_cmc(saver, cmc_scores_i2v, iteration, 'i2v')

            cmc_scores_v2v, mAP_v2v = self.evaluate_v2v(verbose=verbose)
            if do_tb:
                saver.dump_metric_tb(mAP_v2v, iteration, 'v2v', f'mAP')
                self.tb_cmc(saver, cmc_scores_v2v, iteration, 'v2v')

        if self.perform_x2i:
            cmc_scores_i2i, mAP_i2i = self.evaluate_i2i(verbose=verbose)
            if do_tb:
                saver.dump_metric_tb(mAP_i2i, iteration, 'i2i', f'mAP')
                self.tb_cmc(saver, cmc_scores_i2i, iteration, 'i2i')

            cmc_scores_v2i, mAP_v2i = self.evaluate_v2i(verbose=verbose)
            if do_tb:
                saver.dump_metric_tb(mAP_v2i, iteration, 'v2i', f'mAP')
                self.tb_cmc(saver, cmc_scores_v2i, iteration, 'v2i')
示例#2
0
文件: train_v2v.py 项目: zyg11/VKD
def main():
    conf = Conf()
    args = parse(conf)
    device = conf.get_device()

    conf.suppress_random(set_determinism=args.set_determinism)
    saver = Saver(conf.log_path, args.exp_name)

    train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \
        get_dataloaders(args.dataset_name, conf.nas_path, device, args)

    num_pids = train_loader.dataset.get_num_pids()

    net = nn.DataParallel(get_model(args, num_pids))
    net = net.to(device)

    saver.write_logs(net.module, vars(args))

    opt = Adam(net.parameters(), lr=1e-4, weight_decay=args.wd)
    milestones = list(
        range(args.first_milestone, args.num_epochs, args.step_milestone))
    scheduler = lr_scheduler.MultiStepLR(opt,
                                         milestones=milestones,
                                         gamma=args.gamma)

    triplet_loss = OnlineTripletLoss('soft', True, reduction='mean').to(device)
    class_loss = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("EXP_NAME: ", args.exp_name)

    for e in range(args.num_epochs):

        if e % args.eval_epoch_interval == 0 and e > 0:
            ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader,
                           galleryimg_loader, DATA_CONFS[args.dataset_name],
                           device)
            ev.eval(saver, e, args.verbose)

        if e % args.save_epoch_interval == 0 and e > 0:
            saver.save_net(net.module, f'chk_{e // args.save_epoch_interval}')

        avm = AvgMeter(['triplet', 'class'])

        for it, (x, y, cams) in enumerate(train_loader):
            net.train()

            x, y = x.to(device), y.to(device)

            opt.zero_grad()
            embeddings, f_class = net(x, return_logits=True)

            triplet_loss_batch = triplet_loss(embeddings, y)
            class_loss_batch = class_loss(f_class, y)
            loss = triplet_loss_batch + class_loss_batch

            avm.add([triplet_loss_batch.item(), class_loss_batch.item()])

            loss.backward()
            opt.step()

        if e % args.print_epoch_interval == 0:
            stats = avm()
            str_ = f"Epoch: {e}"
            for (l, m) in stats:
                str_ += f" - {l} {m:.2f}"
                saver.dump_metric_tb(m, e, 'losses', f"avg_{l}")
            saver.dump_metric_tb(opt.param_groups[0]['lr'], e, 'lr', 'lr')
            print(str_)

        scheduler.step()

    ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader,
                   galleryimg_loader, DATA_CONFS[args.dataset_name], device)
    ev.eval(saver, e, args.verbose)

    saver.save_net(net.module, 'chk_end')
    saver.writer.close()
示例#3
0
文件: eval.py 项目: zyg11/VKD
 def tb_cmc(saver: Saver, cmc_scores, it, method):
     for cmc_v in [0, 4, 9]:
         saver.dump_metric_tb(cmc_scores[cmc_v], it, f'{method}',
                              f'cmc{cmc_v + 1}')