def test_ggd_batches(self): graphres = torch.load( os.path.join(mydir, "data", "basic-duke_graph_3_00190415.pck")) with TemporaryDirectory() as tmpdir: model = NNModelGraphresPerConnection() model.load_state_dict( torch.load(os.path.join(mydir, "data", "snapshot_009.pyt"))['model_state']) model.eval() lst = GraphDiffList(tmpdir, model) old = [] batch_size = 4 n = (len(graphres) // batch_size) * batch_size for i in range(n): ex1 = graphres[i] old.append((model(ex1.pos) - model(ex1.neg)).item()) lst.append(graphres[i]) for i0 in range(0, n, batch_size): batch = make_ggd_batch( [lst[i] for i in range(i0, i0 + batch_size)]) l = model.ggd_batch_forward(batch) for i in range(i0, i0 + batch_size): assert abs(l[i - i0].item() - old[i]) < 1e-3
def main(): dataset = Duke('data', cachedir="cachedir") #_mc5") model = NNModelGraphresPerConnection() logdir = dataset.logdir print(logdir) fn = sorted(glob("%s/snapshot_???.pyt" % logdir))[-1] model.load_state_dict(torch.load(fn)['model_state']) model.eval() gt_not_in_graph = long_connections = long_connections_within_bound = 0 for name, cam in tqdm(graph_names(dataset, "eval"), "Estimating long structure"): name = name.replace("/lunarc/nobackup/projects/lu-haar/ggdtrack/", "") graph, detection_weight_features, connection_batch = torch.load( name + '-%s-eval_graph' % model.feature_name) promote_graph(graph) connection_weights = model.connection_batch_forward(connection_batch) detection_weights = model.detection_model(detection_weight_features) scene = dataset.scene(cam) gt_tracks, gt_graph_frames = ground_truth_tracks( scene.ground_truth(), graph) for tr in gt_tracks: prv = tr[0] for det in tr[1:]: prv.gt_next = det prv = det for det in graph: for i, nxt in zip(det.weight_index, det.next): if det.track_id == nxt.track_id != None and nxt.frame - det.frame > 1: long_connections += 1 upper = get_upper_bound_from_gt(det, nxt, connection_weights, detection_weights) if upper is None: gt_not_in_graph += 1 elif 0 < connection_weights[i] < upper: long_connections_within_bound += 1 # print (" %s -[%4.2f]-> %s" % (det.track_id, connection_weights[i], nxt.track_id), # det.frame, nxt.frame, upper) # tracks = lp_track(graph, connection_batch, detection_weight_features, model) # print(tracks) print() print(gt_not_in_graph, long_connections, long_connections_within_bound) print(long_connections_within_bound / (long_connections - gt_not_in_graph)) print()