def main(): conf = Conf() conf.suppress_random() device = conf.get_device() args = parse(conf) # ---- SAVER OLD NET TO RESTORE PARAMS saver_trinet = Saver( Path(args.trinet_folder).parent, Path(args.trinet_folder).name) old_params, old_hparams = saver_trinet.load_logs() args.backbone = old_params['backbone'] args.metric = old_params['metric'] train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ get_dataloaders(args.dataset_name, conf.nas_path, device, args) num_pids = train_loader.dataset.get_num_pids() assert num_pids == old_hparams['num_classes'] net = get_model(args, num_pids).to(device) state_dict = torch.load( Path(args.trinet_folder) / 'chk' / args.trinet_chk_name) net.load_state_dict(state_dict) e = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, device=device, data_conf=DATA_CONFS[args.dataset_name]) e.eval(None, 0, verbose=True, do_tb=False)
def main(): conf = Conf() args = parse(conf) device = conf.get_device() conf.suppress_random(set_determinism=args.set_determinism) saver = Saver(conf.log_path, args.exp_name) train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ get_dataloaders(args.dataset_name, conf.nas_path, device, args) num_pids = train_loader.dataset.get_num_pids() net = nn.DataParallel(get_model(args, num_pids)) net = net.to(device) saver.write_logs(net.module, vars(args)) opt = Adam(net.parameters(), lr=1e-4, weight_decay=args.wd) milestones = list( range(args.first_milestone, args.num_epochs, args.step_milestone)) scheduler = lr_scheduler.MultiStepLR(opt, milestones=milestones, gamma=args.gamma) triplet_loss = OnlineTripletLoss('soft', True, reduction='mean').to(device) class_loss = nn.CrossEntropyLoss(reduction='mean').to(device) print("EXP_NAME: ", args.exp_name) for e in range(args.num_epochs): if e % args.eval_epoch_interval == 0 and e > 0: ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, DATA_CONFS[args.dataset_name], device) ev.eval(saver, e, args.verbose) if e % args.save_epoch_interval == 0 and e > 0: saver.save_net(net.module, f'chk_{e // args.save_epoch_interval}') avm = AvgMeter(['triplet', 'class']) for it, (x, y, cams) in enumerate(train_loader): net.train() x, y = x.to(device), y.to(device) opt.zero_grad() embeddings, f_class = net(x, return_logits=True) triplet_loss_batch = triplet_loss(embeddings, y) class_loss_batch = class_loss(f_class, y) loss = triplet_loss_batch + class_loss_batch avm.add([triplet_loss_batch.item(), class_loss_batch.item()]) loss.backward() opt.step() if e % args.print_epoch_interval == 0: stats = avm() str_ = f"Epoch: {e}" for (l, m) in stats: str_ += f" - {l} {m:.2f}" saver.dump_metric_tb(m, e, 'losses', f"avg_{l}") saver.dump_metric_tb(opt.param_groups[0]['lr'], e, 'lr', 'lr') print(str_) scheduler.step() ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, DATA_CONFS[args.dataset_name], device) ev.eval(saver, e, args.verbose) saver.save_net(net.module, 'chk_end') saver.writer.close()
def main(): conf = Conf() conf.suppress_random() device = conf.get_device() args = parse(conf) dest_path = args.dest_path / (Path(args.net1).name + '__vs__' + Path(args.net2).name) dest_path.mkdir(exist_ok=True, parents=True) both_path = dest_path / 'both' both_path.mkdir(exist_ok=True, parents=True) net1_path = dest_path / Path(args.net1).name net1_path.mkdir(exist_ok=True, parents=True) net2_path = dest_path / Path(args.net2).name net2_path.mkdir(exist_ok=True, parents=True) orig_path = dest_path / 'orig' orig_path.mkdir(exist_ok=True, parents=True) # ---- Restore net net1 = Saver.load_net(args.net1, args.chk_net1, args.dataset_name).to(device) net2 = Saver.load_net(args.net2, args.chk_net2, args.dataset_name).to(device) net1.eval() net2.eval() train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ get_dataloaders(args.dataset_name, conf.nas_path, device, args) # register hooks hook_net_1, hook_net_2 = Hook(), Hook() net1.backbone.features_layers[4].register_forward_hook(hook_net_1) net2.backbone.features_layers[4].register_forward_hook(hook_net_2) dst_idx = 0 for idx_batch, (vids, *_) in enumerate(tqdm(galleryimg_loader, 'iterating..')): if idx_batch < len(galleryimg_loader) - 50: continue net1.zero_grad() net2.zero_grad() hook_net_1.reset() hook_net_2.reset() vids = vids.to(device) attn_1 = extract_grad_cam(net1, vids, device, hook_net_1) attn_2 = extract_grad_cam(net2, vids, device, hook_net_2) B, N_VIEWS = attn_1.shape[0], attn_1.shape[1] for idx_b in range(B): for idx_v in range(N_VIEWS): el_img = vids[idx_b, idx_v] el_attn_1 = attn_1[idx_b, idx_v] el_attn_2 = attn_2[idx_b, idx_v] el_img = el_img.cpu().numpy().transpose(1, 2, 0) el_attn_1 = el_attn_1.cpu().numpy() el_attn_2 = el_attn_2.cpu().numpy() mean, var = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] el_img = (el_img * var) + mean el_img = np.clip(el_img, 0, 1) el_attn_1 = cv2.blur(el_attn_1, (3, 3)) el_attn_1 = cv2.resize(el_attn_1, (el_img.shape[1], el_img.shape[0]), interpolation=cv2.INTER_CUBIC) el_attn_2 = cv2.blur(el_attn_2, (3, 3)) el_attn_2 = cv2.resize(el_attn_2, (el_img.shape[1], el_img.shape[0]), interpolation=cv2.INTER_CUBIC) save_img(el_img, el_attn_1, net1_path / f'{dst_idx}.png') save_img(el_img, el_attn_2, net2_path / f'{dst_idx}.png') save_img(el_img, None, orig_path / f'{dst_idx}.png') save_img(np.concatenate([el_img, el_img], 1), np.concatenate([el_attn_1, el_attn_2], 1), both_path / f'{dst_idx}.png') dst_idx += 1
for (l, m) in stats: str_ += f" - {l} {m:.2f}" self.saver.dump_metric_tb(m, self._epoch, 'losses', f"avg_{l}") self.saver.dump_metric_tb(opt.defaults['lr'], self._epoch, 'lr', 'lr') print(str_) self._epoch += 1 self._gen += 1 return student_net if __name__ == '__main__': conf = Conf() device = conf.get_device() args = parse(conf) conf.suppress_random(set_determinism=args.set_determinism) train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ get_dataloaders(args.dataset_name, conf.nas_path, device, args) teacher_net: TriNet = Saver.load_net(args.teacher, args.teacher_chk_name, args.dataset_name).to(device) student_net: TriNet = deepcopy(teacher_net) if args.student is None \ else Saver.load_net(args.student, args.student_chk_name, args.dataset_name) student_net = student_net.to(device) ev = Evaluator(student_net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader,