def test_save_profile(self):
        logdir = self.get_temp_dir()
        profiler.start(logdir)
        with traceme.TraceMe('three_times_five'):
            three = constant_op.constant(3)
            five = constant_op.constant(5)
            product = three * five
        self.assertAllEqual(15, product)

        profiler.stop()
        file_list = gfile.ListDirectory(logdir)
        self.assertEqual(len(file_list), 2)
        for file_name in gfile.ListDirectory(logdir):
            if gfile.IsDirectory(os.path.join(logdir, file_name)):
                self.assertEqual(file_name, 'plugins')
            else:
                self.assertTrue(file_name.endswith('.profile-empty'))
        profile_dir = os.path.join(logdir, 'plugins', 'profile')
        run = gfile.ListDirectory(profile_dir)[0]
        hostname = socket.gethostname()
        overview_page = os.path.join(profile_dir, run,
                                     hostname + '.overview_page.pb')
        self.assertTrue(gfile.Exists(overview_page))
        input_pipeline = os.path.join(profile_dir, run,
                                      hostname + '.input_pipeline.pb')
        self.assertTrue(gfile.Exists(input_pipeline))
        tensorflow_stats = os.path.join(profile_dir, run,
                                        hostname + '.tensorflow_stats.pb')
        self.assertTrue(gfile.Exists(tensorflow_stats))
        kernel_stats = os.path.join(profile_dir, run,
                                    hostname + '.kernel_stats.pb')
        self.assertTrue(gfile.Exists(kernel_stats))
        trace_file = os.path.join(profile_dir, run,
                                  hostname + '.trace.json.gz')
        self.assertTrue(gfile.Exists(trace_file))
Пример #2
0
    def test_profile(self):
        profiler.start()
        with traceme.TraceMe('three_times_five'):
            three = constant_op.constant(3)
            five = constant_op.constant(5)
            product = three * five
        self.assertAllEqual(15, product)
        with self.assertRaises(profiler.ProfilerAlreadyRunningError):
            profiler.start()

        profile_result = profiler.stop()
        profile_pb = trace_events_pb2.Trace()
        profile_pb.ParseFromString(profile_result)
        devices = frozenset(device.name
                            for device in profile_pb.devices.values())
        self.assertIn('/host:CPU', devices)
        if not test_util.IsBuiltWithROCm() and config.list_physical_devices(
                'GPU'):
            # device tracing is not yet supported on the ROCm platform
            self.assertIn('/device:GPU:0', devices)
        events = frozenset(event.name for event in profile_pb.trace_events)
        self.assertIn('three_times_five', events)
        self.assertIn('Mul', events)
        with self.assertRaises(profiler.ProfilerNotRunningError):
            profiler.stop()
Пример #3
0
def _constant_impl(
    value, dtype, shape, name, verify_shape, allow_broadcast):
  """Implementation of constant."""
  ctx = context.context()
  if ctx.executing_eagerly():
    if _pywrap_traceme.enabled:
      with traceme.TraceMe("tf.constant"):
        return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
    return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)

  g = ops.get_default_graph()
  tensor_value = attr_value_pb2.AttrValue()
  tensor_value.tensor.CopyFrom(
      tensor_util.make_tensor_proto(
          value, dtype=dtype, shape=shape, verify_shape=verify_shape,
          allow_broadcast=allow_broadcast))
  dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
  attrs = {"value": tensor_value, "dtype": dtype_value}
  const_tensor = g._create_op_internal(  # pylint: disable=protected-access
      "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0]

  if op_callbacks.should_invoke_op_callbacks():
    # TODO(b/147670703): Once the special-op creation code paths
    # are unified. Remove this `if` block.
    callback_outputs = op_callbacks.invoke_op_callbacks(
        "Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g)
    if callback_outputs is not None:
      const_tensor, = callback_outputs
  return const_tensor
  def test_context_manager_with_options(self):
    logdir = self.get_temp_dir()
    options = profiler.ProfilerOptions(
        host_tracer_level=3, python_tracer_level=1)
    with profiler.Profile(logdir, options):
      with traceme.TraceMe('three_times_five'):
        three = constant_op.constant(3)
        five = constant_op.constant(5)
        product = three * five
      self.assertAllEqual(15, product)

    file_list = gfile.ListDirectory(logdir)
    self.assertEqual(len(file_list), 2)
Пример #5
0
 def on_batch(self, step=0, mode=ModeKeys.TRAIN, size=1):
   """Provide a scope for running one batch."""
   with traceme.TraceMe(
       'TraceContext', graph_type=mode, step_num=step, batch_size=size):
     batch_logs = {'batch': step, 'size': size}
     self.callbacks._call_batch_hook(
         mode, 'begin', step, batch_logs)
     self.progbar.on_batch_begin(step, batch_logs)
     try:
       yield batch_logs
     finally:
       if not batch_logs.pop('data_exhausted', False):
         self.callbacks._call_batch_hook(
             mode, 'end', step, batch_logs)
         self.progbar.on_batch_end(step, batch_logs)
    def test_save_profile(self):
        logdir = self.get_temp_dir()
        profiler.start(logdir)
        with traceme.TraceMe('three_times_five'):
            three = constant_op.constant(3)
            five = constant_op.constant(5)
            product = three * five
        self.assertAllEqual(15, product)

        profiler.stop()
        file_list = gfile.ListDirectory(logdir)
        self.assertEqual(len(file_list), 2)
        for file_name in gfile.ListDirectory(logdir):
            if gfile.IsDirectory(os.path.join(logdir, file_name)):
                self.assertEqual(file_name, 'plugins')
            else:
                self.assertTrue(file_name.endswith('.profile-empty'))
        profile_dir = os.path.join(logdir, 'plugins', 'profile')
        run = gfile.ListDirectory(profile_dir)[0]
        hostname = socket.gethostname()
        overview_page = os.path.join(profile_dir, run,
                                     hostname + '.overview_page.pb')
        self.assertTrue(gfile.Exists(overview_page))
        input_pipeline = os.path.join(profile_dir, run,
                                      hostname + '.input_pipeline.pb')
        self.assertTrue(gfile.Exists(input_pipeline))
        tensorflow_stats = os.path.join(profile_dir, run,
                                        hostname + '.tensorflow_stats.pb')
        self.assertTrue(gfile.Exists(tensorflow_stats))

        trace_file = os.path.join(profile_dir, run, hostname + '.trace')
        self.assertTrue(gfile.Exists(trace_file))
        with gfile.Open(trace_file, 'rb') as f:
            profile_pb = trace_events_pb2.Trace()
            profile_pb.ParseFromString(f.read())
        devices = frozenset(device.name
                            for device in profile_pb.devices.values())
        self.assertIn('/host:CPU', devices)
        if config.list_physical_devices('GPU'):
            self.assertIn('/device:GPU:0', devices)
        events = frozenset(event.name for event in profile_pb.trace_events)
        self.assertIn('three_times_five', events)
        self.assertIn('Mul:Mul', events)
Пример #7
0
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            auto_switch=True,
            retry_fit=True,
            absorb=True,
            train_after_switch=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False,
            revert_after_fit=False):
        """
        Custom fit function for the context model

        auto_switch:        Enable/disable autonomous context switching
        train_after_switch:
        retry_fit:          Locate the next fitting context by re-performing fit.
        absorb:             Reset the switch sequence counter upon successful training.
                            This is mainly used to maintain switch sequencing for temporally-extended tasks
        revert_after_fit    This is a debug parameter to revert weights after performing a fit. This is used
                            to calculate the context deltas without incorrectly learning while auto switching
                            is disabled
        """

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight),
                    validation_split=validation_split,
                    shuffle=False))

        with self.distribute_strategy.scope(
        ), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=bool(verbose & Verbosity.Progress),
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            self.initialize_fit()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched_during_epoch = False  # Indicate if the model has attempted at least one switch during this epoch
                switched = True  # Indicate if the model switched on the most recent fit iteration
                weights = backend.batch_get_value(self.trainable_variables)
                # Perform a 'fit call'. Assuming retry_fit, this call is re-attempted after each switch until a context fits
                while switched and (retry_fit or not switched_during_epoch):
                    self.initialize_epoch(epoch)
                    iterator = iter(dataset)

                    # Perform a fit call
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe('TraceContext',
                                                 graph_type='train',
                                                 epoch_num=epoch,
                                                 step_num=step,
                                                 batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)

                        switched = not self.update_and_switch(
                            epoch, auto_switch, absorb, retry_fit, verbose)
                        switched_during_epoch |= switched

                        # If a switch occurred, we need to restore the weights
                        if switched or (switched_during_epoch
                                        and not train_after_switch
                                        ) or revert_after_fit:
                            backend.batch_set_value(
                                zip(self.trainable_variables, weights))
                            self.reset_metrics()

                epoch_logs = copy.copy(logs)

                # Run validation.
                if validation_data and self._should_eval(
                        epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {
                        'val_' + name: val
                        for name, val in val_logs.items()
                    }
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            dynamic_switch=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight),
                    validation_split=validation_split,
                    shuffle=False))

        with self.distribute_strategy.scope(
        ), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=bool(verbose & Verbosity.Progress),
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched = True
                weights = backend.batch_get_value(self.trainable_variables)
                while switched:
                    self.initialize_epoch(epoch)
                    iterator = iter(dataset)
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe('TraceContext',
                                                 graph_type='train',
                                                 epoch_num=epoch,
                                                 step_num=step,
                                                 batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)

                        switched = not self.update_and_switch(
                            epoch, dynamic_switch, verbose)
                        # If a switch occurred, we need to restore the weights
                        if switched:
                            backend.batch_set_value(
                                zip(self.trainable_variables, weights))
                            self.reset_metrics()

                epoch_logs = copy.copy(logs)

                if self.accumulate_gradients:
                    self.optimizer.apply_gradients(
                        zip(self.accumulated_gradients,
                            self.trainable_variables))

                # Run validation.
                if validation_data and self._should_eval(
                        epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {
                        'val_' + name: val
                        for name, val in val_logs.items()
                    }
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history