コード例 #1
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_biasCNN.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        dataset_val = dataset_biasCNN.get_dataset(FLAGS.dataset_name,
                                                  'validation',
                                                  FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        network_fn_val = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            is_training=False)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_biasCNN.get_preprocessing(
            preprocessing_name,
            is_training=True,
            flipLR=FLAGS.flipLR,
            random_scale=FLAGS.random_scale,
            is_windowed=FLAGS.is_windowed)

        image_preprocessing_fn_val = preprocessing_biasCNN.get_preprocessing(
            preprocessing_name,
            is_training=False,
            flipLR=FLAGS.flipLR,
            random_scale=FLAGS.random_scale,
            is_windowed=FLAGS.is_windowed)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ############################################
        # Create a provider for the validation set #
        ############################################
        provider_val = slim.dataset_data_provider.DatasetDataProvider(
            dataset_val,
            shuffle=True,
            common_queue_capacity=2 * FLAGS.batch_size_val,
            common_queue_min=FLAGS.batch_size_val)
        [image_val, label_val] = provider_val.get(['image', 'label'])
        label_val -= FLAGS.labels_offset

        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

        image_val = image_preprocessing_fn_val(image_val, eval_image_size,
                                               eval_image_size)

        images_val, labels_val = tf.train.batch(
            [image_val, label_val],
            batch_size=FLAGS.batch_size_val,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size_val)

        ###############################
        # Define the model (training) #
        ###############################

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()

            with tf.variable_scope('my_scope'):
                logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        if FLAGS.quantize_delay >= 0:
            tf.contrib.quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        #################################
        # Define the model (validation) #
        #################################

        with tf.variable_scope('my_scope', reuse=True):
            logits_val, _ = network_fn_val(images_val)

        predictions_val = tf.argmax(logits_val, 1)
        labels_val = tf.squeeze(labels_val)

        # Define the metrics:
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions_val, labels_val),
            'Recall_5':
            slim.metrics.streaming_recall_at_k(logits_val, labels_val, 5),
        })

        for name, value in names_to_values.items():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection('summaries', op)

        # Gather validation summaries
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Create a non-default saver so we don't delete all the old checkpoints.
        my_saver = tf_saver.Saver(
            max_to_keep=FLAGS.max_checkpoints_to_keep,
            keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
        )

        # Create a non-default dictionary of options for train_step_fn
        # This is a hack that lets us pass everything we need to run evaluation, into the training loop function
        from tensorflow.python.framework import ops
        from tensorflow.python.framework import constant_op
        from tensorflow.python.ops import math_ops

        with ops.name_scope('train_step'):
            train_step_kwargs = {}

            if FLAGS.max_number_of_steps:
                should_stop_op = math_ops.greater_equal(
                    global_step, FLAGS.max_number_of_steps)
            else:
                should_stop_op = constant_op.constant(False)
            train_step_kwargs['should_stop'] = should_stop_op
            if FLAGS.log_every_n_steps > 0:
                train_step_kwargs['should_log'] = math_ops.equal(
                    math_ops.mod(global_step, FLAGS.log_every_n_steps), 0)
            train_step_kwargs['should_val'] = math_ops.equal(
                math_ops.mod(global_step, FLAGS.val_every_n_steps), 0)
            train_step_kwargs['eval_op'] = list(names_to_updates.values())


#    assert(FLAGS.max_number_of_steps==100000)
        print(should_stop_op)
        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            saver=my_saver,
            train_step_fn=learning_biasCNN.train_step_fn,
            train_step_kwargs=train_step_kwargs)
コード例 #2
0
ファイル: estimator.py プロジェクト: lengjia/RRL
    def export_savedmodel(self,
                          export_dir_base,
                          serving_input_receiver_fn,
                          assets_extra=None,
                          as_text=False,
                          checkpoint_path=None):
        """Exports inference graph as a SavedModel into given dir.

    This method builds a new graph by first calling the
    serving_input_receiver_fn to obtain feature `Tensor`s, and then calling
    this `Estimator`'s model_fn to generate the model graph based on those
    features. It restores the given checkpoint (or, lacking that, the most
    recent checkpoint) into this graph in a fresh session.  Finally it creates
    a timestamped export directory below the given export_dir_base, and writes
    a `SavedModel` into it containing a single `MetaGraphDef` saved from this
    session.

    The exported `MetaGraphDef` will provide one `SignatureDef` for each
    element of the export_outputs dict returned from the model_fn, named using
    the same keys.  One of these keys is always
    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
    signature will be served when a serving request does not specify one.
    For each signature, the outputs are provided by the corresponding
    `ExportOutput`s, and the inputs are always the input receivers provided by
    the serving_input_receiver_fn.

    Extra assets may be written into the SavedModel via the extra_assets
    argument.  This should be a dict, where each key gives a destination path
    (including the filename) relative to the assets.extra directory.  The
    corresponding value gives the full path of the source file to be copied.
    For example, the simple case of copying a single file without renaming it
    is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.

    Args:
      export_dir_base: A string containing a directory in which to create
        timestamped subdirectories containing exported SavedModels.
      serving_input_receiver_fn: A function that takes no argument and
        returns a `ServingInputReceiver`.
      assets_extra: A dict specifying how to populate the assets.extra directory
        within the exported SavedModel, or `None` if no extra assets are needed.
      as_text: whether to write the SavedModel proto in text format.
      checkpoint_path: The checkpoint path to export.  If `None` (the default),
        the most recent checkpoint found within the model directory is chosen.

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if no serving_input_receiver_fn is provided, no export_outputs
          are provided, or no checkpoint can be found.
    """
        if serving_input_receiver_fn is None:
            raise ValueError('serving_input_receiver_fn must be defined.')

        with ops.Graph().as_default() as g:
            self._create_and_assert_global_step(g)
            random_seed.set_random_seed(self._config.tf_random_seed)
            serving_input_receiver = serving_input_receiver_fn()

            # Call the model_fn and collect the export_outputs.
            estimator_spec = self._call_model_fn(
                features=serving_input_receiver.features,
                labels=None,
                mode=model_fn_lib.ModeKeys.PREDICT)

            # Build the SignatureDefs from receivers and all outputs
            signature_def_map = build_all_signature_defs(
                serving_input_receiver.receiver_tensors,
                estimator_spec.export_outputs)

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(self._model_dir)
            if not checkpoint_path:
                raise ValueError("Couldn't find trained model at %s." %
                                 self._model_dir)

            export_dir = get_timestamped_export_dir(export_dir_base)
            temp_export_dir = get_temp_export_dir(export_dir)

            # TODO(soergel): Consider whether MonitoredSession makes sense here
            with tf_session.Session() as session:

                saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
                    sharded=True)
                saver_for_restore.restore(session, checkpoint_path)

                # TODO(b/36111876): replace legacy_init_op with main_op mechanism
                # pylint: disable=protected-access
                local_init_op = (
                    estimator_spec.scaffold.local_init_op
                    or monitored_session.Scaffold._default_local_init_op())
                # pylint: enable=protected-access

                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(
                    temp_export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=local_init_op)
                builder.save(as_text)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(
                    compat.as_bytes(temp_export_dir),
                    compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(
                        compat.as_bytes(assets_extra_path),
                        compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    gfile.MakeDirs(dest_path)
                    gfile.Copy(source, dest_absolute)

            gfile.Rename(temp_export_dir, export_dir)
            return export_dir
コード例 #3
0
 def make_saver():
     return tf_saver.Saver(sharded=True,
                           max_to_keep=keep_checkpoint_max,
                           defer_build=True)
コード例 #4
0
ファイル: freeze_graph.py プロジェクト: zzhuuh2/tensorflow
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not checkpoint_management.checkpoint_exists(input_checkpoint)):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_parition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_parition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    print(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                    return -1
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
コード例 #5
0
    def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
        logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
        if gfile.Exists(logdir):  # For running on jenkins.
            gfile.DeleteRecursively(logdir)

        # First, train only the weights of the model.
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            total_loss = self.ModelLoss()
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            weights = variables_lib.get_variables_by_name('weights')

            train_op = training.create_train_op(total_loss,
                                                optimizer,
                                                variables_to_train=weights)

            saver = saver_lib.Saver()
            loss = training.train(
                train_op,
                logdir,
                hooks=[
                    basic_session_run_hooks.CheckpointSaverHook(logdir,
                                                                save_steps=200,
                                                                saver=saver),
                    basic_session_run_hooks.StopAtStepHook(num_steps=200),
                ],
                save_checkpoint_secs=None,
                save_summaries_steps=None)
            self.assertGreater(loss, .015)
            self.assertLess(loss, .05)

        # Next, train the biases of the model.
        with ops.Graph().as_default():
            random_seed.set_random_seed(1)
            total_loss = self.ModelLoss()
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            biases = variables_lib.get_variables_by_name('biases')

            train_op = training.create_train_op(total_loss,
                                                optimizer,
                                                variables_to_train=biases)

            saver = saver_lib.Saver()
            loss = training.train(
                train_op,
                logdir,
                hooks=[
                    basic_session_run_hooks.CheckpointSaverHook(logdir,
                                                                save_steps=300,
                                                                saver=saver),
                    basic_session_run_hooks.StopAtStepHook(num_steps=300),
                ],
                save_checkpoint_secs=None,
                save_summaries_steps=None)
            self.assertGreater(loss, .015)
            self.assertLess(loss, .05)

        # Finally, train both weights and bias to get lower loss.
        with ops.Graph().as_default():
            random_seed.set_random_seed(2)
            total_loss = self.ModelLoss()
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = training.create_train_op(total_loss, optimizer)
            saver = saver_lib.Saver()
            loss = training.train(
                train_op,
                logdir,
                hooks=[
                    basic_session_run_hooks.StopAtStepHook(num_steps=400),
                ],
                save_checkpoint_secs=None,
                save_summaries_steps=None)
            self.assertIsNotNone(loss)
            self.assertLess(loss, .015)
コード例 #6
0
    def testInitCrossedColumnWeightsFromCkpt(self):
        sparse_col_1 = fc.sparse_column_with_hash_bucket(column_name="col_1",
                                                         hash_bucket_size=4)
        sparse_col_2 = fc.sparse_column_with_hash_bucket(column_name="col_2",
                                                         hash_bucket_size=4)

        crossed_col = fc.crossed_column(columns=[sparse_col_1, sparse_col_2],
                                        hash_bucket_size=4)

        input_tensor = sparse_tensor_lib.SparseTensor(indices=[[0, 0], [1, 1],
                                                               [2, 2], [3, 3]],
                                                      values=[0, 1, 2, 3],
                                                      dense_shape=[4, 4])

        # Invoking 'weighted_sum_from_feature_columns' will create the crossed
        # column weights variable.
        with variable_scope.variable_scope("run_1"):
            with variable_scope.variable_scope(crossed_col.name):
                # Returns looked up column weights which is same as crossed column
                # weights as well as actual references to weights variables.
                _, col_weights, _ = (
                    feature_column_ops.weighted_sum_from_feature_columns(
                        {
                            sparse_col_1.name: input_tensor,
                            sparse_col_2.name: input_tensor
                        }, [crossed_col], 1))
                # Update the weights since default initializer initializes all weights
                # to 0.0.
                for weight in col_weights.values():
                    assign_op = state_ops.assign(weight[0], weight[0] + 0.5)

        save = saver.Saver()
        ckpt_dir_prefix = os.path.join(self.get_temp_dir(),
                                       "init_crossed_col_w_from_ckpt")
        ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
        checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")

        with self.test_session() as sess:
            sess.run(variables.global_variables_initializer())
            sess.run(assign_op)
            saved_col_weights = col_weights[crossed_col][0].eval()
            save.save(sess, checkpoint_path)

        crossed_col_initialized = fc.crossed_column(
            columns=[sparse_col_1, sparse_col_2],
            hash_bucket_size=4,
            ckpt_to_load_from=checkpoint_path,
            tensor_name_in_ckpt=("run_1/col_1_X_col_2/"
                                 "weighted_sum_from_feature_columns/"
                                 "col_1_X_col_2/weights"))

        with variable_scope.variable_scope("run_2"):
            # This will initialize the crossed column weights from provided checkpoint
            # and return a [4, 1] tensor which is same as weights variable. Since we
            # won't modify weights, this should be same as 'saved_col_weights'.
            _, col_weights, _ = (
                feature_column_ops.weighted_sum_from_feature_columns(
                    {
                        sparse_col_1.name: input_tensor,
                        sparse_col_2.name: input_tensor
                    }, [crossed_col_initialized], 1))
            col_weights_from_ckpt = col_weights[crossed_col_initialized][0]

        with self.test_session() as sess:
            sess.run(variables.global_variables_initializer())
            loaded_col_weights = col_weights_from_ckpt.eval()

        self.assertAllClose(saved_col_weights, loaded_col_weights)
コード例 #7
0
ファイル: quantize_graph.py プロジェクト: Xilinx/Vitis-AI
def CreateQuantizeDeployGraph(graph=None, checkpoint='', config=None):
  """Python wrapper for the decent_q create deploy graph tool.

  Args:
    graph: the graph to be quantized, default graph will be used if set None.
    checkpoint: the checkpoint path
    config: the QuantizeConfig

  Returns:
    Transformed Graph(as default) for quantize deploy.
  """
  if config is None:
    raise ValueError("Please set the QuantizeConfig.")
  elif not isinstance(config, QuantizeConfig):
    raise ValueError("Config shoulb be a QuantizeConfig object.")

  # Create the output_dir
  if not os.path.exists(config.output_dir):
    try:
      os.makedirs(config.output_dir)
    except Exception as e:
      print(e)

  if graph is None:
    graph = get_default_graph()
  quantize_eval_graph_def = graph.as_graph_def()

  if os.path.isdir(checkpoint):
    checkpoint = checkpoint_management.latest_checkpoint(checkpoint)
  else:
    pass
  print("INFO: Creating quantize eval model from: {}".format(checkpoint))
  step_in_ckpt = checkpoint.rsplit("-")[-1]

  # Freeze the checkpoint into the graph
  config.output_nodes = get_quantized_nodes(quantize_eval_graph_def,
                                            config.output_nodes)
  saver = saver_lib.Saver()
  with Session() as sess:
    saver.restore(sess, checkpoint)
    frozen_graph_def = graph_util.convert_variables_to_constants(
      sess, quantize_eval_graph_def, config.output_nodes)

  # Convert folded batchnorms
  frozen_quantize_eval_graph_def = ConvertFoldedBatchnorms(
    frozen_graph_def, config)

  # Deploy
  # quantize_deploy_graph_def = CreateQuantizeDeployGraphDef(
  #   frozen_quantize_eval_graph_def, config)

  # Save the model
  # for quantize finetune model, replace input node with placeholder
  # replaced_graph_def = frozen_quantize_eval_graph_def
  for target_node_name, shape in zip(config.input_nodes, config.input_shapes):
    frozen_quantize_eval_graph_def = SetInputNodesAsPlaceholder(
      frozen_quantize_eval_graph_def, target_node_name, shape)

  frozen_quantize_eval_path = os.path.join(
    config.output_dir, "quantize_eval_model_{}_{}.pb".format(
      step_in_ckpt, time.strftime("%Y%m%d%H%M%S", time.localtime())))
  save_pb_file(frozen_quantize_eval_graph_def, frozen_quantize_eval_path)
  print("INFO: Quantize eval model is generated in: {}".format(
    frozen_quantize_eval_path))

  # deploy_path = os.path.join(
  #   config.output_dir, "deploy_model_{}_{}.pb".format(
  #     step_in_ckpt, time.strftime("%Y%m%d%H%M%S", time.localtime())))
  # save_pb_file(quantize_deploy_graph_def, deploy_path)
  # print("INFO: Deploy model is generated in: {}".format(deploy_path))
  return
コード例 #8
0
def evaluate_once(master,
                  checkpoint_path,
                  logdir,
                  num_evals=1,
                  initial_op=None,
                  initial_op_feed_dict=None,
                  eval_op=None,
                  eval_op_feed_dict=None,
                  final_op=None,
                  final_op_feed_dict=None,
                  summary_op=_USE_DEFAULT,
                  summary_op_feed_dict=None,
                  variables_to_restore=None,
                  session_config=None,
                  hooks=None):
    """Evaluates the model at the given checkpoint path.

  Args:
    master: The BNS address of the TensorFlow master.
    checkpoint_path: The path to a checkpoint to use for evaluation.
    logdir: The directory where the TensorFlow summaries are written to.
    num_evals: The number of times to run `eval_op`.
    initial_op: An operation run at the beginning of evaluation.
    initial_op_feed_dict: A feed dictionary to use when executing `initial_op`.
    eval_op: A operation run `num_evals` times.
    eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
    final_op: An operation to execute after all of the `eval_op` executions. The
      value of `final_op` is returned.
    final_op_feed_dict: A feed dictionary to use when executing `final_op`.
    summary_op: The summary_op to evaluate after running TF-Slims metric ops. By
      default the summary_op is set to tf.summary.merge_all().
    summary_op_feed_dict: An optional feed dictionary to use when running the
      `summary_op`.
    variables_to_restore: A list of TensorFlow variables to restore during
      evaluation. If the argument is left as `None` then
      slim.variables.GetVariablesToRestore() is used.
    session_config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.
    hooks: A list of additional `SessionRunHook` objects to pass during the
      evaluation.

  Returns:
    The value of `final_op` or `None` if `final_op` is `None`.
  """
    if summary_op == _USE_DEFAULT:
        summary_op = summary.merge_all()

    all_hooks = [
        evaluation.StopAfterNEvalsHook(num_evals),
    ]

    if summary_op is not None:
        all_hooks.append(
            evaluation.SummaryAtEndHook(log_dir=logdir,
                                        summary_op=summary_op,
                                        feed_dict=summary_op_feed_dict))
    if hooks is not None:
        all_hooks.extend(hooks)

    saver = None
    if variables_to_restore is not None:
        saver = tf_saver.Saver(variables_to_restore)

    return evaluation.evaluate_once(checkpoint_path,
                                    master=master,
                                    scaffold=monitored_session.Scaffold(
                                        init_op=initial_op,
                                        init_feed_dict=initial_op_feed_dict,
                                        saver=saver),
                                    eval_ops=eval_op,
                                    feed_dict=eval_op_feed_dict,
                                    final_ops=final_op,
                                    final_ops_feed_dict=final_op_feed_dict,
                                    hooks=all_hooks,
                                    config=session_config)
コード例 #9
0
def evaluation_loop(master,
                    checkpoint_dir,
                    logdir,
                    num_evals=1,
                    initial_op=None,
                    initial_op_feed_dict=None,
                    init_fn=None,
                    eval_op=None,
                    eval_op_feed_dict=None,
                    final_op=None,
                    final_op_feed_dict=None,
                    summary_op=_USE_DEFAULT,
                    summary_op_feed_dict=None,
                    variables_to_restore=None,
                    eval_interval_secs=60,
                    max_number_of_evaluations=None,
                    session_config=None,
                    timeout=None,
                    hooks=None):
    """Runs TF-Slim's Evaluation Loop.

  Args:
    master: The BNS address of the TensorFlow master.
    checkpoint_dir: The directory where checkpoints are stored.
    logdir: The directory where the TensorFlow summaries are written to.
    num_evals: The number of times to run `eval_op`.
    initial_op: An operation run at the beginning of evaluation.
    initial_op_feed_dict: A feed dictionary to use when executing `initial_op`.
    init_fn: An optional callable to be executed after `init_op` is called. The
      callable must accept one argument, the session being initialized.
    eval_op: A operation run `num_evals` times.
    eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
    final_op: An operation to execute after all of the `eval_op` executions. The
      value of `final_op` is returned.
    final_op_feed_dict: A feed dictionary to use when executing `final_op`.
    summary_op: The summary_op to evaluate after running TF-Slims metric ops. By
      default the summary_op is set to tf.summary.merge_all().
    summary_op_feed_dict: An optional feed dictionary to use when running the
      `summary_op`.
    variables_to_restore: A list of TensorFlow variables to restore during
      evaluation. If the argument is left as `None` then
      slim.variables.GetVariablesToRestore() is used.
    eval_interval_secs: The minimum number of seconds between evaluations.
    max_number_of_evaluations: the max number of iterations of the evaluation.
      If the value is left as 'None', the evaluation continues indefinitely.
    session_config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.
    timeout: The maximum amount of time to wait between checkpoints. If left as
      `None`, then the process will wait indefinitely.
    hooks: A list of additional `SessionRunHook` objects to pass during
      repeated evaluations.

  Returns:
    The value of `final_op` or `None` if `final_op` is `None`.
  """
    if summary_op == _USE_DEFAULT:
        summary_op = summary.merge_all()

    all_hooks = [
        evaluation.StopAfterNEvalsHook(num_evals),
    ]

    if summary_op is not None:
        all_hooks.append(
            evaluation.SummaryAtEndHook(log_dir=logdir,
                                        summary_op=summary_op,
                                        feed_dict=summary_op_feed_dict))

    if hooks is not None:
        # Add custom hooks if provided.
        all_hooks.extend(hooks)

    saver = None
    if variables_to_restore is not None:
        saver = tf_saver.Saver(variables_to_restore)

    return evaluation.evaluate_repeatedly(
        checkpoint_dir,
        master=master,
        scaffold=monitored_session.Scaffold(
            init_op=initial_op,
            init_feed_dict=initial_op_feed_dict,
            init_fn=init_fn,
            saver=saver),
        eval_ops=eval_op,
        feed_dict=eval_op_feed_dict,
        final_ops=final_op,
        final_ops_feed_dict=final_op_feed_dict,
        eval_interval_secs=eval_interval_secs,
        hooks=all_hooks,
        config=session_config,
        max_number_of_evaluations=max_number_of_evaluations,
        timeout=timeout)
コード例 #10
0
def test(queries=list(), location='./test'):
    """
    Test your system with the input. For each input, generate a list of IDs that is returned
    :param queries: list of image-IDs. Each element is assumed to be an entry in the test set. Hence, the image
    with id <id> is located on my computer at './test/pics/<id>.jpg'. Make sure this is the file you work with...
    :param location: The location of the test data folder hierarchy
    :return: a dictionary with keys equal to the images in the queries - list, and values a list of image-IDs
    retrieved for that input
    """

    model_checkpoint = './data/model.checkpoint'
    size = 128
    n_components = 40
    device = '/cpu:0'
    debug = False

    data = load_data('./train/', size, debug)

    pca, transformed_data = load_pca('./data'.format(n_components), data, 1000,
                                     n_components)

    my_return_dict = {}

    g = tf.Graph()
    with g.as_default():
        sess = tf.Session()
        with tf.device(device):
            network = create_combined_network(n_components)
            tf_saver.Saver().restore(sess, model_checkpoint)

            query_images = []
            print('Loading images...')
            for query in queries:
                file_name = '{}/pics/{}.jpg'.format(location, query)
                try:
                    image = Image.open(file_name)
                    query_images.append(
                        np.asarray(image.resize((299, 299), Image.ANTIALIAS),
                                   dtype=np.float32))
                except Exception:
                    print('Could not open image {}'.format(file_name))
                    continue

            print('Running images through model...')
            net_outputs = network.eval(query_images, sess)

            print('Generating candidates')
            for i, query in enumerate(queries):
                net_output = net_outputs[i]
                candidates = get_top_n(net_output, transformed_data,
                                       data['train']['image_names'], 200)
                try:
                    candidate_names = [
                        c.replace('.jpg', '') for c in candidates
                    ]
                except:
                    candidates = [c.decode('UTF-8') for c in candidates]
                    candidate_names = [
                        c.replace('.jpg', '') for c in candidates
                    ]
                my_return_dict[query] = candidate_names

            return my_return_dict
コード例 #11
0
ファイル: evaluation.py プロジェクト: zhufengGNSS/tensorflow
def evaluation_loop(master, checkpoint_dir, logdir, num_evals=1,
                    eval_op=None, eval_op_feed_dict=None,
                    final_op=None, final_op_feed_dict=None,
                    summary_op=_USE_DEFAULT, summary_op_feed_dict=None,
                    variables_to_restore=None,
                    eval_interval_secs=60):
  """Runs TF-Slim's Evaluation Loop.

  Args:
    master: The BNS address of the TensorFlow master.
    checkpoint_dir: The directory where checkpoints are stored.
    logdir: The directory where the TensorFlow summaries are written to.
    num_evals: The number of times to run `eval_op`.
    eval_op: A operation run `num_evals` times.
    eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
    final_op: An operation to execute after all of the `eval_op` executions. The
      value of `final_op` is returned.
    final_op_feed_dict: A feed dictionary to use when executing `final_op`.
    summary_op: The summary_op to evaluate after running TF-Slims metric ops. By
      default the summary_op is set to tf.merge_all_summaries().
    summary_op_feed_dict: An optional feed dictionary to use when running the
      `summary_op`.
    variables_to_restore: A list of TensorFlow variables to restore during
      evaluation. If the argument is left as `None` then
      slim.variables.GetVariablesToRestore() is used.
    eval_interval_secs: The minimum number of seconds between evaluations.
  """
  if summary_op == _USE_DEFAULT:
    summary_op = logging_ops.merge_all_summaries()

  global_step = variables.get_or_create_global_step()

  init_op = control_flow_ops.group(
      tf_variables.initialize_all_variables(),
      tf_variables.initialize_local_variables(),
      data_flow_ops.initialize_all_tables())

  saver = tf_saver.Saver(
      variables_to_restore or variables.get_variables_to_restore())

  summary_writer = summary_io.SummaryWriter(logdir)

  sv = supervisor.Supervisor(
      graph=ops.get_default_graph(),
      logdir=logdir,
      init_op=init_op,
      summary_op=None,
      summary_writer=None,
      global_step=None,
      saver=saver)

  last_checkpoint = None
  while True:
    last_checkpoint = wait_for_new_checkpoint(checkpoint_dir, last_checkpoint)
    start = time.time()
    logging.info(
        'Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                  time.gmtime()))

    with sv.managed_session(master, start_standard_services=False) as sess:
      sv.saver.restore(sess, last_checkpoint)
      sv.start_queue_runners(sess)
      evaluation(
          sess,
          num_evals=num_evals,
          eval_op=eval_op,
          eval_op_feed_dict=eval_op_feed_dict,
          final_op=final_op,
          final_op_feed_dict=final_op_feed_dict,
          summary_op=summary_op,
          summary_op_feed_dict=summary_op_feed_dict,
          summary_writer=summary_writer,
          global_step=global_step)

    logging.info(
        'Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                  time.gmtime()))
    time_to_next_eval = start + eval_interval_secs - time.time()
    if time_to_next_eval > 0:
      time.sleep(time_to_next_eval)
コード例 #12
0
def train(location='./train/'):
    """
    The training procedure is triggered here. OPTIONAL to run; everything that is required for testing the model
    must be saved to file (e.g., pickle) so that the test procedure can load, execute and report
    :param location: The location of the training data folder hierarchy
    :return: nothing
    """

    inception_checkpoint = './data/model.ckpt'
    lower_checkpoint = './data/lower_graph.checkpoint'
    combined_model_checkpoint = './data/model.checkpoint'
    size = 128
    n_components = 40
    device = '/cpu:0'
    debug = False

    data = load_data(location, size, debug)
    pca, transformed_data = load_pca('./data'.format(n_components), data, 1000,
                                     n_components)
    second_data = {
        'train': {
            'inputs': data['train']['bottlenecks'],
            'targets': transformed_data
        }
    }

    if debug:
        second_data['validate'] = {
            'inputs':
            data['validate']['bottlenecks'],
            'targets':
            get_transformed_pca_output(
                pca.transform(data['validate']['bottlenecks']))
        }

    with tf.device(device):
        g = tf.Graph()
        with g.as_default():
            with tf.Session() as sess:
                print('Training lower network')
                network = create_network(n_components)
                network = train_network(network,
                                        second_data,
                                        n_components,
                                        sess,
                                        epochs=100,
                                        validate=debug,
                                        device=device)
                saver = tf_saver.Saver()
                print('Saving lower network')
                saver.save(sess, lower_checkpoint)

        g = tf.Graph()
        with g.as_default():
            sess = tf.Session()

            network = create_combined_network(n_components)

            inception_scope = 'InceptionV3'
            lower_scope = 'Network'

            all_vars = tf.all_variables()
            inception_vars = [
                k for k in all_vars if k.name.startswith(inception_scope)
            ]
            lower_vars = [
                k for k in all_vars if k.name.startswith(lower_scope)
            ]

            print('Loading inception_v3 network')
            tf_saver.Saver(inception_vars).restore(sess, inception_checkpoint)
            print('Loading lower network')
            tf_saver.Saver(lower_vars).restore(
                sess, './data/lower_graph.checkpoint')
            print('Saving combined network')
            tf_saver.Saver().save(sess, combined_model_checkpoint)
コード例 #13
0
def get_transformed_data(img_paths,
                         size,
                         checkpoint='./data/model.ckpt',
                         num_classes=6012,
                         batch_size=32,
                         name_regex=None):
    """
    Loads the supplied image_paths, runs them through the trained Inception_v3 network and returns
         a dictionary per image with it's bottleneck and a resized image.

    :param img_paths: iterable containing the paths to the images to be loaded
    :param size: image width and height in pixels
    :param checkpoint: location of the Inception_v3 weights
    :param num_classes:
    :param batch_size:
    :param name_regex: regex to obtain image name from image path
    :return:
    """
    if not os.path.exists(checkpoint):
        tf.logging.fatal(
            'Checkpoint %s does not exist. See README.md for more information',
            checkpoint)
    name_regex = name_regex if name_regex is not None else re.compile(
        '.*\/.*\/(.*\.jpg)')
    g = tf.Graph()
    with g.as_default():
        input_images = tf.placeholder('float32', [None, 299, 299, 3])
        transformed_inputs = tf_transform_input_img(input_images)

        with slim.arg_scope(inception_v3_arg_scope()):
            logits, end_points = inception_v3(transformed_inputs,
                                              num_classes=num_classes,
                                              is_training=False)

        bottleneck = end_points['PreLogits']
        saver = tf_saver.Saver()
        data = {
            'image_names': [],
            'images': np.empty((len(img_paths), size, size, 3),
                               dtype=np.uint8),
            'bottlenecks': np.empty((len(img_paths), 2048), dtype=np.float32)
        }
        i = 0
        with tf.Session() as sess:
            saver.restore(sess, checkpoint)
            for c in chunks(img_paths, batch_size):
                batch = []
                inner_i = 0
                for file_name in c:
                    name = name_regex.match(file_name).group(1)
                    try:
                        image = Image.open(file_name)
                    except Exception:
                        print('Could not open image {}'.format(file_name))
                        continue
                    img, resized_img = preprocess_image(image, size)
                    batch.append(img)
                    data['images'][i + inner_i] = resized_img
                    data['image_names'].append(name)
                    inner_i += 1
                feed_dict = {input_images: np.asarray(batch)}
                # Run the evaluation on the image
                bottleneck_eval = sess.run(bottleneck, feed_dict=feed_dict)
                inner_i = 0
                for bn in bottleneck_eval:
                    # Resize the bottlenecks to the 0 - 1 range
                    data['bottlenecks'][i + inner_i] = (np.squeeze(bn) +
                                                        0.3) * (1.0 / 8.0)
                    inner_i += 1
                i += inner_i
            if i < len(img_paths):
                for key in ['images', 'bottlenecks']:
                    prev_shape = list(data[key].shape)
                    prev_shape[0] = i
                    data[key].resize(prev_shape)
            return data
コード例 #14
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_blacklist="",
                                 input_meta_graph=False):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""

  _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(saver_def=input_saver_def)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph:
      restorer = saver_lib.import_meta_graph(
          input_checkpoint + ".meta", clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.split(","))
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor
      saver = saver_lib.Saver(var_list=var_list)
      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.split(","))

    variable_names_blacklist = (variable_names_blacklist.split(",")
                                if variable_names_blacklist else None)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names.split(","),
        variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())
  #gg_print
  for ind, val in enumerate(output_graph_def.node):
      print(ind, val.name, val.op)
      [print(u'└─── %d ─ %s' % (i, n)) for i, n in enumerate(val.input)]

  return output_graph_def
コード例 #15
0
    def _testFreezeGraph(self, saver_write_version):

        checkpoint_prefix = os.path.join(self.get_temp_dir(),
                                         "saved_checkpoint")
        checkpoint_meta_graph_file = os.path.join(self.get_temp_dir(),
                                                  "saved_checkpoint.meta")
        checkpoint_state_name = "checkpoint_state"
        input_graph_name = "input_graph.pb"
        output_graph_name = "output_graph.pb"

        # We'll create an input graph that has a single variable containing 1.0,
        # and that then multiplies it by 2.
        with ops.Graph().as_default():
            variable_node = variables.VariableV1(1.0, name="variable_node")
            output_node = math_ops.multiply(variable_node,
                                            2.0,
                                            name="output_node")
            sess = session.Session()
            init = variables.global_variables_initializer()
            sess.run(init)
            output = sess.run(output_node)
            self.assertNear(2.0, output, 0.00001)
            saver = saver_lib.Saver(write_version=saver_write_version)
            checkpoint_path = saver.save(sess,
                                         checkpoint_prefix,
                                         global_step=0,
                                         latest_filename=checkpoint_state_name)
            graph_io.write_graph(sess.graph, self.get_temp_dir(),
                                 input_graph_name)

        # We save out the graph to disk, and then call the const conversion
        # routine.
        input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
        input_saver_def_path = ""
        input_binary = False
        output_node_names = "output_node"
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        output_graph_path = os.path.join(self.get_temp_dir(),
                                         output_graph_name)
        clear_devices = False
        input_meta_graph = checkpoint_meta_graph_file

        freeze_graph.freeze_graph(input_graph_path,
                                  input_saver_def_path,
                                  input_binary,
                                  checkpoint_path,
                                  output_node_names,
                                  restore_op_name,
                                  filename_tensor_name,
                                  output_graph_path,
                                  clear_devices,
                                  "",
                                  "",
                                  input_meta_graph,
                                  checkpoint_version=saver_write_version)

        # Now we make sure the variable is now a constant, and that the graph still
        # produces the expected result.
        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                _ = importer.import_graph_def(output_graph_def, name="")

            self.assertEqual(4, len(output_graph_def.node))
            for node in output_graph_def.node:
                self.assertNotEqual("VariableV2", node.op)
                self.assertNotEqual("Variable", node.op)

            with session.Session() as sess:
                output_node = sess.graph.get_tensor_by_name("output_node:0")
                output = sess.run(output_node)
                self.assertNear(2.0, output, 0.00001)
コード例 #16
0
 def testStandardServicesWithGlobalStep(self):
   logdir = self._test_dir("standard_services_with_global_step")
   # Create a checkpoint.
   with ops.Graph().as_default():
     v = variables.VariableV1([123], name="global_step")
     sv = supervisor.Supervisor(logdir=logdir)
     meta_graph_def = meta_graph.create_meta_graph_def(
         saver_def=sv.saver.saver_def)
     sess = sv.prepare_or_wait_for_session("")
     # This is where the checkpoint will appear, with step number 123.
     save_path = "%s-123" % sv.save_path
     self._wait_for_glob(save_path, 3.0)
     self._wait_for_glob(
         os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
     # Wait to make sure everything is written to file before stopping.
     time.sleep(1)
     sv.stop()
   # There should be an event file with a version number.
   rr = _summary_iterator(logdir)
   ev = next(rr)
   self.assertEquals("brain.Event:2", ev.file_version)
   ev = next(rr)
   ev_graph = graph_pb2.GraphDef()
   ev_graph.ParseFromString(ev.graph_def)
   self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
   ev = next(rr)
   ev_meta_graph = meta_graph_pb2.MetaGraphDef()
   ev_meta_graph.ParseFromString(ev.meta_graph_def)
   self.assertProtoEquals(meta_graph_def, ev_meta_graph)
   self.assertProtoEquals(
       sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
   ev = next(rr)
   # It is actually undeterministic whether SessionLog.START gets written
   # before the summary or the checkpoint, but this works when run 10000 times.
   self.assertEquals(123, ev.step)
   self.assertEquals(event_pb2.SessionLog.START, ev.session_log.status)
   first = next(rr)
   second = next(rr)
   # It is undeterministic whether the value gets written before the checkpoint
   # since they are on separate threads, so we check for both conditions.
   if first.HasField("summary"):
     self.assertProtoEquals("""value { tag: 'global_step/sec'
                                       simple_value: 0.0 }""", first.summary)
     self.assertEquals(123, second.step)
     self.assertEquals(event_pb2.SessionLog.CHECKPOINT,
                       second.session_log.status)
   else:
     self.assertEquals(123, first.step)
     self.assertEquals(event_pb2.SessionLog.CHECKPOINT,
                       first.session_log.status)
     self.assertProtoEquals("""value { tag: 'global_step/sec'
                                       simple_value: 0.0 }""", second.summary)
   ev = next(rr)
   self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
   self.assertRaises(StopIteration, lambda: next(rr))
   # There should be a checkpoint file with the variable "foo"
   with ops.Graph().as_default(), self.cached_session() as sess:
     v = variables.VariableV1([-12], name="global_step")
     sav = saver_lib.Saver([v])
     sav.restore(sess, save_path)
     self.assertEqual(123, v.eval()[0])
コード例 #17
0
ファイル: learning.py プロジェクト: neuroph12/CNNDDDD
def train(train_op,
          logdir,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_steps=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=_USE_DEFAULT,
          init_fn=None,
          ready_op=_USE_DEFAULT,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          summary_writer=_USE_DEFAULT,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None,
          session_config=None,
          session_wrapper=None,
          trace_every_n_steps=None,
          ignore_live_threads=False):
    """Runs a training loop using a TensorFlow supervisor.

  When the sync_optimizer is supplied, gradient updates are applied
  synchronously. Otherwise, gradient updates are applied asynchronous.

  Args:
    train_op: A `Tensor` that, when executed, will apply the gradients and
      return the loss value.
    logdir: The directory where training logs are written to. If None, model
      checkpoints and summaries will not be written.
    train_step_fn: The function to call in order to execute a single gradient
      step. The function must have take exactly four arguments: the current
      session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
    train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
      default, two `Boolean`, scalar ops called "should_stop" and "should_log"
      are provided.
    log_every_n_steps: The frequency, in terms of global steps, that the loss
      and global step are logged.
    graph: The graph to pass to the supervisor. If no graph is supplied the
      default graph is used.
    master: The address of the tensorflow master.
    is_chief: Specifies whether or not the training is being run by the primary
      replica during replica training.
    global_step: The `Tensor` representing the global step. If left as `None`,
      then training_util.get_or_create_global_step(), that is,
      tf.contrib.framework.global_step() is used.
    number_of_steps: The max number of gradient steps to take during training,
      as measured by 'global_step': training will stop if global_step is
      greater than 'number_of_steps'. If the value is left as None, training
      proceeds indefinitely.
    init_op: The initialization operation. If left to its default value, then
      the session is initialized by calling `tf.global_variables_initializer()`.
    init_feed_dict: A feed dictionary to use when executing the `init_op`.
    local_init_op: The local initialization operation. If left to its default
      value, then the session is initialized by calling
      `tf.local_variables_initializer()` and `tf.tables_initializer()`.
    init_fn: An optional callable to be executed after `init_op` is called. The
      callable must accept one argument, the session being initialized.
    ready_op: Operation to check if the model is ready to use. If left to its
      default value, then the session checks for readiness by calling
      `tf.report_uninitialized_variables()`.
    summary_op: The summary operation.
    save_summaries_secs: How often, in seconds, to save summaries.
    summary_writer: `SummaryWriter` to use.  Can be `None`
      to indicate that no summaries should be written. If unset, we
      create a SummaryWriter.
    startup_delay_steps: The number of steps to wait for before beginning. Note
      that this must be 0 if a sync_optimizer is supplied.
    saver: Saver to save checkpoints. If None, a default one will be created
      and used.
    save_interval_secs: How often, in seconds, to save the model to `logdir`.
    sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of
      them. If the argument is supplied, gradient updates will be synchronous.
      If left as `None`, gradient updates will be asynchronous.
    session_config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.
    session_wrapper: A function that takes a `tf.Session` object as the only
      argument and returns a wrapped session object that has the same methods
      that the original object has, or `None`. Iff not `None`, the wrapped
      object will be used for training.
    trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
      and add it to the summaries every `trace_every_n_steps`. If None, no trace
      information will be produced or saved.
    ignore_live_threads: If `True` ignores threads that remain running after
      a grace period when stopping the supervisor, instead of raising a
      RuntimeError.

  Returns:
    the value of the loss function after training.

  Raises:
    ValueError: if `train_op` is empty or if `startup_delay_steps` is
      non-zero when `sync_optimizer` is supplied, if `number_of_steps` is
      negative, or if `trace_every_n_steps` is not `None` and no `logdir` is
      provided.
  """
    if train_op is None:
        raise ValueError('train_op cannot be None.')

    if logdir is None:
        if summary_op != _USE_DEFAULT:
            raise ValueError('Cannot provide summary_op because logdir=None')
        if saver is not None:
            raise ValueError('Cannot provide saver because logdir=None')
        if trace_every_n_steps is not None:
            raise ValueError('Cannot provide trace_every_n_steps because '
                             'logdir=None')

    if isinstance(sync_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        sync_optimizer = [sync_optimizer]
    if sync_optimizer is not None and startup_delay_steps > 0:
        raise ValueError(
            'startup_delay_steps must be zero when sync_optimizer is supplied.'
        )

    if number_of_steps is not None and number_of_steps <= 0:
        raise ValueError(
            '`number_of_steps` must be either None or a positive number.')

    graph = graph or ops.get_default_graph()
    with graph.as_default():
        if global_step is None:
            global_step = training_util.get_or_create_global_step()
        saver = saver or tf_saver.Saver()

        if sync_optimizer is not None:
            for opt in sync_optimizer:
                if not isinstance(
                        opt, sync_replicas_optimizer.SyncReplicasOptimizer):
                    raise ValueError(
                        '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.'
                    )

        with ops.name_scope('init_ops'):
            if init_op == _USE_DEFAULT:
                init_op = variables.global_variables_initializer()

            if ready_op == _USE_DEFAULT:
                ready_op = variables.report_uninitialized_variables()

            if local_init_op == _USE_DEFAULT:
                local_init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    lookup_ops.tables_initializer())

            if sync_optimizer is not None and isinstance(sync_optimizer, list):
                with ops.control_dependencies(
                    [local_init_op] if local_init_op is not None else []):
                    if is_chief:
                        local_init_op = control_flow_ops.group(
                            *[opt.chief_init_op for opt in sync_optimizer])
                    else:
                        local_init_op = control_flow_ops.group(
                            *
                            [opt.local_step_init_op for opt in sync_optimizer])
                ready_for_local_init_op = control_flow_ops.group(
                    *[opt.ready_for_local_init_op for opt in sync_optimizer])
            else:
                ready_for_local_init_op = None

        if summary_op == _USE_DEFAULT:
            summary_op = summary.merge_all()

        if summary_writer == _USE_DEFAULT:
            summary_writer = supervisor.Supervisor.USE_DEFAULT

        if is_chief and sync_optimizer is not None:
            # Need to create these BEFORE the supervisor finalizes the graph:
            init_tokens_op = [
                opt.get_init_tokens_op() for opt in sync_optimizer
            ]
            chief_queue_runner = [
                opt.get_chief_queue_runner() for opt in sync_optimizer
            ]

        if train_step_kwargs == _USE_DEFAULT:
            with ops.name_scope('train_step'):
                train_step_kwargs = {}

                if number_of_steps:
                    should_stop_op = math_ops.greater_equal(
                        global_step, number_of_steps)
                else:
                    should_stop_op = constant_op.constant(False)
                train_step_kwargs['should_stop'] = should_stop_op
                if log_every_n_steps > 0:
                    train_step_kwargs['should_log'] = math_ops.equal(
                        math_ops.mod(global_step, log_every_n_steps), 0)
                if is_chief and trace_every_n_steps is not None:
                    train_step_kwargs['should_trace'] = math_ops.equal(
                        math_ops.mod(global_step, trace_every_n_steps), 0)
                    train_step_kwargs['logdir'] = logdir

    sv = supervisor.Supervisor(graph=graph,
                               is_chief=is_chief,
                               logdir=logdir,
                               init_op=init_op,
                               init_feed_dict=init_feed_dict,
                               local_init_op=local_init_op,
                               ready_for_local_init_op=ready_for_local_init_op,
                               ready_op=ready_op,
                               summary_op=summary_op,
                               summary_writer=summary_writer,
                               global_step=global_step,
                               saver=saver,
                               save_summaries_secs=save_summaries_secs,
                               save_model_secs=save_interval_secs,
                               init_fn=init_fn)

    if summary_writer is not None:
        train_step_kwargs['summary_writer'] = sv.summary_writer

    total_loss = None
    should_retry = True
    while should_retry:
        try:
            should_retry = False
            with sv.managed_session(master,
                                    start_standard_services=False,
                                    config=session_config) as sess:
                logging.info('Starting Session.')
                if session_wrapper is not None:
                    logging.info('Wrapping session with wrapper function: %s',
                                 session_wrapper)
                    sess = session_wrapper(sess)
                if is_chief:
                    if logdir:
                        sv.start_standard_services(sess)
                elif startup_delay_steps > 0:
                    # (use sys.maxsize because sys.maxint doesn't exist in Python 3)
                    _wait_for_step(
                        sess, global_step,
                        min(startup_delay_steps, number_of_steps
                            or sys.maxsize))
                threads = sv.start_queue_runners(sess)
                logging.info('Starting Queues.')
                if is_chief and sync_optimizer is not None:
                    sv.start_queue_runners(sess, chief_queue_runner)
                    sess.run(init_tokens_op)
                try:
                    while not sv.should_stop():
                        total_loss, should_stop = train_step_fn(
                            sess, train_op, global_step, train_step_kwargs)
                        if should_stop:
                            logging.info('Stopping Training.')
                            sv.request_stop()
                            break
                except errors.OutOfRangeError as e:
                    # OutOfRangeError is thrown when epoch limit per
                    # tf.train.limit_epochs is reached.
                    logging.info(
                        'Caught OutOfRangeError. Stopping Training. %s', e)
                if logdir and sv.is_chief:
                    logging.info('Finished training! Saving model to disk.')
                    sv.saver.save(sess,
                                  sv.save_path,
                                  global_step=sv.global_step)
                    sv.stop(threads,
                            close_summary_writer=True,
                            ignore_live_threads=ignore_live_threads)

        except errors.AbortedError:
            # Always re-run on AbortedError as it indicates a restart of one of the
            # distributed tensorflow servers.
            logging.info('Retrying training!')
            should_retry = True

    return total_loss
コード例 #18
0
def train(train_op,
          logdir,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_steps=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=None,
          init_fn=None,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None):
    """Runs a training loop using a TensorFlow supervisor.

  When the sync_optimizer is supplied, gradient updates are applied
  synchronously. Otherwise, gradient updates are applied asynchronous.

  Args:
    train_op: A `Tensor` that, when executed, will apply the gradients and
      return the loss value.
    logdir: The directory where training logs are written to. If None, model
      checkpoints and summaries will not be written.
    train_step_fn: The function to call in order to execute a single gradient
      step. The function must have take exactly four arguments: the current
      session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
    train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
      default, two `Boolean`, scalar ops called "should_stop" and "should_log"
      are provided.
    log_every_n_steps: The frequency, in terms of global steps, that the loss
      and global step and logged.
    graph: The graph to pass to the supervisor. If no graph is supplied the
      default graph is used.
    master: The BNS name of the tensorflow master.
    is_chief: Specifies whether or not the training is being run by the primary
      replica during replica training.
    global_step: The `Tensor` representing the global step. If left as `None`,
      then slim.variables.get_or_create_global_step() is used.
    number_of_steps: The max number of gradient steps to take during training.
      If the value is left as None, training proceeds indefinitely.
    init_op: The initialization operation. If left to its default value, then
      the session is initialized by calling `tf.initialize_all_variables()`.
    init_feed_dict: A feed dictionary to use when executing the `init_op`.
    local_init_op: The local initialization operation. If None,
      then the session is initialized by calling
      `tf.initialize_local_variables()` and `tf.initialize_all_tables()`.
    init_fn: An optional callable to be executed after `init_op` is called. The
      callable must accept one argument, the session being initialized.
    summary_op: The summary operation.
    save_summaries_secs: How often, in seconds, to save summaries.
    startup_delay_steps: The number of steps to wait for before beginning. Note
      that this must be 0 if a sync_optimizer is supplied.
    saver: Saver to save checkpoints. If None, a default one will be created
      and used.
    save_interval_secs: How often, in seconds, to save the model to `logdir`.
    sync_optimizer: an instance of tf.train.SyncReplicasOptimizer. If the
      argument is supplied, gradient updates will be synchronous. If left as
      `None`, gradient updates will be asynchronous.

  Returns:
    the value of the loss function after training.

  Raises:
    ValueError: if `train_op` is empty or if `startup_delay_steps` is
      non-zero when `sync_optimizer` is supplied, or if `number_of_steps` is
      negative.
  """
    if train_op is None:
        raise ValueError('train_op cannot be None.')

    if logdir is None:
        if summary_op != _USE_DEFAULT:
            raise ValueError('Cannot provide summary_op because logdir=None')
        if saver is not None:
            raise ValueError('Cannot provide saver because logdir=None')

    if sync_optimizer and startup_delay_steps > 0:
        raise ValueError(
            'startup_delay_steps must be zero when sync_optimizer is supplied.'
        )

    if number_of_steps is not None and number_of_steps <= 0:
        raise ValueError(
            '`number_of_steps` must be either None or a positive number.')

    graph = graph or ops.get_default_graph()
    with graph.as_default():
        if global_step is None:
            global_step = variables.get_or_create_global_step()
        saver = saver or tf_saver.Saver()

        if init_op == _USE_DEFAULT:
            init_op = tf_variables.initialize_all_variables()

        if summary_op == _USE_DEFAULT:
            summary_op = logging_ops.merge_all_summaries()

        cleanup_op = None

        if is_chief and sync_optimizer:
            if not isinstance(sync_optimizer,
                              sync_replicas_optimizer.SyncReplicasOptimizer):
                raise ValueError(
                    '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer'
                )

            # Need to create these BEFORE the supervisor finalizes the graph:
            with ops.control_dependencies([init_op]):
                init_tokens_op = sync_optimizer.get_init_tokens_op()
            init_op = init_tokens_op
            chief_queue_runner = sync_optimizer.get_chief_queue_runner()
            cleanup_op = sync_optimizer.get_clean_up_op()

        if train_step_kwargs == _USE_DEFAULT:
            train_step_kwargs = {}

            if number_of_steps:
                should_stop_op = math_ops.greater_equal(
                    global_step, number_of_steps)
            else:
                should_stop_op = constant_op.constant(False)
            train_step_kwargs['should_stop'] = should_stop_op
            train_step_kwargs['should_log'] = math_ops.equal(
                math_ops.mod(global_step, log_every_n_steps), 0)

    sv = supervisor.Supervisor(graph=graph,
                               is_chief=is_chief,
                               logdir=logdir,
                               init_op=init_op,
                               init_feed_dict=init_feed_dict,
                               local_init_op=local_init_op,
                               summary_op=summary_op,
                               global_step=global_step,
                               saver=saver,
                               save_summaries_secs=save_summaries_secs,
                               save_model_secs=save_interval_secs,
                               init_fn=init_fn)

    should_retry = True
    while should_retry:
        try:
            should_retry = False
            with sv.managed_session(master,
                                    start_standard_services=False) as sess:
                logging.info('Starting Session.')
                if is_chief:
                    if logdir:
                        sv.start_standard_services(sess)
                elif startup_delay_steps > 0:
                    _wait_for_step(
                        sess, global_step,
                        min(startup_delay_steps, number_of_steps
                            or sys.maxint))
                sv.start_queue_runners(sess)
                logging.info('Starting Queues.')
                if is_chief and sync_optimizer:
                    sv.start_queue_runners(sess, [chief_queue_runner])
                try:
                    while not sv.should_stop():
                        total_loss, should_stop = train_step_fn(
                            sess, train_op, global_step, train_step_kwargs)
                        if should_stop:
                            logging.info('Stopping Training.')
                            break
                    if logdir and sv.is_chief:
                        logging.info(
                            'Finished training! Saving model to disk.')
                        sv.saver.save(sess,
                                      sv.save_path,
                                      global_step=sv.global_step)
                finally:
                    if sv.is_chief and cleanup_op is not None:
                        logging.info('About to execute sync_clean_up_op!')
                        sess.run(cleanup_op)

        except errors.AbortedError:
            # Always re-run on AbortedError as it indicates a restart of one of the
            # distributed tensorflow servers.
            logging.info('Retrying training!')
            should_retry = True

    return total_loss
コード例 #19
0
    def _testCudnnCompatibleRnnCells(self, num_layers, seq_length, num_units,
                                     input_size, batch_size, rnn_mode,
                                     use_block_cell):
        has_state_c = rnn_mode == cudnn_rnn_ops.CUDNN_LSTM
        np.random.seed(0)
        # Train graph
        with ops.Graph().as_default():
            random_seed.set_random_seed(299)
            input_data = array_ops.placeholder(
                dtypes.float32, shape=[seq_length, batch_size, input_size])
            output_tuple, cudnn_model, cudnn_params = self._build_forward_cudnn_model(
                rnn_mode, num_layers, num_units, input_data, is_training=True)
            target_output = array_ops.placeholder(dtype=dtypes.float32,
                                                  shape=None)
            total_sum = sum(map(math_ops.reduce_sum, output_tuple))

            loss_op = losses.log_loss(labels=target_output,
                                      predictions=total_sum)
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1e-2)
            train_op = optimizer.minimize(loss_op)

            saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)

            # Train Cudnn model
            with self.test_session(use_gpu=True,
                                   graph=ops.get_default_graph()) as sess:
                sess.run(variables.global_variables_initializer())
                # Train 128 steps
                num_steps = 128
                for _ in range(num_steps):
                    inputs = np.random.rand(seq_length, batch_size,
                                            input_size).astype(np.float32)
                    targets = np.random.rand()
                    sess.run(train_op,
                             feed_dict={
                                 input_data: inputs,
                                 target_output: targets
                             })

                save_path = os.path.join(self.get_temp_dir(),
                                         ("cudnn-rnn-%s-test" % rnn_mode))
                save_v = saver.save(sess, save_path)
                self.assertEqual(save_path, save_v)
                cudnn_params_v = sess.run(cudnn_params)

        # cuDNN inference graph
        with ops.Graph().as_default():
            random_seed.set_random_seed(299)
            cudnn_inputs = array_ops.placeholder(
                dtypes.float32, shape=[seq_length, batch_size, input_size])
            (cudnn_output_tuple, cudnn_model,
             cudnn_params) = self._build_forward_cudnn_model(rnn_mode,
                                                             num_layers,
                                                             num_units,
                                                             cudnn_inputs,
                                                             is_training=False)
            saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)

            inference_input = np.random.rand(seq_length, batch_size,
                                             input_size).astype(np.float32)
            with self.test_session(use_gpu=True,
                                   graph=ops.get_default_graph()) as sess:
                sess.run(variables.global_variables_initializer())
                saver.restore(sess, save_path)
                restored_cudnn_params_v = sess.run(cudnn_params)
                self.assertAllEqual(cudnn_params_v, restored_cudnn_params_v)

                # Cudnn inference
                cudnn_output = sess.run(
                    cudnn_output_tuple,
                    feed_dict={cudnn_inputs: inference_input})

        # Canonical RNN inference graph
        with ops.Graph().as_default():
            random_seed.set_random_seed(299)
            cell_inputs = array_ops.placeholder(
                dtypes.float32, shape=[seq_length, batch_size, input_size])
            (output, states) = _create_cudnn_compatible_canonical_rnn(
                cudnn_model, cell_inputs, use_block_cell)
            saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)

            with self.test_session(use_gpu=True,
                                   graph=ops.get_default_graph()) as sess:
                saver.restore(sess, save_path)

                # BlockCell inference
                output_v, states_v = sess.run(
                    [output, states], feed_dict={cell_inputs: inference_input})

                # output across timestamps are packed into one tensor.
                self.assertAllClose(cudnn_output[0],
                                    output_v,
                                    atol=1e-6,
                                    rtol=1e-6)

                for i in range(num_layers):
                    if has_state_c:
                        # output_h
                        self.assertAllClose(cudnn_output[1][i, :],
                                            states_v[i].h,
                                            atol=1e-6,
                                            rtol=1e-6)
                        # output_c
                        self.assertAllClose(cudnn_output[2][i, :],
                                            states_v[i].c,
                                            atol=1e-6,
                                            rtol=1e-6)
                    else:
                        self.assertAllClose(cudnn_output[1][i, :],
                                            states_v[i],
                                            atol=1e-6,
                                            rtol=1e-6)
    def _test_loading_variable_with_max_rows(self, np_value, partitioner,
                                             max_rows_in_memory):
        """Helper function for various tests using max_rows_in_memory."""
        ops.reset_default_graph()
        old_tensor_name = 'matrix_to_load_and_remap'
        matrix = variable_scope.get_variable(old_tensor_name,
                                             dtype=dtypes.float32,
                                             initializer=constant_op.constant(
                                                 np_value,
                                                 dtype=dtypes.float32),
                                             partitioner=partitioner)

        with self.test_session() as sess:
            ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt')
            save = saver.Saver([matrix])
            variables.global_variables_initializer().run()
            save.save(sess, ckpt_path)
            num_rows, num_cols = np_value.shape

            # Tests loading the entire tensor (except reversed).
            remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
                ckpt_path=ckpt_path,
                old_tensor_name=old_tensor_name,
                # Simply reverses the rows of the matrix.
                row_remapping=list(range(num_rows - 1, -1, -1)),
                col_remapping=[],
                initializing_values=[],
                num_rows=num_rows,
                num_cols=num_cols,
                max_rows_in_memory=max_rows_in_memory)
            self.assertAllClose(np_value[::-1], remapped_matrix.eval())

            # Tests loading the tensor (except for the first and last rows), with
            # uninitialized values. Requires num_rows to be at least 3 since we're
            # skipping the first and last rows.
            self.assertGreater(num_rows, 2)
            prefix_rows = 2
            suffix_rows = 3
            remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
                ckpt_path=ckpt_path,
                old_tensor_name=old_tensor_name,
                # Reverses the rows of the matrix, then prepends and appends
                # uninitialized rows.
                row_remapping=([-1] * prefix_rows +
                               list(range(1, num_rows - 1)) +
                               [-1] * suffix_rows),
                col_remapping=[],
                initializing_values=[42] * (prefix_rows + suffix_rows) *
                num_cols,
                num_rows=num_rows - 2 + prefix_rows + suffix_rows,
                num_cols=num_cols,
                max_rows_in_memory=max_rows_in_memory)
            self.assertAllClose(
                np.vstack([
                    np.tile(42, [prefix_rows, num_cols]), np_value[1:-1],
                    np.tile(42, [suffix_rows, num_cols])
                ]), remapped_matrix.eval())

            # Tests when everything is taken from initializing_values.
            new_rows = 7
            initializing_values = [42] * new_rows * num_cols
            remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
                ckpt_path=ckpt_path,
                old_tensor_name=old_tensor_name,
                # Nothing is loaded from the old tensor.
                row_remapping=[-1] * new_rows,
                col_remapping=[],
                initializing_values=initializing_values,
                num_rows=new_rows,
                num_cols=num_cols,
                max_rows_in_memory=max_rows_in_memory)
            self.assertAllClose(
                np.reshape(initializing_values, (new_rows, num_cols)),
                remapped_matrix.eval())
コード例 #21
0
    def testTrainWithAlteredGradients(self):
        # Use the same learning rate but different gradient multipliers
        # to train two models. Model with equivalently larger learning
        # rate (i.e., learning_rate * gradient_multiplier) has smaller
        # training loss.
        logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/')
        logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/')

        if gfile.Exists(logdir1):
            gfile.DeleteRecursively(logdir1)
        if gfile.Exists(logdir2):
            gfile.DeleteRecursively(logdir2)

        multipliers = [1., 1000.]
        number_of_steps = 10
        losses = []
        learning_rate = 0.001

        # First, train the model with equivalently smaller learning rate.
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            train_op = self.create_train_op(learning_rate=learning_rate,
                                            gradient_multiplier=multipliers[0])

            saver = saver_lib.Saver()

            loss = training.train(
                train_op,
                logdir1,
                hooks=[
                    basic_session_run_hooks.StopAtStepHook(
                        num_steps=number_of_steps),
                    basic_session_run_hooks.CheckpointSaverHook(logdir1,
                                                                save_steps=50,
                                                                saver=saver),
                ])

            losses.append(loss)
            self.assertGreater(loss, .5)

        # Second, train the model with equivalently larger learning rate.
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            train_op = self.create_train_op(learning_rate=learning_rate,
                                            gradient_multiplier=multipliers[1])
            saver = saver_lib.Saver()

            loss = training.train(
                train_op,
                logdir2,
                hooks=[
                    basic_session_run_hooks.StopAtStepHook(
                        num_steps=number_of_steps),
                    basic_session_run_hooks.CheckpointSaverHook(logdir2,
                                                                save_steps=50,
                                                                saver=saver),
                ])

            losses.append(loss)
            self.assertIsNotNone(loss)
            self.assertLess(loss, .5)

        # The loss of the model trained with larger learning rate should
        # be smaller.
        self.assertGreater(losses[0], losses[1])
コード例 #22
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 clear_devices,
                                 initializer_nodes,
                                 optimize_graph=True,
                                 variable_names_blacklist=''):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError('Input checkpoint "' + input_checkpoint +
                         '" does not exist!')

    if not output_node_names:
        raise ValueError(
            'You must supply the name of a node to --output_node_names.')

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ''

    with tf.Graph().as_default():
        tf.import_graph_def(input_graph_def, name='')

        if optimize_graph:
            logging.info('Graph Rewriter optimizations enabled')
            rewrite_options = rewriter_config_pb2.RewriterConfig()
            rewrite_options.optimizers.append('pruning')
            rewrite_options.optimizers.append('constfold')
            rewrite_options.optimizers.append('layout')
            graph_options = tf.GraphOptions(rewrite_options=rewrite_options,
                                            infer_shapes=True)
        else:
            logging.info('Graph Rewriter optimizations disabled')
            graph_options = tf.GraphOptions()
        config = tf.ConfigProto(graph_options=graph_options)
        with session.Session(config=config) as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def)
                saver.restore(sess, input_checkpoint)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(
                    input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ':0')
                    except KeyError:
                        # This tensor doesn't exist in the graph (for example it's
                        # 'global_step' or a similar housekeeping element) so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes)

            variable_names_blacklist = (variable_names_blacklist.split(',')
                                        if variable_names_blacklist else None)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(','),
                variable_names_blacklist=variable_names_blacklist)

    return output_graph_def
コード例 #23
0
    def testTrainWithInitFromCheckpoint(self):
        logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
        logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')

        if gfile.Exists(logdir1):  # For running on jenkins.
            gfile.DeleteRecursively(logdir1)
        if gfile.Exists(logdir2):  # For running on jenkins.
            gfile.DeleteRecursively(logdir2)

        # First, train the model one step (make sure the error is high).
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            train_op = self.create_train_op()
            saver = saver_lib.Saver()
            loss = training.train(
                train_op,
                logdir1,
                hooks=[
                    basic_session_run_hooks.CheckpointSaverHook(logdir1,
                                                                save_steps=1,
                                                                saver=saver),
                    basic_session_run_hooks.StopAtStepHook(num_steps=1),
                ],
                save_checkpoint_secs=None,
                save_summaries_steps=None)
            self.assertGreater(loss, .5)

        # Next, train the model to convergence.
        with ops.Graph().as_default():
            random_seed.set_random_seed(1)
            train_op = self.create_train_op()
            saver = saver_lib.Saver()
            loss = training.train(
                train_op,
                logdir1,
                hooks=[
                    basic_session_run_hooks.CheckpointSaverHook(logdir1,
                                                                save_steps=300,
                                                                saver=saver),
                    basic_session_run_hooks.StopAtStepHook(num_steps=300),
                ],
                save_checkpoint_secs=None,
                save_summaries_steps=None)
            self.assertIsNotNone(loss)
            self.assertLess(loss, .02)

        # Finally, advance the model a single step and validate that the loss is
        # still low.
        with ops.Graph().as_default():
            random_seed.set_random_seed(2)
            train_op = self.create_train_op()

            model_variables = variables_lib2.global_variables()
            model_path = checkpoint_management.latest_checkpoint(logdir1)

            assign_fn = variables_lib.assign_from_checkpoint_fn(
                model_path, model_variables)

            def init_fn(_, session):
                assign_fn(session)

            loss = training.train(
                train_op,
                None,
                scaffold=monitored_session.Scaffold(init_fn=init_fn),
                hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)],
                save_checkpoint_secs=None,
                save_summaries_steps=None)

            self.assertIsNotNone(loss)
            self.assertLess(loss, .02)
コード例 #24
0
def freeze_graph_with_def_protos(
        input_graph_def,
        input_checkpoint,
        output_node_names,
        restore_op_name,
        filename_tensor_name,
        clear_devices,
        initializer_nodes,
        variable_names_blacklist=''):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError(
            'Input checkpoint "' + input_checkpoint + '" does not exist!')

    if not output_node_names:
        raise ValueError(
            'You must supply the name of a node to --output_node_names.')

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ''

    # print('>>>>>input_graph_def.node', input_graph_def.node)
    with tf.Graph().as_default():
        tf.import_graph_def(input_graph_def, name='')
        config = tf.ConfigProto(graph_options=tf.GraphOptions())
        with session.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            print('>>>>>>var_to_shape_map', var_to_shape_map)
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ':0')
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
            saver = saver_lib.Saver(var_list=var_list)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes)

            variable_names_blacklist = (variable_names_blacklist.split(',') if
                                        variable_names_blacklist else None)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(','),
                variable_names_blacklist=variable_names_blacklist)
            # if input_saver_def:
            #     print('>>>>>input_saver_def', input_saver_def)
            #     saver = saver_lib.Saver(saver_def=input_saver_def)
            #     saver.restore(sess, input_checkpoint)
            # else:
            #     var_list = {}
            #     reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            #     var_to_shape_map = reader.get_variable_to_shape_map()
            #     for key in var_to_shape_map:
            #         try:
            #             tensor = sess.graph.get_tensor_by_name(key + ':0')
            #         except KeyError:
            #             # This tensor doesn't exist in the graph (for example it's
            #             # 'global_step' or a similar housekeeping element) so skip it.
            #             continue
            #         var_list[key] = tensor
            #     saver = saver_lib.Saver(var_list=var_list)
            #     saver.restore(sess, input_checkpoint)
            #     if initializer_nodes:
            #         sess.run(initializer_nodes)
            #
            # variable_names_blacklist = (variable_names_blacklist.split(',') if
            #                             variable_names_blacklist else None)
            # output_graph_def = graph_util.convert_variables_to_constants(
            #     sess,
            #     input_graph_def,
            #     output_node_names.split(','),
            #     variable_names_blacklist=variable_names_blacklist)

    return output_graph_def
コード例 #25
0
    def model_fn(features, labels, mode):
        """model_fn for keras Estimator."""
        model = _clone_and_build_model(mode=mode,
                                       keras_model=keras_model,
                                       custom_objects=custom_objects,
                                       features=features,
                                       labels=labels,
                                       optimizer_config=optimizer_config)
        model_output_names = []
        # We need to make sure that the output names of the last layer in the model
        # is the same for each of the cloned models. This is required for mirrored
        # strategy when we call regroup.
        if distribution_strategy_context.has_strategy():
            for name in model.output_names:
                name = re.compile(r'_\d$').sub('', name)
                model_output_names.append(name)
        else:
            model_output_names = model.output_names

        # Get inputs to EstimatorSpec
        predictions = dict(zip(model_output_names, model.outputs))

        loss = None
        train_op = None
        eval_metric_ops = None

        # Set loss and metric only during train and evaluate.
        if mode is not ModeKeys.PREDICT:
            if mode is ModeKeys.TRAIN:
                model._make_train_function()  # pylint: disable=protected-access
            else:
                model._make_test_function()  # pylint: disable=protected-access
            loss = model.total_loss

            eval_metric_ops = _convert_keras_metrics_to_estimator(model)

        # Set train_op only during train.
        if mode is ModeKeys.TRAIN:
            train_op = model.train_function.updates_op

        if (not model._is_graph_network
                and hasattr(keras_model, '_original_attributes_cache')
                and keras_model._original_attributes_cache is not None):
            # To avoid `model_fn` being destructive for the initial model argument.
            models.in_place_subclassed_model_state_restoration(keras_model)

        scaffold = None
        if save_object_ckpt:
            model._track_trackable(training_util.get_global_step(),
                                   'estimator_global_step')
            # Create saver that maps variable names to object-checkpoint keys.
            object_graph = graph_view.ObjectGraphView(model)
            var_list = object_graph.frozen_saveable_objects()
            saver = saver_lib.Saver(var_list=var_list, sharded=True)
            saver._object_restore_saver = trackable_util.frozen_saver(model)
            scaffold = monitored_session.Scaffold(saver=saver)

        return model_fn_lib.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            export_outputs={
                _DEFAULT_SERVING_KEY: export_lib.PredictOutput(predictions)
            },
            scaffold=scaffold)
コード例 #26
0
    def set_model(self, model):
        """Sets Keras model and creates summary ops."""

        self.model = model
        self._init_writer(model)
        # histogram summaries only enabled in graph mode
        if not context.executing_eagerly():
            self._make_histogram_ops(model)
            self.merged = tf_summary.merge_all()

        # If both embedding_freq and embeddings_data are available, we will
        # visualize embeddings.
        if self.embeddings_freq and self.embeddings_data is not None:
            # Avoid circular dependency.
            from tensorflow.python.keras.engine import training_utils  # pylint: disable=g-import-not-at-top
            self.embeddings_data = training_utils.standardize_input_data(
                self.embeddings_data, model.input_names)

            # If embedding_layer_names are not provided, get all of the embedding
            # layers from the model.
            embeddings_layer_names = self.embeddings_layer_names
            if not embeddings_layer_names:
                embeddings_layer_names = [
                    layer.name for layer in self.model.layers
                    if type(layer).__name__ == 'Embedding'
                ]

            self.assign_embeddings = []
            embeddings_vars = {}

            self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
            self.step = step = array_ops.placeholder(dtypes.int32)

            for layer in self.model.layers:
                if layer.name in embeddings_layer_names:
                    embedding_input = self.model.get_layer(layer.name).output
                    embedding_size = np.prod(embedding_input.shape[1:])
                    embedding_input = array_ops.reshape(
                        embedding_input, (step, int(embedding_size)))
                    shape = (self.embeddings_data[0].shape[0],
                             int(embedding_size))
                    embedding = variables.Variable(array_ops.zeros(shape),
                                                   name=layer.name +
                                                   '_embedding')
                    embeddings_vars[layer.name] = embedding
                    batch = state_ops.assign(
                        embedding[batch_id:batch_id + step], embedding_input)
                    self.assign_embeddings.append(batch)

            self.saver = saver.Saver(list(embeddings_vars.values()))

            # Create embeddings_metadata dictionary
            if isinstance(self.embeddings_metadata, str):
                embeddings_metadata = {
                    layer_name: self.embeddings_metadata
                    for layer_name in embeddings_vars.keys()
                }
            else:
                # If embedding_metadata is already a dictionary
                embeddings_metadata = self.embeddings_metadata

            try:
                from tensorboard.plugins import projector
            except ImportError:
                raise ImportError(
                    'Failed to import TensorBoard. Please make sure that '
                    'TensorBoard integration is complete."')

            # TODO(psv): Add integration tests to test embedding visualization
            # with TensorBoard callback. We are unable to write a unit test for this
            # because TensorBoard dependency assumes TensorFlow package is installed.
            config = projector.ProjectorConfig()
            for layer_name, tensor in embeddings_vars.items():
                embedding = config.embeddings.add()
                embedding.tensor_name = tensor.name

                if (embeddings_metadata is not None
                        and layer_name in embeddings_metadata):
                    embedding.metadata_path = embeddings_metadata[layer_name]

            projector.visualize_embeddings(self.writer, config)
コード例 #27
0
 def _save_return_saver(self, sess, var):
     saver = saver_lib.Saver(var_list=[var])
     test_dir = self.get_temp_dir()
     prefix = os.path.join(test_dir, "ckpt")
     return saver.save(sess, prefix), saver
コード例 #28
0
    def testSinglePartitionedVariable(self):
        """Ensures partitioned variables fail cleanly with freeze graph."""
        checkpoint_prefix = os.path.join(self.get_temp_dir(),
                                         "saved_checkpoint")
        checkpoint_state_name = "checkpoint_state"
        input_graph_name = "input_graph.pb"
        output_graph_name = "output_graph.pb"

        # Create a graph with partition variables. When weights are partitioned into
        # a single partition, the weights variable is followed by a identity ->
        # identity (an additional identity node).
        partitioner = partitioned_variables.fixed_size_partitioner(1)
        with ops.Graph().as_default():
            with variable_scope.variable_scope("part",
                                               partitioner=partitioner):
                batch_size, height, width, depth = 5, 128, 128, 3
                input1 = array_ops.zeros((batch_size, height, width, depth),
                                         name="input1")
                input2 = array_ops.zeros((batch_size, height, width, depth),
                                         name="input2")

                num_nodes = depth
                filter1 = variable_scope.get_variable("filter",
                                                      [num_nodes, num_nodes])
                filter2 = array_ops.reshape(filter1,
                                            [1, 1, num_nodes, num_nodes])
                conv = nn.conv2d(input=input1,
                                 filter=filter2,
                                 strides=[1, 1, 1, 1],
                                 padding="SAME")
                node = math_ops.add(conv, input2, name="test/add")
                node = nn.relu6(node, name="test/relu6")

            # Save graph and checkpoints.
            sess = session.Session()
            sess.run(variables.global_variables_initializer())

            saver = saver_lib.Saver()
            checkpoint_path = saver.save(sess,
                                         checkpoint_prefix,
                                         global_step=0,
                                         latest_filename=checkpoint_state_name)
            graph_io.write_graph(sess.graph, self.get_temp_dir(),
                                 input_graph_name)

            # Ensure this graph has partition variables.
            self.assertTrue([
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ])

        # Test freezing graph doesn't make it crash.
        output_node_names = "save/restore_all"
        output_graph_path = os.path.join(self.get_temp_dir(),
                                         output_graph_name)

        return_value = freeze_graph.freeze_graph_with_def_protos(
            input_graph_def=sess.graph_def,
            input_saver_def=None,
            input_checkpoint=checkpoint_path,
            output_node_names=output_node_names,
            restore_op_name="save/restore_all",  # default value
            filename_tensor_name="save/Const:0",  # default value
            output_graph=output_graph_path,
            clear_devices=False,
            initializer_nodes="")
        self.assertTrue(return_value, -1)
コード例 #29
0
def save_variables_to_ckpt(model_dir):
  init_all_op = [variables.global_variables_initializer()]
  with tf_session.Session() as sess:
    sess.run(init_all_op)
    saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
コード例 #30
0
ファイル: graph_actions.py プロジェクト: yaya213/tensorflow
def _make_saver(graph):
    vars_to_save = graph.get_collection(ops.GraphKeys.VARIABLES)
    if vars_to_save:
        return tf_saver.Saver(vars_to_save, sharded=True)
    else:
        return None