def inference(dataset_):
  is_train_pl = tf.placeholder(tf.bool)
  img_pl, _, = model.placeholder_inputs(BATCH_SIZE, IM_DIM, VOL_DIM)
  pred = model.get_model(img_pl, is_train_pl)
  pred = tf.sigmoid(pred)

  config = tf.ConfigProto()
  config.gpu_options.allocator_type = 'BFC'
  config.gpu_options.allow_growth = True
  config.allow_soft_placement = True

  with tf.Session(config=config) as sess:
    model_path = os.path.join(TRAIN_DIR, "trained_models")
    ckpt = tf.train.get_checkpoint_state(model_path)
    restorer = tf.train.Saver()
    restorer.restore(sess, ckpt.model_checkpoint_path)

    test_samples = dataset_.getTestSampleSize()

    for batch_idx in range(test_samples):
      imgs, view_names = dataset_.next_test_batch(batch_idx, 1)  

      feed_dict = {img_pl: imgs, is_train_pl: False}
      pred_res = sess.run(pred, feed_dict=feed_dict)

      for i in range(len(view_names)):
        vol_ = pred_res[i]

        cloth = view_names[i][0]
        mesh = view_names[i][1]
        name_ = view_names[i][2][:-4]

        save_path = os.path.join(OUTPUT_DIR, cloth, mesh)
        if not os.path.exists(save_path): 
          os.makedirs(save_path)

        save_path_name = os.path.join(save_path, name_+".h5")
        if os.path.exists(save_path_name):
          os.remove(save_path_name)

        h5_fout = h5py.File(save_path_name)
        h5_fout.create_dataset(
                'data', data=vol_,
                compression='gzip', compression_opts=4,
                dtype='float32')
        h5_fout.close()

        print(batch_idx, save_path_name)
def inference():
    is_train_pl = tf.placeholder(tf.bool)
    img_pl, _, = model.placeholder_inputs(BATCH_SIZE, IM_DIM, VOL_DIM)
    pred = model.get_model(img_pl, is_train_pl)
    pred = tf.sigmoid(pred)

    config = tf.ConfigProto(device_count={'CPU': 1})
    with tf.Session(config=config) as sess:
        model_path = os.path.join(TRAIN_DIR, "trained_models")
        ckpt = tf.train.get_checkpoint_state(model_path)
        restorer = tf.train.Saver()
        restorer.restore(sess, ckpt.model_checkpoint_path)

        img_1 = np.array(misc.imread(img_path) / 255.0)
        img_1 = img_1.reshape((1, 128, 128, 3))
        feed_dict = {img_pl: img_1, is_train_pl: False}
        pred_res = sess.run(pred, feed_dict=feed_dict)

        vol_ = pred_res[0]  # (vol_dim, vol_dim, vol_dim, 1)
        name_ = '001'  # FLAGS.img.strip().split('.')[0] # xx.xxx.png

        save_path = OUTPUT_DIR
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_path_name = os.path.join(save_path, name_ + ".h5")
        if os.path.exists(save_path_name):
            os.remove(save_path_name)

        h5_fout = h5py.File(save_path_name)
        h5_fout.create_dataset('data',
                               data=vol_,
                               compression='gzip',
                               compression_opts=4,
                               dtype='float32')
        h5_fout.close()

        print(name_ + '.h5 is predicted into %s' % (save_path_name))
Esempio n. 3
0
def train(dataset_):
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(GPU_INDEX)):
            is_train_pl = tf.placeholder(tf.bool)
            img_pl, vol_pl = model.placeholder_inputs(BATCH_SIZE, IM_DIM,
                                                      VOL_DIM)

            # batch
            global_step = tf.Variable(0)
            bn_decay = get_bn_decay(global_step)
            tf.summary.scalar('bn_decay', bn_decay)

            # get prediction and loss
            pred = model.get_model(img_pl,
                                   is_train_pl,
                                   weight_decay=args.wd,
                                   bn_decay=bn_decay)
            loss = model.get_MSFE_cross_entropy_loss(pred, vol_pl)
            tf.summary.scalar('loss', loss)

            # Get training operator
            learning_rate = get_learning_rate(global_step)
            tf.summary.scalar('learning_rate', learning_rate)
            optimizer = tf.train.AdamOptimizer(learning_rate)
            train_op = optimizer.minimize(loss, global_step=global_step)

            summary_op = tf.summary.merge_all()

            saver = tf.train.Saver()

        config = tf.ConfigProto()
        config.gpu_options.allocator_type = 'BFC'
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        with tf.Session(config=config) as sess:
            model_path = os.path.join(TRAIN_DIR, "trained_models")
            if tf.gfile.Exists(os.path.join(model_path, "checkpoint")):
                ckpt = tf.train.get_checkpoint_state(model_path)
                restorer = tf.train.Saver()
                restorer.restore(sess, ckpt.model_checkpoint_path)
                print("Load parameters from checkpoint.")
            else:
                sess.run(tf.global_variables_initializer())

            train_summary_writer = tf.summary.FileWriter(model_path,
                                                         graph=sess.graph)

            train_sample_size = dataset_.getTrainSampleSize()
            train_batches = train_sample_size // BATCH_SIZE

            for epoch in range(TRAIN_EPOCHS):
                dataset_.shuffleTrainNames()

                for batch_idx in range(train_batches):
                    imgs, vols_clr = dataset_.next_batch(
                        batch_idx * BATCH_SIZE, BATCH_SIZE)
                    vols_occu = np.prod(
                        vols_clr > -0.5, axis=-1,
                        keepdims=True)  # (batch, vol_dim, vol_dim, vol_dim, 1)
                    vols_occu = vols_occu.astype(np.float32)

                    feed_dict = {
                        img_pl: imgs,
                        vol_pl: vols_occu,
                        is_train_pl: True
                    }

                    step = sess.run(global_step)
                    _, loss_val = sess.run([train_op, loss],
                                           feed_dict=feed_dict)

                    log_string("<TRAIN> Epoch {} - Batch {}: loss: {}.".format(
                        epoch, batch_idx, loss_val))

                if epoch % args.epochs_to_save == 0:
                    checkpoint_path = os.path.join(model_path, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=epoch)