def testDeferredSlotRestoration(self):
        with self.test_session():
            checkpoint_directory = self.get_temp_dir()

            root = trackable_utils.Checkpoint()
            root.var = trackable_utils.add_variable(root,
                                                    name="var",
                                                    initializer=0.)
            optimizer = adam.Adam(0.1)
            variables = [root.var]
            gradients = [1.]
            train_op = optimizer.apply_gradients(zip(gradients, variables))
            # Note that `optimizer` has not been added as a dependency of
            # `root`. Create a one-off grouping so that slot variables for `root.var`
            # get initialized too.
            self.evaluate(
                trackable_utils.gather_initializers(
                    trackable_utils.Checkpoint(root=root,
                                               optimizer=optimizer)))
            self.evaluate(train_op)
            self.evaluate(state_ops.assign(root.var, 12.))
            no_slots_path = root.save(
                os.path.join(checkpoint_directory, "no_slots"))
            root.optimizer = optimizer
            self.evaluate(state_ops.assign(root.var, 13.))
            self.evaluate(
                state_ops.assign(
                    optimizer.get_slot(slot_name="m", var=root.var), 14.))
            slots_path = root.save(
                os.path.join(checkpoint_directory, "with_slots"))
            new_root = trackable_utils.Checkpoint()
            # Load the slot-containing checkpoint (deferred), then immediately
            # overwrite the non-slot variable (also deferred).
            slot_status = new_root.restore(slots_path)
            no_slot_status = new_root.restore(no_slots_path)
            with self.assertRaises(AssertionError):
                no_slot_status.assert_consumed()
            new_root.var = trackable_utils.add_variable(new_root,
                                                        name="var",
                                                        shape=[])
            no_slot_status.assert_consumed()
            no_slot_status.run_restore_ops()
            self.assertEqual(12., self.evaluate(new_root.var))
            new_root.optimizer = adam.Adam(0.1)
            slot_status.assert_existing_objects_matched()
            if not context.executing_eagerly():
                with self.assertRaisesRegex(AssertionError,
                                            "Unresolved object"):
                    slot_status.assert_consumed()
            self.assertEqual(12., self.evaluate(new_root.var))
            if context.executing_eagerly():
                # Slot variables are only created with restoring initializers when
                # executing eagerly.
                self.assertEqual(
                    14.,
                    self.evaluate(
                        new_root.optimizer.get_slot(slot_name="m",
                                                    var=new_root.var)))
            else:
                # Slot variables are not created eagerly when graph building.
                with self.assertRaises(KeyError):
                    new_root.optimizer.get_slot(slot_name="m",
                                                var=new_root.var)
            variables = [new_root.var]
            gradients = [1.]
            train_op = new_root.optimizer.apply_gradients(
                zip(gradients, variables))
            # The slot variable now exists; restore() didn't create it, but we should
            # now have a restore op for it.
            slot_status.run_restore_ops()
            if not context.executing_eagerly():
                # The train op hasn't run when graph building, so the slot variable has
                # its restored value. It has run in eager, so the value will
                # be different.
                self.assertEqual(
                    14.,
                    self.evaluate(
                        new_root.optimizer.get_slot(slot_name="m",
                                                    var=new_root.var)))
            self.evaluate(train_op)
            slot_status.assert_consumed()
Exemple #2
0
  def start(self):
    """Starts the evaluation loop."""
    optimizer_checkpoint = tracking_util.Checkpoint(iter=self._iterations)
    checkpoint = tracking_util.Checkpoint(
        model=self.model, optimizer=optimizer_checkpoint)

    for latest_checkpoint in checkpoint_utils.checkpoints_iterator(
        self.checkpoint_dir):
      try:
        # `expect_partial` because the checkpoint can have other `Trackable`s
        # such as `optimizer`.
        checkpoint.restore(latest_checkpoint).expect_partial()
        checkpoint_attributes = list_checkpoint_attributes(latest_checkpoint)
        # The checkpoint should contain model and optimizer for SidecarEvaluator
        # to work. But the model weights saved by ModelCheckpoint callback does
        # not contain model as an attribute. To make SidecarEvaluator compatibly
        # work in this case, use model.load_weights to load the model's weights,
        # while self._iterations is still restored by checkpoint variable.
        if 'model' not in checkpoint_attributes:
          self.model.load_weights(latest_checkpoint)
        # The model checkpoint might not include optimizer in cases, e.g.
        # using a custom training loop. Directly assign the iterations
        # property to be used in callbacks.
        if self.model.optimizer:
          self.model.optimizer.iterations.assign(self._iterations)
      except (errors_impl.OpError,) as e:
        # A couple errors can happen here with the coordinator racing to write
        # checkpoint:
        # 1) OpError: open failed for <file path>: No such file or directory
        # 2) NotFoundError (subclass of OpError): Unsuccessful
        # TensorSliceReader constructor.
        # TODO(rchao): Remove this except block once b/150954027 is resolved.
        logging.info(
            'SidecarEvaluator has an error loading '
            'checkpoint: %s. Retrying. Error: %s: %s', latest_checkpoint,
            e.__class__.__name__, e)
        continue

      if self._iterations.numpy() == _ITERATIONS_UNINITIALIZED:
        raise RuntimeError(
            '`iterations` cannot be loaded from the '
            'checkpoint file. Please ensure `iterations` is '
            'tracked in the `checkpoint` saved by the coordinator.')

      logging.info(
          'Evaluation starts: Model weights loaded from latest '
          'checkpoint file: %s.', latest_checkpoint)

      self.model.evaluate(
          self.data, steps=self.steps, callbacks=self.callbacks, verbose=2)

      return_metrics = {}
      for metric in self.model.metrics:
        result = metric.result()
        if isinstance(result, dict):
          return_metrics.update(result)
        else:
          return_metrics[metric.name] = result

      logging.info(
          'End of evaluation. Metrics: %s', ' '.join([
              '{}={}'.format(name, value.numpy())
              for name, value in return_metrics.items()
          ]))

      # TODO(rchao): Make the max evaluation robust in case users save the
      # checkpoints with epoch format {epoch:03d}.
      if (self.max_evaluations and
          latest_checkpoint.endswith('-{}'.format(self.max_evaluations))):
        # Exit the loop because we have evaluated the final checkpoint file.
        logging.info('Last checkpoint evaluated. SidecarEvaluator stops.')
        return
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                maintenance_event=None,
                training_finished=None):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    def mock_request_compute_metadata(*args, **kwargs):
      del kwargs  # Unused.
      if args[0] == 'instance/maintenance-event':
        if (not maintenance_event.is_set()) and (
            strategy.cluster_resolver.task_id
            == 1) and (random.randrange(0, 9) > 6):
          maintenance_event.set()

          logging.info('Maintenance notice available.')
          return 'TERMINATE_ON_HOST_MAINTENANCE'
        else:
          return 'NONE'

      return ''

    with mock.patch.object(gce_util, 'request_compute_metadata',
                           mock_request_compute_metadata), mock.patch.object(
                               gce_util, 'detect_platform',
                               lambda: gce_util.PlatformDevice.GCE_GPU):

      class Model(module.Module):

        def __init__(self):
          self.v = variables_lib.Variable(
              0.,
              synchronization=variables_lib.VariableSynchronization.ON_WRITE,
              aggregation=variables_lib.VariableAggregation.SUM)

        @def_function.function(input_signature=[])
        def __call__(self):
          return self.v.read_value()

      with strategy.scope():
        model = Model()
        fh_ckpt = tracking_util.Checkpoint(model=model)

        failure_handler = failure_handling.CoordinatedCheckpointManager(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Restored training at %d', failure_handler.total_runs)
      for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH,
                         EPOCHS_TO_RUN):

        for step in range(failure_handler.total_runs % STEPS_PER_EPOCH,
                          STEPS_PER_EPOCH):
          failure_handler.run(distributed_train_step, epoch, step)

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)

      training_finished.set()

      pre_del_thread_count = threading.activeCount()
      failure_handler.__del__()
      self.assertLessEqual(threading.activeCount(), pre_del_thread_count - 1)
 def testNamingWithOptimizer(self):
     input_value = constant_op.constant([[3.]])
     model = MyModel()
     # A nuisance Model using the same optimizer. Its slot variables should not
     # go in the checkpoint, since it is never depended on.
     other_model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     optimizer_step = training_util.get_or_create_global_step()
     root_trackable = trackable_utils.Checkpoint(
         optimizer=optimizer, model=model, optimizer_step=optimizer_step)
     if context.executing_eagerly():
         optimizer.minimize(lambda: model(input_value),
                            global_step=optimizer_step)
         optimizer.minimize(lambda: other_model(input_value),
                            global_step=optimizer_step)
     else:
         train_op = optimizer.minimize(model(input_value),
                                       global_step=optimizer_step)
         optimizer.minimize(other_model(input_value),
                            global_step=optimizer_step)
         self.evaluate(trackable_utils.gather_initializers(root_trackable))
         self.evaluate(train_op)
     named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
         root_trackable).serialize_object_graph()
     expected_checkpoint_names = (
         # Created in the root node, so no prefix.
         "optimizer_step",
         "model/_second/kernel",
         "model/_named_dense/kernel",
         "model/_named_dense/bias",
         # non-Layer dependency of the model
         "model/_non_layer/a_variable",
         # The optimizer creates two non-slot variables
         "optimizer/beta1_power",
         "optimizer/beta2_power",
         # Slot variables
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
     )
     suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
     expected_checkpoint_names = [
         name + suffix for name in expected_checkpoint_names
     ]
     named_variables = {v.name: v for v in named_variables}
     six.assertCountEqual(self, expected_checkpoint_names,
                          named_variables.keys())
     # Check that we've mapped to the right variable objects (not exhaustive)
     self.assertEqual("global_step",
                      named_variables["optimizer_step" + suffix].full_name)
     self.assertEqual(
         "my_model/dense_1/kernel",
         named_variables["model/_second/kernel" + suffix].full_name)
     self.assertEqual(
         "my_model/dense/kernel",
         named_variables["model/_named_dense/kernel" + suffix].full_name)
     self.assertEqual(
         "beta1_power",
         named_variables["optimizer/beta1_power" + suffix].full_name)
     self.assertEqual(
         "beta2_power",
         named_variables["optimizer/beta2_power" + suffix].full_name)
     # Spot check the generated protocol buffers.
     self.assertEqual("optimizer",
                      serialized_graph.nodes[0].children[1].local_name)
     optimizer_node = serialized_graph.nodes[
         serialized_graph.nodes[0].children[1].node_id]
     self.assertEqual("beta1_power", optimizer_node.children[0].local_name)
     self.assertEqual(
         "beta1_power", serialized_graph.nodes[
             optimizer_node.children[0].node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].full_name)
     # We strip off the :0 suffix, as variable.name-based saving does.
     self.assertEqual(
         "my_model/dense/kernel/Adam",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel/Adam:0",
         optimizer.get_slot(var=model._named_dense.kernel, name="m").name)
     self.assertEqual(
         "model/_named_dense/kernel" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].checkpoint_key)
     self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
     self.assertEqual(
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].checkpoint_key)
Exemple #5
0
 def _create_and_call():
     checkpoint = util.Checkpoint(m=_LazyTrivialObjects())
     checkpoint.m()
     checkpoint.restore(checkpoint_path)
Exemple #6
0
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  training_started_event=None,
                  raise_app_error_on_worker=None,
                  termination_config=failure_handling.TerminationConfig()):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        class Model(module.Module):
            def __init__(self):
                self.v = variables_lib.Variable(
                    0.,
                    synchronization=variables_lib.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variables_lib.VariableAggregation.SUM)

            @def_function.function(input_signature=[])
            def __call__(self):
                return self.v.read_value()

        with mock.patch.object(gce_util, 'on_gcp', lambda: False):

            with strategy.scope():
                model = Model()
                # Named it fh_ckpt because it'd be better that the user have their
                # regular checkpoint separate from the checkpoint for
                # WorkerPreemptionHandler, since we will create CheckpointManager
                # to manage the checkpoint and only one CheckpointManager should be
                # active in a particular directory at a time.
                fh_ckpt = tracking_util.Checkpoint(model=model)

                worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
                    strategy.cluster_resolver, fh_ckpt, checkpoint_dir,
                    termination_config)

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    if distribution_strategy_context.get_distribution_strategy(
                    ).cluster_resolver.task_id == raise_app_error_on_worker:
                        raise errors_impl.ResourceExhaustedError(
                            node_def=None,
                            op=None,
                            message='Running out of resources')

                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Restored training at %d',
                         worker_preemption_watcher.total_runs)
            for epoch in range(
                    worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    worker_preemption_watcher.run(distributed_train_step,
                                                  epoch, step)
                # Add some randomness to when preemption actually happens. We should
                # trigger it for sure if the training is coming to an end and it hasn't
                # been triggered yet.
                if epoch >= EPOCHS_TO_RUN - 2:
                    trigger_it = True
                else:
                    trigger_it = False

                self._maybe_trigger_a_preemption(training_started_event,
                                                 trigger_it)

            logging.info('Training finished.')

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)
 def testSaveRestore(self):
     with self.test_session():
         model = MyModel()
         optimizer = adam.AdamOptimizer(0.001)
         root_trackable = trackable_utils.Checkpoint(optimizer=optimizer,
                                                     model=model)
         input_value = constant_op.constant([[3.]])
         if context.executing_eagerly():
             optimizer.minimize(lambda: model(input_value))
         else:
             train_op = optimizer.minimize(model(input_value))
             # TODO(allenl): Make initialization more pleasant when graph building.
             root_trackable.save_counter  # pylint: disable=pointless-statement
             self.evaluate(
                 trackable_utils.gather_initializers(root_trackable))
             self.evaluate(train_op)
         prefix = os.path.join(self.get_temp_dir(), "ckpt")
         self.evaluate(
             state_ops.assign(model._named_dense.variables[1], [42.]))
         m_bias_slot = optimizer.get_slot(model._named_dense.variables[1],
                                          "m")
         self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
         save_path = root_trackable.save(file_prefix=prefix)
         self.evaluate(
             state_ops.assign(model._named_dense.variables[1], [43.]))
         self.evaluate(state_ops.assign(root_trackable.save_counter, 3))
         optimizer_variables = self.evaluate(optimizer.variables())
         self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
         # Immediate restoration
         status = root_trackable.restore(
             save_path=save_path).assert_consumed()
         status.run_restore_ops()
         self.assertAllEqual([42.],
                             self.evaluate(model._named_dense.variables[1]))
         self.assertAllEqual(1, self.evaluate(root_trackable.save_counter))
         self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
         if not context.executing_eagerly():
             return  # Restore-on-create is only supported when executing eagerly
         on_create_model = MyModel()
         on_create_optimizer = adam.AdamOptimizer(
             0.001,
             # Preserve beta1_power and beta2_power when applying gradients
             # so we can test that they've been restored correctly.
             beta1=1.0,
             beta2=1.0)
         on_create_root = trackable_utils.Checkpoint(
             optimizer=on_create_optimizer, model=on_create_model)
         # Deferred restoration
         status = on_create_root.restore(save_path=save_path)
         status.assert_nontrivial_match()
         status.assert_existing_objects_matched()
         with self.assertRaises(AssertionError):
             status.assert_consumed()
         on_create_model(constant_op.constant([[3.]]))  # create variables
         self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
         self.assertAllEqual([42.],
                             self.evaluate(
                                 on_create_model._named_dense.variables[1]))
         on_create_m_bias_slot = on_create_optimizer.get_slot(
             on_create_model._named_dense.variables[1], "m")
         status.assert_existing_objects_matched()
         with self.assertRaises(AssertionError):
             status.assert_consumed()
         # Optimizer slot variables are created when the original variable is
         # restored.
         self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
         self.assertAllEqual(optimizer_variables[2:],
                             self.evaluate(on_create_optimizer.variables()))
         dummy_var = variables.Variable([1.])
         on_create_optimizer.minimize(loss=dummy_var.read_value)
         status.assert_existing_objects_matched()
         status.assert_consumed()
         beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators(
         )
         self.assertAllEqual(optimizer_variables[0],
                             self.evaluate(beta1_power))
         self.assertAllEqual(optimizer_variables[1],
                             self.evaluate(beta2_power))
  def test_checkpoint_restore_before_variable_creation(self):
    self.skip_if_oss()

    class TestModule(module.Module):

      def __init__(self, initializer, rows):
        self._initializer = initializer
        self._rows = rows

        table = tpu_embedding_v2_utils.TableConfig(
            vocabulary_size=self._rows,
            dim=4,
            initializer=self._initializer,
            combiner='sum',
            name='table')
        feature_config = (tpu_embedding_v2_utils.FeatureConfig(
            table=table, name='feature'),)
        optimizer = tpu_embedding_v2_utils.SGD()

        self.tpu_embedding = tpu_embedding_v2.TPUEmbedding(
            feature_config, optimizer)

      def create_embedding(self):
        # We aren't training so batch_size here doesn't matter.
        self.tpu_embedding.build(64)

    strategy = self._get_strategy()
    with strategy.scope():
      module1 = TestModule(init_ops_v2.Ones(),
                           strategy.num_replicas_in_sync * 2)
      module1.create_embedding()

    checkpoint = util.Checkpoint(test_module=module1)
    checkpoint.save(self._get_tmpdir('restore_before_create', 'save'))

    # Reinitialize the tpu
    strategy = self._get_strategy()

    with strategy.scope():
      module2 = TestModule(init_ops_v2.Zeros(),
                           strategy.num_replicas_in_sync * 2)

    checkpoint = util.Checkpoint(test_module=module2)
    checkpoint.restore(self._get_tmpdir('restore_before_create', 'save-1'))

    with strategy.scope():
      module2.create_embedding()

    def get_values(mid):
      return mid._variables['table']['parameters'].variables[0].numpy()

    self.assertAllClose(
        np.ones((strategy.num_replicas_in_sync * 2, 4)),
        get_values(module2.tpu_embedding))

    # Fetch the values from the TPU to check that they are the same.
    module2.tpu_embedding._retrieve_variables()

    self.assertAllClose(
        np.ones((strategy.num_replicas_in_sync * 2, 4)),
        get_values(module2.tpu_embedding))
  def test_model_export_cpu(self):
    strategy = self._get_strategy()
    num_rows = strategy.num_replicas_in_sync

    with strategy.scope():
      first_mid_level_contents = np.ones((num_rows, 4))
      first_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(first_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)

      first_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, first_mid_level_optimizer)

      first_mid_level.build(64)

    cpu_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
    cpu_mid_level = tpu_embedding_v2.TPUEmbedding(feature_config,
                                                  cpu_mid_level_optimizer)

    cpu_mid_level.build(64)

    first_mid_level._load_variables()

    tpu_checkpoint = util.Checkpoint(model=first_mid_level)
    tpu_checkpoint.save(self._get_tmpdir('export_cpu', 'save'))

    # We restore the checkpoint of our tpu mid level onto our cpu mid level.
    cpu_checkpoint = util.Checkpoint(model=cpu_mid_level)
    cpu_checkpoint.restore(self._get_tmpdir('export_cpu', 'save-1'))

    @def_function.function
    def serve_tensors(features):
      features = tpu_embedding_v2.cpu_embedding_lookup(
          features, None, cpu_mid_level.embedding_tables,
          cpu_mid_level._feature_config)
      return features[0]

    signatures = {
        'serving_default':
            serve_tensors.get_concrete_function((tensor_spec.TensorSpec(
                shape=(2,), dtype=dtypes.int32, name='feature'),))
    }
    save.save(
        cpu_mid_level,
        export_dir=self._get_tmpdir('export_cpu', 'exported_model'),
        signatures=signatures)

    imported = load.load(self._get_tmpdir('export_cpu', 'exported_model'))
    predict_fn = imported.signatures['serving_default']

    input_feature_value = np.array([1, 0])
    input_batch = (constant_op.constant(
        input_feature_value, dtype=dtypes.int32),)
    prediction = predict_fn(*input_batch)['output_0']
    self.assertAllClose(prediction.numpy(),
                        first_mid_level_contents[input_feature_value])
Exemple #10
0
    def testCheckpoint(self, strategy_fn, save_with_ls, restore_with_ls):
        class MySGD(gradient_descent.SGD):
            """A custom optimizer that tracks an extra variable."""
            def __init__(self, *args, **kwargs):
                super(MySGD, self).__init__(*args, **kwargs)
                self.my_var = variables.Variable(0.)
                self._track_trackable(self.my_var, 'my_var')

        strategy = strategy_fn()
        replicas = strategy.num_replicas_in_sync
        if (isinstance(strategy, mirrored_strategy.MirroredStrategy)
                and not context.executing_eagerly()):
            # TODO(b/121381184): Enable running the test in this case.
            return

        with self.test_session(), strategy.scope():
            # Build and run a simple model.
            var = variables.Variable([2.0])
            opt = inner_opt = MySGD(1., momentum=1.)
            if save_with_ls:
                loss_scale = loss_scale_module.DynamicLossScale(
                    initial_loss_scale=1., increment_period=2., multiplier=2.)
                opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
            run_fn = lambda: opt.minimize(lambda: var / replicas + 1.,
                                          var_list=[var])
            opt_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self.evaluate(strategy.experimental_local_results(opt_op))

            # Assert values.
            self.assertEqual(self.evaluate(var), 1.)
            if save_with_ls:
                self.assertEqual(self.evaluate(loss_scale()), 1.)
                self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
            slot_var = opt.get_slot(var, 'momentum')
            self.assertEqual(self.evaluate(slot_var).item(), -1)
            self.assertEqual(self.evaluate(opt.iterations), 1)

            # Set optimizer variable to check arbitrary optimizer attributes can be
            # saved/restored
            self.evaluate(inner_opt.my_var.assign(1.))

            # Save a checkpoint.
            checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
            prefix = os.path.join(self.get_temp_dir(), 'ckpt')
            save_path = checkpoint.save(prefix)

            # Create new model
            var = variables.Variable([2.0])
            opt = inner_opt = MySGD(1., momentum=1.)
            if restore_with_ls:
                loss_scale = loss_scale_module.DynamicLossScale(
                    initial_loss_scale=1., increment_period=2., multiplier=2.)
                opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)

            # Restore new model.
            checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
            status = checkpoint.restore(save_path)
            if save_with_ls:
                status.assert_existing_objects_matched()
            else:
                status.assert_nontrivial_match()

            # Assert restored values. We can only assert in eager mode since the
            # variables are uninitialized in graph mode
            if context.executing_eagerly():
                self.assertEqual(self.evaluate(var), 1.)
                if save_with_ls and restore_with_ls:
                    self.assertEqual(self.evaluate(loss_scale()), 1.)
                    self.assertEqual(self.evaluate(loss_scale._num_good_steps),
                                     1)
                elif restore_with_ls:
                    self.assertEqual(self.evaluate(loss_scale()), 1.)
                    self.assertEqual(self.evaluate(loss_scale._num_good_steps),
                                     0)
                self.assertEqual(self.evaluate(opt.iterations), 1)

            # Run the model again.
            run_fn = lambda: opt.minimize(lambda: var / replicas + 1.,
                                          var_list=[var])
            opt_op = strategy.experimental_run(run_fn)

            # Assert new values.
            self.evaluate(variables.global_variables_initializer())
            status.run_restore_ops()
            self.evaluate(strategy.experimental_local_results(opt_op))
            self.assertEqual(self.evaluate(var), -1)
            slot_var = opt.get_slot(var, 'momentum')
            self.assertEqual(self.evaluate(slot_var).item(), -2)
            self.assertEqual(self.evaluate(opt.iterations), 2)
            self.assertEqual(self.evaluate(inner_opt.my_var), 1)

            # Restore model again to test restoring after slots are created
            status = checkpoint.restore(save_path)
            if save_with_ls and restore_with_ls:
                status.assert_consumed()
            elif save_with_ls:
                status.assert_existing_objects_matched()
            elif restore_with_ls:
                status.assert_nontrivial_match()
            status.run_restore_ops()
            self.assertEqual(self.evaluate(var), 1)
            self.assertEqual(self.evaluate(slot_var).item(), -1)
  def test_checkpoint_restore_loads(self):
    strategy = self._get_strategy()
    num_rows = strategy.num_replicas_in_sync

    def get_values(mid):
      return ops.convert_to_tensor(
          mid._variables['table']['parameters'].variables[0])

    with strategy.scope():
      first_mid_level_contents = np.ones((num_rows, 4))
      first_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(first_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)

      first_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, first_mid_level_optimizer)
      first_mid_level.build(64)

    first_mid_level._load_variables()

    first_checkpoint = util.Checkpoint(model=first_mid_level)
    first_checkpoint.save(self._get_tmpdir('restore', 'save'))

    tpu_strategy_util.initialize_tpu_system(self.resolver)

    with strategy.scope():
      second_mid_level_contents = np.ones((num_rows, 4)) * 2
      second_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(second_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)
      second_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, second_mid_level_optimizer)
      second_mid_level.build(64)

    second_mid_level._load_variables()

    self.assertAllClose(
        second_mid_level_contents,
        get_values(second_mid_level),
        msg='Second mid level api should contain its initial values.',
    )
    # We restore the checkpoint of our first model into our second model.
    # This should load the first mid level API object onto the TPU.
    second_checkpoint = util.Checkpoint(model=second_mid_level)
    second_checkpoint.restore(self._get_tmpdir('restore', 'save-1'))

    # Call retrieve here as a way to check what the TPU contains.
    # Calling the retrieve ops directly might make for a cleaner separation of
    # test and module, though.
    second_mid_level._retrieve_variables()

    self.assertAllClose(
        first_mid_level_contents,
        get_values(second_mid_level),
        msg='Second mid level api should have retrieved the first model values.'
    )
Exemple #12
0
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                maintenance_event=None,
                training_finished=None,
                frequent_send=False,
                training_restarted=None,
                termination_config=failure_handling.TerminationConfig(
                    time_till_termination=0)):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    def mock_termination_watcher_function_gce(*args, **kwargs):
      del args, kwargs
      if not frequent_send:
        time.sleep(1)
        if (not maintenance_event.is_set()) and (random.randrange(0, 7) == 5):
          maintenance_event.set()
          logging.info('Termination notice available.')
          return True

      elif frequent_send and not maintenance_event.is_set():
        logging.info('Termination notice available.')
        return True

      return False

    with mock.patch.object(
        gce_util, 'termination_watcher_function_gce',
        mock_termination_watcher_function_gce), mock.patch.object(
            gce_util, 'detect_platform',
            lambda: gce_util.PlatformDevice.GCE_GPU):

      class Model(module.Module):

        def __init__(self):
          self.v = variables_lib.Variable(
              0.,
              synchronization=variables_lib.VariableSynchronization.ON_WRITE,
              aggregation=variables_lib.VariableAggregation.SUM)

        @def_function.function(input_signature=[])
        def __call__(self):
          return self.v.read_value()

      with strategy.scope():
        model = Model()
        fh_ckpt = tracking_util.Checkpoint(model=model)

        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir,
            termination_config)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Start training at %d', worker_preemption_watcher.total_runs)

      # If the training process has been restarted, verify that the expected
      # number of checkpoints have been written.
      # We also want to check training_finished, because there's a corner case
      # where the signal is sent quite late and training finishes before the
      # grace period ends.
      if training_restarted.is_set() and not training_finished.is_set():
        match_group = [
            re.search(r'.*ckpt-(\d+).index', a_file)
            for a_file in gfile.ListDirectory(checkpoint_dir)
        ]
        checkpoint_index = [
            a_match.group(1) for a_match in match_group if a_match
        ]
        if termination_config.time_till_termination > 0:
          # Two checkpoints were saved for the extended grace period.
          self.assertEqual(int(checkpoint_index[0]), 2)
        else:
          self.assertEqual(int(checkpoint_index[0]), 1)

      for epoch in range(
          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
          EPOCHS_TO_RUN):

        for step in range(
            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
            STEPS_PER_EPOCH):
          worker_preemption_watcher.run(distributed_train_step, epoch, step)

      logging.info('Training finished.')
      training_finished.set()

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)

      running_threads = test_util.get_running_threads()
      if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                              running_threads) and test_util.has_thread(
                                  _LOCAL_WATCHER_THREAD_PREFIX,
                                  running_threads):
        try:
          # Explicitly call __del__ since making it None and gc.collect does
          # not invoke __del__ here.
          worker_preemption_watcher.__del__()

          time.sleep(2)

          running_threads = test_util.get_running_threads()
          self.assertFalse(
              test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                   running_threads))
          self.assertFalse(
              test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                   running_threads))

        except urllib.error.URLError as e:
          if 'Temporary failure in name resolution' in e.message:
            # This is caused by a weird flakiness that mock.patch does not
            # correctly patch gce_util.request_compute_metadata, a real request
            # is attempted, and an error is hit in
            # gce_util.request_compute_metadata
            logging.warning('Hit a mock issue.')
            return
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  maintenance_event,
                  training_finished,
                  frequent_send=False):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        def mock_request_compute_metadata(*args, **kwargs):
            del kwargs  # Unused.
            if args[0] == 'instance/maintenance-event':
                if not frequent_send:
                    time.sleep(1)
                    if (not maintenance_event.is_set()) and (random.randrange(
                            0, 20) > 18):
                        maintenance_event.set()
                        logging.info('Maintenance notice available.')
                        return 'TERMINATE_ON_HOST_MAINTENANCE'
                elif frequent_send and not maintenance_event.is_set():
                    return 'TERMINATE_ON_HOST_MAINTENANCE'

            return 'NONE'

        with mock.patch.object(
                gce_util, 'request_compute_metadata',
                mock_request_compute_metadata), mock.patch.object(
                    gce_util, 'detect_platform',
                    lambda: gce_util.PlatformDevice.GCE_GPU):

            class Model(module.Module):
                def __init__(self):
                    self.v = variables_lib.Variable(
                        0.,
                        synchronization=variables_lib.VariableSynchronization.
                        ON_WRITE,
                        aggregation=variables_lib.VariableAggregation.SUM)

                @def_function.function(input_signature=[])
                def __call__(self):
                    return self.v.read_value()

            with strategy.scope():
                model = Model()
                fh_ckpt = tracking_util.Checkpoint(model=model)

                failure_handler = failure_handling.CoordinatedCheckpointManager(
                    strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d', failure_handler.total_runs)
            for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH,
                               EPOCHS_TO_RUN):

                for step in range(failure_handler.total_runs % STEPS_PER_EPOCH,
                                  STEPS_PER_EPOCH):
                    failure_handler.run(distributed_train_step, epoch, step)

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)

            training_finished.set()

            running_threads = test_util.get_running_threads()
            strategy.gather(constant_op.constant([10]), axis=0)
            self.assertTrue(
                test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                     running_threads))
            self.assertTrue(
                test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                     running_threads))

            strategy.gather(constant_op.constant([10]), axis=0)

            # Explicitly call __del__ since making it None and gc.collect does
            # not invoke __del__ here.
            failure_handler.__del__()

            time.sleep(2)

            running_threads = test_util.get_running_threads()
            self.assertFalse(
                test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                     running_threads))
            self.assertFalse(
                test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                     running_threads))
Exemple #14
0
  def start(self):
    """Starts the evaluation loop."""
    optimizer_checkpoint = tracking_util.Checkpoint(iter=self._iterations)
    checkpoint = tracking_util.Checkpoint(
        model=self.model, optimizer=optimizer_checkpoint)

    for latest_checkpoint in checkpoint_utils.checkpoints_iterator(
        self.checkpoint_dir):
      try:
        # `expect_partial` because the checkpoint can have other `Trackable`s
        # such as `optimizer`.
        checkpoint.restore(latest_checkpoint).expect_partial()
        checkpoint_attributes = list_checkpoint_attributes(latest_checkpoint)
        # The checkpoint should contain model and optimizer for SidecarEvaluator
        # to work. But the model weights saved by ModelCheckpoint callback does
        # not contain model as an attribute. To make SidecarEvaluator compatibly
        # work in this case, if model attribute is not found but
        # layer_with_weights attribute is found, use model.load_weights to load
        # the model's weights, while self._iterations is still restored by
        # checkpoint variable.
        if 'model' not in checkpoint_attributes:
          for attribute in checkpoint_attributes:
            # check whether the checkpoint has the required attributes for
            # model.load_weights to work.
            if re.match(r'^layer_with_weights-[\d+]', attribute) is not None:
              self.model.load_weights(latest_checkpoint)
              break
      except (errors_impl.OpError,) as e:
        # A couple errors can happen here with the coordinator racing to write
        # checkpoint:
        # 1) OpError: open failed for <file path>: No such file or directory
        # 2) NotFoundError (subclass of OpError): Unsuccessful
        # TensorSliceReader constructor.
        # TODO(rchao): Remove this except block once b/150954027 is resolved.
        logging.info(
            'SidecarEvaluator has an error loading '
            'checkpoint: %s. Retrying. Error: %s: %s', latest_checkpoint,
            e.__class__.__name__, e)
        continue

      if self._iterations.numpy() == _ITERATIONS_UNINITIALIZED:
        raise RuntimeError(
            '`iterations` cannot be loaded from the '
            'checkpoint file. Please ensure `iterations` is '
            'tracked in the `checkpoint` saved by the coordinator.')

      logging.info(
          'Evaluation starts: Model weights loaded from latest '
          'checkpoint file: %s.', latest_checkpoint)

      # TODO(rchao): Support arbitrary callback for extensibility.
      self.model.evaluate(self.data, steps=self.steps)

      logging.info(
          'End of evaluation. Metrics: %s', ' '.join([
              '{}={}'.format(metric.name,
                             metric.result().numpy())
              for metric in self.model.metrics
          ]))

      if self._summary_writer:
        with summary_ops_v2.record_if(True), self._summary_writer.as_default():
          for metric in self.model.metrics:
            summary_ops_v2.scalar(
                metric.name,
                metric.result(),
                step=self._iterations.read_value())

      # TODO(rchao): Make the max evaluation robust in case users save the
      # checkpoints with epoch format {epoch:03d}.
      if (self.max_evaluations and
          latest_checkpoint.endswith('-{}'.format(self.max_evaluations))):
        # Exit the loop because we have evaluated the final checkpoint file.
        logging.info('Last checkpoint evaluated. SidecarEvaluator stops.')
        return
    def test_initialize_if_not_restoring(self):
        with self.test_session():
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
            with test_util.device(use_gpu=True):
                model = MyModel()
                optimizer = adam.Adam(0.001)
                root = trackable_utils.Checkpoint(
                    model=model
                )  # Do not save the optimizer with the checkpoint.
                optimizer_checkpoint = trackable_utils.Checkpoint(
                    optimizer=optimizer)

                checkpoint_path = checkpoint_management.latest_checkpoint(
                    checkpoint_directory)
                status = root.restore(save_path=checkpoint_path)
                input_value = constant_op.constant([[3.]])

                def train_fn():
                    with backprop.GradientTape() as tape:
                        loss = model(input_value)
                    variables = model.trainable_variables
                    gradients = tape.gradient(loss, variables)
                    return optimizer.apply_gradients(zip(gradients, variables))

                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                status.initialize_or_restore()
                # TODO(tanzheny): Add hyper variables to .variables(), and set them with
                # set_weights etc.
                variables_not_in_the_variables_property = [
                    obj for obj in optimizer._hyper.values()
                    if isinstance(obj, variables_lib.Variable)
                ]
                self.evaluate([
                    v.initializer for v in optimizer.variables() +
                    variables_not_in_the_variables_property
                ])
                train_fn()
                model_save_path = root.save(file_prefix=checkpoint_prefix)
                self.evaluate(optimizer.beta_1.assign(42.))
                optimizer_save_path = optimizer_checkpoint.save(
                    optimizer_only_prefix)
            del train_fn

            # Restore into a graph with the optimizer
            with test_util.device(use_gpu=True):
                model = MyModel()
                optimizer = adam.Adam(0.001)
                root = trackable_utils.Checkpoint(optimizer=optimizer,
                                                  model=model)
                status = root.restore(save_path=model_save_path)
                input_value = constant_op.constant([[3.]])

                def train_fn1():
                    with backprop.GradientTape() as tape:
                        loss = model(input_value)
                    variables = model.trainable_variables
                    gradients = tape.gradient(loss, variables)
                    return optimizer.apply_gradients(zip(gradients, variables))

                if not context.executing_eagerly():
                    train_fn1 = functools.partial(self.evaluate, train_fn1())
                status.initialize_or_restore()
                train_fn1()
                with self.assertRaises(AssertionError):
                    status.assert_existing_objects_matched()
                with self.assertRaises(AssertionError):
                    status.assert_consumed()
            del train_fn1

            # Make sure initialization doesn't clobber later restores
            with test_util.device(use_gpu=True):
                model = MyModel()
                optimizer = adam.Adam(0.001, beta_1=1.0)
                root = trackable_utils.Checkpoint(optimizer=optimizer,
                                                  model=model)
                opt_root = trackable_utils.Checkpoint(optimizer=optimizer)
                status = root.restore(save_path=model_save_path)
                init_only_optimizer_status = opt_root.restore(save_path=None)
                optimizer_status = opt_root.restore(
                    save_path=optimizer_save_path)
                input_value = constant_op.constant([[3.]])

                def train_fn2():
                    with backprop.GradientTape() as tape:
                        loss = model(input_value)
                    variables = model.trainable_variables
                    gradients = tape.gradient(loss, variables)
                    return optimizer.apply_gradients(zip(gradients, variables))

                if not context.executing_eagerly():
                    train_fn2 = functools.partial(self.evaluate, train_fn2())
                optimizer_status.run_restore_ops()
                status.initialize_or_restore()
                init_only_optimizer_status.initialize_or_restore()
                train_fn2()
                self.assertEqual(42., self.evaluate(optimizer.beta_1))
Exemple #16
0
    def testDeferredSlotRestoration(self):
        checkpoint_directory = self.get_temp_dir()

        root = trackable_utils.Checkpoint()
        root.var = trackable_utils.add_variable(root,
                                                name="var",
                                                initializer=0.)
        optimizer = adam.AdamOptimizer(0.1)
        if context.executing_eagerly():
            optimizer.minimize(root.var.read_value)
        else:
            train_op = optimizer.minimize(root.var)
            # Note that `optimizer` has not been added as a dependency of
            # `root`. Create a one-off grouping so that slot variables for `root.var`
            # get initialized too.
            self.evaluate(
                trackable_utils.gather_initializers(
                    trackable_utils.Checkpoint(root=root,
                                               optimizer=optimizer)))
            self.evaluate(train_op)
        self.evaluate(state_ops.assign(root.var, 12.))
        no_slots_path = root.save(
            os.path.join(checkpoint_directory, "no_slots"))
        root.optimizer = optimizer
        self.evaluate(state_ops.assign(root.var, 13.))
        self.evaluate(
            state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.))
        slots_path = root.save(os.path.join(checkpoint_directory,
                                            "with_slots"))
        new_root = trackable_utils.Checkpoint()
        # Load the slot-containing checkpoint (deferred), then immediately overwrite
        # the non-slot variable (also deferred).
        slot_status = new_root.restore(slots_path)
        no_slot_status = new_root.restore(no_slots_path)
        with self.assertRaises(AssertionError):
            no_slot_status.assert_consumed()
        new_root.var = trackable_utils.add_variable(new_root,
                                                    name="var",
                                                    shape=[])
        no_slot_status.assert_consumed()
        no_slot_status.run_restore_ops()
        self.assertEqual(12., self.evaluate(new_root.var))
        new_root.optimizer = adam.AdamOptimizer(0.1)
        slot_status.assert_existing_objects_matched()
        with self.assertRaisesRegexp(AssertionError, "beta1_power"):
            slot_status.assert_consumed()
        self.assertEqual(12., self.evaluate(new_root.var))
        if context.executing_eagerly():
            # Slot variables are only created with restoring initializers when
            # executing eagerly.
            self.assertEqual(
                14.,
                self.evaluate(
                    new_root.optimizer.get_slot(name="m", var=new_root.var)))
        else:
            self.assertIs(
                new_root.optimizer.get_slot(name="m", var=new_root.var), None)
        if context.executing_eagerly():
            new_root.optimizer.minimize(new_root.var.read_value)
        else:
            train_op = new_root.optimizer.minimize(new_root.var)
            # The slot variable now exists; restore() didn't create it, but we should
            # now have a restore op for it.
            slot_status.run_restore_ops()
            self.assertEqual(
                14.,
                self.evaluate(
                    new_root.optimizer.get_slot(name="m", var=new_root.var)))
            self.evaluate(train_op)
        slot_status.assert_consumed()
    def test_trackable_save_restore(self):
        with self.test_session():

            def _templated():
                v = variable_scope.get_variable(
                    "v",
                    shape=[1],
                    initializer=init_ops.zeros_initializer(),
                    use_resource=True)
                v2 = variable_scope.get_variable(
                    "v2",
                    shape=[1],
                    initializer=init_ops.zeros_initializer(),
                    use_resource=True)
                manual = _ManualScope()
                return v, v + 1., v2, manual, manual()

            save_template = template.make_template("s1", _templated)
            v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
            six.assertCountEqual(
                self, [
                    id(v1_save),
                    id(v2_save),
                    id(manual_scope),
                    id(manual_scope_v),
                    id(save_template)
                ], map(id, trackable_utils.list_objects(save_template)))
            manual_dep, = manual_scope._checkpoint_dependencies
            self.assertEqual("in_manual_scope", manual_dep.name)
            self.assertIs(manual_scope_v, manual_dep.ref)
            optimizer = adam.Adam(0.0)
            save_root = trackable_utils.Checkpoint(my_template=save_template,
                                                   optimizer=optimizer)
            optimizer.minimize(v1_save.read_value, var_list=[v1_save])
            self.evaluate([v.initializer for v in save_template.variables])
            optimizer_variables = optimizer.variables() + list(
                optimizer._hyper.values())
            self.evaluate([v.initializer for v in optimizer_variables])
            self.evaluate(v1_save.assign([12.]))
            self.evaluate(v2_save.assign([14.]))
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            save_path = save_root.save(checkpoint_prefix)

            load_template = template.make_template("s2", _templated)
            load_optimizer = adam.Adam(0.0)
            load_root = trackable_utils.Checkpoint(my_template=load_template,
                                                   optimizer=load_optimizer)
            status = load_root.restore(save_path)
            var, var_plus_one, var2, _, _ = load_template()
            load_optimizer.minimize(var.read_value, var_list=[var])
            self.assertLen(load_template._checkpoint_dependencies, 3)
            self.assertEqual("v",
                             load_template._checkpoint_dependencies[0].name)
            self.assertEqual("v2",
                             load_template._checkpoint_dependencies[1].name)
            self.assertEqual("ManualScope",
                             load_template._checkpoint_dependencies[2].name)
            status.assert_consumed().run_restore_ops()
            self.assertAllEqual([12.], self.evaluate(var))
            self.assertAllEqual([13.], self.evaluate(var_plus_one))
            self.assertAllEqual([14.], self.evaluate(var2))
Exemple #18
0
    def testMultipleGraphsNonSlotVariables(self):
        with context.graph_mode():
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            optimizer = adam.AdamOptimizer(0.001)
            # Construct a model in one graph
            first_graph = ops.Graph()
            first_session = session_lib.Session(graph=first_graph)
            with first_graph.as_default(), first_session.as_default():
                first_variable = resource_variable_ops.ResourceVariable([1.])
                first_root_trackable = trackable_utils.Checkpoint(
                    optimizer=optimizer, variable=first_variable)
                train_op = optimizer.minimize(first_variable.read_value)
                self.evaluate(
                    trackable_utils.gather_initializers(first_root_trackable))
                self.evaluate(train_op)
                self.evaluate(first_variable.assign([1.]))
                self.evaluate(
                    optimizer.get_slot(var=first_variable,
                                       name="m").assign([2.]))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.evaluate(beta1_power.assign(3.))

            # Save and load in a second graph
            second_graph = ops.Graph()
            with second_graph.as_default(), session_lib.Session(
                    graph=second_graph):
                second_variable = resource_variable_ops.ResourceVariable([1.])
                second_root_trackable = trackable_utils.Checkpoint(
                    optimizer=optimizer, variable=second_variable)
                train_op = optimizer.minimize(second_variable.read_value)
                second_root_trackable.restore(None).initialize_or_restore()
                self.evaluate(train_op)
                self.evaluate(second_variable.assign([4.]))
                self.evaluate(
                    optimizer.get_slot(var=second_variable,
                                       name="m").assign([5.]))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.evaluate(beta1_power.assign(6.))
                save_path = second_root_trackable.save(checkpoint_prefix)
                self.evaluate(second_variable.assign([7.]))
                self.evaluate(
                    optimizer.get_slot(var=second_variable,
                                       name="m").assign([8.]))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.assertAllEqual(6., self.evaluate(beta1_power))
                status = second_root_trackable.restore(save_path)
                status.assert_consumed().run_restore_ops()
                self.assertAllEqual([4.], self.evaluate(second_variable))
                self.assertAllEqual([5.],
                                    self.evaluate(
                                        optimizer.get_slot(var=second_variable,
                                                           name="m")))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.assertAllEqual(6., self.evaluate(beta1_power))

            # Check that the first graph is unmolested
            with first_graph.as_default(), first_session.as_default():
                self.assertAllEqual([1.], self.evaluate(first_variable))
                self.assertAllEqual([2.],
                                    self.evaluate(
                                        optimizer.get_slot(var=first_variable,
                                                           name="m")))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.assertAllEqual(3., self.evaluate(beta1_power))
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                maintenance_event=None,
                training_finished=None,
                frequent_send=False):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    def mock_termination_watcher_function_gce(*args, **kwargs):
      del args, kwargs
      if not frequent_send:
        time.sleep(1)
        if (not maintenance_event.is_set()) and (random.randrange(0, 20) > 18):
          maintenance_event.set()
          logging.info('Termination notice available.')
          return True

      elif frequent_send and not maintenance_event.is_set():
        logging.info('Termination notice available.')
        return True

      return False

    with mock.patch.object(
        gce_util, 'termination_watcher_function_gce',
        mock_termination_watcher_function_gce), mock.patch.object(
            gce_util, 'detect_platform',
            lambda: gce_util.PlatformDevice.GCE_GPU):

      class Model(module.Module):

        def __init__(self):
          self.v = variables_lib.Variable(
              0.,
              synchronization=variables_lib.VariableSynchronization.ON_WRITE,
              aggregation=variables_lib.VariableAggregation.SUM)

        @def_function.function(input_signature=[])
        def __call__(self):
          return self.v.read_value()

      with strategy.scope():
        model = Model()
        fh_ckpt = tracking_util.Checkpoint(model=model)

        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Start training at %d', worker_preemption_watcher.total_runs)
      for epoch in range(
          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
          EPOCHS_TO_RUN):

        for step in range(
            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
            STEPS_PER_EPOCH):
          worker_preemption_watcher.run(distributed_train_step, epoch, step)

      training_finished.set()

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)

      running_threads = test_util.get_running_threads()
      if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                              running_threads) and test_util.has_thread(
                                  _LOCAL_WATCHER_THREAD_PREFIX,
                                  running_threads):
        try:
          # Explicitly call __del__ since making it None and gc.collect does
          # not invoke __del__ here.
          worker_preemption_watcher.__del__()

          time.sleep(2)

          running_threads = test_util.get_running_threads()
          self.assertFalse(
              test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                   running_threads))
          self.assertFalse(
              test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                   running_threads))

        except urllib.error.URLError as e:
          if 'Temporary failure in name resolution' in e.message:
            # This is caused by a weird flakiness that mock.patch does not
            # correctly patch gce_util.request_compute_metadata, a real request
            # is attempted, and an error is hit in
            # gce_util.request_compute_metadata
            logging.warning('Hit a mock issue.')
            return
    def testSaveRestoreState(self, mock_time):
        directory = self.get_temp_dir()
        mock_time.time.return_value = 3.
        checkpoint = util.Checkpoint()
        first_manager = checkpoint_management.CheckpointManager(checkpoint,
                                                                directory,
                                                                max_to_keep=2)
        first_time = 10000.
        first_name = os.path.join(directory, "ckpt-1")
        mock_time.time.return_value = first_time
        first_manager.save()
        state = checkpoint_management.get_checkpoint_state(directory)
        second_time = first_time + 3610.
        second_name = os.path.join(directory, "ckpt-2")
        mock_time.time.return_value = second_time
        first_manager.save()
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual([first_time, second_time],
                         state.all_model_checkpoint_timestamps)
        self.assertEqual([first_name, second_name], first_manager.checkpoints)
        self.assertEqual(second_name, first_manager.latest_checkpoint)
        del first_manager

        second_manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory,
            max_to_keep=2,
            keep_checkpoint_every_n_hours=1.5)
        self.assertEqual([first_name, second_name], second_manager.checkpoints)
        self.assertEqual(second_name, second_manager.latest_checkpoint)
        third_name = os.path.join(directory, "ckpt-3")
        third_time = second_time + 3600. * 0.2
        mock_time.time.return_value = third_time
        second_manager.save()
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
        self.assertEqual([second_name, third_name], second_manager.checkpoints)
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(first_time, state.last_preserved_timestamp)
        fourth_time = third_time + 3600. * 0.5
        mock_time.time.return_value = fourth_time
        fourth_name = os.path.join(directory, "ckpt-4")
        second_manager.save()
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
        self.assertEqual([third_name, fourth_name], second_manager.checkpoints)
        fifth_time = fourth_time + 3600. * 0.5
        mock_time.time.return_value = fifth_time
        fifth_name = os.path.join(directory, "ckpt-5")
        second_manager.save()
        self.assertEqual([fourth_name, fifth_name], second_manager.checkpoints)
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(first_time, state.last_preserved_timestamp)
        del second_manager
        third_manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory,
            max_to_keep=2,
            keep_checkpoint_every_n_hours=1.5)
        self.assertEqual(fifth_name, third_manager.latest_checkpoint)
        mock_time.time.return_value += 10.
        third_manager.save()
        sixth_name = os.path.join(directory, "ckpt-6")
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(fourth_time, state.last_preserved_timestamp)
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
        self.assertEqual([fifth_name, sixth_name], third_manager.checkpoints)
    def test_initialize_if_not_restoring(self):
        with self.test_session():
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
            with test_util.device(use_gpu=True):
                model = MyModel()
                optimizer = adam.AdamOptimizer(0.001)
                root = trackable_utils.Checkpoint(
                    model=
                    model,  # Do not save the optimizer with the checkpoint.
                    global_step=training_util.get_or_create_global_step())
                optimizer_checkpoint = trackable_utils.Checkpoint(
                    optimizer=optimizer)

                checkpoint_path = checkpoint_management.latest_checkpoint(
                    checkpoint_directory)
                status = root.restore(save_path=checkpoint_path)
                input_value = constant_op.constant([[3.]])
                train_fn = functools.partial(optimizer.minimize,
                                             functools.partial(
                                                 model, input_value),
                                             global_step=root.global_step)
                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                status.initialize_or_restore()
                self.evaluate([v.initializer for v in optimizer.variables()])
                train_fn()
                model_save_path = root.save(file_prefix=checkpoint_prefix)
                self.evaluate(optimizer.variables()[0].assign(42.))
                optimizer_save_path = optimizer_checkpoint.save(
                    optimizer_only_prefix)

            # Restore into a graph with the optimizer
            with test_util.device(use_gpu=True):
                model = MyModel()
                optimizer = adam.AdamOptimizer(0.001)
                root = trackable_utils.Checkpoint(
                    optimizer=optimizer,
                    model=model,
                    global_step=training_util.get_or_create_global_step())
                status = root.restore(save_path=model_save_path)
                input_value = constant_op.constant([[3.]])
                train_fn = functools.partial(optimizer.minimize,
                                             functools.partial(
                                                 model, input_value),
                                             global_step=root.global_step)
                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                status.initialize_or_restore()
                train_fn()
                with self.assertRaises(AssertionError):
                    status.assert_existing_objects_matched()
                with self.assertRaises(AssertionError):
                    status.assert_consumed()

            # Make sure initialization doesn't clobber later restores
            with test_util.device(use_gpu=True):
                model = MyModel()
                optimizer = adam.AdamOptimizer(0.001, beta1=1.0)
                root = trackable_utils.Checkpoint(
                    optimizer=optimizer,
                    model=model,
                    global_step=training_util.get_or_create_global_step())
                opt_root = trackable_utils.Checkpoint(optimizer=optimizer)
                status = root.restore(save_path=model_save_path)
                init_only_optimizer_status = opt_root.restore(save_path=None)
                optimizer_status = opt_root.restore(
                    save_path=optimizer_save_path)
                input_value = constant_op.constant([[3.]])
                train_fn = functools.partial(optimizer.minimize,
                                             functools.partial(
                                                 model, input_value),
                                             global_step=root.global_step)
                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                optimizer_status.run_restore_ops()
                status.initialize_or_restore()
                init_only_optimizer_status.initialize_or_restore()
                train_fn()
                self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
Exemple #22
0
 def test_unbuilt_model_does_not_prevent_saving(self):
   root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)]))
   save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
Exemple #23
0
def _save_checkpoint():
    original_checkpoint = util.Checkpoint(m=_LazyTrivialObjects())
    original_checkpoint.m()
    return original_checkpoint.write(os.path.join(test.get_temp_dir(), "ckpt"))
    def testNamingWithOptimizer(self):
        input_value = constant_op.constant([[3.]])
        model = MyModel()
        # A nuisance Model using the same optimizer. Its slot variables should not
        # go in the checkpoint, since it is never depended on.
        other_model = MyModel()
        optimizer = adam.Adam(0.001)
        step = training_util.get_or_create_global_step()
        root_trackable = trackable_utils.Checkpoint(optimizer=optimizer,
                                                    model=model,
                                                    step=step)

        with backprop.GradientTape() as tape:
            loss = model(input_value)
        variables = model.trainable_variables
        gradients = tape.gradient(loss, variables)
        train_op = control_flow_ops.group(
            optimizer.apply_gradients(zip(gradients, variables)),
            step.assign_add(1))

        with backprop.GradientTape() as tape:
            loss = other_model(input_value)
        variables = other_model.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))

        self.evaluate(trackable_utils.gather_initializers(root_trackable))
        self.evaluate(train_op)
        named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
            root_trackable).serialize_object_graph()
        expected_slot_keys = (
            "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
            "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
            "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
            "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
            "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
            "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
        )
        expected_checkpoint_names = (
            # Created in the root node, so no prefix.
            "step",
            "model/_second/kernel",
            "model/_named_dense/kernel",
            "model/_named_dense/bias",
            # non-Layer dependency of the model
            "model/_non_layer/a_variable",
            "optimizer/learning_rate",
            "optimizer/beta_1",
            "optimizer/beta_2",
            "optimizer/iter",
            "optimizer/decay",
        ) + expected_slot_keys
        suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
        expected_checkpoint_names = [
            name + suffix for name in expected_checkpoint_names
        ]
        named_variables = {v.name: v for v in named_variables}
        six.assertCountEqual(self, expected_checkpoint_names,
                             named_variables.keys())
        # Check that we've mapped to the right variable objects (not exhaustive)
        self.assertEqual("global_step",
                         named_variables["step" + suffix].full_name)
        self.assertEqual(
            "my_model/dense_1/kernel",
            named_variables["model/_second/kernel" + suffix].full_name)
        self.assertEqual(
            "my_model/dense/kernel",
            named_variables["model/_named_dense/kernel" + suffix].full_name)
        self.assertEqual(
            "Adam/beta_1",
            named_variables["optimizer/beta_1" + suffix].full_name)
        self.assertEqual(
            "Adam/beta_2",
            named_variables["optimizer/beta_2" + suffix].full_name)
        # Spot check the generated protocol buffers.
        self.assertEqual("optimizer",
                         serialized_graph.nodes[0].children[1].local_name)
        optimizer_node = serialized_graph.nodes[
            serialized_graph.nodes[0].children[1].node_id]
        children = [node.local_name for node in optimizer_node.children]
        six.assertCountEqual(
            self,
            # hyper variable dependencies
            ["beta_1", "beta_2", "iter", "decay", "learning_rate"],
            children)
        serialized_slot_keys = []
        for slot in optimizer_node.slot_variables:
            for attribute in (serialized_graph.nodes[
                    slot.slot_variable_node_id].attributes):
                serialized_slot_keys.append(attribute.checkpoint_key)
        six.assertCountEqual(self,
                             [key + suffix for key in expected_slot_keys],
                             serialized_slot_keys)
    def gen_outputs(self,
                    ds_fn,
                    break_points,
                    num_outputs,
                    ckpt_saved=False,
                    sparse_tensors=False,
                    verify_exhausted=True,
                    save_checkpoint_at_end=True):
        """Generates elements from input dataset while stopping at break points.

    Produces `num_outputs` outputs and saves the state of the iterator in the
    Saver checkpoint.

    Args:
      ds_fn: 0-argument function that returns the dataset.
      break_points: A list of integers. For each `break_point` in
        `break_points`, we produce outputs till `break_point` number of items
        have been produced and then checkpoint the state. The current graph and
        session are destroyed and a new graph and session are used to produce
        outputs till next checkpoint or till `num_outputs` elements have been
        produced. `break_point` must be <= `num_outputs`.
      num_outputs: The total number of outputs to produce from the iterator.
      ckpt_saved: Whether a checkpoint already exists.
      sparse_tensors:  Whether dataset is built from SparseTensor(s).
      verify_exhausted: Whether to verify that the iterator has been exhausted
        after producing `num_outputs` elements.
      save_checkpoint_at_end: Whether to save a checkpoint after producing all
        outputs. If False, checkpoints are saved each break point but not at the
        end. Note that checkpoints overwrite each other so there is always only
        a single checkpoint available. Defaults to True.

    Returns:
      A list of `num_outputs` items.
    """
        outputs = []

        if context.executing_eagerly():
            for i in range(len(break_points) + 1):
                iterator = iter(ds_fn())
                ckpt = tracking_util.Checkpoint(iterator=iterator)
                if ckpt_saved:
                    ckpt_path = self._latest_ckpt()
                    ckpt.restore(ckpt_path)
                start = break_points[i - 1] if i > 0 else 0
                end = break_points[i] if i < len(break_points) else num_outputs
                num_iters = end - start
                for _ in range(num_iters):
                    outputs.append(self.evaluate(next(iterator)))
                if i == len(break_points) and verify_exhausted:
                    with self.assertRaises(StopIteration):
                        next(iterator)
                if save_checkpoint_at_end or i < len(break_points):
                    ckpt_path = ckpt.save(self._ckpt_path())
                    ckpt_saved = True
        else:

            def get_ops():
                if ckpt_saved:
                    saver = self._import_meta_graph()
                    init_op, get_next_op = self._get_iterator_ops_from_collection(
                        ds_fn, sparse_tensors=sparse_tensors)
                else:
                    init_op, get_next_op, saver = self._build_graph(
                        ds_fn, sparse_tensors=sparse_tensors)
                return init_op, get_next_op, saver

            for i in range(len(break_points) + 1):
                with ops.Graph().as_default() as g:
                    init_op, get_next_op, saver = get_ops()
                    get_next_op = remove_variants(get_next_op)
                    with self.session(graph=g) as sess:
                        if ckpt_saved:
                            self._initialize(init_op, sess)
                            self._restore(saver, sess)
                        else:
                            self._initialize(init_op, sess)
                        start = break_points[i - 1] if i > 0 else 0
                        end = break_points[i] if i < len(
                            break_points) else num_outputs
                        num_iters = end - start
                        for _ in range(num_iters):
                            outputs.append(sess.run(get_next_op))
                        if i == len(break_points) and verify_exhausted:
                            with self.assertRaises(errors.OutOfRangeError):
                                sess.run(get_next_op)
                        if save_checkpoint_at_end or i < len(break_points):
                            self._save(sess, saver)
                            ckpt_saved = True

        return outputs
 def testSaveRestore(self):
     with self.test_session():
         model = MyModel()
         optimizer = adam.Adam(0.001)
         root_trackable = trackable_utils.Checkpoint(optimizer=optimizer,
                                                     model=model)
         input_value = constant_op.constant([[3.]])
         with backprop.GradientTape() as tape:
             loss = model(input_value)
         variables = model.trainable_variables
         gradients = tape.gradient(loss, variables)
         train_op = optimizer.apply_gradients(zip(gradients, variables))
         self.assertFalse(root_trackable.save_counter.trainable)
         self.evaluate(trackable_utils.gather_initializers(root_trackable))
         self.evaluate(train_op)
         prefix = os.path.join(self.get_temp_dir(), "ckpt")
         self.evaluate(
             state_ops.assign(model._named_dense.variables[1], [42.]))
         m_bias_slot = optimizer.get_slot(model._named_dense.variables[1],
                                          "m")
         self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
         save_path = root_trackable.save(file_prefix=prefix)
         self.evaluate(
             state_ops.assign(model._named_dense.variables[1], [43.]))
         self.evaluate(state_ops.assign(root_trackable.save_counter, 3))
         optimizer_variables = self.evaluate(
             sorted(optimizer.variables(), key=lambda v: v.name))
         self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
         # Immediate restoration
         status = root_trackable.restore(
             save_path=save_path).assert_consumed()
         status.run_restore_ops()
         self.assertAllEqual([42.],
                             self.evaluate(model._named_dense.variables[1]))
         self.assertAllEqual(1, self.evaluate(root_trackable.save_counter))
         self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
         if not context.executing_eagerly():
             return  # Restore-on-create is only supported when executing eagerly
         on_create_model = MyModel()
         on_create_optimizer = adam.Adam(0.001)
         on_create_root = trackable_utils.Checkpoint(
             optimizer=on_create_optimizer, model=on_create_model)
         # Deferred restoration
         status = on_create_root.restore(save_path=save_path)
         status.assert_nontrivial_match()
         status.assert_existing_objects_matched()
         with self.assertRaises(AssertionError):
             status.assert_consumed()
         on_create_model(constant_op.constant([[3.]]))  # create variables
         self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
         self.assertAllEqual([42.],
                             self.evaluate(
                                 on_create_model._named_dense.variables[1]))
         on_create_m_bias_slot = on_create_optimizer.get_slot(
             on_create_model._named_dense.variables[1], "m")
         status.assert_existing_objects_matched()
         if not context.executing_eagerly():
             with self.assertRaises(AssertionError):
                 status.assert_consumed()
         # Optimizer slot variables are created when the original variable is
         # restored.
         self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
         dummy_var = resource_variable_ops.ResourceVariable([1.])
         on_create_optimizer.minimize(loss=dummy_var.read_value,
                                      var_list=[dummy_var])
         status.assert_existing_objects_matched()
         status.assert_consumed()
         self.assertAllEqual(
             optimizer_variables,
             # Creation order is different, so .variables() needs to be re-sorted.
             self.evaluate(
                 sorted(optimizer.variables(), key=lambda v: v.name)))
        def fn(model_path, checkpoint_dir):
            global_batch_size = per_worker_batch_size * num_workers
            strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
            )
            with strategy.scope():
                multi_worker_model = build_and_compile_cnn_model()

            callbacks = [
                keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
            ]

            multi_worker_dataset = mnist_dataset(global_batch_size)
            if shard_policy:
                options = dataset_ops.Options()
                options.experimental_distribute.auto_shard_policy = shard_policy
                multi_worker_dataset = multi_worker_dataset.with_options(
                    options)

            multi_worker_model.fit(multi_worker_dataset,
                                   epochs=2,
                                   steps_per_epoch=20,
                                   callbacks=callbacks)

            def _is_chief(task_type, task_id):
                return task_type is None or task_type == 'chief' or (
                    task_type == 'worker' and task_id == 0)

            def _get_temp_dir(dirpath, task_id):
                base_dirpath = 'workertemp_' + str(task_id)
                temp_dir = os.path.join(dirpath, base_dirpath)
                file_io.recursive_create_dir_v2(temp_dir)
                return temp_dir

            def write_filepath(filepath, task_type, task_id):
                dirpath = os.path.dirname(filepath)
                base = os.path.basename(filepath)
                if not _is_chief(task_type, task_id):
                    dirpath = _get_temp_dir(dirpath, task_id)
                return os.path.join(dirpath, base)

            task_type, task_id = (strategy.cluster_resolver.task_type,
                                  strategy.cluster_resolver.task_id)
            write_model_path = write_filepath(model_path, task_type, task_id)

            multi_worker_model.save(write_model_path)
            if not _is_chief(task_type, task_id):
                file_io.delete_recursively_v2(
                    os.path.dirname(write_model_path))

            # Make sure chief finishes saving before non-chief's assertions.
            multi_process_runner.get_barrier().wait()

            if not file_io.file_exists_v2(model_path):
                raise RuntimeError()
            if file_io.file_exists_v2(write_model_path) != _is_chief(
                    task_type, task_id):
                raise RuntimeError()

            loaded_model = keras.saving.save.load_model(model_path)
            loaded_model.fit(multi_worker_dataset,
                             epochs=2,
                             steps_per_epoch=20)

            checkpoint = tracking_util.Checkpoint(model=multi_worker_model)
            write_checkpoint_dir = write_filepath(checkpoint_dir, task_type,
                                                  task_id)
            checkpoint_manager = checkpoint_management.CheckpointManager(
                checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

            checkpoint_manager.save()
            if not _is_chief(task_type, task_id):
                file_io.delete_recursively_v2(write_checkpoint_dir)

            # Make sure chief finishes saving before non-chief's assertions.
            multi_process_runner.get_barrier().wait()

            if not file_io.file_exists_v2(checkpoint_dir):
                raise RuntimeError()
            if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief(
                    task_type, task_id):
                raise RuntimeError()

            latest_checkpoint = checkpoint_management.latest_checkpoint(
                checkpoint_dir)
            checkpoint.restore(latest_checkpoint)
            multi_worker_model.fit(multi_worker_dataset,
                                   epochs=2,
                                   steps_per_epoch=20)

            logging.info('testMultiWorkerTutorial successfully ends')
Exemple #28
0
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  training_started_event=None,
                  raise_app_error_on_worker=None,
                  training_restarted=None,
                  training_finished=None,
                  termination_config=failure_handling.TerminationConfig()):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        class Model(module.Module):
            def __init__(self):
                self.v = variables_lib.Variable(
                    0.,
                    synchronization=variables_lib.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variables_lib.VariableAggregation.SUM)

            @def_function.function(input_signature=[])
            def __call__(self):
                return self.v.read_value()

        with mock.patch.object(gce_util, 'on_gcp', lambda: False):

            with strategy.scope():
                model = Model()
                # Named it fh_ckpt because it'd be better that the user have their
                # regular checkpoint separate from the checkpoint for
                # WorkerPreemptionHandler, since we will create CheckpointManager
                # to manage the checkpoint and only one CheckpointManager should be
                # active in a particular directory at a time.
                fh_ckpt = tracking_util.Checkpoint(model=model)

                worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
                    strategy.cluster_resolver, fh_ckpt, checkpoint_dir,
                    termination_config)

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    if distribution_strategy_context.get_distribution_strategy(
                    ).cluster_resolver.task_id == raise_app_error_on_worker:
                        raise errors_impl.ResourceExhaustedError(
                            node_def=None,
                            op=None,
                            message='Running out of resources')

                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         worker_preemption_watcher.total_runs)

            # If the training process has been restarted, verify that the expected
            # number of checkpoints have been written.
            # we also want to check training_finished, because there's a corner case
            # where the signal is sent quite late and training finishes before the
            # grace period ends.
            if training_restarted and training_restarted.is_set(
            ) and not training_finished.is_set():
                logging.info('training restarted')
                match_group = [
                    re.search(r'.*ckpt-(\d+).index', a_file)
                    for a_file in gfile.ListDirectory(checkpoint_dir)
                ]
                checkpoint_index = [
                    a_match.group(1) for a_match in match_group if a_match
                ]
                if getattr(termination_config, 'time_till_termination', 0):
                    # Two checkpoints were saved for the extended grace period.
                    self.assertEqual(int(checkpoint_index[0]), 2)
                else:
                    self.assertEqual(int(checkpoint_index[0]), 1)

            for epoch in range(
                    worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    worker_preemption_watcher.run(distributed_train_step,
                                                  epoch, step)
                # Add some randomness to when preemption actually happens. We should
                # trigger it for sure if the training is coming to an end and it hasn't
                # been triggered yet.
                if epoch >= EPOCHS_TO_RUN - 2:
                    trigger_it = True
                else:
                    trigger_it = False

                self._maybe_trigger_a_preemption(training_started_event,
                                                 trigger_it)

            training_finished.set()

            logging.info('Training finished.')

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)