def test_save_profile(self): logdir = self.get_temp_dir() profiler.start(logdir) with traceme.TraceMe('three_times_five'): three = constant_op.constant(3) five = constant_op.constant(5) product = three * five self.assertAllEqual(15, product) profiler.stop() file_list = gfile.ListDirectory(logdir) self.assertEqual(len(file_list), 2) for file_name in gfile.ListDirectory(logdir): if gfile.IsDirectory(os.path.join(logdir, file_name)): self.assertEqual(file_name, 'plugins') else: self.assertTrue(file_name.endswith('.profile-empty')) profile_dir = os.path.join(logdir, 'plugins', 'profile') run = gfile.ListDirectory(profile_dir)[0] hostname = socket.gethostname() overview_page = os.path.join(profile_dir, run, hostname + '.overview_page.pb') self.assertTrue(gfile.Exists(overview_page)) input_pipeline = os.path.join(profile_dir, run, hostname + '.input_pipeline.pb') self.assertTrue(gfile.Exists(input_pipeline)) tensorflow_stats = os.path.join(profile_dir, run, hostname + '.tensorflow_stats.pb') self.assertTrue(gfile.Exists(tensorflow_stats)) kernel_stats = os.path.join(profile_dir, run, hostname + '.kernel_stats.pb') self.assertTrue(gfile.Exists(kernel_stats)) trace_file = os.path.join(profile_dir, run, hostname + '.trace.json.gz') self.assertTrue(gfile.Exists(trace_file))
def test_profile(self): profiler.start() with traceme.TraceMe('three_times_five'): three = constant_op.constant(3) five = constant_op.constant(5) product = three * five self.assertAllEqual(15, product) with self.assertRaises(profiler.ProfilerAlreadyRunningError): profiler.start() profile_result = profiler.stop() profile_pb = trace_events_pb2.Trace() profile_pb.ParseFromString(profile_result) devices = frozenset(device.name for device in profile_pb.devices.values()) self.assertIn('/host:CPU', devices) if not test_util.IsBuiltWithROCm() and config.list_physical_devices( 'GPU'): # device tracing is not yet supported on the ROCm platform self.assertIn('/device:GPU:0', devices) events = frozenset(event.name for event in profile_pb.trace_events) self.assertIn('three_times_five', events) self.assertIn('Mul', events) with self.assertRaises(profiler.ProfilerNotRunningError): profiler.stop()
def _constant_impl( value, dtype, shape, name, verify_shape, allow_broadcast): """Implementation of constant.""" ctx = context.context() if ctx.executing_eagerly(): if _pywrap_traceme.enabled: with traceme.TraceMe("tf.constant"): return _constant_eager_impl(ctx, value, dtype, shape, verify_shape) return _constant_eager_impl(ctx, value, dtype, shape, verify_shape) g = ops.get_default_graph() tensor_value = attr_value_pb2.AttrValue() tensor_value.tensor.CopyFrom( tensor_util.make_tensor_proto( value, dtype=dtype, shape=shape, verify_shape=verify_shape, allow_broadcast=allow_broadcast)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) attrs = {"value": tensor_value, "dtype": dtype_value} const_tensor = g._create_op_internal( # pylint: disable=protected-access "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0] if op_callbacks.should_invoke_op_callbacks(): # TODO(b/147670703): Once the special-op creation code paths # are unified. Remove this `if` block. callback_outputs = op_callbacks.invoke_op_callbacks( "Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g) if callback_outputs is not None: const_tensor, = callback_outputs return const_tensor
def test_context_manager_with_options(self): logdir = self.get_temp_dir() options = profiler.ProfilerOptions( host_tracer_level=3, python_tracer_level=1) with profiler.Profile(logdir, options): with traceme.TraceMe('three_times_five'): three = constant_op.constant(3) five = constant_op.constant(5) product = three * five self.assertAllEqual(15, product) file_list = gfile.ListDirectory(logdir) self.assertEqual(len(file_list), 2)
def on_batch(self, step=0, mode=ModeKeys.TRAIN, size=1): """Provide a scope for running one batch.""" with traceme.TraceMe( 'TraceContext', graph_type=mode, step_num=step, batch_size=size): batch_logs = {'batch': step, 'size': size} self.callbacks._call_batch_hook( mode, 'begin', step, batch_logs) self.progbar.on_batch_begin(step, batch_logs) try: yield batch_logs finally: if not batch_logs.pop('data_exhausted', False): self.callbacks._call_batch_hook( mode, 'end', step, batch_logs) self.progbar.on_batch_end(step, batch_logs)
def test_save_profile(self): logdir = self.get_temp_dir() profiler.start(logdir) with traceme.TraceMe('three_times_five'): three = constant_op.constant(3) five = constant_op.constant(5) product = three * five self.assertAllEqual(15, product) profiler.stop() file_list = gfile.ListDirectory(logdir) self.assertEqual(len(file_list), 2) for file_name in gfile.ListDirectory(logdir): if gfile.IsDirectory(os.path.join(logdir, file_name)): self.assertEqual(file_name, 'plugins') else: self.assertTrue(file_name.endswith('.profile-empty')) profile_dir = os.path.join(logdir, 'plugins', 'profile') run = gfile.ListDirectory(profile_dir)[0] hostname = socket.gethostname() overview_page = os.path.join(profile_dir, run, hostname + '.overview_page.pb') self.assertTrue(gfile.Exists(overview_page)) input_pipeline = os.path.join(profile_dir, run, hostname + '.input_pipeline.pb') self.assertTrue(gfile.Exists(input_pipeline)) tensorflow_stats = os.path.join(profile_dir, run, hostname + '.tensorflow_stats.pb') self.assertTrue(gfile.Exists(tensorflow_stats)) trace_file = os.path.join(profile_dir, run, hostname + '.trace') self.assertTrue(gfile.Exists(trace_file)) with gfile.Open(trace_file, 'rb') as f: profile_pb = trace_events_pb2.Trace() profile_pb.ParseFromString(f.read()) devices = frozenset(device.name for device in profile_pb.devices.values()) self.assertIn('/host:CPU', devices) if config.list_physical_devices('GPU'): self.assertIn('/device:GPU:0', devices) events = frozenset(event.name for event in profile_pb.trace_events) self.assertIn('three_times_five', events) self.assertIn('Mul:Mul', events)
def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, auto_switch=True, retry_fit=True, absorb=True, train_after_switch=True, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False, revert_after_fit=False): """ Custom fit function for the context model auto_switch: Enable/disable autonomous context switching train_after_switch: retry_fit: Locate the next fitting context by re-performing fit. absorb: Reset the switch sequence counter upon successful training. This is mainly used to maintain switch sequencing for temporally-extended tasks revert_after_fit This is a debug parameter to revert weights after performing a fit. This is used to calculate the context deltas without incorrectly learning while auto switching is disabled """ training._keras_api_gauge.get_cell('fit').set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph('Model', 'fit') self._assert_compile_was_called() self._check_call_args('fit') if validation_split: # Create the validation data using the training data. Only supported for # `Tensor` and `NumPy` input. (x, y, sample_weight), validation_data = ( data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split, shuffle=False)) with self.distribute_strategy.scope( ), training_utils.RespectCompiledTrainableState(self): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. data_handler = WindowedDataHandler( x=x, y=y, sample_weight=sample_weight, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=epochs, shuffle=shuffle, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=bool(verbose & Verbosity.Progress), model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps) self.stop_training = False train_function = self.make_train_function() callbacks.on_train_begin() self.initialize_fit() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to # happen after `callbacks.on_train_begin`. data_handler._initial_epoch = ( self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) for epoch, window_iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) dataset = tf.data.Dataset.zip(next(window_iterator)) switched_during_epoch = False # Indicate if the model has attempted at least one switch during this epoch switched = True # Indicate if the model switched on the most recent fit iteration weights = backend.batch_get_value(self.trainable_variables) # Perform a 'fit call'. Assuming retry_fit, this call is re-attempted after each switch until a context fits while switched and (retry_fit or not switched_during_epoch): self.initialize_epoch(epoch) iterator = iter(dataset) # Perform a fit call with data_handler.catch_stop_iteration(): for step in data_handler.steps(): with traceme.TraceMe('TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size): callbacks.on_train_batch_begin(step) tmp_logs = train_function(iterator) # Catch OutOfRangeError for Datasets of unknown size. # This blocks until the batch has finished executing. # TODO(b/150292341): Allow multiple async steps here. if not data_handler.inferred_steps: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. callbacks.on_train_batch_end(step, logs) switched = not self.update_and_switch( epoch, auto_switch, absorb, retry_fit, verbose) switched_during_epoch |= switched # If a switch occurred, we need to restore the weights if switched or (switched_during_epoch and not train_after_switch ) or revert_after_fit: backend.batch_set_value( zip(self.trainable_variables, weights)) self.reset_metrics() epoch_logs = copy.copy(logs) # Run validation. if validation_data and self._should_eval( epoch, validation_freq): val_x, val_y, val_sample_weight = ( data_adapter.unpack_x_y_sample_weight(validation_data)) val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps=validation_steps, callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, return_dict=True) val_logs = { 'val_' + name: val for name, val in val_logs.items() } epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) if self.stop_training: break callbacks.on_train_end() return self.history
def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, dynamic_switch=True, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False): training._keras_api_gauge.get_cell('fit').set(True) # Legacy graph support is contained in `training_v1.Model`. version_utils.disallow_legacy_graph('Model', 'fit') self._assert_compile_was_called() self._check_call_args('fit') if validation_split: # Create the validation data using the training data. Only supported for # `Tensor` and `NumPy` input. (x, y, sample_weight), validation_data = ( data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split, shuffle=False)) with self.distribute_strategy.scope( ), training_utils.RespectCompiledTrainableState(self): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. data_handler = WindowedDataHandler( x=x, y=y, sample_weight=sample_weight, batch_size=batch_size, steps_per_epoch=steps_per_epoch, initial_epoch=initial_epoch, epochs=epochs, shuffle=shuffle, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, model=self) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=bool(verbose & Verbosity.Progress), model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps) self.stop_training = False train_function = self.make_train_function() callbacks.on_train_begin() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to # happen after `callbacks.on_train_begin`. data_handler._initial_epoch = ( self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) for epoch, window_iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) dataset = tf.data.Dataset.zip(next(window_iterator)) switched = True weights = backend.batch_get_value(self.trainable_variables) while switched: self.initialize_epoch(epoch) iterator = iter(dataset) with data_handler.catch_stop_iteration(): for step in data_handler.steps(): with traceme.TraceMe('TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size): callbacks.on_train_batch_begin(step) tmp_logs = train_function(iterator) # Catch OutOfRangeError for Datasets of unknown size. # This blocks until the batch has finished executing. # TODO(b/150292341): Allow multiple async steps here. if not data_handler.inferred_steps: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. callbacks.on_train_batch_end(step, logs) switched = not self.update_and_switch( epoch, dynamic_switch, verbose) # If a switch occurred, we need to restore the weights if switched: backend.batch_set_value( zip(self.trainable_variables, weights)) self.reset_metrics() epoch_logs = copy.copy(logs) if self.accumulate_gradients: self.optimizer.apply_gradients( zip(self.accumulated_gradients, self.trainable_variables)) # Run validation. if validation_data and self._should_eval( epoch, validation_freq): val_x, val_y, val_sample_weight = ( data_adapter.unpack_x_y_sample_weight(validation_data)) val_logs = self.evaluate( x=val_x, y=val_y, sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, steps=validation_steps, callbacks=callbacks, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, return_dict=True) val_logs = { 'val_' + name: val for name, val in val_logs.items() } epoch_logs.update(val_logs) callbacks.on_epoch_end(epoch, epoch_logs) if self.stop_training: break callbacks.on_train_end() return self.history