def testDeferredSlotRestoration(self): with self.test_session(): checkpoint_directory = self.get_temp_dir() root = trackable_utils.Checkpoint() root.var = trackable_utils.add_variable(root, name="var", initializer=0.) optimizer = adam.Adam(0.1) variables = [root.var] gradients = [1.] train_op = optimizer.apply_gradients(zip(gradients, variables)) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate( trackable_utils.gather_initializers( trackable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = root.save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate( state_ops.assign( optimizer.get_slot(slot_name="m", var=root.var), 14.)) slots_path = root.save( os.path.join(checkpoint_directory, "with_slots")) new_root = trackable_utils.Checkpoint() # Load the slot-containing checkpoint (deferred), then immediately # overwrite the non-slot variable (also deferred). slot_status = new_root.restore(slots_path) no_slot_status = new_root.restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = trackable_utils.add_variable(new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.Adam(0.1) slot_status.assert_existing_objects_matched() if not context.executing_eagerly(): with self.assertRaisesRegex(AssertionError, "Unresolved object"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.executing_eagerly(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual( 14., self.evaluate( new_root.optimizer.get_slot(slot_name="m", var=new_root.var))) else: # Slot variables are not created eagerly when graph building. with self.assertRaises(KeyError): new_root.optimizer.get_slot(slot_name="m", var=new_root.var) variables = [new_root.var] gradients = [1.] train_op = new_root.optimizer.apply_gradients( zip(gradients, variables)) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() if not context.executing_eagerly(): # The train op hasn't run when graph building, so the slot variable has # its restored value. It has run in eager, so the value will # be different. self.assertEqual( 14., self.evaluate( new_root.optimizer.get_slot(slot_name="m", var=new_root.var))) self.evaluate(train_op) slot_status.assert_consumed()
def start(self): """Starts the evaluation loop.""" optimizer_checkpoint = tracking_util.Checkpoint(iter=self._iterations) checkpoint = tracking_util.Checkpoint( model=self.model, optimizer=optimizer_checkpoint) for latest_checkpoint in checkpoint_utils.checkpoints_iterator( self.checkpoint_dir): try: # `expect_partial` because the checkpoint can have other `Trackable`s # such as `optimizer`. checkpoint.restore(latest_checkpoint).expect_partial() checkpoint_attributes = list_checkpoint_attributes(latest_checkpoint) # The checkpoint should contain model and optimizer for SidecarEvaluator # to work. But the model weights saved by ModelCheckpoint callback does # not contain model as an attribute. To make SidecarEvaluator compatibly # work in this case, use model.load_weights to load the model's weights, # while self._iterations is still restored by checkpoint variable. if 'model' not in checkpoint_attributes: self.model.load_weights(latest_checkpoint) # The model checkpoint might not include optimizer in cases, e.g. # using a custom training loop. Directly assign the iterations # property to be used in callbacks. if self.model.optimizer: self.model.optimizer.iterations.assign(self._iterations) except (errors_impl.OpError,) as e: # A couple errors can happen here with the coordinator racing to write # checkpoint: # 1) OpError: open failed for <file path>: No such file or directory # 2) NotFoundError (subclass of OpError): Unsuccessful # TensorSliceReader constructor. # TODO(rchao): Remove this except block once b/150954027 is resolved. logging.info( 'SidecarEvaluator has an error loading ' 'checkpoint: %s. Retrying. Error: %s: %s', latest_checkpoint, e.__class__.__name__, e) continue if self._iterations.numpy() == _ITERATIONS_UNINITIALIZED: raise RuntimeError( '`iterations` cannot be loaded from the ' 'checkpoint file. Please ensure `iterations` is ' 'tracked in the `checkpoint` saved by the coordinator.') logging.info( 'Evaluation starts: Model weights loaded from latest ' 'checkpoint file: %s.', latest_checkpoint) self.model.evaluate( self.data, steps=self.steps, callbacks=self.callbacks, verbose=2) return_metrics = {} for metric in self.model.metrics: result = metric.result() if isinstance(result, dict): return_metrics.update(result) else: return_metrics[metric.name] = result logging.info( 'End of evaluation. Metrics: %s', ' '.join([ '{}={}'.format(name, value.numpy()) for name, value in return_metrics.items() ])) # TODO(rchao): Make the max evaluation robust in case users save the # checkpoints with epoch format {epoch:03d}. if (self.max_evaluations and latest_checkpoint.endswith('-{}'.format(self.max_evaluations))): # Exit the loop because we have evaluated the final checkpoint file. logging.info('Last checkpoint evaluated. SidecarEvaluator stops.') return
def worker_fn(self, checkpoint_dir, cluster_spec, maintenance_event=None, training_finished=None): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() def mock_request_compute_metadata(*args, **kwargs): del kwargs # Unused. if args[0] == 'instance/maintenance-event': if (not maintenance_event.is_set()) and ( strategy.cluster_resolver.task_id == 1) and (random.randrange(0, 9) > 6): maintenance_event.set() logging.info('Maintenance notice available.') return 'TERMINATE_ON_HOST_MAINTENANCE' else: return 'NONE' return '' with mock.patch.object(gce_util, 'request_compute_metadata', mock_request_compute_metadata), mock.patch.object( gce_util, 'detect_platform', lambda: gce_util.PlatformDevice.GCE_GPU): class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization.ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with strategy.scope(): model = Model() fh_ckpt = tracking_util.Checkpoint(model=model) failure_handler = failure_handling.CoordinatedCheckpointManager( strategy.cluster_resolver, fh_ckpt, checkpoint_dir) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Restored training at %d', failure_handler.total_runs) for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range(failure_handler.total_runs % STEPS_PER_EPOCH, STEPS_PER_EPOCH): failure_handler.run(distributed_train_step, epoch, step) self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH) training_finished.set() pre_del_thread_count = threading.activeCount() failure_handler.__del__() self.assertLessEqual(threading.activeCount(), pre_del_thread_count - 1)
def testNamingWithOptimizer(self): input_value = constant_op.constant([[3.]]) model = MyModel() # A nuisance Model using the same optimizer. Its slot variables should not # go in the checkpoint, since it is never depended on. other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() root_trackable = trackable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize(lambda: model(input_value), global_step=optimizer_step) optimizer.minimize(lambda: other_model(input_value), global_step=optimizer_step) else: train_op = optimizer.minimize(model(input_value), global_step=optimizer_step) optimizer.minimize(other_model(input_value), global_step=optimizer_step) self.evaluate(trackable_utils.gather_initializers(root_trackable)) self.evaluate(train_op) named_variables, serialized_graph, _ = graph_view.ObjectGraphView( root_trackable).serialize_object_graph() expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", "model/_second/kernel", "model/_named_dense/kernel", "model/_named_dense/bias", # non-Layer dependency of the model "model/_non_layer/a_variable", # The optimizer creates two non-slot variables "optimizer/beta1_power", "optimizer/beta2_power", # Slot variables "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m", "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v", "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m", "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v", "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m", "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v", ) suffix = "/.ATTRIBUTES/VARIABLE_VALUE" expected_checkpoint_names = [ name + suffix for name in expected_checkpoint_names ] named_variables = {v.name: v for v in named_variables} six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual("global_step", named_variables["optimizer_step" + suffix].full_name) self.assertEqual( "my_model/dense_1/kernel", named_variables["model/_second/kernel" + suffix].full_name) self.assertEqual( "my_model/dense/kernel", named_variables["model/_named_dense/kernel" + suffix].full_name) self.assertEqual( "beta1_power", named_variables["optimizer/beta1_power" + suffix].full_name) self.assertEqual( "beta2_power", named_variables["optimizer/beta2_power" + suffix].full_name) # Spot check the generated protocol buffers. self.assertEqual("optimizer", serialized_graph.nodes[0].children[1].local_name) optimizer_node = serialized_graph.nodes[ serialized_graph.nodes[0].children[1].node_id] self.assertEqual("beta1_power", optimizer_node.children[0].local_name) self.assertEqual( "beta1_power", serialized_graph.nodes[ optimizer_node.children[0].node_id].attributes[0].full_name) self.assertEqual( "my_model/dense/kernel", serialized_graph.nodes[optimizer_node.slot_variables[ 0].original_variable_node_id].attributes[0].full_name) # We strip off the :0 suffix, as variable.name-based saving does. self.assertEqual( "my_model/dense/kernel/Adam", serialized_graph.nodes[optimizer_node.slot_variables[ 0].slot_variable_node_id].attributes[0].full_name) self.assertEqual( "my_model/dense/kernel/Adam:0", optimizer.get_slot(var=model._named_dense.kernel, name="m").name) self.assertEqual( "model/_named_dense/kernel" + suffix, serialized_graph.nodes[optimizer_node.slot_variables[ 0].original_variable_node_id].attributes[0].checkpoint_key) self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) self.assertEqual( "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix, serialized_graph.nodes[optimizer_node.slot_variables[ 0].slot_variable_node_id].attributes[0].checkpoint_key)
def _create_and_call(): checkpoint = util.Checkpoint(m=_LazyTrivialObjects()) checkpoint.m() checkpoint.restore(checkpoint_path)
def worker_fn(self, checkpoint_dir, cluster_spec, training_started_event=None, raise_app_error_on_worker=None, termination_config=failure_handling.TerminationConfig()): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization. ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with mock.patch.object(gce_util, 'on_gcp', lambda: False): with strategy.scope(): model = Model() # Named it fh_ckpt because it'd be better that the user have their # regular checkpoint separate from the checkpoint for # WorkerPreemptionHandler, since we will create CheckpointManager # to manage the checkpoint and only one CheckpointManager should be # active in a particular directory at a time. fh_ckpt = tracking_util.Checkpoint(model=model) worker_preemption_watcher = failure_handling.WorkerPreemptionHandler( strategy.cluster_resolver, fh_ckpt, checkpoint_dir, termination_config) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): if distribution_strategy_context.get_distribution_strategy( ).cluster_resolver.task_id == raise_app_error_on_worker: raise errors_impl.ResourceExhaustedError( node_def=None, op=None, message='Running out of resources') model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Restored training at %d', worker_preemption_watcher.total_runs) for epoch in range( worker_preemption_watcher.total_runs // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range( worker_preemption_watcher.total_runs % STEPS_PER_EPOCH, STEPS_PER_EPOCH): worker_preemption_watcher.run(distributed_train_step, epoch, step) # Add some randomness to when preemption actually happens. We should # trigger it for sure if the training is coming to an end and it hasn't # been triggered yet. if epoch >= EPOCHS_TO_RUN - 2: trigger_it = True else: trigger_it = False self._maybe_trigger_a_preemption(training_started_event, trigger_it) logging.info('Training finished.') self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)
def testSaveRestore(self): with self.test_session(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root_trackable = trackable_utils.Checkpoint(optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): optimizer.minimize(lambda: model(input_value)) else: train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. root_trackable.save_counter # pylint: disable=pointless-statement self.evaluate( trackable_utils.gather_initializers(root_trackable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate( state_ops.assign(model._named_dense.variables[1], [42.])) m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) save_path = root_trackable.save(file_prefix=prefix) self.evaluate( state_ops.assign(model._named_dense.variables[1], [43.])) self.evaluate(state_ops.assign(root_trackable.save_counter, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration status = root_trackable.restore( save_path=save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) self.assertAllEqual(1, self.evaluate(root_trackable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly on_create_model = MyModel() on_create_optimizer = adam.AdamOptimizer( 0.001, # Preserve beta1_power and beta2_power when applying gradients # so we can test that they've been restored correctly. beta1=1.0, beta2=1.0) on_create_root = trackable_utils.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) status.assert_nontrivial_match() status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() on_create_model(constant_op.constant([[3.]])) # create variables self.assertAllEqual(1, self.evaluate(on_create_root.save_counter)) self.assertAllEqual([42.], self.evaluate( on_create_model._named_dense.variables[1])) on_create_m_bias_slot = on_create_optimizer.get_slot( on_create_model._named_dense.variables[1], "m") status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() # Optimizer slot variables are created when the original variable is # restored. self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) self.assertAllEqual(optimizer_variables[2:], self.evaluate(on_create_optimizer.variables())) dummy_var = variables.Variable([1.]) on_create_optimizer.minimize(loss=dummy_var.read_value) status.assert_existing_objects_matched() status.assert_consumed() beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators( ) self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
def test_checkpoint_restore_before_variable_creation(self): self.skip_if_oss() class TestModule(module.Module): def __init__(self, initializer, rows): self._initializer = initializer self._rows = rows table = tpu_embedding_v2_utils.TableConfig( vocabulary_size=self._rows, dim=4, initializer=self._initializer, combiner='sum', name='table') feature_config = (tpu_embedding_v2_utils.FeatureConfig( table=table, name='feature'),) optimizer = tpu_embedding_v2_utils.SGD() self.tpu_embedding = tpu_embedding_v2.TPUEmbedding( feature_config, optimizer) def create_embedding(self): # We aren't training so batch_size here doesn't matter. self.tpu_embedding.build(64) strategy = self._get_strategy() with strategy.scope(): module1 = TestModule(init_ops_v2.Ones(), strategy.num_replicas_in_sync * 2) module1.create_embedding() checkpoint = util.Checkpoint(test_module=module1) checkpoint.save(self._get_tmpdir('restore_before_create', 'save')) # Reinitialize the tpu strategy = self._get_strategy() with strategy.scope(): module2 = TestModule(init_ops_v2.Zeros(), strategy.num_replicas_in_sync * 2) checkpoint = util.Checkpoint(test_module=module2) checkpoint.restore(self._get_tmpdir('restore_before_create', 'save-1')) with strategy.scope(): module2.create_embedding() def get_values(mid): return mid._variables['table']['parameters'].variables[0].numpy() self.assertAllClose( np.ones((strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding)) # Fetch the values from the TPU to check that they are the same. module2.tpu_embedding._retrieve_variables() self.assertAllClose( np.ones((strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding))
def test_model_export_cpu(self): strategy = self._get_strategy() num_rows = strategy.num_replicas_in_sync with strategy.scope(): first_mid_level_contents = np.ones((num_rows, 4)) first_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1) initializer = init_ops_v2.Constant(first_mid_level_contents) table = tpu_embedding_v2_utils.TableConfig( vocabulary_size=num_rows, dim=4, initializer=initializer, combiner='sum', name='table') feature_config = (tpu_embedding_v2_utils.FeatureConfig( table=table, name='feature'),) first_mid_level = tpu_embedding_v2.TPUEmbedding( feature_config, first_mid_level_optimizer) first_mid_level.build(64) cpu_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1) cpu_mid_level = tpu_embedding_v2.TPUEmbedding(feature_config, cpu_mid_level_optimizer) cpu_mid_level.build(64) first_mid_level._load_variables() tpu_checkpoint = util.Checkpoint(model=first_mid_level) tpu_checkpoint.save(self._get_tmpdir('export_cpu', 'save')) # We restore the checkpoint of our tpu mid level onto our cpu mid level. cpu_checkpoint = util.Checkpoint(model=cpu_mid_level) cpu_checkpoint.restore(self._get_tmpdir('export_cpu', 'save-1')) @def_function.function def serve_tensors(features): features = tpu_embedding_v2.cpu_embedding_lookup( features, None, cpu_mid_level.embedding_tables, cpu_mid_level._feature_config) return features[0] signatures = { 'serving_default': serve_tensors.get_concrete_function((tensor_spec.TensorSpec( shape=(2,), dtype=dtypes.int32, name='feature'),)) } save.save( cpu_mid_level, export_dir=self._get_tmpdir('export_cpu', 'exported_model'), signatures=signatures) imported = load.load(self._get_tmpdir('export_cpu', 'exported_model')) predict_fn = imported.signatures['serving_default'] input_feature_value = np.array([1, 0]) input_batch = (constant_op.constant( input_feature_value, dtype=dtypes.int32),) prediction = predict_fn(*input_batch)['output_0'] self.assertAllClose(prediction.numpy(), first_mid_level_contents[input_feature_value])
def testCheckpoint(self, strategy_fn, save_with_ls, restore_with_ls): class MySGD(gradient_descent.SGD): """A custom optimizer that tracks an extra variable.""" def __init__(self, *args, **kwargs): super(MySGD, self).__init__(*args, **kwargs) self.my_var = variables.Variable(0.) self._track_trackable(self.my_var, 'my_var') strategy = strategy_fn() replicas = strategy.num_replicas_in_sync if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and not context.executing_eagerly()): # TODO(b/121381184): Enable running the test in this case. return with self.test_session(), strategy.scope(): # Build and run a simple model. var = variables.Variable([2.0]) opt = inner_opt = MySGD(1., momentum=1.) if save_with_ls: loss_scale = loss_scale_module.DynamicLossScale( initial_loss_scale=1., increment_period=2., multiplier=2.) opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var]) opt_op = strategy.experimental_run(run_fn) self.evaluate(variables.global_variables_initializer()) self.evaluate(strategy.experimental_local_results(opt_op)) # Assert values. self.assertEqual(self.evaluate(var), 1.) if save_with_ls: self.assertEqual(self.evaluate(loss_scale()), 1.) self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) slot_var = opt.get_slot(var, 'momentum') self.assertEqual(self.evaluate(slot_var).item(), -1) self.assertEqual(self.evaluate(opt.iterations), 1) # Set optimizer variable to check arbitrary optimizer attributes can be # saved/restored self.evaluate(inner_opt.my_var.assign(1.)) # Save a checkpoint. checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var) prefix = os.path.join(self.get_temp_dir(), 'ckpt') save_path = checkpoint.save(prefix) # Create new model var = variables.Variable([2.0]) opt = inner_opt = MySGD(1., momentum=1.) if restore_with_ls: loss_scale = loss_scale_module.DynamicLossScale( initial_loss_scale=1., increment_period=2., multiplier=2.) opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) # Restore new model. checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var) status = checkpoint.restore(save_path) if save_with_ls: status.assert_existing_objects_matched() else: status.assert_nontrivial_match() # Assert restored values. We can only assert in eager mode since the # variables are uninitialized in graph mode if context.executing_eagerly(): self.assertEqual(self.evaluate(var), 1.) if save_with_ls and restore_with_ls: self.assertEqual(self.evaluate(loss_scale()), 1.) self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) elif restore_with_ls: self.assertEqual(self.evaluate(loss_scale()), 1.) self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0) self.assertEqual(self.evaluate(opt.iterations), 1) # Run the model again. run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var]) opt_op = strategy.experimental_run(run_fn) # Assert new values. self.evaluate(variables.global_variables_initializer()) status.run_restore_ops() self.evaluate(strategy.experimental_local_results(opt_op)) self.assertEqual(self.evaluate(var), -1) slot_var = opt.get_slot(var, 'momentum') self.assertEqual(self.evaluate(slot_var).item(), -2) self.assertEqual(self.evaluate(opt.iterations), 2) self.assertEqual(self.evaluate(inner_opt.my_var), 1) # Restore model again to test restoring after slots are created status = checkpoint.restore(save_path) if save_with_ls and restore_with_ls: status.assert_consumed() elif save_with_ls: status.assert_existing_objects_matched() elif restore_with_ls: status.assert_nontrivial_match() status.run_restore_ops() self.assertEqual(self.evaluate(var), 1) self.assertEqual(self.evaluate(slot_var).item(), -1)
def test_checkpoint_restore_loads(self): strategy = self._get_strategy() num_rows = strategy.num_replicas_in_sync def get_values(mid): return ops.convert_to_tensor( mid._variables['table']['parameters'].variables[0]) with strategy.scope(): first_mid_level_contents = np.ones((num_rows, 4)) first_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1) initializer = init_ops_v2.Constant(first_mid_level_contents) table = tpu_embedding_v2_utils.TableConfig( vocabulary_size=num_rows, dim=4, initializer=initializer, combiner='sum', name='table') feature_config = (tpu_embedding_v2_utils.FeatureConfig( table=table, name='feature'),) first_mid_level = tpu_embedding_v2.TPUEmbedding( feature_config, first_mid_level_optimizer) first_mid_level.build(64) first_mid_level._load_variables() first_checkpoint = util.Checkpoint(model=first_mid_level) first_checkpoint.save(self._get_tmpdir('restore', 'save')) tpu_strategy_util.initialize_tpu_system(self.resolver) with strategy.scope(): second_mid_level_contents = np.ones((num_rows, 4)) * 2 second_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1) initializer = init_ops_v2.Constant(second_mid_level_contents) table = tpu_embedding_v2_utils.TableConfig( vocabulary_size=num_rows, dim=4, initializer=initializer, combiner='sum', name='table') feature_config = (tpu_embedding_v2_utils.FeatureConfig( table=table, name='feature'),) second_mid_level = tpu_embedding_v2.TPUEmbedding( feature_config, second_mid_level_optimizer) second_mid_level.build(64) second_mid_level._load_variables() self.assertAllClose( second_mid_level_contents, get_values(second_mid_level), msg='Second mid level api should contain its initial values.', ) # We restore the checkpoint of our first model into our second model. # This should load the first mid level API object onto the TPU. second_checkpoint = util.Checkpoint(model=second_mid_level) second_checkpoint.restore(self._get_tmpdir('restore', 'save-1')) # Call retrieve here as a way to check what the TPU contains. # Calling the retrieve ops directly might make for a cleaner separation of # test and module, though. second_mid_level._retrieve_variables() self.assertAllClose( first_mid_level_contents, get_values(second_mid_level), msg='Second mid level api should have retrieved the first model values.' )
def worker_fn(self, checkpoint_dir, cluster_spec, maintenance_event=None, training_finished=None, frequent_send=False, training_restarted=None, termination_config=failure_handling.TerminationConfig( time_till_termination=0)): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() def mock_termination_watcher_function_gce(*args, **kwargs): del args, kwargs if not frequent_send: time.sleep(1) if (not maintenance_event.is_set()) and (random.randrange(0, 7) == 5): maintenance_event.set() logging.info('Termination notice available.') return True elif frequent_send and not maintenance_event.is_set(): logging.info('Termination notice available.') return True return False with mock.patch.object( gce_util, 'termination_watcher_function_gce', mock_termination_watcher_function_gce), mock.patch.object( gce_util, 'detect_platform', lambda: gce_util.PlatformDevice.GCE_GPU): class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization.ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with strategy.scope(): model = Model() fh_ckpt = tracking_util.Checkpoint(model=model) worker_preemption_watcher = failure_handling.WorkerPreemptionHandler( strategy.cluster_resolver, fh_ckpt, checkpoint_dir, termination_config) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Start training at %d', worker_preemption_watcher.total_runs) # If the training process has been restarted, verify that the expected # number of checkpoints have been written. # We also want to check training_finished, because there's a corner case # where the signal is sent quite late and training finishes before the # grace period ends. if training_restarted.is_set() and not training_finished.is_set(): match_group = [ re.search(r'.*ckpt-(\d+).index', a_file) for a_file in gfile.ListDirectory(checkpoint_dir) ] checkpoint_index = [ a_match.group(1) for a_match in match_group if a_match ] if termination_config.time_till_termination > 0: # Two checkpoints were saved for the extended grace period. self.assertEqual(int(checkpoint_index[0]), 2) else: self.assertEqual(int(checkpoint_index[0]), 1) for epoch in range( worker_preemption_watcher.total_runs // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range( worker_preemption_watcher.total_runs % STEPS_PER_EPOCH, STEPS_PER_EPOCH): worker_preemption_watcher.run(distributed_train_step, epoch, step) logging.info('Training finished.') training_finished.set() self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH) running_threads = test_util.get_running_threads() if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads) and test_util.has_thread( _LOCAL_WATCHER_THREAD_PREFIX, running_threads): try: # Explicitly call __del__ since making it None and gc.collect does # not invoke __del__ here. worker_preemption_watcher.__del__() time.sleep(2) running_threads = test_util.get_running_threads() self.assertFalse( test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX, running_threads)) self.assertFalse( test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads)) except urllib.error.URLError as e: if 'Temporary failure in name resolution' in e.message: # This is caused by a weird flakiness that mock.patch does not # correctly patch gce_util.request_compute_metadata, a real request # is attempted, and an error is hit in # gce_util.request_compute_metadata logging.warning('Hit a mock issue.') return
def worker_fn(self, checkpoint_dir, cluster_spec, maintenance_event, training_finished, frequent_send=False): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() def mock_request_compute_metadata(*args, **kwargs): del kwargs # Unused. if args[0] == 'instance/maintenance-event': if not frequent_send: time.sleep(1) if (not maintenance_event.is_set()) and (random.randrange( 0, 20) > 18): maintenance_event.set() logging.info('Maintenance notice available.') return 'TERMINATE_ON_HOST_MAINTENANCE' elif frequent_send and not maintenance_event.is_set(): return 'TERMINATE_ON_HOST_MAINTENANCE' return 'NONE' with mock.patch.object( gce_util, 'request_compute_metadata', mock_request_compute_metadata), mock.patch.object( gce_util, 'detect_platform', lambda: gce_util.PlatformDevice.GCE_GPU): class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization. ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with strategy.scope(): model = Model() fh_ckpt = tracking_util.Checkpoint(model=model) failure_handler = failure_handling.CoordinatedCheckpointManager( strategy.cluster_resolver, fh_ckpt, checkpoint_dir) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Start training at %d', failure_handler.total_runs) for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range(failure_handler.total_runs % STEPS_PER_EPOCH, STEPS_PER_EPOCH): failure_handler.run(distributed_train_step, epoch, step) self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH) training_finished.set() running_threads = test_util.get_running_threads() strategy.gather(constant_op.constant([10]), axis=0) self.assertTrue( test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads)) self.assertTrue( test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX, running_threads)) strategy.gather(constant_op.constant([10]), axis=0) # Explicitly call __del__ since making it None and gc.collect does # not invoke __del__ here. failure_handler.__del__() time.sleep(2) running_threads = test_util.get_running_threads() self.assertFalse( test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX, running_threads)) self.assertFalse( test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads))
def start(self): """Starts the evaluation loop.""" optimizer_checkpoint = tracking_util.Checkpoint(iter=self._iterations) checkpoint = tracking_util.Checkpoint( model=self.model, optimizer=optimizer_checkpoint) for latest_checkpoint in checkpoint_utils.checkpoints_iterator( self.checkpoint_dir): try: # `expect_partial` because the checkpoint can have other `Trackable`s # such as `optimizer`. checkpoint.restore(latest_checkpoint).expect_partial() checkpoint_attributes = list_checkpoint_attributes(latest_checkpoint) # The checkpoint should contain model and optimizer for SidecarEvaluator # to work. But the model weights saved by ModelCheckpoint callback does # not contain model as an attribute. To make SidecarEvaluator compatibly # work in this case, if model attribute is not found but # layer_with_weights attribute is found, use model.load_weights to load # the model's weights, while self._iterations is still restored by # checkpoint variable. if 'model' not in checkpoint_attributes: for attribute in checkpoint_attributes: # check whether the checkpoint has the required attributes for # model.load_weights to work. if re.match(r'^layer_with_weights-[\d+]', attribute) is not None: self.model.load_weights(latest_checkpoint) break except (errors_impl.OpError,) as e: # A couple errors can happen here with the coordinator racing to write # checkpoint: # 1) OpError: open failed for <file path>: No such file or directory # 2) NotFoundError (subclass of OpError): Unsuccessful # TensorSliceReader constructor. # TODO(rchao): Remove this except block once b/150954027 is resolved. logging.info( 'SidecarEvaluator has an error loading ' 'checkpoint: %s. Retrying. Error: %s: %s', latest_checkpoint, e.__class__.__name__, e) continue if self._iterations.numpy() == _ITERATIONS_UNINITIALIZED: raise RuntimeError( '`iterations` cannot be loaded from the ' 'checkpoint file. Please ensure `iterations` is ' 'tracked in the `checkpoint` saved by the coordinator.') logging.info( 'Evaluation starts: Model weights loaded from latest ' 'checkpoint file: %s.', latest_checkpoint) # TODO(rchao): Support arbitrary callback for extensibility. self.model.evaluate(self.data, steps=self.steps) logging.info( 'End of evaluation. Metrics: %s', ' '.join([ '{}={}'.format(metric.name, metric.result().numpy()) for metric in self.model.metrics ])) if self._summary_writer: with summary_ops_v2.record_if(True), self._summary_writer.as_default(): for metric in self.model.metrics: summary_ops_v2.scalar( metric.name, metric.result(), step=self._iterations.read_value()) # TODO(rchao): Make the max evaluation robust in case users save the # checkpoints with epoch format {epoch:03d}. if (self.max_evaluations and latest_checkpoint.endswith('-{}'.format(self.max_evaluations))): # Exit the loop because we have evaluated the final checkpoint file. logging.info('Last checkpoint evaluated. SidecarEvaluator stops.') return
def test_initialize_if_not_restoring(self): with self.test_session(): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") optimizer_only_prefix = os.path.join(checkpoint_directory, "opt") with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.Adam(0.001) root = trackable_utils.Checkpoint( model=model ) # Do not save the optimizer with the checkpoint. optimizer_checkpoint = trackable_utils.Checkpoint( optimizer=optimizer) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) def train_fn(): with backprop.GradientTape() as tape: loss = model(input_value) variables = model.trainable_variables gradients = tape.gradient(loss, variables) return optimizer.apply_gradients(zip(gradients, variables)) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() # TODO(tanzheny): Add hyper variables to .variables(), and set them with # set_weights etc. variables_not_in_the_variables_property = [ obj for obj in optimizer._hyper.values() if isinstance(obj, variables_lib.Variable) ] self.evaluate([ v.initializer for v in optimizer.variables() + variables_not_in_the_variables_property ]) train_fn() model_save_path = root.save(file_prefix=checkpoint_prefix) self.evaluate(optimizer.beta_1.assign(42.)) optimizer_save_path = optimizer_checkpoint.save( optimizer_only_prefix) del train_fn # Restore into a graph with the optimizer with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.Adam(0.001) root = trackable_utils.Checkpoint(optimizer=optimizer, model=model) status = root.restore(save_path=model_save_path) input_value = constant_op.constant([[3.]]) def train_fn1(): with backprop.GradientTape() as tape: loss = model(input_value) variables = model.trainable_variables gradients = tape.gradient(loss, variables) return optimizer.apply_gradients(zip(gradients, variables)) if not context.executing_eagerly(): train_fn1 = functools.partial(self.evaluate, train_fn1()) status.initialize_or_restore() train_fn1() with self.assertRaises(AssertionError): status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() del train_fn1 # Make sure initialization doesn't clobber later restores with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.Adam(0.001, beta_1=1.0) root = trackable_utils.Checkpoint(optimizer=optimizer, model=model) opt_root = trackable_utils.Checkpoint(optimizer=optimizer) status = root.restore(save_path=model_save_path) init_only_optimizer_status = opt_root.restore(save_path=None) optimizer_status = opt_root.restore( save_path=optimizer_save_path) input_value = constant_op.constant([[3.]]) def train_fn2(): with backprop.GradientTape() as tape: loss = model(input_value) variables = model.trainable_variables gradients = tape.gradient(loss, variables) return optimizer.apply_gradients(zip(gradients, variables)) if not context.executing_eagerly(): train_fn2 = functools.partial(self.evaluate, train_fn2()) optimizer_status.run_restore_ops() status.initialize_or_restore() init_only_optimizer_status.initialize_or_restore() train_fn2() self.assertEqual(42., self.evaluate(optimizer.beta_1))
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = trackable_utils.Checkpoint() root.var = trackable_utils.add_variable(root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): optimizer.minimize(root.var.read_value) else: train_op = optimizer.minimize(root.var) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate( trackable_utils.gather_initializers( trackable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = root.save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate( state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = root.save(os.path.join(checkpoint_directory, "with_slots")) new_root = trackable_utils.Checkpoint() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = new_root.restore(slots_path) no_slot_status = new_root.restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = trackable_utils.add_variable(new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.AdamOptimizer(0.1) slot_status.assert_existing_objects_matched() with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.executing_eagerly(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual( 14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs( new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.executing_eagerly(): new_root.optimizer.minimize(new_root.var.read_value) else: train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual( 14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) slot_status.assert_consumed()
def test_trackable_save_restore(self): with self.test_session(): def _templated(): v = variable_scope.get_variable( "v", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) manual = _ManualScope() return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() six.assertCountEqual( self, [ id(v1_save), id(v2_save), id(manual_scope), id(manual_scope_v), id(save_template) ], map(id, trackable_utils.list_objects(save_template))) manual_dep, = manual_scope._checkpoint_dependencies self.assertEqual("in_manual_scope", manual_dep.name) self.assertIs(manual_scope_v, manual_dep.ref) optimizer = adam.Adam(0.0) save_root = trackable_utils.Checkpoint(my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value, var_list=[v1_save]) self.evaluate([v.initializer for v in save_template.variables]) optimizer_variables = optimizer.variables() + list( optimizer._hyper.values()) self.evaluate([v.initializer for v in optimizer_variables]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = template.make_template("s2", _templated) load_optimizer = adam.Adam(0.0) load_root = trackable_utils.Checkpoint(my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value, var_list=[var]) self.assertLen(load_template._checkpoint_dependencies, 3) self.assertEqual("v", load_template._checkpoint_dependencies[0].name) self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) self.assertEqual("ManualScope", load_template._checkpoint_dependencies[2].name) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2))
def testMultipleGraphsNonSlotVariables(self): with context.graph_mode(): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") optimizer = adam.AdamOptimizer(0.001) # Construct a model in one graph first_graph = ops.Graph() first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) first_root_trackable = trackable_utils.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) self.evaluate( trackable_utils.gather_initializers(first_root_trackable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) self.evaluate( optimizer.get_slot(var=first_variable, name="m").assign([2.])) beta1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta1_power.assign(3.)) # Save and load in a second graph second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session( graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) second_root_trackable = trackable_utils.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) second_root_trackable.restore(None).initialize_or_restore() self.evaluate(train_op) self.evaluate(second_variable.assign([4.])) self.evaluate( optimizer.get_slot(var=second_variable, name="m").assign([5.])) beta1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta1_power.assign(6.)) save_path = second_root_trackable.save(checkpoint_prefix) self.evaluate(second_variable.assign([7.])) self.evaluate( optimizer.get_slot(var=second_variable, name="m").assign([8.])) beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta1_power)) status = second_root_trackable.restore(save_path) status.assert_consumed().run_restore_ops() self.assertAllEqual([4.], self.evaluate(second_variable)) self.assertAllEqual([5.], self.evaluate( optimizer.get_slot(var=second_variable, name="m"))) beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta1_power)) # Check that the first graph is unmolested with first_graph.as_default(), first_session.as_default(): self.assertAllEqual([1.], self.evaluate(first_variable)) self.assertAllEqual([2.], self.evaluate( optimizer.get_slot(var=first_variable, name="m"))) beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta1_power))
def worker_fn(self, checkpoint_dir, cluster_spec, maintenance_event=None, training_finished=None, frequent_send=False): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() def mock_termination_watcher_function_gce(*args, **kwargs): del args, kwargs if not frequent_send: time.sleep(1) if (not maintenance_event.is_set()) and (random.randrange(0, 20) > 18): maintenance_event.set() logging.info('Termination notice available.') return True elif frequent_send and not maintenance_event.is_set(): logging.info('Termination notice available.') return True return False with mock.patch.object( gce_util, 'termination_watcher_function_gce', mock_termination_watcher_function_gce), mock.patch.object( gce_util, 'detect_platform', lambda: gce_util.PlatformDevice.GCE_GPU): class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization.ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with strategy.scope(): model = Model() fh_ckpt = tracking_util.Checkpoint(model=model) worker_preemption_watcher = failure_handling.WorkerPreemptionHandler( strategy.cluster_resolver, fh_ckpt, checkpoint_dir) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Start training at %d', worker_preemption_watcher.total_runs) for epoch in range( worker_preemption_watcher.total_runs // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range( worker_preemption_watcher.total_runs % STEPS_PER_EPOCH, STEPS_PER_EPOCH): worker_preemption_watcher.run(distributed_train_step, epoch, step) training_finished.set() self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH) running_threads = test_util.get_running_threads() if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads) and test_util.has_thread( _LOCAL_WATCHER_THREAD_PREFIX, running_threads): try: # Explicitly call __del__ since making it None and gc.collect does # not invoke __del__ here. worker_preemption_watcher.__del__() time.sleep(2) running_threads = test_util.get_running_threads() self.assertFalse( test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX, running_threads)) self.assertFalse( test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads)) except urllib.error.URLError as e: if 'Temporary failure in name resolution' in e.message: # This is caused by a weird flakiness that mock.patch does not # correctly patch gce_util.request_compute_metadata, a real request # is attempted, and an error is hit in # gce_util.request_compute_metadata logging.warning('Hit a mock issue.') return
def testSaveRestoreState(self, mock_time): directory = self.get_temp_dir() mock_time.time.return_value = 3. checkpoint = util.Checkpoint() first_manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=2) first_time = 10000. first_name = os.path.join(directory, "ckpt-1") mock_time.time.return_value = first_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) second_time = first_time + 3610. second_name = os.path.join(directory, "ckpt-2") mock_time.time.return_value = second_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual([first_time, second_time], state.all_model_checkpoint_timestamps) self.assertEqual([first_name, second_name], first_manager.checkpoints) self.assertEqual(second_name, first_manager.latest_checkpoint) del first_manager second_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual([first_name, second_name], second_manager.checkpoints) self.assertEqual(second_name, second_manager.latest_checkpoint) third_name = os.path.join(directory, "ckpt-3") third_time = second_time + 3600. * 0.2 mock_time.time.return_value = third_time second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([second_name, third_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) fourth_time = third_time + 3600. * 0.5 mock_time.time.return_value = fourth_time fourth_name = os.path.join(directory, "ckpt-4") second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([third_name, fourth_name], second_manager.checkpoints) fifth_time = fourth_time + 3600. * 0.5 mock_time.time.return_value = fifth_time fifth_name = os.path.join(directory, "ckpt-5") second_manager.save() self.assertEqual([fourth_name, fifth_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) del second_manager third_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual(fifth_name, third_manager.latest_checkpoint) mock_time.time.return_value += 10. third_manager.save() sixth_name = os.path.join(directory, "ckpt-6") state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(fourth_time, state.last_preserved_timestamp) self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertFalse(checkpoint_management.checkpoint_exists(third_name)) self.assertEqual([fifth_name, sixth_name], third_manager.checkpoints)
def test_initialize_if_not_restoring(self): with self.test_session(): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") optimizer_only_prefix = os.path.join(checkpoint_directory, "opt") with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = trackable_utils.Checkpoint( model= model, # Do not save the optimizer with the checkpoint. global_step=training_util.get_or_create_global_step()) optimizer_checkpoint = trackable_utils.Checkpoint( optimizer=optimizer) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial(optimizer.minimize, functools.partial( model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() self.evaluate([v.initializer for v in optimizer.variables()]) train_fn() model_save_path = root.save(file_prefix=checkpoint_prefix) self.evaluate(optimizer.variables()[0].assign(42.)) optimizer_save_path = optimizer_checkpoint.save( optimizer_only_prefix) # Restore into a graph with the optimizer with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = trackable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) status = root.restore(save_path=model_save_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial(optimizer.minimize, functools.partial( model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) status.initialize_or_restore() train_fn() with self.assertRaises(AssertionError): status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() # Make sure initialization doesn't clobber later restores with test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001, beta1=1.0) root = trackable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) opt_root = trackable_utils.Checkpoint(optimizer=optimizer) status = root.restore(save_path=model_save_path) init_only_optimizer_status = opt_root.restore(save_path=None) optimizer_status = opt_root.restore( save_path=optimizer_save_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial(optimizer.minimize, functools.partial( model, input_value), global_step=root.global_step) if not context.executing_eagerly(): train_fn = functools.partial(self.evaluate, train_fn()) optimizer_status.run_restore_ops() status.initialize_or_restore() init_only_optimizer_status.initialize_or_restore() train_fn() self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
def test_unbuilt_model_does_not_prevent_saving(self): root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)])) save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
def _save_checkpoint(): original_checkpoint = util.Checkpoint(m=_LazyTrivialObjects()) original_checkpoint.m() return original_checkpoint.write(os.path.join(test.get_temp_dir(), "ckpt"))
def testNamingWithOptimizer(self): input_value = constant_op.constant([[3.]]) model = MyModel() # A nuisance Model using the same optimizer. Its slot variables should not # go in the checkpoint, since it is never depended on. other_model = MyModel() optimizer = adam.Adam(0.001) step = training_util.get_or_create_global_step() root_trackable = trackable_utils.Checkpoint(optimizer=optimizer, model=model, step=step) with backprop.GradientTape() as tape: loss = model(input_value) variables = model.trainable_variables gradients = tape.gradient(loss, variables) train_op = control_flow_ops.group( optimizer.apply_gradients(zip(gradients, variables)), step.assign_add(1)) with backprop.GradientTape() as tape: loss = other_model(input_value) variables = other_model.trainable_variables gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) self.evaluate(trackable_utils.gather_initializers(root_trackable)) self.evaluate(train_op) named_variables, serialized_graph, _ = graph_view.ObjectGraphView( root_trackable).serialize_object_graph() expected_slot_keys = ( "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m", "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v", "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m", "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v", "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m", "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v", ) expected_checkpoint_names = ( # Created in the root node, so no prefix. "step", "model/_second/kernel", "model/_named_dense/kernel", "model/_named_dense/bias", # non-Layer dependency of the model "model/_non_layer/a_variable", "optimizer/learning_rate", "optimizer/beta_1", "optimizer/beta_2", "optimizer/iter", "optimizer/decay", ) + expected_slot_keys suffix = "/.ATTRIBUTES/VARIABLE_VALUE" expected_checkpoint_names = [ name + suffix for name in expected_checkpoint_names ] named_variables = {v.name: v for v in named_variables} six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual("global_step", named_variables["step" + suffix].full_name) self.assertEqual( "my_model/dense_1/kernel", named_variables["model/_second/kernel" + suffix].full_name) self.assertEqual( "my_model/dense/kernel", named_variables["model/_named_dense/kernel" + suffix].full_name) self.assertEqual( "Adam/beta_1", named_variables["optimizer/beta_1" + suffix].full_name) self.assertEqual( "Adam/beta_2", named_variables["optimizer/beta_2" + suffix].full_name) # Spot check the generated protocol buffers. self.assertEqual("optimizer", serialized_graph.nodes[0].children[1].local_name) optimizer_node = serialized_graph.nodes[ serialized_graph.nodes[0].children[1].node_id] children = [node.local_name for node in optimizer_node.children] six.assertCountEqual( self, # hyper variable dependencies ["beta_1", "beta_2", "iter", "decay", "learning_rate"], children) serialized_slot_keys = [] for slot in optimizer_node.slot_variables: for attribute in (serialized_graph.nodes[ slot.slot_variable_node_id].attributes): serialized_slot_keys.append(attribute.checkpoint_key) six.assertCountEqual(self, [key + suffix for key in expected_slot_keys], serialized_slot_keys)
def gen_outputs(self, ds_fn, break_points, num_outputs, ckpt_saved=False, sparse_tensors=False, verify_exhausted=True, save_checkpoint_at_end=True): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the Saver checkpoint. Args: ds_fn: 0-argument function that returns the dataset. break_points: A list of integers. For each `break_point` in `break_points`, we produce outputs till `break_point` number of items have been produced and then checkpoint the state. The current graph and session are destroyed and a new graph and session are used to produce outputs till next checkpoint or till `num_outputs` elements have been produced. `break_point` must be <= `num_outputs`. num_outputs: The total number of outputs to produce from the iterator. ckpt_saved: Whether a checkpoint already exists. sparse_tensors: Whether dataset is built from SparseTensor(s). verify_exhausted: Whether to verify that the iterator has been exhausted after producing `num_outputs` elements. save_checkpoint_at_end: Whether to save a checkpoint after producing all outputs. If False, checkpoints are saved each break point but not at the end. Note that checkpoints overwrite each other so there is always only a single checkpoint available. Defaults to True. Returns: A list of `num_outputs` items. """ outputs = [] if context.executing_eagerly(): for i in range(len(break_points) + 1): iterator = iter(ds_fn()) ckpt = tracking_util.Checkpoint(iterator=iterator) if ckpt_saved: ckpt_path = self._latest_ckpt() ckpt.restore(ckpt_path) start = break_points[i - 1] if i > 0 else 0 end = break_points[i] if i < len(break_points) else num_outputs num_iters = end - start for _ in range(num_iters): outputs.append(self.evaluate(next(iterator))) if i == len(break_points) and verify_exhausted: with self.assertRaises(StopIteration): next(iterator) if save_checkpoint_at_end or i < len(break_points): ckpt_path = ckpt.save(self._ckpt_path()) ckpt_saved = True else: def get_ops(): if ckpt_saved: saver = self._import_meta_graph() init_op, get_next_op = self._get_iterator_ops_from_collection( ds_fn, sparse_tensors=sparse_tensors) else: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) return init_op, get_next_op, saver for i in range(len(break_points) + 1): with ops.Graph().as_default() as g: init_op, get_next_op, saver = get_ops() get_next_op = remove_variants(get_next_op) with self.session(graph=g) as sess: if ckpt_saved: self._initialize(init_op, sess) self._restore(saver, sess) else: self._initialize(init_op, sess) start = break_points[i - 1] if i > 0 else 0 end = break_points[i] if i < len( break_points) else num_outputs num_iters = end - start for _ in range(num_iters): outputs.append(sess.run(get_next_op)) if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) if save_checkpoint_at_end or i < len(break_points): self._save(sess, saver) ckpt_saved = True return outputs
def testSaveRestore(self): with self.test_session(): model = MyModel() optimizer = adam.Adam(0.001) root_trackable = trackable_utils.Checkpoint(optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) with backprop.GradientTape() as tape: loss = model(input_value) variables = model.trainable_variables gradients = tape.gradient(loss, variables) train_op = optimizer.apply_gradients(zip(gradients, variables)) self.assertFalse(root_trackable.save_counter.trainable) self.evaluate(trackable_utils.gather_initializers(root_trackable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate( state_ops.assign(model._named_dense.variables[1], [42.])) m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) save_path = root_trackable.save(file_prefix=prefix) self.evaluate( state_ops.assign(model._named_dense.variables[1], [43.])) self.evaluate(state_ops.assign(root_trackable.save_counter, 3)) optimizer_variables = self.evaluate( sorted(optimizer.variables(), key=lambda v: v.name)) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration status = root_trackable.restore( save_path=save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) self.assertAllEqual(1, self.evaluate(root_trackable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly on_create_model = MyModel() on_create_optimizer = adam.Adam(0.001) on_create_root = trackable_utils.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) status.assert_nontrivial_match() status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() on_create_model(constant_op.constant([[3.]])) # create variables self.assertAllEqual(1, self.evaluate(on_create_root.save_counter)) self.assertAllEqual([42.], self.evaluate( on_create_model._named_dense.variables[1])) on_create_m_bias_slot = on_create_optimizer.get_slot( on_create_model._named_dense.variables[1], "m") status.assert_existing_objects_matched() if not context.executing_eagerly(): with self.assertRaises(AssertionError): status.assert_consumed() # Optimizer slot variables are created when the original variable is # restored. self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) dummy_var = resource_variable_ops.ResourceVariable([1.]) on_create_optimizer.minimize(loss=dummy_var.read_value, var_list=[dummy_var]) status.assert_existing_objects_matched() status.assert_consumed() self.assertAllEqual( optimizer_variables, # Creation order is different, so .variables() needs to be re-sorted. self.evaluate( sorted(optimizer.variables(), key=lambda v: v.name)))
def fn(model_path, checkpoint_dir): global_batch_size = per_worker_batch_size * num_workers strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( ) with strategy.scope(): multi_worker_model = build_and_compile_cnn_model() callbacks = [ keras.callbacks.ModelCheckpoint( filepath=os.path.join(self.get_temp_dir(), 'checkpoint')) ] multi_worker_dataset = mnist_dataset(global_batch_size) if shard_policy: options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = shard_policy multi_worker_dataset = multi_worker_dataset.with_options( options) multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20, callbacks=callbacks) def _is_chief(task_type, task_id): return task_type is None or task_type == 'chief' or ( task_type == 'worker' and task_id == 0) def _get_temp_dir(dirpath, task_id): base_dirpath = 'workertemp_' + str(task_id) temp_dir = os.path.join(dirpath, base_dirpath) file_io.recursive_create_dir_v2(temp_dir) return temp_dir def write_filepath(filepath, task_type, task_id): dirpath = os.path.dirname(filepath) base = os.path.basename(filepath) if not _is_chief(task_type, task_id): dirpath = _get_temp_dir(dirpath, task_id) return os.path.join(dirpath, base) task_type, task_id = (strategy.cluster_resolver.task_type, strategy.cluster_resolver.task_id) write_model_path = write_filepath(model_path, task_type, task_id) multi_worker_model.save(write_model_path) if not _is_chief(task_type, task_id): file_io.delete_recursively_v2( os.path.dirname(write_model_path)) # Make sure chief finishes saving before non-chief's assertions. multi_process_runner.get_barrier().wait() if not file_io.file_exists_v2(model_path): raise RuntimeError() if file_io.file_exists_v2(write_model_path) != _is_chief( task_type, task_id): raise RuntimeError() loaded_model = keras.saving.save.load_model(model_path) loaded_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20) checkpoint = tracking_util.Checkpoint(model=multi_worker_model) write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id) checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=write_checkpoint_dir, max_to_keep=1) checkpoint_manager.save() if not _is_chief(task_type, task_id): file_io.delete_recursively_v2(write_checkpoint_dir) # Make sure chief finishes saving before non-chief's assertions. multi_process_runner.get_barrier().wait() if not file_io.file_exists_v2(checkpoint_dir): raise RuntimeError() if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief( task_type, task_id): raise RuntimeError() latest_checkpoint = checkpoint_management.latest_checkpoint( checkpoint_dir) checkpoint.restore(latest_checkpoint) multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20) logging.info('testMultiWorkerTutorial successfully ends')
def worker_fn(self, checkpoint_dir, cluster_spec, training_started_event=None, raise_app_error_on_worker=None, training_restarted=None, training_finished=None, termination_config=failure_handling.TerminationConfig()): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization. ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with mock.patch.object(gce_util, 'on_gcp', lambda: False): with strategy.scope(): model = Model() # Named it fh_ckpt because it'd be better that the user have their # regular checkpoint separate from the checkpoint for # WorkerPreemptionHandler, since we will create CheckpointManager # to manage the checkpoint and only one CheckpointManager should be # active in a particular directory at a time. fh_ckpt = tracking_util.Checkpoint(model=model) worker_preemption_watcher = failure_handling.WorkerPreemptionHandler( strategy.cluster_resolver, fh_ckpt, checkpoint_dir, termination_config) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): if distribution_strategy_context.get_distribution_strategy( ).cluster_resolver.task_id == raise_app_error_on_worker: raise errors_impl.ResourceExhaustedError( node_def=None, op=None, message='Running out of resources') model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Start training at %d', worker_preemption_watcher.total_runs) # If the training process has been restarted, verify that the expected # number of checkpoints have been written. # we also want to check training_finished, because there's a corner case # where the signal is sent quite late and training finishes before the # grace period ends. if training_restarted and training_restarted.is_set( ) and not training_finished.is_set(): logging.info('training restarted') match_group = [ re.search(r'.*ckpt-(\d+).index', a_file) for a_file in gfile.ListDirectory(checkpoint_dir) ] checkpoint_index = [ a_match.group(1) for a_match in match_group if a_match ] if getattr(termination_config, 'time_till_termination', 0): # Two checkpoints were saved for the extended grace period. self.assertEqual(int(checkpoint_index[0]), 2) else: self.assertEqual(int(checkpoint_index[0]), 1) for epoch in range( worker_preemption_watcher.total_runs // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range( worker_preemption_watcher.total_runs % STEPS_PER_EPOCH, STEPS_PER_EPOCH): worker_preemption_watcher.run(distributed_train_step, epoch, step) # Add some randomness to when preemption actually happens. We should # trigger it for sure if the training is coming to an end and it hasn't # been triggered yet. if epoch >= EPOCHS_TO_RUN - 2: trigger_it = True else: trigger_it = False self._maybe_trigger_a_preemption(training_started_event, trigger_it) training_finished.set() logging.info('Training finished.') self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)