Ejemplo n.º 1
0
def get_pred_fn(cfg, model_file=None, name='deepnet', tmr_pred=None):

    if tmr_pred is None:
        tmr_pred = contextlib.suppress()

    cfg = edict(cfg.__dict__)
    cfg = config.convert_to_deepcut(cfg)

    if model_file is None:
        ckpt_file = os.path.join(cfg.cachedir, name + '_ckpt')
        latest_ckpt = tf.train.get_checkpoint_state(cfg.cachedir, ckpt_file)
        model_file = latest_ckpt.model_checkpoint_path
        init_weights = model_file
    else:
        init_weights = model_file

    tf.reset_default_graph()
    sess, inputs, outputs = predict.setup_pose_prediction(cfg, init_weights)

    def pred_fn(all_f):

        if cfg.img_dim == 1:
            cur_im = np.tile(all_f, [1, 1, 1, 3])
        else:
            cur_im = all_f
        cur_im, _ = PoseTools.preprocess_ims(
            cur_im,
            in_locs=np.zeros([cur_im.shape[0], cfg.n_classes, 2]),
            conf=cfg,
            distort=False,
            scale=cfg.dlc_rescale)

        with tmr_pred:
            cur_out = sess.run(outputs, feed_dict={inputs: cur_im})
        scmap, locref = predict.extract_cnn_output(cur_out, cfg)
        pose = predict.argmax_pose_predict(scmap, locref, cfg.stride)
        pose = pose[:, :, :2] * cfg.dlc_rescale
        ret_dict = {}
        ret_dict['locs'] = pose
        ret_dict['hmaps'] = scmap
        ret_dict['conf'] = np.max(scmap, axis=(1, 2))
        return ret_dict

    def close_fn():
        sess.close()

    return pred_fn, close_fn, model_file
Ejemplo n.º 2
0
def get_read_fn(cfg, data_path):
    cfg = edict(cfg.__dict__)
    cfg = config.convert_to_deepcut(cfg)
    cfg.batch_size = 1
    cfg.shuffle = False

    dataset = PoseDataset(cfg, data_path, distort=False)
    n = dataset.num_images

    def read_fn():
        batch_np = dataset.next_batch()
        loc_in = batch_np[Batch.locs]
        ims = batch_np[Batch.inputs]
        if cfg.img_dim == 1:
            ims = ims[:, :, :, 0:1]
        info = [0, 0, 0]
        return ims, loc_in, info

    return read_fn, n
Ejemplo n.º 3
0
def train(cfg, name='deepnet'):
    #    setup_logging()

    cfg = edict(cfg.__dict__)
    cfg = config.convert_to_deepcut(cfg)

    if name == 'deepnet':
        train_data_file = os.path.join(cfg.cachedir, 'traindata')
    else:
        train_data_file = os.path.join(cfg.cachedir, name + '_traindata')

    with open(train_data_file, 'wb') as td_file:
        pickle.dump(cfg, td_file, protocol=2)
    logging.info('Saved config to {}'.format(train_data_file))

    dirname = os.path.dirname(os.path.dirname(__file__))
    init_weights = os.path.join(dirname, 'pretrained', 'resnet_v1_50.ckpt')

    if not os.path.exists(init_weights):
        # Download and save the pretrained resnet weights.
        logging.info('Downloading pretrained resnet 50 weights ...')
        urllib.urlretrieve(
            'http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz',
            os.path.join(dirname, 'pretrained',
                         'resnet_v1_50_2016_08_28.tar.gz'))
        tar = tarfile.open(
            os.path.join(dirname, 'pretrained',
                         'resnet_v1_50_2016_08_28.tar.gz'))
        tar.extractall(path=os.path.join(dirname, 'pretrained'))
        tar.close()
        logging.info('Done downloading pretrained weights')

    db_file_name = os.path.join(cfg.cachedir, 'train_data.p')
    dataset = PoseDataset(cfg, db_file_name, distort=True)
    train_info = {
        'train_dist': [],
        'train_loss': [],
        'val_dist': [],
        'val_loss': [],
        'step': []
    }

    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    net = pose_net(cfg)
    losses = net.train(batch)
    total_loss = losses['total_loss']
    outputs = [net.heads['part_pred'], net.heads['locref']]

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(max_to_keep=50, save_relative_paths=True)

    sess = tf.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)

    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, init_weights)

    #max_iter = int(cfg.multi_step[-1][1])
    max_iter = int(cfg.dl_steps)
    display_iters = cfg.display_step
    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    model_name = os.path.join(cfg.cachedir, name)
    ckpt_file = os.path.join(cfg.cachedir, name + '_ckpt')

    start = time.time()
    save_start = time.time()
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val] = sess.run(
            [train_op, total_loss],  # merged_summaries],
            feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        #       train_writer.add_summary(summary, it)

        if it % display_iters == 0:
            cur_out, batch_out = sess.run(
                [outputs, batch], feed_dict={learning_rate: current_lr})
            scmap, locref = predict.extract_cnn_output(cur_out, cfg)

            # Extract maximum scoring location from the heatmap, assume 1 person
            loc_pred = predict.argmax_pose_predict(scmap, locref, cfg.stride)
            loc_in = batch_out[Batch.locs]
            dd = np.sqrt(
                np.sum(np.square(loc_pred[:, :, :2] - loc_in), axis=-1))
            dd = dd * cfg.dlc_rescale
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            end = time.time()
            # print('Time to train: {}'.format(end-start))
            start = end
            print("iteration: {} loss: {} dist: {}  lr: {}".format(
                it, "{0:.4f}".format(average_loss),
                '{0:.2f}'.format(dd.mean()), current_lr))
            train_info['step'].append(it)
            train_info['train_loss'].append(loss_val)
            train_info['val_loss'].append(loss_val)
            train_info['val_dist'].append(dd.mean())
            train_info['train_dist'].append(dd.mean())

            save_td(cfg, train_info, name)

        # Save snapshot
        if 'save_time' in cfg.keys() and cfg['save_time'] is not None:
            if (time.time() - save_start) > cfg['save_time'] * 60:
                saver.save(sess,
                           model_name,
                           global_step=it,
                           latest_filename=os.path.basename(ckpt_file))
                save_start = time.time()
        else:
            if (it % cfg.save_step == 0) or it == max_iter:
                saver.save(sess,
                           model_name,
                           global_step=it,
                           latest_filename=os.path.basename(ckpt_file))

    coord.request_stop()
    coord.join([thread], stop_grace_period_secs=60)
    sess.close()