pre_image = torch.from_numpy(prev_frame[np.newaxis]) pre_image = pre_image.to(args.device, non_blocking=True) with torch.no_grad(): out = model(torch.cat([image, pre_image], axis=1)) res = [] for i, (head, x) in enumerate(zip(args.heads.keys(), out)): if head in ('hm', ): res.append(x.sigmoid_()) else: res.append(x) dets = generic_decode({k: res[i] for (i, k) in enumerate(args.heads)}, 10, args) for k in dets: dets[k] = dets[k].detach().cpu().numpy() if not tracker.init and len(dets) > 0: tracker.init_track(dets) elif len(dets) > 0: tracker.step(dets) with open(save_fp, "w") as f: for track in tracker.tracks: x1, y1, x2, y2 = args.down_ratio * track['bbox'] x1, x2 = x1 * s_w, x2 * s_w y1, y2 = s_h * y1, s_h * y2 score = track['score'] f.write("{} {} {} {} {} {}\n".format(score, track['tracking_id'], x1, y1, x2, y2))
def run_epoch(self, phase, epoch, data_loader, rank): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: model_with_loss.eval() torch.cuda.empty_cache() results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = { l: AverageMeter() for l in self.loss_stats if l in ('tot', 'hm', 'wh', 'tracking') } num_iters = len( data_loader ) if self.args.num_iters[phase] < 0 else self.args.num_iters[phase] bar = Bar('{}'.format("tracking"), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k in ('fpath', 'prev_fpath'): continue if type(batch[k]) != list: batch[k] = batch[k].to(self.args.device, non_blocking=True) else: for i in range(len(batch[k])): batch[k][i] = batch[k][i].to(self.args.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch) loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model_with_loss.parameters(), self.args.clip_value) self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]| '.format(epoch, iter_id, num_iters, phase=phase) for l in avg_loss_stats: avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['image'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_loss_stats[l].avg) Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if rank == 0 and phase == 'val' and self.args.write_mota_metrics and epoch in self.args.save_point: curr_name = None tracker = None for i in range(self.args.batch_size): try: fpath = batch['fpath'][i] except IndexError: break fpath = fpath.split('.')[0].split('/')[-1] name, num = fpath.split("_frame_") num = int(num) if num % self.args.val_select_frame != 0: continue if name != curr_name: curr_name = name tracker = Tracker(self.args) out = [x[i][None] for x in output] res = out dets = generic_decode( {k: res[t] for (t, k) in enumerate(self.args.heads)}, self.args.max_objs, self.args) for k in dets: dets[k] = dets[k].detach().cpu().numpy() if not tracker.init and len(dets) > 0: tracker.init_track(dets) elif len(dets) > 0: tracker.step(dets) with open(os.path.join(self.args.res_dir, fpath + '.txt'), "w") as f: for track in tracker.tracks: x1, y1, x2, y2 = track['bbox'] f.write("{} {} {} {} {} {}\n".format( track['score'], track['tracking_id'], x1, y1, x2, y2)) if rank == 0 and self.args.print_iter > 0: # If not using progress bar if iter_id % self.args.print_iter == 0: print('{}| {}'.format("tracking", Bar.suffix)) else: bar.next() del output, loss, loss_stats if rank == 0 and phase == 'val' and self.args.write_mota_metrics and epoch in self.args.save_point: self.compute_map(epoch) bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results