Exemple #1
0
def main(_):
    if not tf.gfile.Exists(FLAGS.eval_log_dir):
        tf.gfile.MakeDirs(FLAGS.eval_log_dir)

    dataset = common_flags.create_dataset(FLAGS.dataset_name,
                                          FLAGS.dataset_split_name)
    model = common_flags.create_model(num_classes=FLAGS.num_classes)
    data = data_provider.get_data(dataset,
                                  FLAGS.model_name,
                                  FLAGS.batch_size,
                                  is_training=False,
                                  height=FLAGS.height,
                                  width=FLAGS.width)
    logits, endpoints = model.create_model(data.images,
                                           num_classes=FLAGS.num_classes,
                                           is_training=False)
    eval_ops = model.create_summary(data, logits, is_training=False)
    slim.get_or_create_global_step()
    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True
    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.train_dir,
        logdir=FLAGS.eval_log_dir,
        eval_op=eval_ops,
        num_evals=FLAGS.num_evals,
        eval_interval_secs=FLAGS.eval_interval_secs,
        max_number_of_evaluations=FLAGS.number_of_steps,
        session_config=session_config)
Exemple #2
0
def main(_):
    prepare_training_dir()

    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    hparams = get_training_hparams()

    # If ps_tasks is zero, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    device_setter = tf.train.replica_device_setter(FLAGS.ps_tasks,
                                                   merge_devices=True)
    with tf.device(device_setter):
        data = data_provider.get_data(
            dataset,
            FLAGS.batch_size,
            augment=hparams.use_augment_input,
            central_crop_size=common_flags.get_crop_size())
        endpoints = model.create_base(data.images, data.labels_one_hot)
        total_loss = model.create_loss(data, endpoints)
        model.create_summaries(data,
                               endpoints,
                               dataset.charset,
                               is_training=True)
        init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
                                                  FLAGS.checkpoint_inception)
        if FLAGS.show_graph_stats:
            tf.logging.info('Total number of weights in the graph: %s',
                            calculate_graph_metrics())
        train(total_loss, init_fn, hparams)
Exemple #3
0
def main(_):
    if not tf.gfile.Exists(FLAGS.eval_log_dir):
        tf.gfile.MakeDirs(FLAGS.eval_log_dir)

    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    data = data_provider.get_data(
        dataset,
        FLAGS.batch_size,
        augment=False,
        central_crop_size=common_flags.get_crop_size())
    endpoints = model.create_base(data.images, labels_one_hot=None)
    model.create_loss(data, endpoints)
    eval_ops = model.create_summaries(data,
                                      endpoints,
                                      dataset.charset,
                                      is_training=False)
    slim.get_or_create_global_step()
    session_config = tf.ConfigProto(device_count={"GPU": 0})
    ###
    session_config.gpu_options.allow_growth = True
    session_config.log_device_placement = False
    ###
    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.train_log_dir,
        logdir=FLAGS.eval_log_dir,
        eval_op=eval_ops,
        num_evals=FLAGS.num_batches,
        eval_interval_secs=FLAGS.eval_interval_secs,
        max_number_of_evaluations=1,
        session_config=session_config)
Exemple #4
0
def main(_):
    if not tf.gfile.Exists(FLAGS.eval_log_dir):
        tf.gfile.MakeDirs(FLAGS.eval_log_dir)

    dataset = common_flags.create_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name)
    model = common_flags.create_model(num_classes=FLAGS.num_classes)
    data = data_provider.get_data(dataset,
                                  FLAGS.model_name,
                                  FLAGS.batch_size,
                                  is_training=False,
                                  height=FLAGS.height,
                                  width=FLAGS.width)
    logits, endpoints = model.create_model(data.images,
                                           num_classes=FLAGS.num_classes,
                                           is_training=False)
    eval_ops = model.create_summary(data, logits, is_training=False)
    slim.get_or_create_global_step()
    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True
    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.train_dir,
        logdir=FLAGS.eval_log_dir,
        eval_op=eval_ops,
        num_evals=FLAGS.num_evals,
        eval_interval_secs=FLAGS.eval_interval_secs,
        max_number_of_evaluations=FLAGS.number_of_steps,
        session_config=session_config)
Exemple #5
0
def main(_):
    if not tf.gfile.Exists(FLAGS.eval_log_dir):
        tf.gfile.MakeDirs(FLAGS.eval_log_dir)

    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    data = data_provider.get_data(
        dataset,
        FLAGS.batch_size,
        augment=False,
        central_crop_size=common_flags.get_crop_size())
    print("JIBIRI!")
    print(data.images)
    endpoints = model.create_base(data.images, labels_one_hot=None)
    model.create_loss(data, endpoints)
    eval_ops = model.create_summaries(data,
                                      endpoints,
                                      dataset.charset,
                                      is_training=False)
    slim.get_or_create_global_step()
    session_config = tf.ConfigProto(device_count={"GPU": 0})
    checkpoint_path = "%s/model.ckpt-90482" % FLAGS.train_log_dir
    eval_result = slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_log_dir,
        eval_op=eval_ops,
        session_config=session_config)
Exemple #6
0
def main(_):

    prepare_training_dir()
    logging.info('dataset_name: {}, split_name: {}'.format(
        FLAGS.dataset_name, FLAGS.dataset_split_name))
    dataset = common_flags.create_dataset(FLAGS.dataset_name,
                                          FLAGS.dataset_split_name)
    model = common_flags.create_model(num_classes=FLAGS.num_classes)

    data = data_provider.get_data(dataset,
                                  FLAGS.model_name,
                                  batch_size=FLAGS.batch_size,
                                  is_training=True,
                                  height=FLAGS.height,
                                  width=FLAGS.width)

    logits, endpoints = model.create_model(data.images,
                                           num_classes=dataset.num_classes,
                                           weight_decay=FLAGS.weight_decay,
                                           is_training=True)
    total_loss = model.create_loss(logits, endpoints, data.labels_one_hot,
                                   FLAGS.label_smoothing)
    model.create_summary(data, logits, is_training=True)
    init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint_path,
                                              FLAGS.checkpoint_inception,
                                              FLAGS.checkpoint_exclude_scopes)
    variables_to_train = model.get_variables_to_train(FLAGS.trainable_scopes)
    if FLAGS.show_graph_state:
        logging.info('Total number of weights in the graph: %s',
                     utils.calculate_graph_metrics())
    train(total_loss, init_fn, variables_to_train)
    def initial(self):
        dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
        model = common_flags.create_model(dataset.num_char_classes,
                                          dataset.max_sequence_length,
                                          dataset.num_of_views,
                                          dataset.null_code,
                                          charset=dataset.charset)
        data = data_provider.get_data(
            dataset,
            FLAGS.batch_size,
            augment=False,
            central_crop_size=common_flags.get_crop_size())

        self.image_height = int(data.images.shape[1])
        self.image_width = int(data.images.shape[2])
        self.image_channel = int(data.images.shape[3])
        self.num_of_view = dataset.num_of_views
        placeholder_shape = (1, self.image_height, self.image_width,
                             self.image_channel)
        print placeholder_shape
        self.placeholder = tf.placeholder(tf.float32, shape=placeholder_shape)
        self.endpoint = model.create_base(self.placeholder,
                                          labels_one_hot=None)
        init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint)

        self.sess = tf.Session(config=config)
        tf.tables_initializer().run(session=self.sess)
        init_fn(self.sess)
Exemple #8
0
def main(_):
  if not tf.gfile.Exists(FLAGS.eval_log_dir):
    tf.gfile.MakeDirs(FLAGS.eval_log_dir)

  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
  data = data_provider.get_data(
      dataset,
      FLAGS.batch_size,
      augment=False,
      central_crop_size=common_flags.get_crop_size())
  endpoints = model.create_base(data.images, labels_one_hot=None)
  model.create_loss(data, endpoints)
  eval_ops = model.create_summaries(
      data, endpoints, dataset.charset, is_training=False)
  slim.get_or_create_global_step()
  session_config = tf.ConfigProto(device_count={"GPU": 0})
  slim.evaluation.evaluation_loop(
      master=FLAGS.master,
      checkpoint_dir=FLAGS.train_log_dir,
      logdir=FLAGS.eval_log_dir,
      eval_op=eval_ops,
      num_evals=FLAGS.num_batches,
      eval_interval_secs=FLAGS.eval_interval_secs,
      max_number_of_evaluations=FLAGS.number_of_steps,
      session_config=session_config)
Exemple #9
0
def string_to_int64(raw_labels):
    raw_labels = [list(raw_labels[i]) for i in range(32)]

    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    inv_charset = {v: k for k, v in dataset.charset.items()}

    return [[inv_charset[raw_labels[i][j]] for j in range(37)]
            for i in range(32)]
Exemple #10
0
def load_model(checkpoint, batch_size, dataset_name):
  width, height = get_dataset_image_size(dataset_name)
  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(
      num_char_classes=dataset.num_char_classes,
      seq_length=dataset.max_sequence_length,
      num_views=dataset.num_of_views,
      null_code=dataset.null_code,
      charset=dataset.charset)
  images_placeholder = tf.placeholder(tf.float32,
                                      shape=[batch_size, height, width, 3])
  endpoints = model.create_base(images_placeholder, labels_one_hot=None)
  init_fn = model.create_init_fn_to_restore(checkpoint)
  return images_placeholder, endpoints, init_fn
def load_model(checkpoint, batch_size, dataset_name):
    width, height = get_dataset_image_size(dataset_name)
    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(
        num_char_classes=dataset.num_char_classes,
        seq_length=dataset.max_sequence_length,
        num_views=dataset.num_of_views,
        null_code=dataset.null_code,
        charset=dataset.charset)
    images_placeholder = tf.placeholder(tf.float32,
                                        shape=[batch_size, height, width, 3])
    endpoints = model.create_base(images_placeholder, labels_one_hot=None)
    init_fn = model.create_init_fn_to_restore(checkpoint)
    return images_placeholder, endpoints, init_fn
def create_model(batch_size, dataset_name):
  width, height = get_dataset_image_size(dataset_name)
  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(
    num_char_classes=dataset.num_char_classes,
    seq_length=dataset.max_sequence_length,
    num_views=dataset.num_of_views,
    null_code=dataset.null_code,
    charset=dataset.charset)
  raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3])
  images = tf.map_fn(data_provider.preprocess_image, raw_images,
                     dtype=tf.float32)
  endpoints = model.create_base(images, labels_one_hot=None)
  return raw_images, endpoints
Exemple #13
0
 def test_provided_data_has_correct_shape(self):
     dataset_name = 'flowers'
     model_name = 'inception_v3'
     for split_name in ['train']:
         is_training = True if split_name == 'train' else False
         dataset = common_flags.create_dataset(dataset_name, split_name)
         batch_size = 4
         data = data_provider.get_data(dataset,
                                       model_name,
                                       batch_size=batch_size,
                                       is_training=is_training,
                                       height=224,
                                       width=224)
         with self.test_session() as sess, slim.queues.QueueRunners(sess):
             images_np, labels_np = sess.run([data.images, data.labels])
         self.assertEqual(images_np.shape, (batch_size, 224, 224, 3))
Exemple #14
0
    def create_input_feed(self, graph_def, serving):
        """Returns the input feed for the model.

    Creates random images, according to the size specified by dataset_name,
    format it in the correct way depending on whether the model was exported
    for serving, and return the correctly keyed feed_dict for inference.

    Args:
      graph_def: Graph definition of the loaded model.
      serving: Whether the model was exported for Serving.

    Returns:
      The feed_dict suitable for model inference.
    """
        # Creates a dataset based on FLAGS.dataset_name.
        self.dataset = common_flags.create_dataset('test')
        # Create some random images to test inference for any dataset.
        self.images = {
            'img1':
            np.random.uniform(low=64, high=192,
                              size=self.dataset.image_shape).astype('uint8'),
            'img2':
            np.random.uniform(low=32, high=224,
                              size=self.dataset.image_shape).astype('uint8'),
        }
        signature_def = graph_def.signature_def[
            tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
        if serving:
            input_name = signature_def.inputs[
                tf.saved_model.CLASSIFY_INPUTS].name
            # Model for serving takes input: inputs['inputs'] = 'tf_example:0'
            feed_dict = {
                input_name: [
                    _create_tf_example_string(self.images['img1']),
                    _create_tf_example_string(self.images['img2'])
                ]
            }
        else:
            input_name = signature_def.inputs['images'].name
            # Model for direct use takes input: inputs['images'] = 'original_image:0'
            feed_dict = {
                input_name:
                np.stack([self.images['img1'], self.images['img2']])
            }
        return feed_dict
def main(_):
    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    raw_images = tf.placeholder(
        dtype=tf.uint8,
        shape=[1, 50, 750, 3])  # fix here (batch_size, height, width, channel)
    images = tf.map_fn(data_provider.preprocess_image,
                       raw_images,
                       dtype=tf.float32)
    endpoints = model.create_base(images, labels_one_hot=None)
    saver = tf.train.Saver()

    with tf.Session() as sess:
        saver.restore(sess, FLAGS.checkpoint)
        print(raw_images, endpoints.predicted_chars)
        tf.train.write_graph(sess.graph_def, '.', 'train.pbtxt')
        saver.save(sess, 'ckpt/model')
Exemple #16
0
def main(_):
  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                  dataset.max_sequence_length,
                                  dataset.num_of_views, dataset.null_code,
                                  charset=dataset.charset)
  data = data_provider.get_data(
      dataset,
      FLAGS.batch_size,
      augment=False,
      central_crop_size=common_flags.get_crop_size())


  input_image = Image.open(FLAGS.input_image).convert("RGB").resize((data.images.shape[2]
    / dataset.num_of_views, data.images.shape[1]))
  input_array = np.array(input_image).astype(np.float32)
  #input_array = np.concatenate((input_array, input_array, input_array, input_array), axis=1)
  #Image.fromarray(input_array.astype(np.uint8), "RGB").save("test_input_1.jpg")
  input_array = np.expand_dims(input_array, axis=0)
  print input_array.shape
  print input_array.dtype
  #input_image.save("test_input.jpg")
  #return
  
  placeholder_shape = (1, data.images.shape[1], data.images.shape[2], data.images.shape[3])
  print placeholder_shape
  image_placeholder = tf.placeholder(tf.float32, shape=placeholder_shape)
  endpoints = model.create_base(image_placeholder, labels_one_hot=None)
  init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint)
  with tf.Session() as sess:
    tf.tables_initializer().run()  # required by the CharsetMapper
    init_fn(sess)
    predictions = sess.run(endpoints.predicted_text,
                           feed_dict={image_placeholder: input_array})
  print("Predicted strings:")
  for line in predictions:
    print(line)
Exemple #17
0
def main(_):
    # 检查训练目录
    prepare_training_dir()

    # 建立数据集 split_name: train test
    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)

    # 建立模型 max_sequence_length: 37, num_of_views: 4, null_code:133
    # 这里还没有创建模型,只是返回了模型类,和初始化了模型相关参数
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    hparams = get_training_hparams()

    # If ps_tasks is zero, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    device_setter = tf.train.replica_device_setter(FLAGS.ps_tasks,
                                                   merge_devices=True)
    with tf.device(device_setter):
        # 获得训练数据
        data = data_provider.get_data(
            dataset,
            FLAGS.batch_size,
            augment=hparams.use_augment_input,
            central_crop_size=common_flags.get_crop_size())

        # 打印 dataset 的数据,看一下
        # print("#######################")
        # print("images:", data.images)
        # print("labels:", data.labels)
        # print(dir(data.labels))
        # print("labels_one_hot:", data.labels_one_hot.shape)
        # print("labels_0:", data.labels[0])
        # print("labels_one_host_0:", data.labels_one_hot[0])
        # print("#######################")
        # init = tf.global_variables_initializer()
        # with tf.Session() as session:
        #     session.run(init)
        #     coord = tf.train.Coordinator()
        #     threads = tf.train.start_queue_runners(coord=coord)
        #     labels = session.run(data.labels)
        #     print(labels[0])
        #     labels = session.run(data.labels_one_hot)
        #     print(labels[0])
        #     coord.request_stop()
        #     coord.join(threads)
        #     return

        # 创建模型
        endpoints = model.create_base(data.images, data.labels_one_hot)
        # 创建损失函数
        total_loss = model.create_loss(data, endpoints)
        model.create_summaries(data,
                               endpoints,
                               dataset.charset,
                               is_training=True)

        # 恢复数据
        init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
                                                  FLAGS.checkpoint_inception)
        if FLAGS.show_graph_stats:
            logging.info('Total number of weights in the graph: %s',
                         calculate_graph_metrics())
        train(total_loss, init_fn, hparams)
Exemple #18
0
def export_model(export_dir,
                 export_for_serving,
                 batch_size=None,
                 crop_image_width=None,
                 crop_image_height=None):
    """Exports a model to the named directory.

  Note that --datatset_name and --checkpoint are required and parsed by the
  underlying module common_flags.

  Args:
    export_dir: The output dir where model is exported to.
    export_for_serving: If True, expects a serialized image as input and attach
      image normalization as part of exported graph.
    batch_size: For non-serving export, the input batch_size needs to be
      specified.
    crop_image_width: Width of the input image. Uses the dataset default if
      None.
    crop_image_height: Height of the input image. Uses the dataset default if
      None.

  Returns:
    Returns the model signature_def.
  """
    # Dataset object used only to get all parameters for the model.
    dataset = common_flags.create_dataset(split_name='test')
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views,
                                      dataset.null_code,
                                      charset=dataset.charset)
    dataset_image_height, dataset_image_width, image_depth = dataset.image_shape

    # Add check for charmap file
    if not os.path.exists(dataset.charset_file):
        raise ValueError('No charset defined at {}: export will fail'.format(
            dataset.charset))

    # Default to dataset dimensions, otherwise use provided dimensions.
    image_width = crop_image_width or dataset_image_width
    image_height = crop_image_height or dataset_image_height

    if export_for_serving:
        images_orig = tf.placeholder(tf.string,
                                     shape=[batch_size],
                                     name='tf_example')
        images_orig_float = model_export_lib.generate_tfexample_image(
            images_orig,
            image_height,
            image_width,
            image_depth,
            name='float_images')
    else:
        images_shape = (batch_size, image_height, image_width, image_depth)
        images_orig = tf.placeholder(tf.uint8,
                                     shape=images_shape,
                                     name='original_image')
        images_orig_float = tf.image.convert_image_dtype(images_orig,
                                                         dtype=tf.float32,
                                                         name='float_images')

    endpoints = model.create_base(images_orig_float, labels_one_hot=None)

    sess = tf.Session()
    saver = tf.train.Saver(slim.get_variables_to_restore(), sharded=True)
    saver.restore(sess, get_checkpoint_path())
    tf.logging.info('Model restored successfully.')

    # Create model signature.
    if export_for_serving:
        input_tensors = {
            tf.saved_model.signature_constants.CLASSIFY_INPUTS: images_orig
        }
    else:
        input_tensors = {'images': images_orig}
    signature_inputs = model_export_lib.build_tensor_info(input_tensors)
    # NOTE: Tensors 'image_float' and 'chars_logit' are used by the inference
    # or to compute saliency maps.
    output_tensors = {
        'images_float': images_orig_float,
        'predictions': endpoints.predicted_chars,
        'scores': endpoints.predicted_scores,
        'chars_logit': endpoints.chars_logit,
        'predicted_length': endpoints.predicted_length,
        'predicted_text': endpoints.predicted_text,
        'predicted_conf': endpoints.predicted_conf,
        'normalized_seq_conf': endpoints.normalized_seq_conf
    }
    for i, t in enumerate(
            model_export_lib.attention_ocr_attention_masks(
                dataset.max_sequence_length)):
        output_tensors['attention_mask_%d' % i] = t
    signature_outputs = model_export_lib.build_tensor_info(output_tensors)
    signature_def = tf.saved_model.signature_def_utils.build_signature_def(
        signature_inputs, signature_outputs,
        tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME)
    # Save model.
    builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
    builder.add_meta_graph_and_variables(
        sess, [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            signature_def
        },
        main_op=tf.tables_initializer(),
        strip_default_attrs=True)
    builder.save()
    tf.logging.info('Model has been exported to %s' % export_dir)

    return signature_def
def main(_):
    prepare_training_dir()

    dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
    model = common_flags.create_model(dataset.num_char_classes,
                                      dataset.max_sequence_length,
                                      dataset.num_of_views, dataset.null_code)
    hparams = get_training_hparams()

    # If ps_tasks is zero, the local device is used. When using multiple
    # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
    # across the different devices.
    #device_setter = tf.train.replica_device_setter(
    #    FLAGS.ps_tasks, merge_devices=True)
    with tf.device("/cpu:0"):
        provider = data_provider.get_data(
            dataset,
            FLAGS.batch_size,
            augment=hparams.use_augment_input,
            central_crop_size=common_flags.get_crop_size())
        batch_queue = slim.prefetch_queue.prefetch_queue([
            provider.images, provider.images_orig, provider.labels,
            provider.labels_one_hot
        ],
                                                         capacity=2 *
                                                         FLAGS.num_clones)

    losses = []
    for i in xrange(FLAGS.num_clones):
        with tf.name_scope("clone_{0}".format(i)):
            with tf.device("/gpu:{0}".format(i)):
                #if i == 1:
                #  continue
                images, images_orig, labels, labels_one_hot = batch_queue.dequeue(
                )
                if i == 0:
                    endpoints = model.create_base(images, labels_one_hot)
                else:
                    endpoints = model.create_base(images,
                                                  labels_one_hot,
                                                  reuse=True)
                init_fn = model.create_init_fn_to_restore(
                    FLAGS.checkpoint, FLAGS.checkpoint_inception)
                if FLAGS.show_graph_stats:
                    logging.info('Total number of weights in the graph: %s',
                                 calculate_graph_metrics())

                data = InputEndpoints(images=images,
                                      images_orig=images_orig,
                                      labels=labels,
                                      labels_one_hot=labels_one_hot)

                total_loss, single_model_loss = model.create_loss(
                    data, endpoints)
                losses.append((single_model_loss, i))
                with tf.device("/cpu:0"):
                    tf.summary.scalar('model_loss'.format(i),
                                      single_model_loss)
                    model.create_summaries_multigpu(data,
                                                    endpoints,
                                                    dataset.charset,
                                                    i,
                                                    is_training=True)
    train_multigpu(losses, init_fn, hparams)