def testWarmStart_SparseColumnIntegerized(self):
    # Create feature column.
    sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)

    # Save checkpoint from which to warm-start.
    _, prev_int_val = self._create_prev_run_var(
        "linear_model/sc_int/weights", shape=[10, 1], initializer=ones())
    # Verify we initialized the values correctly.
    self.assertAllEqual(np.ones([10, 1]), prev_int_val)

    partitioner = lambda shape, dtype: [1] * len(shape)
    # New graph, new session WITHOUT warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([sc_int], partitioner)
        sess.run(variables.global_variables_initializer())
        # Without warm-starting, the weights should be initialized using default
        # initializer (which is init_ops.zeros_initializer).
        self._assert_cols_to_vars(cols_to_vars, {sc_int: [np.zeros([10, 1])]},
                                  sess)

    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([sc_int], partitioner)
        ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=".*sc_int.*")
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started.
        self._assert_cols_to_vars(cols_to_vars, {sc_int: [prev_int_val]}, sess)
  def testWarmStart_SparseColumnVocabulary(self):
    # Create vocab for sparse column "sc_vocab".
    vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                   "vocab")
    # Create feature column.
    sc_vocab = fc.categorical_column_with_vocabulary_file(
        "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)

    # Save checkpoint from which to warm-start.
    _, prev_vocab_val = self._create_prev_run_var(
        "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())

    partitioner = lambda shape, dtype: [1] * len(shape)
    # New graph, new session WITHOUT warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
        sess.run(variables.global_variables_initializer())
        # Without warm-starting, the weights should be initialized using default
        # initializer (which is init_ops.zeros_initializer).
        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
                                  sess)

    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
        # Since old vocab is not explicitly set in WarmStartSettings, the old
        # vocab is assumed to be same as new vocab.
        ws_util.warm_start(
            self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*")
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started.
        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
                                  sess)
    def test_from_object_based_checkpoint(self):
        dim = 10
        keys = [0, 1, 2, 3]
        values = [[k] * dim for k in keys]
        with ops.Graph().as_default() as g:
            with self.session(graph=g):
                with variable_scope.variable_scope("outer"):
                    prev_var = self._add_and_initialize_devar(
                        "prefix/devar", keys, values, dim)
                    # Save object-based checkpoint.
                    tracking_util.Checkpoint(v=prev_var).save(
                        os.path.join(self.get_temp_dir(), "checkpoint"))

        with ops.Graph().as_default() as g:
            with self.session(graph=g):
                with variable_scope.variable_scope("outer"):
                    var = self._add_devar("prefix/devar", dim)
                    ws_util.warm_start(
                        self.get_temp_dir(),
                        vars_to_warm_start=["outer/prefix/devar"])
                    self.evaluate(
                        deo.dynamic_embedding_variables_initializer())
                    _keys, _values = self._export_sorted_keys_and_values(var)
                    self.assertAllEqual(keys, _keys)
                    self.assertAllEqual(values, _values)
  def testWarmStart_BucketizedColumn(self):
    # Create feature column.
    real = fc.numeric_column("real")
    real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.])

    # Save checkpoint from which to warm-start.
    _, prev_bucket_val = self._create_prev_run_var(
        "linear_model/real_bucketized/weights",
        shape=[5, 1],
        initializer=norms())

    partitioner = lambda shape, dtype: [1] * len(shape)
    # New graph, new session WITHOUT warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([real_bucket], partitioner)
        sess.run(variables.global_variables_initializer())
        # Without warm-starting, the weights should be initialized using default
        # initializer (which is init_ops.zeros_initializer).
        self._assert_cols_to_vars(cols_to_vars,
                                  {real_bucket: [np.zeros([5, 1])]}, sess)

    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([real_bucket], partitioner)
        ws_util.warm_start(
            self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*")
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started.
        self._assert_cols_to_vars(cols_to_vars,
                                  {real_bucket: [prev_bucket_val]}, sess)
  def testWarmStart_SparseColumnHashed(self):
    # Create feature column.
    sc_hash = fc.categorical_column_with_hash_bucket(
        "sc_hash", hash_bucket_size=15)

    # Save checkpoint from which to warm-start.
    _, prev_hash_val = self._create_prev_run_var(
        "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())

    partitioner = lambda shape, dtype: [1] * len(shape)
    # New graph, new session WITHOUT warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([sc_hash], partitioner)
        sess.run(variables.global_variables_initializer())
        # Without warm-starting, the weights should be initialized using default
        # initializer (which is init_ops.zeros_initializer).
        self._assert_cols_to_vars(cols_to_vars, {sc_hash: [np.zeros([15, 1])]},
                                  sess)

    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([sc_hash], partitioner)
        ws_util.warm_start(
            self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*")
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started.
        self._assert_cols_to_vars(cols_to_vars, {sc_hash: [prev_hash_val]},
                                  sess)
  def testWarmStartMoreSettingsNoPartitioning(self):
    # Create old and new vocabs for sparse column "sc_vocab".
    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                        "old_vocab")
    new_vocab_path = self._write_vocab(
        ["orange", "guava", "banana", "apple", "raspberry",
         "blueberry"], "new_vocab")
    # Create feature columns.
    sc_hash = fc.categorical_column_with_hash_bucket(
        "sc_hash", hash_bucket_size=15)
    sc_keys = fc.categorical_column_with_vocabulary_list(
        "sc_keys", vocabulary_list=["a", "b", "c", "e"])
    sc_vocab = fc.categorical_column_with_vocabulary_file(
        "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
    all_linear_cols = [sc_hash, sc_keys, sc_vocab]

    # Save checkpoint from which to warm-start.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        variable_scope.get_variable(
            "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
        sc_keys_weights = variable_scope.get_variable(
            "some_other_name", shape=[4, 1], initializer=rand())
        variable_scope.get_variable(
            "linear_model/sc_vocab/weights",
            initializer=[[0.5], [1.], [2.], [3.]])
        self._write_checkpoint(sess)
        prev_keys_val = sess.run(sc_keys_weights)

    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model(all_linear_cols,
                                                 partitioner=None)
        vocab_info = ws_util.VocabInfo(
            new_vocab=sc_vocab.vocabulary_file,
            new_vocab_size=sc_vocab.vocabulary_size,
            num_oov_buckets=sc_vocab.num_oov_buckets,
            old_vocab=prev_vocab_path)
        ws_util.warm_start(
            self.get_temp_dir(),
            vars_to_warm_start=".*(sc_keys|sc_vocab).*",
            var_name_to_vocab_info={
                ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
            },
            var_name_to_prev_var_name={
                ws_util._infer_var_name(cols_to_vars[sc_keys]):
                    "some_other_name"
            })
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started.  Var corresponding to
        # sc_hash should not be warm-started.  Var corresponding to sc_vocab
        # should be correctly warm-started after vocab remapping.
        self._assert_cols_to_vars(cols_to_vars, {
            sc_keys: [prev_keys_val],
            sc_hash: [np.zeros([15, 1])],
            sc_vocab: [np.array([[3.], [2.], [1.], [0.5], [0.], [0.]])]
        }, sess)
 def warm_start(g):
   with self.test_session(graph=g) as sess:
     # Initialize with zeros.
     var = variable_scope.get_variable(
         var_name, initializer=[[0., 0.], [0., 0.]])
     ws_util.warm_start(self.get_temp_dir())
     sess.run(variables.global_variables_initializer())
     # Verify weights were correctly warm-started to previous values.
     self.assertAllEqual(original_value, self.evaluate(var))
    def test_checkpoint_overwrite_warm_start(self):
        extra_run_step = 2
        ws_ckpt_dir = tempfile.mkdtemp(
            prefix=os.path.join(self.get_temp_dir(), "warm_start"))
        final_ckpt_dir = tempfile.mkdtemp(
            prefix=os.path.join(self.get_temp_dir(), "final"))
        for run_id, num_shards, k_dtype, d_dtype, init_mode, dim, run_step \
            in _next_run_step_config():
            error_msg = "Cond:{},{},{},{},{},{}".format(
                num_shards, k_dtype, d_dtype, init_mode, dim, run_step)
            with ops.Graph().as_default() as g:
                with self.session(graph=g,
                                  use_gpu=test_util.is_gpu_available(),
                                  config=default_config) as sess:
                    training_util.create_global_step()
                    graph = TestGraph(k_dtype, d_dtype, dim, num_shards, 'var',
                                      'devar', run_id)
                    self.evaluate(variables.global_variables_initializer())
                    sess.run([graph.devar_init_op])
                    prev_x = sess.run([graph.x])[0]
                    for _ in range(run_step):
                        sess.run([graph.var_opt_op, graph.devar_opt_op])
                    saver_lib.Saver().save(sess,
                                           os.path.join(ws_ckpt_dir, "model"))
                    prev_ws_var_loss, prev_ws_devar_loss = sess.run(
                        [graph.var_loss, graph.devar_loss])
                    self.assertAllCloseAccordingToType(prev_ws_var_loss,
                                                       prev_ws_devar_loss,
                                                       msg=error_msg)
                    for _ in range(extra_run_step):
                        sess.run([graph.var_opt_op, graph.devar_opt_op])
                    saver_lib.Saver().save(
                        sess, os.path.join(final_ckpt_dir, "model"))
                    prev_final_var_loss, prev_final_devar_loss = sess.run(
                        [graph.var_loss, graph.devar_loss])
                    self.assertAllCloseAccordingToType(prev_final_var_loss,
                                                       prev_final_devar_loss,
                                                       msg=error_msg)

            with ops.Graph().as_default():
                training_util.create_global_step()
                graph = TestGraph(k_dtype, d_dtype, dim, num_shards, 'var',
                                  'devar', run_id, prev_x)
                ws_util.warm_start(ws_ckpt_dir, vars_to_warm_start=['.*'])
                with monitored_session.MonitoredTrainingSession(
                        config=default_config,
                        is_chief=True,
                        checkpoint_dir=final_ckpt_dir) as sess:
                    var_loss, devar_loss = sess.run(
                        [graph.var_loss, graph.devar_loss])
                    self.assertAllCloseAccordingToType(var_loss,
                                                       prev_final_var_loss,
                                                       msg=error_msg)
                    self.assertAllCloseAccordingToType(devar_loss,
                                                       prev_final_devar_loss,
                                                       msg=error_msg)
 def warm_start(g):
     with self.session(graph=g) as sess:
         # Initialize with zeros.
         var = variable_scope.get_variable(var_name,
                                           initializer=[[0., 0.],
                                                        [0., 0.]])
         ws_util.warm_start(self.get_temp_dir())
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started to previous values.
         self.assertAllEqual(original_value, self.evaluate(var))
    def test_use_var_name_to_prev_var_name(self):
        # Save checkpoint from which to warm-start.
        dim = 2
        keys = [0, 1]
        values = [[0, 0], [1, 1]]
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                with variable_scope.variable_scope("old_outer"):
                    prev_v1 = self._add_and_initialize_devar(
                        "v1", keys, values, dim)
                    prev_v2 = self._add_and_initialize_devar(
                        "v2", keys, values, dim)
                    _write_checkpoint(self, sess)

        # New graph, new session with warm-starting.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                with variable_scope.variable_scope("new_outer"):
                    v1 = self._add_devar("v1", dim)
                    v2 = self._add_devar("v2", dim)
                    self.assertAllEqual(0, self.evaluate(v1.size()))

                    ms_name_to_prev_ms_name = {}
                    for v, prev_v in zip([v1, v2], [prev_v1, prev_v2]):
                        for table, prev_table in zip(v.tables, prev_v.tables):
                            ms_name_to_prev_ms_name[table.saveable.full_name] = \
                              prev_table.saveable.full_name

                    # Unfound MutableHashTable._Saveable names raises ValueError
                    self.assertRaises(
                        ValueError,
                        ws_util.warm_start,
                        self.get_temp_dir(),
                        vars_to_warm_start=["new_outer/v1", "new_outer/v2"])

                    # Unused previous MutableHashTable._Saveable names raises ValueError.
                    self.assertRaises(
                        ValueError,
                        ws_util.warm_start,
                        self.get_temp_dir(),
                        vars_to_warm_start=["new_outer/v1"],
                        var_name_to_prev_var_name=ms_name_to_prev_ms_name)

                    ws_util.warm_start(
                        self.get_temp_dir(),
                        vars_to_warm_start=["new_outer/v1", "new_outer/v2"],
                        var_name_to_prev_var_name=ms_name_to_prev_ms_name)
                    self.evaluate(
                        deo.dynamic_embedding_variables_initializer())
                    # Verify the selection of weights were correctly warm-started (init
                    # overridden to ones).
                    for v in [v1, v2]:
                        keys, values = self._export_sorted_keys_and_values(v)
                        self.assertAllEqual(keys, keys)
                        self.assertAllEqual(values, values)
  def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
    # Create old vocabulary, and use a size smaller than the total number of
    # entries.
    old_vocab_path = self._write_vocab(["apple", "guava", "banana"],
                                       "old_vocab")
    old_vocab_size = 2  # ['apple', 'guava']

    # Create new vocab for sparse column "sc_vocab".
    current_vocab_path = self._write_vocab(
        ["apple", "banana", "guava", "orange"], "current_vocab")
    # Create feature column.  Only use 2 of the actual entries, resulting in
    # ['apple', 'banana'] for the new vocabulary.
    sc_vocab = fc.categorical_column_with_vocabulary_file(
        "sc_vocab", vocabulary_file=current_vocab_path, vocabulary_size=2)

    # Save checkpoint from which to warm-start.
    self._create_prev_run_var(
        "linear_model/sc_vocab/weights", shape=[2, 1], initializer=ones())

    partitioner = lambda shape, dtype: [1] * len(shape)
    # New graph, new session WITHOUT warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
        sess.run(variables.global_variables_initializer())
        # Without warm-starting, the weights should be initialized using default
        # initializer (which is init_ops.zeros_initializer).
        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([2, 1])]},
                                  sess)

    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
        vocab_info = ws_util.VocabInfo(
            new_vocab=sc_vocab.vocabulary_file,
            new_vocab_size=sc_vocab.vocabulary_size,
            num_oov_buckets=sc_vocab.num_oov_buckets,
            old_vocab=old_vocab_path,
            old_vocab_size=old_vocab_size)
        ws_util.warm_start(
            ckpt_to_initialize_from=self.get_temp_dir(),
            vars_to_warm_start=".*sc_vocab.*",
            var_name_to_vocab_info={
                "linear_model/sc_vocab/weights": vocab_info
            })
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started.  'banana' isn't in the
        # first two entries of the old vocabulary, so it's newly initialized.
        self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [[[1], [0]]]}, sess)
    def test_warm_start_optimizers(self):
        extra_run_step = 2
        for run_id, num_shards, k_dtype, d_dtype, init_mode, dim, run_step \
            in _next_run_step_config():
            error_msg = "Cond:{},{},{},{},{},{}".format(
                num_shards, k_dtype, d_dtype, init_mode, dim, run_step)
            with ops.Graph().as_default() as g:
                with self.session(graph=g,
                                  use_gpu=test_util.is_gpu_available(),
                                  config=default_config) as sess:
                    graph = TestGraph(k_dtype, d_dtype, dim, num_shards, 'var',
                                      'devar', run_id)
                    self.evaluate(variables.global_variables_initializer())
                    sess.run([graph.devar_init_op])
                    prev_x = sess.run([graph.x])[0]
                    for _ in range(run_step):
                        sess.run([graph.var_opt_op, graph.devar_opt_op])
                    sess.run([graph.var_loss, graph.devar_loss])
                    _write_checkpoint(self, sess)
                    for _ in range(extra_run_step):
                        sess.run([graph.var_opt_op, graph.devar_opt_op])
                    prev_var_loss, prev_devar_loss = sess.run(
                        [graph.var_loss, graph.devar_loss])
                    self.assertAllCloseAccordingToType(prev_var_loss,
                                                       prev_devar_loss,
                                                       msg=error_msg)

            with ops.Graph().as_default() as g:
                with self.session(graph=g,
                                  use_gpu=test_util.is_gpu_available(),
                                  config=default_config) as sess:
                    graph = TestGraph(k_dtype, d_dtype, dim, num_shards, 'var',
                                      'devar', run_id, prev_x)
                    ws_util.warm_start(self.get_temp_dir(),
                                       vars_to_warm_start=['.*'])
                    self.evaluate(variables.global_variables_initializer())
                    self.evaluate(
                        deo.dynamic_embedding_variables_initializer())
                    for _ in range(extra_run_step):
                        sess.run([graph.var_opt_op, graph.devar_opt_op])
                    var_loss, devar_loss = sess.run(
                        [graph.var_loss, graph.devar_loss])
                    self.assertAllCloseAccordingToType(var_loss,
                                                       prev_var_loss,
                                                       msg=error_msg)
                    self.assertAllCloseAccordingToType(devar_loss,
                                                       prev_devar_loss,
                                                       msg=error_msg)
    def test_both_vars_and_devars(self):
        # Save checkpoint from which to warm-start.
        dim1, dim2 = 10, 20
        keys1, keys2 = [0, 1, 2, 3], [4, 5, 6]
        values1, values2 = [[k] * dim1 for k in keys1], [[k] * dim2
                                                         for k in keys2]
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                var = variable_scope.get_variable(
                    "v1",
                    shape=[10, 1],
                    initializer=init_ops.ones_initializer())
                self.evaluate(variables.global_variables_initializer())
                prev_int_val = self.evaluate(var)
                self.assertAllEqual(np.ones([10, 1]), prev_int_val)
                devar1 = self._add_and_initialize_devar(
                    "devar1", keys1, values1, dim1)
                self.assertAllEqual(4, self.evaluate(devar1.size()))
                devar2 = self._add_and_initialize_devar(
                    "devar2", keys2, values2, dim2)
                self.assertAllEqual(3, self.evaluate(devar2.size()))
                _write_checkpoint(self, sess)

        # New graph, new session with warm-starting.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                # Initialize with zeros.
                var = variable_scope.get_variable(
                    "v1",
                    shape=[10, 1],
                    initializer=init_ops.zeros_initializer())
                devar1 = self._add_devar("devar1", dim1)
                devar2 = self._add_devar("devar2", dim2)
                self.assertAllEqual(0, self.evaluate(devar1.size()))
                self.assertAllEqual(0, self.evaluate(devar2.size()))
                ws_util.warm_start(self.get_temp_dir(),
                                   vars_to_warm_start=[var, devar1])
                self.evaluate(variables.global_variables_initializer())
                self.evaluate(deo.dynamic_embedding_variables_initializer())
                # Verify weights were correctly warm-started (init overridden to ones).
                self.assertAllEqual(var.eval(), prev_int_val)
                self.assertAllEqual(4, self.evaluate(devar1.size()))
                keys, values = self._export_sorted_keys_and_values(devar1)
                self.assertAllEqual(keys1, keys)
                self.assertAllEqual(values1, values)
                self.assertAllEqual(0, self.evaluate(devar2.size()))
  def testWarmStart_ListOfStrings(self):
    # Save checkpoint from which to warm-start.
    _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
                                                initializer=ones())
    # Verify we initialized the values correctly.
    self.assertAllEqual(np.ones([10, 1]), prev_int_val)

    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.session(graph=g) as sess:
        # Initialize with zeros.
        var = variable_scope.get_variable(
            "v1",
            shape=[10, 1],
            initializer=zeros())
        ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=["v1"])
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started (init overridden to ones).
        self.assertAllEqual(var.eval(), prev_int_val)
    def test_list_of_regexes(self):
        # Save checkpoint from which to warm-start.
        dim = 2
        keys = [0, 1]
        values = [[0, 0], [1, 1]]
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                with variable_scope.variable_scope("outer"):
                    self._add_and_initialize_devar("v1", keys, values, dim)
                    self._add_and_initialize_devar("v1/Momentum", keys, values,
                                                   dim)
                    self._add_and_initialize_devar("v2", keys, values, dim)
                    self._add_and_initialize_devar("v2/Momentum", keys, values,
                                                   dim)
                    _write_checkpoint(self, sess)

        # New graph, new session with warm-starting.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
                with variable_scope.variable_scope("outer"):
                    v1 = self._add_devar("v1", dim)
                    v1_momentum = self._add_devar("v1/Momentum", dim)
                    v2 = self._add_devar("v2", dim)
                    v2_momentum = self._add_devar("v2/Momentum", dim)
                    self.assertAllEqual(0, self.evaluate(v1.size()))

                    ws_util.warm_start(
                        self.get_temp_dir(),
                        # This warm-starts both v1 and v1/Momentum, but only
                        # v2 (and not v2/Momentum).
                        vars_to_warm_start=["outer/v1", "outer/v2$"])
                    self.evaluate(
                        deo.dynamic_embedding_variables_initializer())
                    # Verify the selection of weights were correctly warm-started (init
                    # overridden to ones).
                    for v in [v1, v1_momentum, v2]:
                        keys, values = self._export_sorted_keys_and_values(v)
                        self.assertAllEqual(keys, keys)
                        self.assertAllEqual(values, values)
                    self.assertAllEqual(0, self.evaluate(v2_momentum.size()))
예제 #16
0
    def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                   global_step_tensor, saving_listeners):
        """Train a model with the given Estimator Spec."""
        if self._warm_start_settings:
            logging.info('Warm-starting with WarmStartSettings: %s' %
                         (self._warm_start_settings, ))
            warm_starting_util.warm_start(*self._warm_start_settings)
        # Check if the user created a loss summary, and add one if they didn't.
        # We assume here that the summary is called 'loss'. If it is not, we will
        # make another one with the name 'loss' to ensure it shows up in the right
        # graph in TensorBoard.
        # if not any([x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
        #     summary.scalar('loss', estimator_spec.loss)
        ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
        worker_hooks.extend(hooks)
        # worker_hooks.extend([
        #     training.NanTensorHook(estimator_spec.loss)
        # ])

        worker_hooks.extend(estimator_spec.training_hooks)

        if not (estimator_spec.scaffold.saver
                or ops.get_collection(ops.GraphKeys.SAVERS)):
            ops.add_to_collection(
                ops.GraphKeys.SAVERS,
                training.Saver(sharded=True,
                               max_to_keep=self._config.keep_checkpoint_max,
                               keep_checkpoint_every_n_hours=(
                                   self._config.keep_checkpoint_every_n_hours),
                               defer_build=True,
                               save_relative_paths=True))

        chief_hooks = []
        all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
        saver_hooks = [
            h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)
        ]
        if (self._config.save_checkpoints_secs
                or self._config.save_checkpoints_steps):
            if not saver_hooks:
                chief_hooks = [
                    training.CheckpointSaverHook(
                        self._model_dir,
                        save_secs=self._config.save_checkpoints_secs,
                        save_steps=self._config.save_checkpoints_steps,
                        scaffold=estimator_spec.scaffold)
                ]
                saver_hooks = [chief_hooks[0]]
        if saving_listeners:
            if not saver_hooks:
                raise ValueError(
                    'There should be a CheckpointSaverHook to use saving_listeners. '
                    'Please set one of the RunConfig.save_checkpoints_steps or '
                    'RunConfig.save_checkpoints_secs.')
            else:
                # It is expected to have one CheckpointSaverHook. If multiple, we pick
                # up the first one to add listener.
                saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access

        if is_rank0():
            log_step_count_steps = self._config.log_step_count_steps
            checkpoint_dir = self.model_dir
            chief_only_hooks = (tuple(chief_hooks) +
                                tuple(estimator_spec.training_chief_hooks))
        else:
            log_step_count_steps = None
            checkpoint_dir = None
            chief_only_hooks = None

        with MonitoredTrainingSession(
                master=self._config.master,
                is_chief=is_rank0(),
                checkpoint_dir=checkpoint_dir,
                scaffold=estimator_spec.scaffold,
                hooks=worker_hooks,
                chief_only_hooks=chief_only_hooks,
                save_checkpoint_secs=0,  # Saving is handled by a hook.
                save_summaries_steps=self._config.save_summary_steps,
                config=self._session_config,
                log_step_count_steps=log_step_count_steps) as mon_sess:
            loss = None
            while not mon_sess.should_stop():
                _, loss = mon_sess.run(
                    [estimator_spec.train_op, estimator_spec.loss])
        return loss
  def testWarmStartEmbeddingColumnLinearModel(self):
    # Create old and new vocabs for embedding column "sc_vocab".
    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                        "old_vocab")
    new_vocab_path = self._write_vocab(
        ["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
        "new_vocab")

    # Save checkpoint from which to warm-start.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        variable_scope.get_variable(
            "linear_model/sc_vocab_embedding/embedding_weights",
            initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
        variable_scope.get_variable(
            "linear_model/sc_vocab_embedding/weights",
            initializer=[[0.69], [0.71]])
        self._write_checkpoint(sess)

    def _partitioner(shape, dtype):  # pylint:disable=unused-argument
      # Partition each var into 2 equal slices.
      partitions = [1] * len(shape)
      partitions[0] = min(2, shape[0].value)
      return partitions

    # Create feature columns.
    sc_vocab = fc.categorical_column_with_vocabulary_file(
        "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
    emb_vocab = fc.embedding_column(
        categorical_column=sc_vocab,
        dimension=2)
    all_deep_cols = [emb_vocab]
    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = {}
        with variable_scope.variable_scope("", partitioner=_partitioner):
          # Create the variables.
          fc.linear_model(
              features=self._create_dummy_inputs(),
              feature_columns=all_deep_cols,
              cols_to_vars=cols_to_vars)

        # Construct the vocab_info for the embedding weight.
        vocab_info = ws_util.VocabInfo(
            new_vocab=sc_vocab.vocabulary_file,
            new_vocab_size=sc_vocab.vocabulary_size,
            num_oov_buckets=sc_vocab.num_oov_buckets,
            old_vocab=prev_vocab_path,
            # Can't use constant_initializer with load_and_remap.  In practice,
            # use a truncated normal initializer.
            backup_initializer=init_ops.random_uniform_initializer(
                minval=0.42, maxval=0.42))
        ws_util.warm_start(
            self.get_temp_dir(),
            vars_to_warm_start=".*sc_vocab.*",
            var_name_to_vocab_info={
                "linear_model/sc_vocab_embedding/embedding_weights": vocab_info
            })
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started. Var corresponding to
        # emb_vocab should be correctly warm-started after vocab remapping.
        # Missing values are filled in with the EmbeddingColumn's initializer.
        self._assert_cols_to_vars(
            cols_to_vars,
            {
                emb_vocab: [
                    # linear weights part 0.
                    np.array([[0.69]]),
                    # linear weights part 1.
                    np.array([[0.71]]),
                    # embedding_weights part 0.
                    np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
                    # embedding_weights part 1.
                    np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
                ]
            },
            sess)
  def testWarmStartVarsToWarmstartIsNone(self):
    # Create old and new vocabs for sparse column "sc_vocab".
    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                        "old_vocab")
    new_vocab_path = self._write_vocab(
        ["orange", "guava", "banana", "apple", "raspberry",
         "blueberry"], "new_vocab")
    # Create feature columns.
    sc_hash = fc.categorical_column_with_hash_bucket(
        "sc_hash", hash_bucket_size=15)
    sc_keys = fc.categorical_column_with_vocabulary_list(
        "sc_keys", vocabulary_list=["a", "b", "c", "e"])
    sc_vocab = fc.categorical_column_with_vocabulary_file(
        "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
    all_linear_cols = [sc_hash, sc_keys, sc_vocab]

    # Save checkpoint from which to warm-start.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        variable_scope.get_variable(
            "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
        variable_scope.get_variable(
            "some_other_name", shape=[4, 1], initializer=rand())
        variable_scope.get_variable(
            "linear_model/sc_vocab/weights",
            initializer=[[0.5], [1.], [2.], [3.]])
        self._write_checkpoint(sess)

    def _partitioner(shape, dtype):  # pylint:disable=unused-argument
      # Partition each var into 2 equal slices.
      partitions = [1] * len(shape)
      partitions[0] = min(2, shape[0].value)
      return partitions

    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner)
        vocab_info = ws_util.VocabInfo(
            new_vocab=sc_vocab.vocabulary_file,
            new_vocab_size=sc_vocab.vocabulary_size,
            num_oov_buckets=sc_vocab.num_oov_buckets,
            old_vocab=prev_vocab_path)
        ws_util.warm_start(
            self.get_temp_dir(),
            # The special value of None here will ensure that only the variable
            # specified in var_name_to_vocab_info (sc_vocab embedding) is
            # warm-started.
            vars_to_warm_start=None,
            var_name_to_vocab_info={
                ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
            },
            # Even though this is provided, the None value for
            # vars_to_warm_start overrides the logic, and this will not be
            # warm-started.
            var_name_to_prev_var_name={
                ws_util._infer_var_name(cols_to_vars[sc_keys]):
                    "some_other_name"
            })
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started.  Var corresponding to
        # sc_vocab should be correctly warm-started after vocab remapping,
        # and neither of the other two should be warm-started..
        self._assert_cols_to_vars(cols_to_vars, {
            sc_keys: [np.zeros([2, 1]), np.zeros([2, 1])],
            sc_hash: [np.zeros([8, 1]), np.zeros([7, 1])],
            sc_vocab: [
                np.array([[3.], [2.], [1.]]),
                np.array([[0.5], [0.], [0.]])
            ]
        }, sess)
  def testWarmStart_MultipleCols(self):
    # Create vocab for sparse column "sc_vocab".
    vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                   "vocab")

    # Create feature columns.
    sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)
    sc_hash = fc.categorical_column_with_hash_bucket(
        "sc_hash", hash_bucket_size=15)
    sc_keys = fc.categorical_column_with_vocabulary_list(
        "sc_keys", vocabulary_list=["a", "b", "c", "e"])
    sc_vocab = fc.categorical_column_with_vocabulary_file(
        "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
    real = fc.numeric_column("real")
    real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.])
    cross = fc.crossed_column([sc_keys, sc_vocab], hash_bucket_size=20)
    all_linear_cols = [sc_int, sc_hash, sc_keys, sc_vocab, real_bucket, cross]

    # Save checkpoint from which to warm-start.  Also create a bias variable,
    # so we can check that it's also warm-started.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        sc_int_weights = variable_scope.get_variable(
            "linear_model/sc_int/weights", shape=[10, 1], initializer=ones())
        sc_hash_weights = variable_scope.get_variable(
            "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
        sc_keys_weights = variable_scope.get_variable(
            "linear_model/sc_keys/weights", shape=[4, 1], initializer=rand())
        sc_vocab_weights = variable_scope.get_variable(
            "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
        real_bucket_weights = variable_scope.get_variable(
            "linear_model/real_bucketized/weights",
            shape=[5, 1],
            initializer=norms())
        cross_weights = variable_scope.get_variable(
            "linear_model/sc_keys_X_sc_vocab/weights",
            shape=[20, 1],
            initializer=rand())
        bias = variable_scope.get_variable(
            "linear_model/bias_weights",
            shape=[1],
            initializer=rand())
        self._write_checkpoint(sess)
        (prev_int_val, prev_hash_val, prev_keys_val, prev_vocab_val,
         prev_bucket_val, prev_cross_val, prev_bias_val) = sess.run([
             sc_int_weights, sc_hash_weights, sc_keys_weights, sc_vocab_weights,
             real_bucket_weights, cross_weights, bias
         ])

    partitioner = lambda shape, dtype: [1] * len(shape)
    # New graph, new session WITHOUT warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
        sess.run(variables.global_variables_initializer())
        # Without warm-starting, all weights should be initialized using default
        # initializer (which is init_ops.zeros_initializer).
        self._assert_cols_to_vars(cols_to_vars, {
            sc_int: [np.zeros([10, 1])],
            sc_hash: [np.zeros([15, 1])],
            sc_keys: [np.zeros([4, 1])],
            sc_vocab: [np.zeros([4, 1])],
            real_bucket: [np.zeros([5, 1])],
            cross: [np.zeros([20, 1])],
        }, sess)

    # New graph, new session with warm-starting.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
        vocab_info = ws_util.VocabInfo(
            new_vocab=sc_vocab.vocabulary_file,
            new_vocab_size=sc_vocab.vocabulary_size,
            num_oov_buckets=sc_vocab.num_oov_buckets,
            old_vocab=vocab_path)
        ws_util.warm_start(
            self.get_temp_dir(),
            var_name_to_vocab_info={
                "linear_model/sc_vocab/weights": vocab_info
            })
        sess.run(variables.global_variables_initializer())
        # Verify weights were correctly warm-started.
        self._assert_cols_to_vars(cols_to_vars, {
            sc_int: [prev_int_val],
            sc_hash: [prev_hash_val],
            sc_keys: [prev_keys_val],
            sc_vocab: [prev_vocab_val],
            real_bucket: [prev_bucket_val],
            cross: [prev_cross_val],
            "bias": [prev_bias_val],
        }, sess)
예제 #20
0
    def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                   global_step_tensor, saving_listeners,
                                   save_best_ckpt):
        """Train a model with the given Estimator Spec."""
        if self._warm_start_settings:
            logging.info('Warm-starting with WarmStartSettings: %s' %
                         (self._warm_start_settings, ))
            warm_starting_util.warm_start(*self._warm_start_settings)
        worker_hooks.extend(hooks)
        worker_hooks.append(training.NanTensorHook(estimator_spec.loss))
        if self._config.log_step_count_steps is not None:
            tensors = {"loss": estimator_spec.loss, "step": global_step_tensor}
            tensors.update({
                key.replace("/", ""): val
                for key, val in estimator_spec.predictions.items()
                if "/" in key
            })
            worker_hooks.append(
                training.LoggingTensorHook(
                    tensors, every_n_iter=self._config.log_step_count_steps))
        worker_hooks.extend(estimator_spec.training_hooks)

        # Create Saver object
        if not (estimator_spec.scaffold.saver
                or ops.get_collection(ops.GraphKeys.SAVERS)):
            ops.add_to_collection(
                ops.GraphKeys.SAVERS,
                training.Saver(sharded=True,
                               max_to_keep=self._config.keep_checkpoint_max,
                               keep_checkpoint_every_n_hours=(
                                   self._config.keep_checkpoint_every_n_hours),
                               defer_build=True,
                               save_relative_paths=True))

        chief_hooks = []
        all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
        saver_hooks = [
            h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)
        ]
        if (self._config.save_checkpoints_secs
                or self._config.save_checkpoints_steps):
            if not saver_hooks:
                chief_hooks = [
                    training.CheckpointSaverHook(
                        self._model_dir,
                        save_secs=self._config.save_checkpoints_secs,
                        save_steps=self._config.save_checkpoints_steps,
                        scaffold=estimator_spec.scaffold)
                ]
                saver_hooks = [chief_hooks[0]]
        if saving_listeners:
            if not saver_hooks:
                raise ValueError(
                    'There should be a CheckpointSaverHook to use saving_listeners. '
                    'Please set one of the RunConfig.save_checkpoints_steps or '
                    'RunConfig.save_checkpoints_secs.')
            else:
                # It is expected to have one CheckpointSaverHook. If multiple, we pick
                # up the first one to add listener.
                saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access

        if self._train_with_eval:
            self.dataset_handle_hook = IteratorStringHandleHook(
                self.train_iterator, self.eval_iterator)
            worker_hooks.append(self.dataset_handle_hook)
            self._predict_keys = estimator_spec.predictions

        if save_best_ckpt:
            EvaluatorCls = self._params.get("evaluator", None)
            if not issubclass(EvaluatorCls, EvaluateBase):
                raise TypeError(
                    "Parameter `evaluator` must be a EvaluateBase instance, but got {}"
                    .format(type(EvaluatorCls)))
            eval_kwargs = self._params.get("eval_kwargs", {})
            eval_steps = self._params.get("eval_steps", 2500)
            primary_metric = self._params.get("primary_metric", None)
            secondary_metric = self._params.get("secondary_metric", None)

            # We must construct Evaluator inside a graph scope
            evaluator = EvaluatorCls(self, **eval_kwargs)

            worker_hooks.append(
                BestCheckpointSaverHook(evaluator=evaluator,
                                        checkpoint_dir=self._model_dir,
                                        compare_fn=partial(
                                            evaluator.compare,
                                            primary_metric=primary_metric,
                                            secondary_metric=secondary_metric),
                                        tag=self._params["args"].tag,
                                        save_steps=eval_steps))

        # Training session monitor
        with training.MonitoredTrainingSession(
                master=self._config.master,
                is_chief=self._config.is_chief,
                checkpoint_dir=self._model_dir,
                scaffold=estimator_spec.scaffold,
                hooks=worker_hooks,
                chief_only_hooks=(tuple(chief_hooks) +
                                  tuple(estimator_spec.training_chief_hooks)),
                save_checkpoint_secs=0,
                save_summaries_steps=self._config.save_summary_steps,
                config=self._session_config,
                log_step_count_steps=self._config.log_step_count_steps
        ) as mon_sess:
            loss = None

            # Make sure that use self.dataset_handle_hook.xxx_handle after create MonitoredSession()
            self._feed_dict = _add_key_value(
                self._feed_dict, self.handler,
                self.dataset_handle_hook.train_handle)
            while not mon_sess.should_stop():
                _, loss = mon_sess.run(
                    [estimator_spec.train_op, estimator_spec.loss],
                    self._feed_dict)
            return loss