Example #1
0
def main():
    conf = Conf()
    conf.suppress_random()
    device = conf.get_device()

    args = parse(conf)

    dest_path = args.dest_path / (Path(args.net1).name + '__vs__' +
                                  Path(args.net2).name)
    dest_path.mkdir(exist_ok=True, parents=True)

    both_path = dest_path / 'both'
    both_path.mkdir(exist_ok=True, parents=True)

    net1_path = dest_path / Path(args.net1).name
    net1_path.mkdir(exist_ok=True, parents=True)

    net2_path = dest_path / Path(args.net2).name
    net2_path.mkdir(exist_ok=True, parents=True)

    orig_path = dest_path / 'orig'
    orig_path.mkdir(exist_ok=True, parents=True)

    # ---- Restore net
    net1 = Saver.load_net(args.net1, args.chk_net1,
                          args.dataset_name).to(device)
    net2 = Saver.load_net(args.net2, args.chk_net2,
                          args.dataset_name).to(device)

    net1.eval()
    net2.eval()

    train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \
        get_dataloaders(args.dataset_name, conf.nas_path, device, args)

    # register hooks
    hook_net_1, hook_net_2 = Hook(), Hook()

    net1.backbone.features_layers[4].register_forward_hook(hook_net_1)
    net2.backbone.features_layers[4].register_forward_hook(hook_net_2)

    dst_idx = 0

    for idx_batch, (vids,
                    *_) in enumerate(tqdm(galleryimg_loader, 'iterating..')):
        if idx_batch < len(galleryimg_loader) - 50:
            continue
        net1.zero_grad()
        net2.zero_grad()

        hook_net_1.reset()
        hook_net_2.reset()

        vids = vids.to(device)
        attn_1 = extract_grad_cam(net1, vids, device, hook_net_1)
        attn_2 = extract_grad_cam(net2, vids, device, hook_net_2)

        B, N_VIEWS = attn_1.shape[0], attn_1.shape[1]

        for idx_b in range(B):
            for idx_v in range(N_VIEWS):

                el_img = vids[idx_b, idx_v]
                el_attn_1 = attn_1[idx_b, idx_v]
                el_attn_2 = attn_2[idx_b, idx_v]

                el_img = el_img.cpu().numpy().transpose(1, 2, 0)
                el_attn_1 = el_attn_1.cpu().numpy()
                el_attn_2 = el_attn_2.cpu().numpy()

                mean, var = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
                el_img = (el_img * var) + mean
                el_img = np.clip(el_img, 0, 1)

                el_attn_1 = cv2.blur(el_attn_1, (3, 3))
                el_attn_1 = cv2.resize(el_attn_1,
                                       (el_img.shape[1], el_img.shape[0]),
                                       interpolation=cv2.INTER_CUBIC)

                el_attn_2 = cv2.blur(el_attn_2, (3, 3))
                el_attn_2 = cv2.resize(el_attn_2,
                                       (el_img.shape[1], el_img.shape[0]),
                                       interpolation=cv2.INTER_CUBIC)

                save_img(el_img, el_attn_1, net1_path / f'{dst_idx}.png')
                save_img(el_img, el_attn_2, net2_path / f'{dst_idx}.png')

                save_img(el_img, None, orig_path / f'{dst_idx}.png')

                save_img(np.concatenate([el_img, el_img], 1),
                         np.concatenate([el_attn_1, el_attn_2], 1),
                         both_path / f'{dst_idx}.png')

                dst_idx += 1
Example #2
0
        self._gen += 1

        return student_net


if __name__ == '__main__':
    conf = Conf()
    device = conf.get_device()
    args = parse(conf)

    conf.suppress_random(set_determinism=args.set_determinism)

    train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \
        get_dataloaders(args.dataset_name, conf.nas_path, device, args)

    teacher_net: TriNet = Saver.load_net(args.teacher,
                                         args.teacher_chk_name, args.dataset_name).to(device)

    student_net: TriNet = deepcopy(teacher_net) if args.student is None \
        else Saver.load_net(args.student, args.student_chk_name, args.dataset_name)
    student_net = student_net.to(device)

    ev = Evaluator(student_net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader,
                   DATA_CONFS[args.dataset_name], device)

    print('v' * 100)
    ev.eval(saver=None, iteration=None, verbose=True, do_tb=False)
    print('v' * 100)

    student_net.reinit_layers(args.reinit_l4, args.reinit_l3)

    saver = Saver(conf.log_path, args.exp_name)