def eval(self, saver: Saver, iteration: int, verbose: bool, do_tb: bool = True): if self.perform_x2v: cmc_scores_i2v, mAP_i2v = self.evaluate_i2v(verbose=verbose) if do_tb: saver.dump_metric_tb(mAP_i2v, iteration, 'i2v', f'mAP') self.tb_cmc(saver, cmc_scores_i2v, iteration, 'i2v') cmc_scores_v2v, mAP_v2v = self.evaluate_v2v(verbose=verbose) if do_tb: saver.dump_metric_tb(mAP_v2v, iteration, 'v2v', f'mAP') self.tb_cmc(saver, cmc_scores_v2v, iteration, 'v2v') if self.perform_x2i: cmc_scores_i2i, mAP_i2i = self.evaluate_i2i(verbose=verbose) if do_tb: saver.dump_metric_tb(mAP_i2i, iteration, 'i2i', f'mAP') self.tb_cmc(saver, cmc_scores_i2i, iteration, 'i2i') cmc_scores_v2i, mAP_v2i = self.evaluate_v2i(verbose=verbose) if do_tb: saver.dump_metric_tb(mAP_v2i, iteration, 'v2i', f'mAP') self.tb_cmc(saver, cmc_scores_v2i, iteration, 'v2i')
def main(): conf = Conf() args = parse(conf) device = conf.get_device() conf.suppress_random(set_determinism=args.set_determinism) saver = Saver(conf.log_path, args.exp_name) train_loader, query_loader, gallery_loader, queryimg_loader, galleryimg_loader = \ get_dataloaders(args.dataset_name, conf.nas_path, device, args) num_pids = train_loader.dataset.get_num_pids() net = nn.DataParallel(get_model(args, num_pids)) net = net.to(device) saver.write_logs(net.module, vars(args)) opt = Adam(net.parameters(), lr=1e-4, weight_decay=args.wd) milestones = list( range(args.first_milestone, args.num_epochs, args.step_milestone)) scheduler = lr_scheduler.MultiStepLR(opt, milestones=milestones, gamma=args.gamma) triplet_loss = OnlineTripletLoss('soft', True, reduction='mean').to(device) class_loss = nn.CrossEntropyLoss(reduction='mean').to(device) print("EXP_NAME: ", args.exp_name) for e in range(args.num_epochs): if e % args.eval_epoch_interval == 0 and e > 0: ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, DATA_CONFS[args.dataset_name], device) ev.eval(saver, e, args.verbose) if e % args.save_epoch_interval == 0 and e > 0: saver.save_net(net.module, f'chk_{e // args.save_epoch_interval}') avm = AvgMeter(['triplet', 'class']) for it, (x, y, cams) in enumerate(train_loader): net.train() x, y = x.to(device), y.to(device) opt.zero_grad() embeddings, f_class = net(x, return_logits=True) triplet_loss_batch = triplet_loss(embeddings, y) class_loss_batch = class_loss(f_class, y) loss = triplet_loss_batch + class_loss_batch avm.add([triplet_loss_batch.item(), class_loss_batch.item()]) loss.backward() opt.step() if e % args.print_epoch_interval == 0: stats = avm() str_ = f"Epoch: {e}" for (l, m) in stats: str_ += f" - {l} {m:.2f}" saver.dump_metric_tb(m, e, 'losses', f"avg_{l}") saver.dump_metric_tb(opt.param_groups[0]['lr'], e, 'lr', 'lr') print(str_) scheduler.step() ev = Evaluator(net, query_loader, gallery_loader, queryimg_loader, galleryimg_loader, DATA_CONFS[args.dataset_name], device) ev.eval(saver, e, args.verbose) saver.save_net(net.module, 'chk_end') saver.writer.close()
def tb_cmc(saver: Saver, cmc_scores, it, method): for cmc_v in [0, 4, 9]: saver.dump_metric_tb(cmc_scores[cmc_v], it, f'{method}', f'cmc{cmc_v + 1}')