示例#1
0
def main(argv=None):
    if len(argv) > 1:
        model_str = argv[1]
        if model_str == 'vgg16':
            model = helper.import_module(argv[1], FLAGS.vgg16_model_path)
            model_path = FLAGS.vgg16_model_path
        elif model == 'vgg19':
            model_str = helper.import_module(argv[1], FLAGS.vgg19_model_path)
            model_path = FLAGS.vgg19_model_path
    else:
        model = helper.import_module('model', FLAGS.vgg16_model_path)
        model_path = FLAGS.vgg16_model_path

    if tf.gfile.Exists(FLAGS.train_dir):
        raise ValueError('Train dir exists: ' + FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)

    stats_dir = os.path.join(FLAGS.train_dir, 'stats')
    tf.gfile.MakeDirs(stats_dir)
    tf.gfile.MakeDirs(os.path.join(FLAGS.train_dir, 'visualize'))
    f = open(os.path.join(stats_dir, 'log.txt'), 'w')
    sys.stdout = train_helper.Logger(sys.stdout, f)

    copyfile(model_path, os.path.join(FLAGS.train_dir, model_str + '.py'))
    copyfile(FLAGS.config_path, os.path.join(FLAGS.train_dir, 'config.py'))

    print('Experiment dir: ' + FLAGS.train_dir)
    train(model)
示例#2
0
def main(argv=None):
    #ipdb.set_trace()
    # import model
    model = helper.import_module('model', 'models/model_s.py')

    # load image
    '''
    filename_queue = tf.train.string_input_producer(['rgb_img.png'])
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)
    image = tf.image.decode_png(value)    
    print (tf.shape(image))
    exit()
    
    image = Image.open('rgb_img.png')
    image = np.asarray(image)
    image = image.astype('uint8')    
    '''
    filename_queue = tf.train.string_input_producer(['rgb_img.png'])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'height':
                                           tf.FixedLenFeature([], tf.int64),
                                           'width':
                                           tf.FixedLenFeature([], tf.int64),
                                           'depth':
                                           tf.FixedLenFeature([], tf.int64),
                                           'num_labels':
                                           tf.FixedLenFeature([], tf.int64),
                                           'img_name':
                                           tf.FixedLenFeature([], tf.string),
                                           'rgb':
                                           tf.FixedLenFeature([], tf.string),
                                           'label_weights':
                                           tf.FixedLenFeature([], tf.string),
                                           'labels':
                                           tf.FixedLenFeature([], tf.string),
                                       })

    image = tf.decode_raw(features['rgb'], tf.uint8)

    image = tf.reshape(image, shape=[1, 256, 512, 3])
    image = tf.to_float(image)

    # define path where model checkpoint is stored
    resume_path = '/home/stjepan/FAKS/zavrad/VGG_TX1/model.ckpt'

    # evaluate one image
    evalone(model, resume_path, image)
示例#3
0
def main(argv=None):
    model = helper.import_module('model', FLAGS.model_path)

    if tf.gfile.Exists(FLAGS.train_dir):
        raise ValueError('Train dir exists: ' + FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)

    stats_dir = os.path.join(FLAGS.train_dir, 'stats')
    tf.gfile.MakeDirs(stats_dir)
    tf.gfile.MakeDirs(os.path.join(FLAGS.train_dir, 'visualize'))
    f = open(os.path.join(stats_dir, 'log.txt'), 'w')
    sys.stdout = train_helper.Logger(sys.stdout, f)

    copyfile(FLAGS.model_path, os.path.join(FLAGS.train_dir, 'model.py'))
    copyfile(FLAGS.config_path, os.path.join(FLAGS.train_dir, 'config.py'))

    print('Experiment dir: ' + FLAGS.train_dir)
    train(model)
示例#4
0
import eval_helper
import numpy as np

import helper

import sys

from shutil import copyfile

# ffs change this relative path
tf.app.flags.DEFINE_string('config_path', "config/cityscapes.py",
                           """Path to experiment config.""")
FLAGS = tf.app.flags.FLAGS

helper.import_module('config', FLAGS.config_path)

class_names = [
    'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
    'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
    'truck', 'bus', 'train', 'motorcycle', 'bicycle'
]


def collect_confusion(logits, labels, conf_mat):
    predicted_labels = logits.argmax(3).astype(np.int32, copy=False)

    num_examples = FLAGS.batch_size * FLAGS.img_height * FLAGS.img_width
    predicted_labels = np.resize(predicted_labels, [
        num_examples,
    ])
def main(argv=None):
    model = helper.import_module('model', 'models/model_s.py')

    resume_path = 'trained_model/model.ckpt'

    evaluate(model, resume_path)
def main(argv=None):  # pylint: disable=unused-argument
  model = helper.import_module('model', os.path.join(FLAGS.model_dir, 'model.py'))

  if not tf.gfile.Exists(FLAGS.model_dir):
from datasets.voc2012.dataset import Dataset

np.set_printoptions(linewidth=250)

#DATA_DIR = '/home/kivan/datasets/voc2012_aug/data/'
#split = 'val'

DATA_DIR = '/home/kivan/datasets/VOC2012/test_data'

tf.app.flags.DEFINE_string('model_dir',
    #'/home/kivan/datasets/results/tmp/voc2012/25_7_16-38-50/', '')
    #'/home/kivan/datasets/results/tmp/voc2012/25_5_22-30-16', '')
FLAGS = tf.app.flags.FLAGS


helper.import_module('config', os.path.join(FLAGS.model_dir, 'config.py'))


def forward_pass(model, save_dir):
  img_dir = join(DATA_DIR, 'JPEGImages')
  file_path = join(DATA_DIR, 'ImageSets', 'Segmentation', 'test.txt')
  fp = open(file_path)
  file_list = [line.strip() for line in fp]

  save_dir_rgb = join(save_dir, 'rgb')
  tf.gfile.MakeDirs(save_dir_rgb)
  save_dir_submit = join(save_dir, 'submit')
  tf.gfile.MakeDirs(save_dir_submit)
  #sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
  config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
  #config.gpu_options.per_process_gpu_memory_fraction = 0.5 # don't hog all vRAM