def main(): conf = Conf() conf.suppress_random() device = conf.get_device() args = parse(conf) dest_path = args.dest_path / (Path(args.net1).name + '__vs__' + Path(args.net2).name) dest_path.mkdir(exist_ok=True, parents=True) both_path = dest_path / 'both' both_path.mkdir(exist_ok=True, parents=True) net1_path = dest_path / Path(args.net1).name net1_path.mkdir(exist_ok=True, parents=True) net2_path = dest_path / Path(args.net2).name net2_path.mkdir(exist_ok=True, parents=True) orig_path = dest_path / 'orig' orig_path.mkdir(exist_ok=True, parents=True) # ---- Restore net net1 = Saver.load_net(args.net1, args.chk_net1, args.dataset_name).to(device) net2 = Saver.load_net(args.net2, args.chk_net2, args.dataset_name).to(device) net1.eval() net2.eval() train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ get_dataloaders(args.dataset_name, conf.nas_path, device, args) # register hooks hook_net_1, hook_net_2 = Hook(), Hook() net1.backbone.features_layers[4].register_forward_hook(hook_net_1) net2.backbone.features_layers[4].register_forward_hook(hook_net_2) dst_idx = 0 for idx_batch, (vids, *_) in enumerate(tqdm(galleryimg_loader, 'iterating..')): if idx_batch < len(galleryimg_loader) - 50: continue net1.zero_grad() net2.zero_grad() hook_net_1.reset() hook_net_2.reset() vids = vids.to(device) attn_1 = extract_grad_cam(net1, vids, device, hook_net_1) attn_2 = extract_grad_cam(net2, vids, device, hook_net_2) B, N_VIEWS = attn_1.shape[0], attn_1.shape[1] for idx_b in range(B): for idx_v in range(N_VIEWS): el_img = vids[idx_b, idx_v] el_attn_1 = attn_1[idx_b, idx_v] el_attn_2 = attn_2[idx_b, idx_v] el_img = el_img.cpu().numpy().transpose(1, 2, 0) el_attn_1 = el_attn_1.cpu().numpy() el_attn_2 = el_attn_2.cpu().numpy() mean, var = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] el_img = (el_img * var) + mean el_img = np.clip(el_img, 0, 1) el_attn_1 = cv2.blur(el_attn_1, (3, 3)) el_attn_1 = cv2.resize(el_attn_1, (el_img.shape[1], el_img.shape[0]), interpolation=cv2.INTER_CUBIC) el_attn_2 = cv2.blur(el_attn_2, (3, 3)) el_attn_2 = cv2.resize(el_attn_2, (el_img.shape[1], el_img.shape[0]), interpolation=cv2.INTER_CUBIC) save_img(el_img, el_attn_1, net1_path / f'{dst_idx}.png') save_img(el_img, el_attn_2, net2_path / f'{dst_idx}.png') save_img(el_img, None, orig_path / f'{dst_idx}.png') save_img(np.concatenate([el_img, el_img], 1), np.concatenate([el_attn_1, el_attn_2], 1), both_path / f'{dst_idx}.png') dst_idx += 1
self._gen += 1 return student_net if __name__ == '__main__': conf = Conf() device = conf.get_device() args = parse(conf) conf.suppress_random(set_determinism=args.set_determinism) train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ get_dataloaders(args.dataset_name, conf.nas_path, device, args) teacher_net: TriNet = Saver.load_net(args.teacher, args.teacher_chk_name, args.dataset_name).to(device) student_net: TriNet = deepcopy(teacher_net) if args.student is None \ else Saver.load_net(args.student, args.student_chk_name, args.dataset_name) student_net = student_net.to(device) ev = Evaluator(student_net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, DATA_CONFS[args.dataset_name], device) print('v' * 100) ev.eval(saver=None, iteration=None, verbose=True, do_tb=False) print('v' * 100) student_net.reinit_layers(args.reinit_l4, args.reinit_l3) saver = Saver(conf.log_path, args.exp_name)