def prep_eval_gt_tracks_worker(args): model, name, scene, logdir, split_on_no_edge = args ofn = os.path.join(logdir, "tracks", os.path.basename(name)) graph, detection_weight_features, connection_batch = torch.load( name + '-%s-eval_graph' % model.feature_name) promote_graph(graph) tracks, gt_graph_frames = ground_truth_tracks(scene.ground_truth(), graph, add_gt_class=True) if split_on_no_edge: tracks = split_track_on_missing_edge(tracks) len_stats = defaultdict(int) for tr in tracks: for det in tr: if hasattr(det, 'cls'): cls = det.cls gt_cls = det.gt_cls if hasattr(det, 'gt_cls') else None det.__dict__ = {} det.cls = cls if gt_cls is not None: det.gt_cls = gt_cls else: det.__dict__ = {} len_stats[len(tr)] += 1 save_pickle(tracks, ofn) save_pickle({'track_length': len_stats}, ofn + '-stats') return ofn
def __getitem__(self, item): name, cam = self.entries[item] scene = self.dataset.scene(cam) graph, detection_weight_features, connection_batch = torch.load( name + self.suffix) promote_graph(graph) gt_tracks, gt_graph_frames = ground_truth_tracks( scene.ground_truth(), graph) gt_tracks = split_track_on_missing_edge(gt_tracks) for det in graph: det.gt_entry = 0.0 det.gt_present = 0.0 if det.track_id is None else 1.0 det.gt_next = [0.0] * len(det.next) for tr in gt_tracks: tr[0].gt_entry = 1.0 prv = None for det in tr: if prv is not None: prv.gt_next[prv.next.index(det)] = 1.0 prv = det demote_graph(graph) return graph, detection_weight_features, connection_batch
def worker(work): graph, connection_weights, detection_weights, entry_weight, connection_batch, detection_weight_features = work promote_graph(graph) lp_track_weights(graph, connection_weights, detection_weights, entry_weight, add_gt_hamming=True) demote_graph(graph) return graph, connection_batch, detection_weight_features
def test_demote_promote(self): g1 = load_pickle(os.path.join(mydir, "data", "promoted_graph.pck")) g2 = deepcopy(g1) demote_graph(g1) promote_graph(g1) for i in range(len(g1)): assert g1[i].prev == g2[i].prev assert g1[i].next_weight_data == g2[i].next_weight_data assert g1[i].prev.__class__ == g2[i].prev.__class__ assert g1[i].next_weight_data.__class__ == g2[ i].next_weight_data.__class__
def main(): dataset = Duke('data', cachedir="cachedir") #_mc5") model = NNModelGraphresPerConnection() logdir = dataset.logdir print(logdir) fn = sorted(glob("%s/snapshot_???.pyt" % logdir))[-1] model.load_state_dict(torch.load(fn)['model_state']) model.eval() gt_not_in_graph = long_connections = long_connections_within_bound = 0 for name, cam in tqdm(graph_names(dataset, "eval"), "Estimating long structure"): name = name.replace("/lunarc/nobackup/projects/lu-haar/ggdtrack/", "") graph, detection_weight_features, connection_batch = torch.load( name + '-%s-eval_graph' % model.feature_name) promote_graph(graph) connection_weights = model.connection_batch_forward(connection_batch) detection_weights = model.detection_model(detection_weight_features) scene = dataset.scene(cam) gt_tracks, gt_graph_frames = ground_truth_tracks( scene.ground_truth(), graph) for tr in gt_tracks: prv = tr[0] for det in tr[1:]: prv.gt_next = det prv = det for det in graph: for i, nxt in zip(det.weight_index, det.next): if det.track_id == nxt.track_id != None and nxt.frame - det.frame > 1: long_connections += 1 upper = get_upper_bound_from_gt(det, nxt, connection_weights, detection_weights) if upper is None: gt_not_in_graph += 1 elif 0 < connection_weights[i] < upper: long_connections_within_bound += 1 # print (" %s -[%4.2f]-> %s" % (det.track_id, connection_weights[i], nxt.track_id), # det.frame, nxt.frame, upper) # tracks = lp_track(graph, connection_batch, detection_weight_features, model) # print(tracks) print() print(gt_not_in_graph, long_connections, long_connections_within_bound) print(long_connections_within_bound / (long_connections - gt_not_in_graph)) print()
def eval_hamming_worker(work): graph, connection_weights, detection_weights, entry_weight = work promote_graph(graph) lp_track_weights(graph, connection_weights, detection_weights, entry_weight, add_gt_hamming=True) hamming = 0 for det in graph: assert len(det.next) == len(det.outgoing) hamming += det.present.value != det.gt_present hamming += det.entry.value != det.gt_entry hamming += sum(v.value != gt for v, gt in zip(det.outgoing, det.gt_next)) return hamming
def prep_eval_tracks_worker(args): model, name, device, logdir = args ofn = os.path.join(logdir, "tracks", os.path.basename(name)) graph, detection_weight_features, connection_batch = torch.load( name + '-%s-eval_graph' % model.feature_name) promote_graph(graph) detection_weight_features = detection_weight_features.to(device) connection_batch = connection_batch.to(device) tracks = lp_track(graph, connection_batch, detection_weight_features, model) for tr in tracks: for det in tr: if hasattr(det, 'cls'): cls = det.cls det.__dict__ = {} det.cls = cls else: det.__dict__ = {} save_pickle(tracks, ofn) return ofn
def train_frossard(dataset, logdir, model, mean_from=None, device=default_torch_device, limit=None, epochs=1000, resume_from=None, save_every=None): if mean_from is None and resume_from is None: raise NotImplementedError model.to(device) optimizer = optim.Adam(model.parameters(), 1e-3) if resume_from is not None: if os.path.isdir(resume_from): fn = sorted(glob("%s/snapshot_???.pyt" % (resume_from)))[-1] else: fn = resume_from print("Resuming from", fn) snapshot = torch.load(fn) if isinstance(snapshot, dict) and 'model_state' in snapshot: model.load_state_dict(snapshot['model_state']) # optimizer.load_state_dict(snapshot['optimizer_state']) start_epoch = snapshot['epoch'] + 1 else: model.load_state_dict(snapshot) start_epoch = 0 model.to(device) else: start_epoch = 0 for t in model.parameters(): torch.nn.init.normal_(t, 0, 1e-3) mean_model = model.__class__() mean_model.load_state_dict(torch.load(mean_from)['model_state']) model.detection_model.mean = mean_model.detection_model.mean model.detection_model.std = mean_model.detection_model.std model.edge_model.klt_model.mean = mean_model.edge_model.klt_model.mean model.edge_model.klt_model.std = mean_model.edge_model.klt_model.std model.edge_model.long_model.mean = mean_model.edge_model.long_model.mean model.edge_model.long_model.std = mean_model.edge_model.long_model.std model.to(device) if logdir != resume_from: if os.path.exists(logdir): rmtree(logdir) os.makedirs(logdir) entries = graph_names(dataset, "train") if limit is not None: shuffle(entries) entries = entries[:limit] threads = multiprocessing.cpu_count() - 2 train_data = EvalGtGraphs(dataset, entries, '-%s-eval_graph' % model.feature_name) train_loader = DataLoader(train_data, 1, True, collate_fn=single_example_passthrough, num_workers=threads) def worker(work): graph, connection_weights, detection_weights, entry_weight, connection_batch, detection_weight_features = work promote_graph(graph) lp_track_weights(graph, connection_weights, detection_weights, entry_weight, add_gt_hamming=True) demote_graph(graph) return graph, connection_batch, detection_weight_features # lp_tracker_pool = WorkerPool(threads, worker) lp_tracker_pool = None save_count = 0 last_save = time.time() epoch_hamming_distance = batch_count = 0 epoch = start_epoch while True: # shuffle(entries) for graph, detection_weight_features, connection_batch in train_loader: # for graph, detection_weight_features, connection_batch in train_data: model.eval() connection_weights = model.connection_batch_forward( connection_batch.to(device)) detection_weights = model.detection_model( detection_weight_features.to(device)) if lp_tracker_pool is not None: lp_tracker_pool.put( (graph, connection_weights.detach().cpu(), detection_weights.detach().cpu(), model.entry_weight, connection_batch, detection_weight_features)) else: promote_graph(graph) lp_track_weights(graph, connection_weights, detection_weights, model.entry_weight, add_gt_hamming=True) while True: if lp_tracker_pool is not None: try: graph, connection_batch, detection_weight_features = lp_tracker_pool.get( block=False) except Empty: break promote_graph(graph) if not graph: continue if batch_count >= len(train_data): snapp = { 'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'train_hamming': epoch_hamming_distance, } torch.save( snapp, os.path.join(logdir, "snapshot_%.3d.pyt" % epoch)) print('%3d Hamming:' % epoch, epoch_hamming_distance) epoch_hamming_distance = batch_count = 0 epoch += 1 if epoch >= start_epoch + epochs: return continue batch_count += 1 if save_every and time.time() - last_save > save_every: last_save = time.time() torch.save( model.state_dict(), os.path.join(logdir, "model_%.4d.pyt" % save_count)) save_count += 1 # interpolate_missing_detections(tracks) # show_tracks(scene, tracks, gt_graph_frames) model.train() optimizer.zero_grad() hamming_distance_present = hamming_distance_entry = hamming_distance_connect = loss = 0 connection_weights = model.connection_batch_forward( connection_batch.to(device)) detection_weights = model.detection_model( detection_weight_features.to(device)) for det in graph: assert len(det.next) == len(det.outgoing) loss += (det.present.value - det.gt_present) * detection_weights[det.index] loss += (det.entry.value - det.gt_entry) * model.entry_weight_parameter loss += sum( connection_weights[i] * (v.value - gt) for v, i, gt in zip(det.outgoing, det.weight_index, det.gt_next)) hamming_distance_present += det.present.value != det.gt_present hamming_distance_entry += det.entry.value != det.gt_entry hamming_distance_connect += sum( v.value != gt for v, gt in zip(det.outgoing, det.gt_next)) epoch_hamming_distance += hamming_distance_present + hamming_distance_entry + hamming_distance_connect # print(loss.item(), hamming_distance_present, hamming_distance_entry, hamming_distance_connect) loss.backward() # print(model.entry_weight_parameter, model.entry_weight_parameter.grad, hamming_distance_entry) optimizer.step() if lp_tracker_pool is None: break