示例#1
0
    def testSaveByDict(self):
        with ops.device(self._dev()):
            v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
            v2 = resource_variable_ops.ResourceVariable(1.0, name='v2')

            def model():
                return array_ops.constant(2.0) * v1 * v2

            ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')

            # Save the variables under different names.
            _ = model()
            saver = _saver.Saver({'ckpt/v1': v1, 'ckpt/v2': v2})
            saver.save(ckpt_prefix)
            v1.assign(2.0)
            v2.assign(2.0)
            self.assertEqual(v1.read_value().numpy(), 2.0)
            self.assertEqual(v2.read_value().numpy(), 2.0)
            # Can still restore it.
            saver.restore(ckpt_prefix)
            self.assertEqual(v1.read_value().numpy(), 1.0)
            self.assertEqual(v1.read_value().numpy(), 1.0)
            # However, cannot restore it with default name.
            with self.assertRaisesOpError('not found in checkpoint'):
                saver = _saver.Saver([v1, v2]).restore(ckpt_prefix)

            # Can specify which variable in ckpt to restore to which variable.
            def map_func(x):
                return {'v3': 'ckpt/v1', 'v4': 'ckpt/v2'}.get(x, x)

            with _saver.restore_variables_on_create(ckpt_prefix, map_func):
                v3 = resource_variable_ops.ResourceVariable(2.0, name='v3')
                v4 = resource_variable_ops.ResourceVariable(2.0, name='v4')
            self.assertEqual(v3.read_value().numpy(), 1.0)
            self.assertEqual(v4.read_value().numpy(), 1.0)
示例#2
0
    def testSaveRestoreGraphCallable(self):
        with ops.device(self._dev()):

            @graph_callable.graph_callable(
                [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
            def model(x):
                v = variable_scope.get_variable(
                    'v', initializer=init_ops.zeros_initializer(), shape=())
                return v + x

            # Default 2 + 0 = 2
            self.assertEqual(
                2,
                model(array_ops.constant(2, dtype=dtypes.float32)).numpy())

            # Save the variable value 0.
            ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
            _saver.Saver(model.variables).save(ckpt_prefix)

            # update variable to 1, so that 2 + 1 = 3
            model.variables[0].assign(1.)
            self.assertEqual(
                3,
                model(array_ops.constant(2, dtype=dtypes.float32)).numpy())

            # load the variable value 0, so that 2 + 0 = 2
            _saver.Saver(model.variables).restore(ckpt_prefix)
            self.assertEqual(
                2,
                model(array_ops.constant(2, dtype=dtypes.float32)).numpy())

            # update checkpoint variable to 1 and memory value to 2.
            model.variables[0].assign(1.)
            _saver.Saver(model.variables).save(ckpt_prefix)
            model.variables[0].assign(2.)
            self.assertEqual(
                4,
                model(array_ops.constant(2, dtype=dtypes.float32)).numpy())

            # reset the graph and reload on create, so that 1 + 2 = 3
            with ops.Graph().as_default():
                with _saver.restore_variables_on_create(ckpt_prefix):

                    @graph_callable.graph_callable([
                        graph_callable.ShapeAndDtype(shape=(),
                                                     dtype=dtypes.float32)
                    ])
                    def model2(x):
                        v = variable_scope.get_variable(
                            'v',
                            initializer=init_ops.zeros_initializer(),
                            shape=())
                        return v + x

                    self.assertEqual(
                        3,
                        model2(array_ops.constant(
                            2, dtype=dtypes.float32)).numpy())
示例#3
0
def train(training_dataset, preprocessing_type, base_model, optimizer, summary_dir,
        ckpt_dir, logging_every_n_steps, summary_every_n_steps, save_every_n_steps,
        restore_ckpt_file_path):
    # 获取 pretained model
    variables = base_model.variables + [tf.train.get_or_create_global_step()]
    saver = model_saver.Saver(variables)

    if restore_ckpt_file_path is not None:
        saver.restore(restore_ckpt_file_path)
    
    if tf.train.latest_checkpoint(ckpt_dir) is not None:
        saver.restore(tf.train.latest_checkpoint(ckpt_dir))
    

    # writer = tf.summary.FileWriter(train_dir, flush_secs=100)
    writer = summary.create_file_writer(summary_dir, flush_millis=100000)
    for i in range(configs['epoches']):
        tf.compat.v1.logging.info('epoch %d starting...' % (i + 1))
        start_time = time.time()
        with writer.as_default(), summary.always_record_summaries():
            train_one_epoch(dataset=training_dataset, base_model=base_model, optimizer=optimizer, preprocessing_type=preprocessing_type,
                        logging_every_n_steps=logging_every_n_steps, summary_every_n_steps=summary_every_n_steps,
                        save_path=ckpt_dir, saver=saver, save_every_n_steps=save_every_n_steps)
        tf.set_random_seed(1)
        end_time = time.time()
        tf.compat.v1.logging.info('epoch %d training finished, costing %d seconds...' % (i, end_time - start_time))
示例#4
0
def _load_from_ckpt_file(model, ckpt_file_path):
    saver = eager_saver.Saver(model.variables)
    for var in model.variables:
        tf.logging.info('restore var {}'.format(var.name))
    if tf.train.latest_checkpoint(ckpt_file_path) is not None:
        saver.restore(tf.train.latest_checkpoint(ckpt_file_path))
    else:
        raise ValueError('unknown ckpt file {}'.format(ckpt_file_path))
示例#5
0
 def testSameObjectOK(self):
     with context.eager_mode(), ops.device(self._dev()):
         v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
         # While different objects with the same shared_name are not good, passing
         # in the same object multiple times is fine.
         saver = _saver.Saver([v1, v1])
         ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
         saver.save(ckpt_prefix)
示例#6
0
  def testRestoreOnCreate(self):
    with ops.device(self._dev()):
      def model(init_val):
        v1 = resource_variable_ops.ResourceVariable(init_val, name='v1')
        return array_ops.constant(1.0) * v1, v1

      ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
      _, v1 = model(1.0)
      saver = _saver.Saver([v1])
      saver.save(ckpt_prefix)

      with ops.Graph().as_default():
        saver = _saver.Saver([v1])
        with _saver.restore_variables_on_create(ckpt_prefix):
          # Value is from checkpoint, but not from argument.
          ret, _ = model(2.0)
          self.assertEqual(ret.numpy(), 1.0)
示例#7
0
 def testSameNameNoClobbering(self):
     with ops.device(self._dev()):
         v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
         v2 = resource_variable_ops.ResourceVariable(2.0, name='v1')
         saver = _saver.Saver([v1, v2])
         ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
         with self.assertRaisesRegexp(ValueError, 'v1'):
             saver.save(ckpt_prefix)
示例#8
0
 def testDifferentGraphError(self):
     with context.eager_mode(), ops.device(self._dev()):
         with ops.Graph().as_default():
             v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
         with ops.Graph().as_default():
             saver = _saver.Saver([v1])
             ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
             with self.assertRaisesRegexp(ValueError, 'Graph'):
                 saver.save(ckpt_prefix)
示例#9
0
  def testRestoreOnCreate(self):
    with context.eager_mode():
      def model(init_val):
        v1 = resource_variable_ops.ResourceVariable(init_val, name='v1')
        return array_ops.constant(1.0) * v1, v1

      ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
      _, v1 = model(1.0)
      saver = _saver.Saver([v1])
      saver.save(ckpt_prefix)

      with ops.Graph().as_default():
        saver = _saver.Saver([v1])
        with saver.maybe_restore_on_create(ckpt_prefix):
          # Value is from checkpoint, but not from argument.
          ret, _ = model(2.0)
          self.assertEqual(ret.numpy(), 1.0)
          # Create it a second time won't re-assign the checkpoint value.
          v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1')
          self.assertEqual(v1_2.read_value().numpy(), 3.0)
示例#10
0
 def testSameNameNoClobbering(self):
   with ops.device(self._dev()):
     # Note that this test purposefully uses Graphs rather than
     # IsolateTest. Users are more likely to accidentally create the same
     # variable name this way.
     first_graph = ops.Graph()
     with first_graph.as_default():
       v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1')
     with ops.Graph().as_default():
       v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1')
       saver = _saver.Saver([v1_first_graph, v1_second_graph])
     ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
     with self.assertRaisesRegexp(ValueError, 'v1'):
       saver.save(ckpt_prefix)
示例#11
0
  def testRestoreNotFound(self):
    with context.eager_mode():
      def model(v):
        return array_ops.constant(1.0) * v

      ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
      v = resource_variable_ops.ResourceVariable(1.0, name='v1')
      _ = model(v)
      saver = _saver.Saver([v])
      saver.save(ckpt_prefix)

      with self.assertRaisesRegexp(errors.NotFoundError,
                                   'v2 not found in checkpoint'):
        with saver.maybe_restore_on_create(ckpt_prefix):
          _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
示例#12
0
 def _optimizer_test_template(self, optimizer):
     """Checks save and restore. Returns the optimizer variables."""
     v = resource_variable_ops.ResourceVariable([[2., 3.]], name='v')
     loss_fn = lambda: v[0, 0]**2 + v[0, 1]**2
     optimizer.minimize(loss_fn)
     optimizer_variables = _saver.get_optimizer_variables(optimizer)
     saver = _saver.Saver(optimizer_variables + [v])
     checkpoint_path = saver.save(self.get_temp_dir())
     optimizer.minimize(loss_fn)
     after_first_minimize = v.numpy()
     # After we restore, the next step should be exactly the same as the one we
     # just did.
     saver.restore(checkpoint_path)
     optimizer.minimize(loss_fn)
     self.assertAllEqual(after_first_minimize, v.numpy())
     return optimizer_variables
示例#13
0
  def testBasics(self):
    with context.eager_mode():
      v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
      def model():
        return array_ops.constant(2.0) * v1

      ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')

      _ = model()
      saver = _saver.Saver([v1])
      saver.save(ckpt_prefix)
      v1.assign(2.0)
      self.assertEqual(v1.read_value().numpy(), 2.0)

      saver.restore(ckpt_prefix)
      self.assertEqual(v1.read_value().numpy(), 1.0)
示例#14
0
def train(
    training_dataset,
    preprocessing_type,
    base_model,
    optimizer,
    logging_every_n_steps,
    save_every_n_steps,
    summary_every_n_steps,
    train_dir,
    ckpt_dir,
    restore_ckpt_file_path,
):
    # 获取 pretrained model
    variables = base_model.variables + [tf.train.get_or_create_global_step()]
    saver = eager_saver.Saver(variables)

    # 命令行指定 ckpt file
    if restore_ckpt_file_path is not None:
        saver.restore(restore_ckpt_file_path)

    # 当前 logs_dir 中的预训练模型,用于继续训练
    if tf.train.latest_checkpoint(ckpt_dir) is not None:
        saver.restore(tf.train.latest_checkpoint(ckpt_dir))

    train_writer = tf.contrib.summary.create_file_writer(train_dir,
                                                         flush_millis=100000)
    for i in range(CONFIG['epochs']):
        tf_logging.info('epoch %d starting...' % (i + 1))
        start = time.time()
        with train_writer.as_default(), summary.always_record_summaries():
            train_one_epoch(
                dataset=training_dataset,
                base_model=base_model,
                optimizer=optimizer,
                preprocessing_type=preprocessing_type,
                logging_every_n_steps=logging_every_n_steps,
                summary_every_n_steps=summary_every_n_steps,
                saver=saver,
                save_every_n_steps=save_every_n_steps,
                save_path=ckpt_dir,
            )
        tf.set_random_seed(1)
        train_end = time.time()
        tf_logging.info('epoch %d training finished, costing %d seconds...' %
                        (i + 1, train_end - start))