Ejemplo n.º 1
0
def classify_db_all(conf,
                    db_file,
                    model_files,
                    model_type,
                    name='deepnet',
                    distort=False,
                    return_hm=False,
                    hm_dec=1,
                    hm_floor=0.1,
                    hm_nclustermax=1):
    cur_out = []
    extra_str = ''
    if model_type not in ['leap', 'openpose']:
        extra_str = '.index'
    # else:
    #     extra_str = '.h5'
    ts = [os.path.getmtime(f + extra_str) for f in model_files]

    for mndx, m in enumerate(model_files):
        # pred, label, gt_list = apt.classify_gt_data(conf, curm, out_file, m)
        tf_iterator = multiResData.tf_reader(conf, db_file, False)
        tf_iterator.batch_size = 1
        read_fn = tf_iterator.next
        pred_fn, close_fn, _ = apt.get_pred_fn(model_type,
                                               conf,
                                               m,
                                               name=name,
                                               distort=distort)
        ret_list = apt.classify_db(conf,
                                   read_fn,
                                   pred_fn,
                                   tf_iterator.N,
                                   return_hm=return_hm,
                                   hm_dec=hm_dec,
                                   hm_floor=hm_floor,
                                   hm_nclustermax=hm_nclustermax)
        pred, label, gt_list = ret_list[:3]
        if model_type == 'mdn':
            extra_stuff = ret_list[3:]
        else:
            extra_stuff = 0
        close_fn()
        gt_list = np.array(gt_list)
        cur_out.append([pred, label, gt_list, m, extra_stuff, ts[mndx]])

    return cur_out
Ejemplo n.º 2
0
Archivo: nbHG.py Proyecto: tsizemo2/APT
def predsingle(lbl_file, cachedir, view, exp_name, net, model_file_short,
               data_type):
    conf_pvlist = None
    if net == 'openpose':
        if data_type == 'bub':
            conf_pvlist = ['op_affinity_graph', op_af_graph_bub_noslash]
        else:
            assert False, "define aff graph"

    conf = apt.create_conf(lbl_file,
                           view,
                           exp_name,
                           cachedir,
                           net,
                           conf_params=conf_pvlist)
    db_file = os.path.join(conf.cachedir, 'val_TF.tfrecords')
    model_file = os.path.join(conf.cachedir, model_file_short)

    extra_str = ''
    if net not in ['leap', 'openpose']:
        extra_str = '.index'

    tf_iterator = multiResData.tf_reader(conf, db_file, False)
    tf_iterator.batch_size = 1
    read_fn = tf_iterator.next
    pred_fn, close_fn, _ = op.get_pred_fn(conf,
                                          model_file,
                                          name=None,
                                          rawpred=True)

    im, locs, info, _ = read_fn()
    print "im.shape is {}".format(im.shape)
    predmaps = pred_fn(im)
    close_fn()

    return predmaps, im, locs, info
Ejemplo n.º 3
0
aa = [int(re.search('-(\d*)', f).groups(0)[0]) for f in files]
aa = [b - a for a, b in zip(aa[:-1], aa[1:])]
if any([a < 0 for a in aa]):
    bb = int(np.where(np.array(aa) < 0)[0]) + 1
    files = files[bb:]
n_max = 6
if len(files) > n_max:
    gg = len(files)
    sel = np.linspace(0, len(files) - 1, n_max).astype('int')
    files = [files[s] for s in sel]

out_file = os.path.join(conf.cachedir, train_name + '_results.p')
afiles = [f.replace('.index', '') for f in files]

for m in afiles[-1:]:
    tf_iterator = multiResData.tf_reader(conf, gt_file, False)
    tf_iterator.batch_size = 1
    read_fn = tf_iterator.next
    pred_fn, close_fn, _ = apt.get_pred_fn(train_type,
                                           conf,
                                           m,
                                           name=train_name)
    bsize = conf.batch_size
    all_f = np.zeros((bsize, ) + conf.imsz + (conf.img_dim, ))
    n = tf_iterator.N
    pred_locs = np.zeros([n, conf.n_classes, 2])
    unet_locs = np.zeros([n, conf.n_classes, 2])
    mdn_locs = np.zeros([n, conf.n_classes, 2])
    n_batches = int(math.ceil(float(n) / bsize))
    labeled_locs = np.zeros([n, conf.n_classes, 2])
    all_ims = np.zeros([n, conf.imsz[0], conf.imsz[1], conf.img_dim])
Ejemplo n.º 4
0
    def read_image_thread(self, sess, db_type, distort, shuffle, scale):
        # Thread that does the pre processing.

        if self.train_type == 0:
            if db_type == self.DBType.Val:
                filename = os.path.join(self.conf.cachedir,
                                        self.conf.valfilename) + '.tfrecords'
            elif db_type == self.DBType.Train:
                filename = os.path.join(self.conf.cachedir,
                                        self.conf.trainfilename) + '.tfrecords'
            else:
                traceback = sys.exc_info()[2]
                raise_(IOError, "Unspecified DB Type", traceback)

        else:
            filename = os.path.join(self.conf.cachedir,
                                    self.conf.trainfilename) + '.tfrecords'

        cur_db = multiResData.tf_reader(self.conf, filename, shuffle)
        placeholders = self.q_placeholders

        print('Starting preloading thread of type ... {}'.format(db_type))
        batch_np = {}
        while not self.coord.should_stop():
            batch_in = cur_db.next()
            batch_np['orig_images'] = batch_in[0]
            batch_np['orig_locs'] = batch_in[1]
            batch_np['info'] = batch_in[2]
            batch_np['extra_info'] = batch_in[3]
            xs, locs = PoseTools.preprocess_ims(batch_np['orig_images'],
                                                batch_np['orig_locs'],
                                                self.conf, distort, scale)

            batch_np['images'] = xs
            batch_np['locs'] = locs

            for fn in self.q_fns:
                fn(batch_np)

            food = {pl: batch_np[name] for (name, pl) in placeholders}

            success = False
            run_options = tf.RunOptions(timeout_in_ms=30000)
            try:
                while not success:

                    if sess._closed or self.coord.should_stop():
                        return

                    try:
                        if db_type == self.DBType.Val:
                            sess.run(self.val_enqueue_op,
                                     feed_dict=food,
                                     options=run_options)
                        elif db_type == self.DBType.Train:
                            sess.run(self.train_enqueue_op,
                                     feed_dict=food,
                                     options=run_options)
                        success = True

                    except tf.errors.DeadlineExceededError:
                        pass

            except (tf.errors.CancelledError, ) as e:
                return
            except Exception as e:
                logging.exception('Error in preloading thread')
                self.close_cursors()
                sys.exit(1)
                return
Ejemplo n.º 5
0
    def read_image_thread(self, sess, db_type, distort, shuffle, scale):
        # Thread that does the pre processing.

        if self.train_type == 0:
            if db_type == self.DBType.Val:
                filename = os.path.join(self.conf.cachedir, self.conf.valfilename) + '.tfrecords'
            elif db_type == self.DBType.Train:
                filename = os.path.join(self.conf.cachedir, self.conf.trainfilename) + '.tfrecords'
            else:
                traceback = sys.exc_info()[2]
                raise_(IOError, "Unspecified DB Type", traceback)

        else:
            filename = os.path.join(self.conf.cachedir, self.conf.trainfilename) + '.tfrecords'

        cur_db = multiResData.tf_reader(self.conf, filename, shuffle)
        placeholders = self.q_placeholders

        print('Starting preloading thread of type ... {}'.format(db_type))
        batch_np = {}
        while not self.coord.should_stop():
            batch_in = cur_db.next()
            batch_np['orig_images'] = batch_in[0]
            batch_np['orig_locs'] = batch_in[1]
            batch_np['info'] = batch_in[2]
            batch_np['extra_info'] = batch_in[3]
            xs, locs = PoseTools.preprocess_ims(batch_np['orig_images'], batch_np['orig_locs'], self.conf,
                                                distort, scale)

            batch_np['images'] = xs
            batch_np['locs'] = locs

            for fn in self.q_fns:
                fn(batch_np)

            food = {pl: batch_np[name] for (name, pl) in placeholders}

            success = False
            run_options = tf.RunOptions(timeout_in_ms=30000)
            try:
                while not success:

                    if sess._closed or self.coord.should_stop():
                        return

                    try:
                        if db_type == self.DBType.Val:
                            sess.run(self.val_enqueue_op, feed_dict=food,options=run_options)
                        elif db_type == self.DBType.Train:
                            sess.run(self.train_enqueue_op, feed_dict=food, options=run_options)
                        success = True

                    except tf.errors.DeadlineExceededError:
                        pass

            except (tf.errors.CancelledError,) as e:
                return
            except Exception as e:
                logging.exception('Error in preloading thread')
                self.close_cursors()
                sys.exit(1)
                return
Ejemplo n.º 6
0
def compute_peformance(args):
    H = h5py.File(args.lbl_file, 'r')
    nviews = int(apt.read_entry(H['cfg']['NumViews']))
    dir_name = args.whose

    if len(args.nets) == 0:
        all_nets = methods
    else:
        all_nets = args.nets

    all_preds = {}

    for view in range(nviews):
        db_file = os.path.join(out_dir, args.name,
                               args.gt_name) + '_view{}.tfrecords'.format(view)
        conf = apt.create_conf(args.lbl_file,
                               view,
                               name='a',
                               net_type=all_nets[0],
                               cache_dir=os.path.join(out_dir, args.name,
                                                      dir_name))
        conf.labelfile = args.gt_lbl
        if not (os.path.exists(db_file) and args.skip_gt_db):
            print('Creating GT DB file {}'.format(db_file))
            apt.create_tfrecord(conf,
                                split=False,
                                on_gt=True,
                                db_files=(db_file, ))

    for curm in all_nets:
        all_preds[curm] = []
        for view in range(nviews):
            cur_out = []
            db_file = os.path.join(
                out_dir, args.name,
                args.gt_name) + '_view{}.tfrecords'.format(view)
            if args.split_type is None:
                cachedir = os.path.join(out_dir, args.name, dir_name,
                                        '{}_view_{}'.format(curm,
                                                            view), 'full')
                conf = apt.create_conf(args.lbl_file,
                                       view,
                                       name='a',
                                       net_type=curm,
                                       cache_dir=cachedir)
                model_files, ts = get_model_files(conf, cachedir, curm)
                for mndx, m in enumerate(model_files):
                    out_file = m + '_' + args.gt_name
                    load = False
                    if curm == 'unet' or curm == 'deeplabcut':
                        mm = m + '.index'
                    else:
                        mm = m
                    if os.path.exists(out_file + '.mat') and os.path.getmtime(
                            out_file + '.mat') > os.path.getmtime(mm):
                        load = True

                    if load:
                        H = sio.loadmat(out_file)
                        pred = H['pred_locs'] - 1
                        label = H['labeled_locs'] - 1
                        gt_list = H['list'] - 1
                    else:
                        # pred, label, gt_list = apt.classify_gt_data(conf, curm, out_file, m)
                        tf_iterator = multiResData.tf_reader(
                            conf, db_file, False)
                        tf_iterator.batch_size = 1
                        read_fn = tf_iterator.next
                        pred_fn, close_fn, _ = apt.get_pred_fn(curm, conf, m)
                        pred, label, gt_list = apt.classify_db(
                            conf, read_fn, pred_fn, tf_iterator.N)
                        close_fn()
                        mat_pred_locs = pred + 1
                        mat_labeled_locs = np.array(label) + 1
                        mat_list = gt_list

                        sio.savemat(
                            out_file, {
                                'pred_locs': mat_pred_locs,
                                'labeled_locs': mat_labeled_locs,
                                'list': mat_list
                            })

                    cur_out.append(
                        [pred, label, gt_list, m, out_file, view, 0, ts[mndx]])

            else:

                for cur_split in range(nsplits):
                    cachedir = os.path.join(out_dir, args.name,
                                            '{}_view_{}'.format(curm, view),
                                            'cv_{}'.format(cur_split))
                    conf = apt.create_conf(args.lbl_file,
                                           view,
                                           name='a',
                                           net_type=curm,
                                           cache_dir=cachedir)
                    model_files, ts = get_model_files(conf, cachedir, curm)
                    db_file = os.path.join(cachedir, 'val_TF.tfrecords')
                    for mndx, m in enumerate(model_files):
                        out_file = m + '.gt_data'
                        load = False
                        if curm == 'unet' or curm == 'deeplabcut':
                            mm = m + '.index'
                        else:
                            mm = m
                        if os.path.exists(
                                out_file + '.mat') and os.path.getmtime(
                                    out_file + '.mat') > os.path.getmtime(mm):
                            load = True

                        if load:
                            H = sio.loadmat(out_file)
                            pred = H['pred_locs'] - 1
                            label = H['labeled_locs'] - 1
                            gt_list = H['list'] - 1
                        else:
                            tf_iterator = multiResData.tf_reader(
                                conf, db_file, False)
                            tf_iterator.batch_size = 1
                            read_fn = tf_iterator.next
                            pred_fn, close_fn, _ = apt.get_pred_fn(
                                curm, conf, m)
                            pred, label, gt_list = apt.classify_db(
                                conf, read_fn, pred_fn, tf_iterator.N)
                            close_fn()
                            mat_pred_locs = pred + 1
                            mat_labeled_locs = np.array(label) + 1
                            mat_list = gt_list

                            sio.savemat(
                                out_file, {
                                    'pred_locs': mat_pred_locs,
                                    'labeled_locs': mat_labeled_locs,
                                    'list': mat_list
                                })

                        cur_out.append([
                            pred, label, gt_list, m, out_file, view, cur_split,
                            ts[mndx]
                        ])

            all_preds[curm].append(cur_out)

    with open(
            os.path.join(out_dir, args.name, dir_name,
                         args.gt_name + '_results.p'), 'w') as f:
        pickle.dump(all_preds, f)