Exemplo n.º 1
0
def prep_eval_gt_tracks_worker(args):
    model, name, scene, logdir, split_on_no_edge = args
    ofn = os.path.join(logdir, "tracks", os.path.basename(name))

    graph, detection_weight_features, connection_batch = torch.load(
        name + '-%s-eval_graph' % model.feature_name)
    promote_graph(graph)
    tracks, gt_graph_frames = ground_truth_tracks(scene.ground_truth(),
                                                  graph,
                                                  add_gt_class=True)
    if split_on_no_edge:
        tracks = split_track_on_missing_edge(tracks)
    len_stats = defaultdict(int)
    for tr in tracks:
        for det in tr:
            if hasattr(det, 'cls'):
                cls = det.cls
                gt_cls = det.gt_cls if hasattr(det, 'gt_cls') else None
                det.__dict__ = {}
                det.cls = cls
                if gt_cls is not None:
                    det.gt_cls = gt_cls
            else:
                det.__dict__ = {}
            len_stats[len(tr)] += 1

    save_pickle(tracks, ofn)
    save_pickle({'track_length': len_stats}, ofn + '-stats')

    return ofn
Exemplo n.º 2
0
    def __getitem__(self, item):
        name, cam = self.entries[item]
        scene = self.dataset.scene(cam)

        graph, detection_weight_features, connection_batch = torch.load(
            name + self.suffix)
        promote_graph(graph)

        gt_tracks, gt_graph_frames = ground_truth_tracks(
            scene.ground_truth(), graph)
        gt_tracks = split_track_on_missing_edge(gt_tracks)

        for det in graph:
            det.gt_entry = 0.0
            det.gt_present = 0.0 if det.track_id is None else 1.0
            det.gt_next = [0.0] * len(det.next)
        for tr in gt_tracks:
            tr[0].gt_entry = 1.0
            prv = None
            for det in tr:
                if prv is not None:
                    prv.gt_next[prv.next.index(det)] = 1.0
                prv = det

        demote_graph(graph)
        return graph, detection_weight_features, connection_batch
Exemplo n.º 3
0
 def worker(work):
     graph, connection_weights, detection_weights, entry_weight, connection_batch, detection_weight_features = work
     promote_graph(graph)
     lp_track_weights(graph,
                      connection_weights,
                      detection_weights,
                      entry_weight,
                      add_gt_hamming=True)
     demote_graph(graph)
     return graph, connection_batch, detection_weight_features
Exemplo n.º 4
0
    def test_demote_promote(self):
        g1 = load_pickle(os.path.join(mydir, "data", "promoted_graph.pck"))
        g2 = deepcopy(g1)
        demote_graph(g1)
        promote_graph(g1)

        for i in range(len(g1)):
            assert g1[i].prev == g2[i].prev
            assert g1[i].next_weight_data == g2[i].next_weight_data
            assert g1[i].prev.__class__ == g2[i].prev.__class__
            assert g1[i].next_weight_data.__class__ == g2[
                i].next_weight_data.__class__
Exemplo n.º 5
0
def main():
    dataset = Duke('data', cachedir="cachedir")  #_mc5")
    model = NNModelGraphresPerConnection()
    logdir = dataset.logdir
    print(logdir)
    fn = sorted(glob("%s/snapshot_???.pyt" % logdir))[-1]
    model.load_state_dict(torch.load(fn)['model_state'])
    model.eval()

    gt_not_in_graph = long_connections = long_connections_within_bound = 0

    for name, cam in tqdm(graph_names(dataset, "eval"),
                          "Estimating long structure"):
        name = name.replace("/lunarc/nobackup/projects/lu-haar/ggdtrack/", "")
        graph, detection_weight_features, connection_batch = torch.load(
            name + '-%s-eval_graph' % model.feature_name)
        promote_graph(graph)
        connection_weights = model.connection_batch_forward(connection_batch)
        detection_weights = model.detection_model(detection_weight_features)

        scene = dataset.scene(cam)
        gt_tracks, gt_graph_frames = ground_truth_tracks(
            scene.ground_truth(), graph)
        for tr in gt_tracks:
            prv = tr[0]
            for det in tr[1:]:
                prv.gt_next = det
                prv = det

        for det in graph:
            for i, nxt in zip(det.weight_index, det.next):
                if det.track_id == nxt.track_id != None and nxt.frame - det.frame > 1:
                    long_connections += 1
                    upper = get_upper_bound_from_gt(det, nxt,
                                                    connection_weights,
                                                    detection_weights)
                    if upper is None:
                        gt_not_in_graph += 1
                    elif 0 < connection_weights[i] < upper:
                        long_connections_within_bound += 1
                    # print ("  %s -[%4.2f]-> %s" % (det.track_id, connection_weights[i], nxt.track_id),
                    #        det.frame, nxt.frame, upper)

        # tracks = lp_track(graph, connection_batch, detection_weight_features, model)
        # print(tracks)

        print()
        print(gt_not_in_graph, long_connections, long_connections_within_bound)
        print(long_connections_within_bound /
              (long_connections - gt_not_in_graph))
        print()
Exemplo n.º 6
0
def eval_hamming_worker(work):
    graph, connection_weights, detection_weights, entry_weight = work
    promote_graph(graph)
    lp_track_weights(graph,
                     connection_weights,
                     detection_weights,
                     entry_weight,
                     add_gt_hamming=True)

    hamming = 0
    for det in graph:
        assert len(det.next) == len(det.outgoing)
        hamming += det.present.value != det.gt_present
        hamming += det.entry.value != det.gt_entry
        hamming += sum(v.value != gt
                       for v, gt in zip(det.outgoing, det.gt_next))
    return hamming
Exemplo n.º 7
0
def prep_eval_tracks_worker(args):
    model, name, device, logdir = args
    ofn = os.path.join(logdir, "tracks", os.path.basename(name))

    graph, detection_weight_features, connection_batch = torch.load(
        name + '-%s-eval_graph' % model.feature_name)
    promote_graph(graph)
    detection_weight_features = detection_weight_features.to(device)
    connection_batch = connection_batch.to(device)
    tracks = lp_track(graph, connection_batch, detection_weight_features,
                      model)
    for tr in tracks:
        for det in tr:
            if hasattr(det, 'cls'):
                cls = det.cls
                det.__dict__ = {}
                det.cls = cls
            else:
                det.__dict__ = {}
    save_pickle(tracks, ofn)

    return ofn
Exemplo n.º 8
0
def train_frossard(dataset,
                   logdir,
                   model,
                   mean_from=None,
                   device=default_torch_device,
                   limit=None,
                   epochs=1000,
                   resume_from=None,
                   save_every=None):

    if mean_from is None and resume_from is None:
        raise NotImplementedError

    model.to(device)
    optimizer = optim.Adam(model.parameters(), 1e-3)

    if resume_from is not None:
        if os.path.isdir(resume_from):
            fn = sorted(glob("%s/snapshot_???.pyt" % (resume_from)))[-1]
        else:
            fn = resume_from
        print("Resuming from", fn)
        snapshot = torch.load(fn)
        if isinstance(snapshot, dict) and 'model_state' in snapshot:
            model.load_state_dict(snapshot['model_state'])
            # optimizer.load_state_dict(snapshot['optimizer_state'])
            start_epoch = snapshot['epoch'] + 1
        else:
            model.load_state_dict(snapshot)
            start_epoch = 0
        model.to(device)
    else:
        start_epoch = 0
        for t in model.parameters():
            torch.nn.init.normal_(t, 0, 1e-3)

        mean_model = model.__class__()
        mean_model.load_state_dict(torch.load(mean_from)['model_state'])
        model.detection_model.mean = mean_model.detection_model.mean
        model.detection_model.std = mean_model.detection_model.std
        model.edge_model.klt_model.mean = mean_model.edge_model.klt_model.mean
        model.edge_model.klt_model.std = mean_model.edge_model.klt_model.std
        model.edge_model.long_model.mean = mean_model.edge_model.long_model.mean
        model.edge_model.long_model.std = mean_model.edge_model.long_model.std
        model.to(device)

    if logdir != resume_from:
        if os.path.exists(logdir):
            rmtree(logdir)
        os.makedirs(logdir)

    entries = graph_names(dataset, "train")
    if limit is not None:
        shuffle(entries)
        entries = entries[:limit]

    threads = multiprocessing.cpu_count() - 2

    train_data = EvalGtGraphs(dataset, entries,
                              '-%s-eval_graph' % model.feature_name)
    train_loader = DataLoader(train_data,
                              1,
                              True,
                              collate_fn=single_example_passthrough,
                              num_workers=threads)

    def worker(work):
        graph, connection_weights, detection_weights, entry_weight, connection_batch, detection_weight_features = work
        promote_graph(graph)
        lp_track_weights(graph,
                         connection_weights,
                         detection_weights,
                         entry_weight,
                         add_gt_hamming=True)
        demote_graph(graph)
        return graph, connection_batch, detection_weight_features

    # lp_tracker_pool = WorkerPool(threads, worker)
    lp_tracker_pool = None

    save_count = 0
    last_save = time.time()
    epoch_hamming_distance = batch_count = 0
    epoch = start_epoch
    while True:
        #     shuffle(entries)
        for graph, detection_weight_features, connection_batch in train_loader:
            # for graph, detection_weight_features, connection_batch in train_data:

            model.eval()
            connection_weights = model.connection_batch_forward(
                connection_batch.to(device))
            detection_weights = model.detection_model(
                detection_weight_features.to(device))

            if lp_tracker_pool is not None:
                lp_tracker_pool.put(
                    (graph, connection_weights.detach().cpu(),
                     detection_weights.detach().cpu(), model.entry_weight,
                     connection_batch, detection_weight_features))
            else:
                promote_graph(graph)
                lp_track_weights(graph,
                                 connection_weights,
                                 detection_weights,
                                 model.entry_weight,
                                 add_gt_hamming=True)

            while True:
                if lp_tracker_pool is not None:
                    try:
                        graph, connection_batch, detection_weight_features = lp_tracker_pool.get(
                            block=False)
                    except Empty:
                        break
                    promote_graph(graph)
                if not graph:
                    continue

                if batch_count >= len(train_data):
                    snapp = {
                        'epoch': epoch,
                        'model_state': model.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'train_hamming': epoch_hamming_distance,
                    }
                    torch.save(
                        snapp, os.path.join(logdir,
                                            "snapshot_%.3d.pyt" % epoch))
                    print('%3d Hamming:' % epoch, epoch_hamming_distance)
                    epoch_hamming_distance = batch_count = 0
                    epoch += 1
                    if epoch >= start_epoch + epochs:
                        return
                    continue
                batch_count += 1

                if save_every and time.time() - last_save > save_every:
                    last_save = time.time()
                    torch.save(
                        model.state_dict(),
                        os.path.join(logdir, "model_%.4d.pyt" % save_count))
                    save_count += 1

                # interpolate_missing_detections(tracks)
                # show_tracks(scene, tracks, gt_graph_frames)

                model.train()
                optimizer.zero_grad()

                hamming_distance_present = hamming_distance_entry = hamming_distance_connect = loss = 0
                connection_weights = model.connection_batch_forward(
                    connection_batch.to(device))
                detection_weights = model.detection_model(
                    detection_weight_features.to(device))
                for det in graph:
                    assert len(det.next) == len(det.outgoing)
                    loss += (det.present.value -
                             det.gt_present) * detection_weights[det.index]
                    loss += (det.entry.value -
                             det.gt_entry) * model.entry_weight_parameter
                    loss += sum(
                        connection_weights[i] * (v.value - gt)
                        for v, i, gt in zip(det.outgoing, det.weight_index,
                                            det.gt_next))
                    hamming_distance_present += det.present.value != det.gt_present
                    hamming_distance_entry += det.entry.value != det.gt_entry
                    hamming_distance_connect += sum(
                        v.value != gt
                        for v, gt in zip(det.outgoing, det.gt_next))
                epoch_hamming_distance += hamming_distance_present + hamming_distance_entry + hamming_distance_connect

                # print(loss.item(), hamming_distance_present, hamming_distance_entry, hamming_distance_connect)
                loss.backward()
                # print(model.entry_weight_parameter, model.entry_weight_parameter.grad, hamming_distance_entry)
                optimizer.step()

                if lp_tracker_pool is None:
                    break