Esempio n. 1
0
File: eval.py Progetto: zyg11/VKD
def main():
    conf = Conf()
    conf.suppress_random()
    device = conf.get_device()

    args = parse(conf)

    # ---- SAVER OLD NET TO RESTORE PARAMS
    saver_trinet = Saver(
        Path(args.trinet_folder).parent,
        Path(args.trinet_folder).name)
    old_params, old_hparams = saver_trinet.load_logs()
    args.backbone = old_params['backbone']
    args.metric = old_params['metric']

    train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \
        get_dataloaders(args.dataset_name, conf.nas_path, device, args)
    num_pids = train_loader.dataset.get_num_pids()

    assert num_pids == old_hparams['num_classes']

    net = get_model(args, num_pids).to(device)
    state_dict = torch.load(
        Path(args.trinet_folder) / 'chk' / args.trinet_chk_name)
    net.load_state_dict(state_dict)

    e = Evaluator(net,
                  query_loader,
                  gallery_loader,
                  queryimg_loader,
                  galleryimg_loader,
                  device=device,
                  data_conf=DATA_CONFS[args.dataset_name])

    e.eval(None, 0, verbose=True, do_tb=False)
Esempio n. 2
0
def main():
    conf = Conf()
    args = parse(conf)
    device = conf.get_device()

    conf.suppress_random(set_determinism=args.set_determinism)
    saver = Saver(conf.log_path, args.exp_name)

    train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \
        get_dataloaders(args.dataset_name, conf.nas_path, device, args)

    num_pids = train_loader.dataset.get_num_pids()

    net = nn.DataParallel(get_model(args, num_pids))
    net = net.to(device)

    saver.write_logs(net.module, vars(args))

    opt = Adam(net.parameters(), lr=1e-4, weight_decay=args.wd)
    milestones = list(
        range(args.first_milestone, args.num_epochs, args.step_milestone))
    scheduler = lr_scheduler.MultiStepLR(opt,
                                         milestones=milestones,
                                         gamma=args.gamma)

    triplet_loss = OnlineTripletLoss('soft', True, reduction='mean').to(device)
    class_loss = nn.CrossEntropyLoss(reduction='mean').to(device)

    print("EXP_NAME: ", args.exp_name)

    for e in range(args.num_epochs):

        if e % args.eval_epoch_interval == 0 and e > 0:
            ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader,
                           galleryimg_loader, DATA_CONFS[args.dataset_name],
                           device)
            ev.eval(saver, e, args.verbose)

        if e % args.save_epoch_interval == 0 and e > 0:
            saver.save_net(net.module, f'chk_{e // args.save_epoch_interval}')

        avm = AvgMeter(['triplet', 'class'])

        for it, (x, y, cams) in enumerate(train_loader):
            net.train()

            x, y = x.to(device), y.to(device)

            opt.zero_grad()
            embeddings, f_class = net(x, return_logits=True)

            triplet_loss_batch = triplet_loss(embeddings, y)
            class_loss_batch = class_loss(f_class, y)
            loss = triplet_loss_batch + class_loss_batch

            avm.add([triplet_loss_batch.item(), class_loss_batch.item()])

            loss.backward()
            opt.step()

        if e % args.print_epoch_interval == 0:
            stats = avm()
            str_ = f"Epoch: {e}"
            for (l, m) in stats:
                str_ += f" - {l} {m:.2f}"
                saver.dump_metric_tb(m, e, 'losses', f"avg_{l}")
            saver.dump_metric_tb(opt.param_groups[0]['lr'], e, 'lr', 'lr')
            print(str_)

        scheduler.step()

    ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader,
                   galleryimg_loader, DATA_CONFS[args.dataset_name], device)
    ev.eval(saver, e, args.verbose)

    saver.save_net(net.module, 'chk_end')
    saver.writer.close()
Esempio n. 3
0
def main():
    conf = Conf()
    conf.suppress_random()
    device = conf.get_device()

    args = parse(conf)

    dest_path = args.dest_path / (Path(args.net1).name + '__vs__' +
                                  Path(args.net2).name)
    dest_path.mkdir(exist_ok=True, parents=True)

    both_path = dest_path / 'both'
    both_path.mkdir(exist_ok=True, parents=True)

    net1_path = dest_path / Path(args.net1).name
    net1_path.mkdir(exist_ok=True, parents=True)

    net2_path = dest_path / Path(args.net2).name
    net2_path.mkdir(exist_ok=True, parents=True)

    orig_path = dest_path / 'orig'
    orig_path.mkdir(exist_ok=True, parents=True)

    # ---- Restore net
    net1 = Saver.load_net(args.net1, args.chk_net1,
                          args.dataset_name).to(device)
    net2 = Saver.load_net(args.net2, args.chk_net2,
                          args.dataset_name).to(device)

    net1.eval()
    net2.eval()

    train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \
        get_dataloaders(args.dataset_name, conf.nas_path, device, args)

    # register hooks
    hook_net_1, hook_net_2 = Hook(), Hook()

    net1.backbone.features_layers[4].register_forward_hook(hook_net_1)
    net2.backbone.features_layers[4].register_forward_hook(hook_net_2)

    dst_idx = 0

    for idx_batch, (vids,
                    *_) in enumerate(tqdm(galleryimg_loader, 'iterating..')):
        if idx_batch < len(galleryimg_loader) - 50:
            continue
        net1.zero_grad()
        net2.zero_grad()

        hook_net_1.reset()
        hook_net_2.reset()

        vids = vids.to(device)
        attn_1 = extract_grad_cam(net1, vids, device, hook_net_1)
        attn_2 = extract_grad_cam(net2, vids, device, hook_net_2)

        B, N_VIEWS = attn_1.shape[0], attn_1.shape[1]

        for idx_b in range(B):
            for idx_v in range(N_VIEWS):

                el_img = vids[idx_b, idx_v]
                el_attn_1 = attn_1[idx_b, idx_v]
                el_attn_2 = attn_2[idx_b, idx_v]

                el_img = el_img.cpu().numpy().transpose(1, 2, 0)
                el_attn_1 = el_attn_1.cpu().numpy()
                el_attn_2 = el_attn_2.cpu().numpy()

                mean, var = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
                el_img = (el_img * var) + mean
                el_img = np.clip(el_img, 0, 1)

                el_attn_1 = cv2.blur(el_attn_1, (3, 3))
                el_attn_1 = cv2.resize(el_attn_1,
                                       (el_img.shape[1], el_img.shape[0]),
                                       interpolation=cv2.INTER_CUBIC)

                el_attn_2 = cv2.blur(el_attn_2, (3, 3))
                el_attn_2 = cv2.resize(el_attn_2,
                                       (el_img.shape[1], el_img.shape[0]),
                                       interpolation=cv2.INTER_CUBIC)

                save_img(el_img, el_attn_1, net1_path / f'{dst_idx}.png')
                save_img(el_img, el_attn_2, net2_path / f'{dst_idx}.png')

                save_img(el_img, None, orig_path / f'{dst_idx}.png')

                save_img(np.concatenate([el_img, el_img], 1),
                         np.concatenate([el_attn_1, el_attn_2], 1),
                         both_path / f'{dst_idx}.png')

                dst_idx += 1
Esempio n. 4
0
                for (l, m) in stats:
                    str_ += f" - {l} {m:.2f}"
                    self.saver.dump_metric_tb(m, self._epoch, 'losses', f"avg_{l}")
                self.saver.dump_metric_tb(opt.defaults['lr'], self._epoch, 'lr', 'lr')
                print(str_)

            self._epoch += 1

        self._gen += 1

        return student_net


if __name__ == '__main__':
    conf = Conf()
    device = conf.get_device()
    args = parse(conf)

    conf.suppress_random(set_determinism=args.set_determinism)

    train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \
        get_dataloaders(args.dataset_name, conf.nas_path, device, args)

    teacher_net: TriNet = Saver.load_net(args.teacher,
                                         args.teacher_chk_name, args.dataset_name).to(device)

    student_net: TriNet = deepcopy(teacher_net) if args.student is None \
        else Saver.load_net(args.student, args.student_chk_name, args.dataset_name)
    student_net = student_net.to(device)

    ev = Evaluator(student_net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader,