Exemple #1
0
def main(_):
    prepare_training_dir()

    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    hparams = get_training_hparams()

    # If ps_tasks is zero, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    device_setter = tf.train.replica_device_setter(FLAGS.ps_tasks,
                                                   merge_devices=True)
    with tf.device(device_setter):
        data = data_provider.get_data(
            dataset,
            FLAGS.batch_size,
            augment=hparams.use_augment_input,
            central_crop_size=common_flags.get_crop_size())
        endpoints = model.create_base(data.images, data.labels_one_hot)
        total_loss = model.create_loss(data, endpoints)
        model.create_summaries(data,
                               endpoints,
                               dataset.charset,
                               is_training=True)
        init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
                                                  FLAGS.checkpoint_inception)
        if FLAGS.show_graph_stats:
            tf.logging.info('Total number of weights in the graph: %s',
                            calculate_graph_metrics())
        train(total_loss, init_fn, hparams)
Exemple #2
0
def main(_):
  if not tf.gfile.Exists(FLAGS.eval_log_dir):
    tf.gfile.MakeDirs(FLAGS.eval_log_dir)

  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
  data = data_provider.get_data(
      dataset,
      FLAGS.batch_size,
      augment=False,
      central_crop_size=common_flags.get_crop_size())
  endpoints = model.create_base(data.images, labels_one_hot=None)
  model.create_loss(data, endpoints)
  eval_ops = model.create_summaries(
      data, endpoints, dataset.charset, is_training=False)
  slim.get_or_create_global_step()
  session_config = tf.ConfigProto(device_count={"GPU": 0})
  slim.evaluation.evaluation_loop(
      master=FLAGS.master,
      checkpoint_dir=FLAGS.train_log_dir,
      logdir=FLAGS.eval_log_dir,
      eval_op=eval_ops,
      num_evals=FLAGS.num_batches,
      eval_interval_secs=FLAGS.eval_interval_secs,
      max_number_of_evaluations=FLAGS.number_of_steps,
      session_config=session_config)
Exemple #3
0
def main(_):
    if not tf.gfile.Exists(FLAGS.eval_log_dir):
        tf.gfile.MakeDirs(FLAGS.eval_log_dir)

    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    data = data_provider.get_data(
        dataset,
        FLAGS.batch_size,
        augment=False,
        central_crop_size=common_flags.get_crop_size())
    endpoints = model.create_base(data.images, labels_one_hot=None)
    model.create_loss(data, endpoints)
    eval_ops = model.create_summaries(data,
                                      endpoints,
                                      dataset.charset,
                                      is_training=False)
    slim.get_or_create_global_step()
    session_config = tf.ConfigProto(device_count={"GPU": 0})
    ###
    session_config.gpu_options.allow_growth = True
    session_config.log_device_placement = False
    ###
    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.train_log_dir,
        logdir=FLAGS.eval_log_dir,
        eval_op=eval_ops,
        num_evals=FLAGS.num_batches,
        eval_interval_secs=FLAGS.eval_interval_secs,
        max_number_of_evaluations=1,
        session_config=session_config)
Exemple #4
0
def main(_):
    if not tf.gfile.Exists(FLAGS.eval_log_dir):
        tf.gfile.MakeDirs(FLAGS.eval_log_dir)

    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    data = data_provider.get_data(
        dataset,
        FLAGS.batch_size,
        augment=False,
        central_crop_size=common_flags.get_crop_size())
    print("JIBIRI!")
    print(data.images)
    endpoints = model.create_base(data.images, labels_one_hot=None)
    model.create_loss(data, endpoints)
    eval_ops = model.create_summaries(data,
                                      endpoints,
                                      dataset.charset,
                                      is_training=False)
    slim.get_or_create_global_step()
    session_config = tf.ConfigProto(device_count={"GPU": 0})
    checkpoint_path = "%s/model.ckpt-90482" % FLAGS.train_log_dir
    eval_result = slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_log_dir,
        eval_op=eval_ops,
        session_config=session_config)
    def initial(self):
        dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
        model = common_flags.create_model(dataset.num_char_classes,
                                          dataset.max_sequence_length,
                                          dataset.num_of_views,
                                          dataset.null_code,
                                          charset=dataset.charset)
        data = data_provider.get_data(
            dataset,
            FLAGS.batch_size,
            augment=False,
            central_crop_size=common_flags.get_crop_size())

        self.image_height = int(data.images.shape[1])
        self.image_width = int(data.images.shape[2])
        self.image_channel = int(data.images.shape[3])
        self.num_of_view = dataset.num_of_views
        placeholder_shape = (1, self.image_height, self.image_width,
                             self.image_channel)
        print placeholder_shape
        self.placeholder = tf.placeholder(tf.float32, shape=placeholder_shape)
        self.endpoint = model.create_base(self.placeholder,
                                          labels_one_hot=None)
        init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint)

        self.sess = tf.Session(config=config)
        tf.tables_initializer().run(session=self.sess)
        init_fn(self.sess)
Exemple #6
0
def main(_):
  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                  dataset.max_sequence_length,
                                  dataset.num_of_views, dataset.null_code,
                                  charset=dataset.charset)
  data = data_provider.get_data(
      dataset,
      FLAGS.batch_size,
      augment=False,
      central_crop_size=common_flags.get_crop_size())


  input_image = Image.open(FLAGS.input_image).convert("RGB").resize((data.images.shape[2]
    / dataset.num_of_views, data.images.shape[1]))
  input_array = np.array(input_image).astype(np.float32)
  #input_array = np.concatenate((input_array, input_array, input_array, input_array), axis=1)
  #Image.fromarray(input_array.astype(np.uint8), "RGB").save("test_input_1.jpg")
  input_array = np.expand_dims(input_array, axis=0)
  print input_array.shape
  print input_array.dtype
  #input_image.save("test_input.jpg")
  #return
  
  placeholder_shape = (1, data.images.shape[1], data.images.shape[2], data.images.shape[3])
  print placeholder_shape
  image_placeholder = tf.placeholder(tf.float32, shape=placeholder_shape)
  endpoints = model.create_base(image_placeholder, labels_one_hot=None)
  init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint)
  with tf.Session() as sess:
    tf.tables_initializer().run()  # required by the CharsetMapper
    init_fn(sess)
    predictions = sess.run(endpoints.predicted_text,
                           feed_dict={image_placeholder: input_array})
  print("Predicted strings:")
  for line in predictions:
    print(line)
def main(_):
    prepare_training_dir()

    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    hparams = get_training_hparams()

    # If ps_tasks is zero, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    #device_setter = tf.train.replica_device_setter(
    #    FLAGS.ps_tasks, merge_devices=True)
    with tf.device("/cpu:0"):
        provider = data_provider.get_data(
            dataset,
            FLAGS.batch_size,
            augment=hparams.use_augment_input,
            central_crop_size=common_flags.get_crop_size())
        batch_queue = slim.prefetch_queue.prefetch_queue([
            provider.images, provider.images_orig, provider.labels,
            provider.labels_one_hot
        ],
                                                         capacity=2 *
                                                         FLAGS.num_clones)

    losses = []
    for i in xrange(FLAGS.num_clones):
        with tf.name_scope("clone_{0}".format(i)):
            with tf.device("/gpu:{0}".format(i)):
                #if i == 1:
                #  continue
                images, images_orig, labels, labels_one_hot = batch_queue.dequeue(
                )
                if i == 0:
                    endpoints = model.create_base(images, labels_one_hot)
                else:
                    endpoints = model.create_base(images,
                                                  labels_one_hot,
                                                  reuse=True)
                init_fn = model.create_init_fn_to_restore(
                    FLAGS.checkpoint, FLAGS.checkpoint_inception)
                if FLAGS.show_graph_stats:
                    logging.info('Total number of weights in the graph: %s',
                                 calculate_graph_metrics())

                data = InputEndpoints(images=images,
                                      images_orig=images_orig,
                                      labels=labels,
                                      labels_one_hot=labels_one_hot)

                total_loss, single_model_loss = model.create_loss(
                    data, endpoints)
                losses.append((single_model_loss, i))
                with tf.device("/cpu:0"):
                    tf.summary.scalar('model_loss'.format(i),
                                      single_model_loss)
                    model.create_summaries_multigpu(data,
                                                    endpoints,
                                                    dataset.charset,
                                                    i,
                                                    is_training=True)
    train_multigpu(losses, init_fn, hparams)
Exemple #8
0
def main(_):
    # 检查训练目录
    prepare_training_dir()

    # 建立数据集 split_name: train test
    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)

    # 建立模型 max_sequence_length: 37, num_of_views: 4, null_code:133
    # 这里还没有创建模型,只是返回了模型类,和初始化了模型相关参数
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    hparams = get_training_hparams()

    # If ps_tasks is zero, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    device_setter = tf.train.replica_device_setter(FLAGS.ps_tasks,
                                                   merge_devices=True)
    with tf.device(device_setter):
        # 获得训练数据
        data = data_provider.get_data(
            dataset,
            FLAGS.batch_size,
            augment=hparams.use_augment_input,
            central_crop_size=common_flags.get_crop_size())

        # 打印 dataset 的数据,看一下
        # print("#######################")
        # print("images:", data.images)
        # print("labels:", data.labels)
        # print(dir(data.labels))
        # print("labels_one_hot:", data.labels_one_hot.shape)
        # print("labels_0:", data.labels[0])
        # print("labels_one_host_0:", data.labels_one_hot[0])
        # print("#######################")
        # init = tf.global_variables_initializer()
        # with tf.Session() as session:
        #     session.run(init)
        #     coord = tf.train.Coordinator()
        #     threads = tf.train.start_queue_runners(coord=coord)
        #     labels = session.run(data.labels)
        #     print(labels[0])
        #     labels = session.run(data.labels_one_hot)
        #     print(labels[0])
        #     coord.request_stop()
        #     coord.join(threads)
        #     return

        # 创建模型
        endpoints = model.create_base(data.images, data.labels_one_hot)
        # 创建损失函数
        total_loss = model.create_loss(data, endpoints)
        model.create_summaries(data,
                               endpoints,
                               dataset.charset,
                               is_training=True)

        # 恢复数据
        init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
                                                  FLAGS.checkpoint_inception)
        if FLAGS.show_graph_stats:
            logging.info('Total number of weights in the graph: %s',
                         calculate_graph_metrics())
        train(total_loss, init_fn, hparams)