Exemplo n.º 1
0
    def __call__(self, resources, test_env, proc_func, args, kwargs,
                 use_dill_for_args):
        """The wrapper function that actually gets run in child process(es)."""

        global _barrier

        self._resources = resources
        _barrier = self._resources.barrier
        proc_func = dill.loads(proc_func)
        if use_dill_for_args:
            args = dill.loads(args)
            kwargs = dill.loads(kwargs)

        if faulthandler is not None:
            faulthandler.enable()
            faulthandler.register(signal.SIGTERM, chain=True)

        # All logging should go to stderr to be streamed to the main process.
        logging.set_stderrthreshold(logging.DEBUG)

        # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so
        # print() and logging.*() write directly to `streaming_pipe_w`.
        # Unfortunately since we cannot prepend task_type and task_id information to
        # the streamed logs we will need a thread per subprocess to distinguish
        # where the piece of message is from.
        os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno())
        os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno())

        pid = os.getpid()
        logging.info('Subprocess with PID %d (%s, %d) is now being started.',
                     pid, test_env.task_type, test_env.task_id)

        # The thread will be dedicated to checking messages from the parent process.
        threading.Thread(  # pylint: disable=unexpected-keyword-arg
            target=self._message_checking_func,
            args=(test_env.task_type, test_env.task_id),
            daemon=True).start()

        if test_env.v2_enabled:
            v2_compat.enable_v2_behavior()

        with self._runtime_mode(test_env.executing_eagerly):
            info = _run_contained(test_env.task_type, test_env.task_id,
                                  proc_func, args, kwargs)
            self._resources.process_status_queue.put(info)

            # Re-raise the exception in addition to reporting it to the parent
            # process, so that even if `--test_timeout` flag is set and the
            # error doesn't make it to be shown in parent process before bazel's
            # timeout, the log would still show what happens in this subprocess,
            # instead of silently suppressing the error due to early bazel
            # timeout. Raising an error in the subprocess produces stack trace in
            # the log, but the program continues running.
            if not info.is_successful:
                six.reraise(*info.exc_info)

            self._close_streaming()

        # Exit with code 0 as it's considered successful exit at this point.
        sys.exit(0)
Exemplo n.º 2
0
def _run(flags):
    """Runs the main uploader program given parsed flags.

    Args:
      flags: An `argparse.Namespace`.
    """

    logging.set_stderrthreshold(logging.WARNING)
    intent = _get_intent(flags)

    store = auth.CredentialsStore()
    if isinstance(intent, _AuthRevokeIntent):
        store.clear()
        sys.stderr.write("Logged out of uploader.\n")
        sys.stderr.flush()
        return
    # TODO(b/141723268): maybe reconfirm Google Account prior to reuse.
    credentials = store.read_credentials()
    if not credentials:
        _prompt_for_user_ack(intent)
        client_config = json.loads(auth.OAUTH_CLIENT_CONFIG)
        flow = auth.build_installed_app_flow(client_config)
        credentials = flow.run(force_console=flags.auth_force_console)
        sys.stderr.write("\n")  # Extra newline after auth flow messages.
        store.write_credentials(credentials)

    channel_options = None
    if flags.grpc_creds_type == "local":
        channel_creds = grpc.local_channel_credentials()
    elif flags.grpc_creds_type == "ssl":
        channel_creds = grpc.ssl_channel_credentials()
    elif flags.grpc_creds_type == "ssl_dev":
        channel_creds = grpc.ssl_channel_credentials(dev_creds.DEV_SSL_CERT)
        channel_options = [("grpc.ssl_target_name_override", "localhost")]
    else:
        msg = "Invalid --grpc_creds_type %s" % flags.grpc_creds_type
        raise base_plugin.FlagsError(msg)

    try:
        server_info = _get_server_info(flags)
    except server_info_lib.CommunicationError as e:
        _die(str(e))
    _handle_server_info(server_info)

    if not server_info.api_server.endpoint:
        logging.error("Server info response: %s", server_info)
        _die("Internal error: frontend did not specify an API server")
    composite_channel_creds = grpc.composite_channel_credentials(
        channel_creds, auth.id_token_call_credentials(credentials)
    )

    # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until
    # logdir exists to open channel.
    channel = grpc.secure_channel(
        server_info.api_server.endpoint,
        composite_channel_creds,
        options=channel_options,
    )
    with channel:
        intent.execute(server_info, channel)
Exemplo n.º 3
0
def main(_):
    opt = FLAGS
    # logging
    logging.set_verbosity(logging.INFO)
    logging.set_stderrthreshold(logging.INFO)
    if FLAGS.log_dir:
        if not os.path.exists(FLAGS.log_dir):
            os.makedirs(FLAGS.log_dir)
        logging.get_absl_handler().use_absl_log_file(FLAGS.dataset,
                                                     log_dir=FLAGS.log_dir)
    # dataset
    if opt.dataset == 'mnist':
        data_train, data_test = tf.keras.datasets.mnist.load_data()
    elif opt.dataset == 'cifar10':
        data_train, data_test = tf.keras.datasets.cifar10.load_data()
    else:
        raise NotImplementError
    x_train, y_train = data_train
    x_test, y_test = data_test
    x_train = x_train.astype(np.float32)
    x_test = x_test.astype(np.float32)
    y_train = y_train.reshape([
        -1,
    ])
    y_test = y_test.reshape([
        -1,
    ])
    # resize to (32, 32)
    if opt.dataset == 'mnist':
        x_train = batch_resize(x_train, (32, 32))[..., None]
        x_test = batch_resize(x_test, (32, 32))[..., None]
    # normalization
    mean = x_train.mean()
    stddev = x_train.std()
    x_train = (x_train - mean) / stddev
    x_test = (x_test - mean) / stddev
    logging.info('{}, {}'.format(x_train.shape, x_test.shape))
    # define abnoraml data and normal
    # training data only contains normal
    x_train = x_train[y_train != opt.anomaly, ...]
    y_train = y_train[y_train != opt.anomaly, ...]
    y_test = (y_test == opt.anomaly).astype(np.float32)
    # tf.data.Dataset
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    train_dataset = train_dataset.shuffle(opt.shuffle_buffer_size).batch(
        opt.batch_size, drop_remainder=True)
    test_dataset = test_dataset.batch(opt.batch_size, drop_remainder=False)

    # training
    ganomaly = GANomaly(opt,
                        train_dataset,
                        valid_dataset=None,
                        test_dataset=test_dataset)
    ganomaly.fit(opt.niter)

    # evaluating
    ganomaly.evaluate_best(test_dataset)
Exemplo n.º 4
0
def main(argv):
  logging.set_stderrthreshold('info')
  with errors.clean_commandline_error_exit():
    if len(argv) > 1:
      errors.log_and_raise(
          'Command line parsing failure: show_examples.py does not accept '
          'positional arguments but some are present on the command line: '
          '"{}".'.format(str(argv[1:])), errors.CommandLineError)
    run()
Exemplo n.º 5
0
def do_test(create_module_fn, exported_names=None, show_debug_info=False):
    """Runs test.

  1. Performs absl and tf "main"-like initialization that must run before almost
     anything else.
  2. Converts `tf.Module` to SavedModel
  3. Converts SavedModel to MLIR
  4. Prints the textual MLIR to stdout (it is expected that the caller will have
     FileCheck checks in its file to check this output).

  This is only for use by the MLIR SavedModel importer tests.

  Args:
    create_module_fn: A callable taking no arguments, which returns the
      `tf.Module` to be converted and printed.
    exported_names: A set of exported names for the MLIR converter (default is
      "export all").
    show_debug_info: If true, shows debug locations in the resulting MLIR.
  """
    if exported_names is None:
        exported_names = []

    # Make LOG(ERROR) in C++ code show up on the console.
    # All `Status` passed around in the C++ API seem to eventually go into
    # `LOG(ERROR)`, so this makes them print out by default.
    logging.set_stderrthreshold('error')

    # In true TF2 releases, v2 behavior is enabled as part of module __init__. In
    # TF1 builds, it must be enabled manually. If you get an error here,
    # it means that TF was used in V1 mode prior to calling this.
    tf.enable_v2_behavior()

    def app_main(argv):
        """Function passed to absl.app.run."""
        if len(argv) > 1:
            raise app.UsageError('Too many command-line arguments.')
        if FLAGS.save_model_path:
            save_model_path = FLAGS.save_model_path
        else:
            save_model_path = tempfile.mktemp(suffix='.saved_model')
        save_options = tf.saved_model.SaveOptions(
            save_debug_info=show_debug_info)
        tf.saved_model.save(create_module_fn(),
                            save_model_path,
                            options=save_options)
        logging.info('Saved model to: %s', save_model_path)
        mlir = pywrap_mlir.experimental_convert_saved_model_to_mlir(
            save_model_path, ','.join(exported_names), show_debug_info)
        # We don't strictly need this, but it serves as a handy sanity check
        # for that API, which is otherwise a bit annoying to test.
        # The canonicalization shouldn't affect these tests in any way.
        mlir = pywrap_mlir.experimental_run_pass_pipeline(
            mlir, 'canonicalize', show_debug_info)
        print(mlir)

    app.run(app_main)
Exemplo n.º 6
0
def do_test(signature_def_map, show_debug_info=False):
    """Runs test.

  1. Performs absl and tf "main"-like initialization that must run before almost
     anything else.
  2. Converts signature_def_map to SavedModel V1
  3. Converts SavedModel V1 to MLIR
  4. Prints the textual MLIR to stdout (it is expected that the caller will have
     FileCheck checks in its file to check this output).

  This is only for use by the MLIR SavedModel importer tests.

  Args:
    signature_def_map: A map from string key to signature_def. The key will be
      used as function name in the resulting MLIR.
    show_debug_info: If true, shows debug locations in the resulting MLIR.
  """

    # Make LOG(ERROR) in C++ code show up on the console.
    # All `Status` passed around in the C++ API seem to eventually go into
    # `LOG(ERROR)`, so this makes them print out by default.
    logging.set_stderrthreshold('error')

    def app_main(argv):
        """Function passed to absl.app.run."""
        if len(argv) > 1:
            raise app.UsageError('Too many command-line arguments.')
        if FLAGS.save_model_path:
            save_model_path = FLAGS.save_model_path
        else:
            save_model_path = tempfile.mktemp(suffix='.saved_model')

        sess = tf.Session()
        sess.run(tf.initializers.global_variables())
        builder = tf.saved_model.builder.SavedModelBuilder(save_model_path)
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map,
            main_op=tf.tables_initializer(),
            strip_default_attrs=True)
        builder.save()

        logging.info('Saved model to: %s', save_model_path)
        mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
            save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
            show_debug_info)
        # We don't strictly need this, but it serves as a handy sanity check
        # for that API, which is otherwise a bit annoying to test.
        # The canonicalization shouldn't affect these tests in any way.
        mlir = pywrap_mlir.experimental_run_pass_pipeline(
            mlir, 'tf-standard-pipeline', show_debug_info)
        print(mlir)

    app.run(app_main)
Exemplo n.º 7
0
def _run(flags, experiment_url_callback=None):
    """Runs the main uploader program given parsed flags.

    Args:
      flags: An `argparse.Namespace`.
      experiment_url_callback: A function accepting a single string argument
        containing the full TB.dev URL of the uploaded experiment.
    """

    logging.set_stderrthreshold(logging.WARNING)
    intent = _get_intent(flags, experiment_url_callback)

    store = auth.CredentialsStore()
    if isinstance(intent, _AuthRevokeIntent):
        store.clear()
        sys.stderr.write("Logged out of uploader.\n")
        sys.stderr.flush()
        return
    # TODO(b/141723268): maybe reconfirm Google Account prior to reuse.
    credentials = store.read_credentials()
    if not credentials:
        _prompt_for_user_ack(intent)
        client_config = json.loads(auth.OAUTH_CLIENT_CONFIG)
        flow = auth.build_installed_app_flow(client_config)
        credentials = flow.run(force_console=flags.auth_force_console)
        sys.stderr.write("\n")  # Extra newline after auth flow messages.
        store.write_credentials(credentials)

    (channel_creds, channel_options) = flags.grpc_creds_type.channel_config()

    try:
        server_info = _get_server_info(flags)
    except server_info_lib.CommunicationError as e:
        _die(str(e))
    _handle_server_info(server_info)
    logging.info("Received server info: <%r>", server_info)

    if not server_info.api_server.endpoint:
        logging.error("Server info response: %s", server_info)
        _die("Internal error: frontend did not specify an API server")
    composite_channel_creds = grpc.composite_channel_credentials(
        channel_creds, auth.id_token_call_credentials(credentials))

    # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until
    # logdir exists to open channel.
    channel = grpc.secure_channel(
        server_info.api_server.endpoint,
        composite_channel_creds,
        options=channel_options,
    )
    with channel:
        intent.execute(server_info, channel)
Exemplo n.º 8
0
def set_logging_level(log_level):
    if log_level.lower() not in _ABSL_LOGGING_LEVEL:
        raise ValueError("Logging initialize error. Can not recognize value of"
                         " <log_level> which by given '{}' , except 'debug',"
                         " 'info', 'warn', 'error'.".format(log_level))
    absl_handler = absl_logging.get_absl_handler()

    if absl_handler in logging.root.handlers:
        logging.root.removeHandler(absl_handler)

    absl_logging.set_verbosity(_ABSL_LOGGING_LEVEL[log_level])
    absl_logging.set_stderrthreshold(_ABSL_LOGGING_LEVEL[log_level])

    absl_logging._warn_preinit_stderr = False
    logging.root.addHandler(absl_handler)
Exemplo n.º 9
0
def _run(flags):
    """Runs the main uploader program given parsed flags.

  Args:
    flags: An `argparse.Namespace`.
  """

    logging.set_stderrthreshold(logging.WARNING)
    intent = _get_intent(flags)

    store = auth.CredentialsStore()
    if isinstance(intent, _AuthRevokeIntent):
        store.clear()
        sys.stderr.write('Logged out of uploader.\n')
        sys.stderr.flush()
        return
    # TODO(b/141723268): maybe reconfirm Google Account prior to reuse.
    credentials = store.read_credentials()
    if not credentials:
        _prompt_for_user_ack(intent)
        client_config = json.loads(auth.OAUTH_CLIENT_CONFIG)
        flow = auth.build_installed_app_flow(client_config)
        credentials = flow.run(force_console=flags.auth_force_console)
        sys.stderr.write('\n')  # Extra newline after auth flow messages.
        store.write_credentials(credentials)

    channel_options = None
    if flags.grpc_creds_type == 'local':
        channel_creds = grpc.local_channel_credentials()
    elif flags.grpc_creds_type == 'ssl':
        channel_creds = grpc.ssl_channel_credentials()
    elif flags.grpc_creds_type == 'ssl_dev':
        channel_creds = grpc.ssl_channel_credentials(dev_creds.DEV_SSL_CERT)
        channel_options = [('grpc.ssl_target_name_override', 'localhost')]
    else:
        msg = 'Invalid --grpc_creds_type %s' % flags.grpc_creds_type
        raise base_plugin.FlagsError(msg)

    composite_channel_creds = grpc.composite_channel_credentials(
        channel_creds, auth.id_token_call_credentials(credentials))

    # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until
    # logdir exists to open channel.
    channel = grpc.secure_channel(flags.endpoint,
                                  composite_channel_creds,
                                  options=channel_options)
    with channel:
        intent.execute(channel)
Exemplo n.º 10
0
def initialize(output_dir, seed):
    """Initilialize output directory and logging levels."""
    # Initialize output directory.
    os.makedirs(output_dir, exist_ok=True)
    log_dir = os.path.join(output_dir, 'logs')
    os.makedirs(log_dir, exist_ok=True)

    # Set logging levels.
    logging.get_absl_handler().use_absl_log_file('log_file', log_dir)
    logging.set_stderrthreshold(logging.INFO)
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
    # Fix seeds.
    random.seed(seed)
    np.random.seed(seed)
    tf.compat.v1.set_random_seed(seed)
    logging.info(f'Random seed is {seed}.')
Exemplo n.º 11
0
def do_test(create_module_fn, exported_names=None):
    """Runs test.

  1. Performs absl and tf "main"-like initialization that must run before almost
     anything else.
  2. Converts `tf.Module` to SavedModel
  3. Converts SavedModel to MLIR
  4. Prints the textual MLIR to stdout (it is expected that the caller will have
     FileCheck checks in its file to check this output).

  This is only for use by the MLIR SavedModel importer tests.

  Args:
    create_module_fn: A callable taking no arguments, which returns the
      `tf.Module` to be converted and printed.
    exported_names: A set of exported names for the MLIR converter (default is
      "export all").
  """
    if exported_names is None:
        exported_names = []

    # Make LOG(ERROR) in C++ code show up on the console.
    # All `Status` passed around in the C++ API seem to eventually go into
    # `LOG(ERROR)`, so this makes them print out by default.
    logging.set_stderrthreshold('error')

    # In true TF2 releases, v2 behavior is enabled as part of module __init__. In
    # TF1 builds, it must be enabled manually. If you get an error here,
    # it means that TF was used in V1 mode prior to calling this.
    tf.enable_v2_behavior()

    def app_main(argv):
        """Function passed to absl.app.run."""
        if len(argv) > 1:
            raise app.UsageError('Too many command-line arguments.')
        save_model_path = FLAGS.save_model_path
        tf.saved_model.save(create_module_fn(), save_model_path)
        logging.info('Saved model to: %s', save_model_path)
        mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
            save_model_path, ','.join(exported_names))
        print(mlir)

    app.run(app_main)
Exemplo n.º 12
0
def StartTeeLogsToFile(program_name: str = None,
                       log_dir: str = None,
                       file_log_level: int = logging.DEBUG) -> None:
    """Log messages to file as well as stderr.

  Args:
    program_name: The name of the program.
    log_dir: The directory to log to.
    file_log_level: The minimum verbosity level to log to file to.

  Raises:
    FileNotFoundError: If the requested log_dir does not exist.
  """
    if not pathlib.Path(log_dir).is_dir():
        raise FileNotFoundError(f"Log directory not found: '{log_dir}'")
    old_verbosity = logging.get_verbosity()
    logging.set_verbosity(file_log_level)
    logging.set_stderrthreshold(old_verbosity)
    logging.get_absl_handler().start_logging_to_file(program_name, log_dir)
    # The Absl logging handler function start_logging_to_file() sets logtostderr
    # to False. Re-enable whatever value it was before the call.
    FLAGS.logtostderr = False
Exemplo n.º 13
0
def do_test(signature_def_map,
            init_op=None,
            canonicalize=False,
            show_debug_info=False):
    """Runs test.

  1. Performs absl and tf "main"-like initialization that must run before almost
     anything else.
  2. Converts signature_def_map to SavedModel V1
  3. Converts SavedModel V1 to MLIR
  4. Prints the textual MLIR to stdout (it is expected that the caller will have
     FileCheck checks in its file to check this output).

  This is only for use by the MLIR SavedModel importer tests.

  Args:
    signature_def_map: A map from string key to signature_def. The key will be
      used as function name in the resulting MLIR.
    init_op: The initializer op for the saved model. If set, it will generate a
      initializer graph in the resulting MLIR.
    canonicalize: If true, canonicalizer will be run on the resulting MLIR.
    show_debug_info: If true, shows debug locations in the resulting MLIR.
  """

    # Make LOG(ERROR) in C++ code show up on the console.
    # All `Status` passed around in the C++ API seem to eventually go into
    # `LOG(ERROR)`, so this makes them print out by default.
    logging.set_stderrthreshold('error')

    def app_main(argv):
        """Function passed to absl.app.run."""
        if len(argv) > 1:
            raise app.UsageError('Too many command-line arguments.')
        if FLAGS.save_model_path:
            save_model_path = FLAGS.save_model_path
        else:
            save_model_path = tempfile.mktemp(suffix='.saved_model')

        sess = tf.Session()
        sess.run(tf.initializers.global_variables())
        builder = tf.saved_model.builder.SavedModelBuilder(save_model_path)
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map,
            main_op=init_op,
            strip_default_attrs=True)
        builder.save()

        logging.info('Saved model to: %s', save_model_path)
        # TODO(b/153507667): Set the following boolean flag once the hoisting
        #                    variables logic from SavedModel importer is removed.
        lift_variables = False
        upgrade_legacy = True
        mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
            save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
            lift_variables, upgrade_legacy, show_debug_info)

        if canonicalize:
            mlir = pywrap_mlir.experimental_run_pass_pipeline(
                mlir, 'canonicalize', show_debug_info)
        print(mlir)

    app.run(app_main)
Exemplo n.º 14
0
flags.DEFINE_integer(
    'shuffle_buffer_size', 2097152,
    'Input shuffle buffer size. A large number beneifts shuffling quality.')

flags.DEFINE_float('learning_rate', 2e-2, 'Initial learning rate.')

flags.DEFINE_integer('batch_size', 256, 'Batch size in terms of trainig.')

flags.DEFINE_integer('num_iterations', 5000000,
                     'Num of iterations in terms of trainig.')

flags.DEFINE_boolean('compile', True,
                     'Compiles functions for faster tf training.')

logging.set_verbosity('info')
logging.set_stderrthreshold('info')


def _validate(common_module):
  """Validates training configurations."""
  # Validate flags.
  validate_flag = common_module.validate
  validate_flag(FLAGS.model_input_keypoint_type,
                common_module.SUPPORTED_TRAINING_MODEL_INPUT_KEYPOINT_TYPES)


def run(input_dataset_class, common_module, keypoint_profiles_module,
        input_example_parser_creator, keypoint_preprocessor_3d):
  """Runs training pipeline.

  Args:
Exemplo n.º 15
0
                lambda x: utils.reshape_latents_conv_to_flat(x,
                                                             axis_n_to_keep=2),
                model_data)

        # Create mask to get rid of uninformative latents
        latent_mask = eval_metric.create_latent_mask(z0)
        informative_dim_n = np.sum(latent_mask)

        model_data = model_data[:, :, latent_mask]
        logging.info("Masking out model data, leaving dim_n=%d dimensions.",
                     model_data.shape[-1])

        gt_trajectory = np.reshape(
            gt_data, [np.product(gt_data.shape[:-1]), gt_data.shape[-1]])

        model_trajectory = np.reshape(
            model_data,
            [np.product(model_data.shape[:-1]), model_data.shape[-1]])

        # Standardize data
        gt_trajectory = eval_metric.standardize_data(gt_trajectory)
        model_trajectory = eval_metric.standardize_data(model_trajectory)

        return gt_trajectory, model_trajectory, informative_dim_n


if __name__ == "__main__":
    flags.mark_flag_as_required("config")
    logging.set_stderrthreshold(logging.INFO)
    app.run(functools.partial(platform.main, HGNExperiment))
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.ERROR)
    logging.set_verbosity(logging.ERROR)
    logging.set_stderrthreshold(logging.ERROR)
    logging.get_absl_handler().use_absl_log_file()

    # Load training and test data.
    train_data, train_labels, test_data, test_labels = load_mnist()

    # Instantiate the tf.Estimator.
    mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
                                              model_dir=FLAGS.model_dir)

    # A function to construct input_fn given (data, label), to be used by the
    # membership inference training hook.
    def input_fn_constructor(x, y):
        return tf.estimator.inputs.numpy_input_fn(x={'x': x},
                                                  y=y,
                                                  shuffle=False)

    with tf.Graph().as_default():
        # Get a summary writer for the hook to write to tensorboard.
        # Can set summary_writer to None if not needed.
        if FLAGS.model_dir:
            summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
        else:
            summary_writer = None
        mia_hook = MembershipInferenceTrainingHook(
            mnist_classifier, (train_data, train_labels),
            (test_data, test_labels),
            input_fn_constructor,
            attack_types=[AttackType.THRESHOLD_ATTACK],
            writer=summary_writer)

    # Create tf.Estimator input functions for the training and test data.
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={'x': train_data},
        y=train_labels,
        batch_size=FLAGS.batch_size,
        num_epochs=FLAGS.epochs,
        shuffle=True)
    eval_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x': test_data},
                                                       y=test_labels,
                                                       num_epochs=1,
                                                       shuffle=False)

    # Training loop.
    steps_per_epoch = 60000 // FLAGS.batch_size
    for epoch in range(1, FLAGS.epochs + 1):
        # Train the model, with the membership inference hook.
        mnist_classifier.train(input_fn=train_input_fn,
                               steps=steps_per_epoch,
                               hooks=[mia_hook])

        # Evaluate the model and print results
        eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
        test_accuracy = eval_results['accuracy']
        print('Test accuracy after %d epochs is: %.3f' %
              (epoch, test_accuracy))

    print('End of training attack')
    attack_results = run_attack_on_tf_estimator_model(
        mnist_classifier, (train_data, train_labels), (test_data, test_labels),
        input_fn_constructor,
        slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
        attack_types=[
            AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
        ])
    attack_properties, attack_values = get_all_attack_results(attack_results)
    print('\n'.join([
        '  %s: %.4f' % (', '.join(p), r)
        for p, r in zip(attack_properties, attack_values)
    ]))
def main(argv):

    # set level of verbosity
    if FLAGS.verbosity_level == 'DEBUG':
        logging.set_verbosity(logging.DEBUG)
        print('logging.DEBUG')
    elif FLAGS.verbosity_level == 'INFO':
        logging.set_verbosity(logging.INFO)
        print('logging.INFO')
    elif FLAGS.verbosity_level == 'WARNING':
        logging.set_verbosity(logging.WARNING)
        print('logging.WARNING')
    elif FLAGS.verbosity_level == 'ERROR':
        logging.set_verbosity(logging.ERROR)
        print('logging.ERROR')
    elif FLAGS.verbosity_level == 'FATAL':
        logging.set_verbosity(logging.FATAL)
        print('logging.FATAL')
    else:
        logging.set_verbosity(logging.INFO)
        print('logging.DEFAULT -> INFO')

    # logging.get_absl_handler().python_handler.stream = sys.stdout

    # test cloud storage python lib
    from google.cloud import storage
    storage_client = storage.Client()
    buckets = storage_client.list_buckets()
    print('GCP Buckets:', buckets)
    for bucket in buckets:
        print('items:', bucket.name)

    # Instantiates a client
    client = google.cloud.logging.Client()

    # Connects the logger to the root logging handler; by default this captures
    # all logs at INFO level and higher
    client.setup_logging()

    fmt = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
    formatter = logger.Formatter(fmt)
    logging.get_absl_handler().setFormatter(formatter)

    # set level of verbosity
    # logging.set_verbosity(logging.DEBUG)

    # logging.set_stderrthreshold(logging.WARNING)
    # logging._warn_preinit_stderr = True
    # loggers = [logger.getLogger()]  # get the root logger

    # for handler in loggers:
    #    print("handler ", handler)
    #    print("       handler.level-->  ", handler.level)
    #    print("       handler.name-->  ", handler.name)
    #    print("       handler.propagate-->  ", handler.propagate)
    #    print("       handler.parent-->  ", handler.parent )
    #    print(dir(handler))
    # level_log = 'INFO'
    # root_logger = logger.getLogger()
    # root_logger.handlers=[handler for handler in root_logger.handlers if isinstance(handler, (CloudLoggingHandler, ContainerEngineHandler, logging.ABSLHandler))]
    #
    # for handler in root_logger.handlers:
    #    print("----- handler ", handler)
    #    print("---------class ", handler.__class__)
    #    if handler.__class__ == logging.ABSLHandler:
    #        handler.python_handler.stream = sys.stdout
    #        #handler.handler.setStream(sys.stdout)
    tf.get_logger().propagate = False
    root_logger = logger.getLogger()
    print(' root_logger :', root_logger)
    print(' root_logger.handlers :', root_logger.handlers)
    print(' len(root_logger) :', len(root_logger.handlers))
    for h in root_logger.handlers:
        print('handlers:', h)
        print("---------class ", h.__class__)
        if h.__class__ == logging.ABSLHandler:
            print('++logging.ABSLHandler')
            h.python_handler.stream = sys.stdout
            h.setLevel(logger.INFO)
        if h.__class__ == google.cloud.logging.handlers.handlers.CloudLoggingHandler:
            print('++CloudLoggingHandler')
            h.setLevel(logger.CRITICAL)
            h.setStream(sys.stdout)
            logger.getLogger().addHandler(h)
        if h.__class__ == logger.StreamHandler:
            print('++logging.StreamHandler')
            h.setLevel(logger.CRITICAL)
            h.setStream(sys.stdout)
            logger.getLogger().addHandler(h)

    logging.set_stderrthreshold(logging.WARNING)
    # handler = client.get_default_handler()
    # print('hhh', handler)
    # logger.getLogger().setLevel(logger.INFO)
    # logger.getLogger().addHandler(handler)

    # handler = logger.StreamHandler(sys.stderr)
    # handler.setLevel(logger.CRITICAL)
    # logger.getLogger().addHandler(handler)

    # handler = logger.StreamHandler(sys.stdout)
    # handler.setLevel(logger.CRITICAL)
    # logger.getLogger().addHandler(handler)

    print(' 0 print --- ')
    logging.info(' 1 logging:')
    logging.info(' 2 logging:')

    print(' 3 print --- ')
    logging.debug(' 4 logging-test-debug')
    logging.info(' 5 logging-test-info')
    logging.warning(' 6 logging-test-warning')
    logging.error(' 7 logging test-error')
    print(' 8 print --- ')
    _ = BertTokenizer.from_pretrained('bert-base-uncased')
    print(' 9 print --- ')
    _ = tf.distribute.MirroredStrategy()
    print('10 print --- ')
def main(argv):
    import tensorflow as tf  # need to be here to have the env variables defined
    tf.get_logger().propagate = False

    # masking error related to cache
    logger.getLogger('googleapiclient.discovery_cache').setLevel(logger.ERROR)

    # set level of verbosity
    if FLAGS.verbosity_level == 'DEBUG':
        logging.set_verbosity(logging.DEBUG)
        print('logging.DEBUG')
    elif FLAGS.verbosity_level == 'INFO':
        logging.set_verbosity(logging.INFO)
    elif FLAGS.verbosity_level == 'WARNING':
        logging.set_verbosity(logging.WARNING)
    elif FLAGS.verbosity_level == 'ERROR':
        logging.set_verbosity(logging.ERROR)
    elif FLAGS.verbosity_level == 'FATAL':
        logging.set_verbosity(logging.FATAL)
    else:
        logging.set_verbosity(logging.INFO)

    # set level of verbosity for Tensorflow
    if FLAGS.verbosity_level == 'VERBOSE':
        tf.debugging.set_log_device_placement(True)
        tf.autograph.set_verbosity(10, alsologtostdout=False)

    # logger.getLogger('googleapiclient.discovery_cache').setLevel(logging.ERROR)

    # fmt = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
    fmt = "[%(levelname)s] %(message)s"
    formatter = logger.Formatter(fmt)
    logging.get_absl_handler().setFormatter(formatter)
    logging.get_absl_handler().python_handler.stream = sys.stdout
    logging.set_stderrthreshold(logging.WARNING)

    # level_log = 'INFO'

    # # Instantiates a client
    # client = google.cloud.logging.Client()
    #
    # # Connects the logger to the root logging handler; by default this captures
    # # all logs at INFO level and higher
    # client.setup_logging(log_level=FLAGS.verbosity)
    #
    # print('loggerDict:', logger.root.manager.loggerDict.keys())
    #
    # for i in logger.root.manager.loggerDict.keys():
    #     if i=='tensorflow':
    #        #print('-> propagate False')
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='google.auth':
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='google_auth_httplib2':
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='pyasn1':
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='sklearn':
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='google.cloud':
    #         logger.getLogger(i).propagate = False  # needed
    #     else:
    #         logger.getLogger(i).propagate = True # needed
    #     handler = logger.getLogger(i).handlers
    #     if handler != []:
    #         #print("logger's name=", i,handler)
    #         for h in handler:
    #             #print('    -> ', h)
    #             if h.__class__ == logger.StreamHandler:
    #                 #print('    -> name=', h.__class__)
    #                 h.setStream(sys.stdout)
    #                 h.setLevel(level_log)
    #                 #print('    --> handlers =', h)
    #
    root_logger = logger.getLogger()
    # root_logger.handlers=[handler for handler in root_logger.handlers if isinstance(handler, (CloudLoggingHandler, ContainerEngineHandler, logging.ABSLHandler))]
    #
    for handler in root_logger.handlers:
        print("----- handler ", handler)
        print("---------class ", handler.__class__)

    #     if handler.__class__ == CloudLoggingHandler:
    #         handler.setStream(sys.stdout)
    #         handler.setLevel(level_log)
    #     if handler.__class__ == logging.ABSLHandler:
    #         handler.python_handler.stream = sys.stdout
    #         handler.setLevel(level_log)
    # #        handler.handler.setStream(sys.stdout)
    #
    # for handler in root_logger.handlers:
    #     print("----- handler ", handler)
    #
    # # Instantiates a client
    # #client = google.cloud.logging.Client()
    #
    # # Connects the logger to the root logging handler; by default this captures
    # # all logs at INFO level and higher
    # #client.setup_logging()
    #
    # # redirect abseil logging messages to the stdout stream
    # #logging.get_absl_handler().python_handler.stream = sys.stdout
    #
    # # some test
    # #tf.get_logger().addHandler(logger.StreamHandler(sys.stdout))
    # #tf.get_logger().disabled = True
    # #tf.autograph.set_verbosity(5 ,alsologtostdout=True)
    #
    # ## DEBUG
    # #fmt = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
    # #formatter = logger.Formatter(fmt)
    # #logging.get_absl_handler().setFormatter(formatter)
    #
    # # set level of verbosity
    # #logging.set_verbosity(logging.DEBUG)
    #
    # print(' 0 print --- ')
    # logging.info(' 1 logging:')
    # logging.info(' 2 logging:')
    #
    # print(' 3 print --- ')
    # logging.debug(' 4 logging-test-debug')
    # logging.info(' 5 logging-test-info')
    # logging.warning(' 6 logging-test-warning')
    # logging.error(' 7 logging test-error')
    # print(' 8 print --- ')
    # #_=BertTokenizer.from_pretrained('bert-base-uncased')
    # print(' 9 print --- ')
    # _= tf.distribute.MirroredStrategy()
    # print('10 print --- ')
    # ## DEBUG

    print('logging.get_verbosity()', logging.get_verbosity())

    # print flags
    abseil_flags = [
        'logtostderr', 'alsologtostderr', 'log_dir', 'v', 'verbosity',
        'stderrthreshold', 'showprefixforinfo', 'run_with_pdb',
        'pdb_post_mortem', 'run_with_profiling', 'profile_file',
        'use_cprofile_for_profiling', 'only_check_args', 'flagfile', 'undefok'
    ]
    logging.info('-- Custom flags:')
    for name in list(FLAGS):
        if name not in abseil_flags:
            logging.info('custom flags: {:40} with value: {:50}'.format(
                name, str(FLAGS[name].value)))
    logging.info('\n-- Abseil flags:')
    for name in list(FLAGS):
        if name in abseil_flags:
            logging.info('abseil flags: {:40} with value: {:50}'.format(
                name, str(FLAGS[name].value)))

    if os.environ.get('LOG_FILE_TO_WRITE') is not None:
        logging.info('os.environ[LOG_FILE_TO_WRITE]: {}'.format(
            os.environ['LOG_FILE_TO_WRITE']))
        # split_path = os.environ['LOG_FILE_TO_WRITE'].split('/')
        # logging.get_absl_handler().use_absl_log_file(split_path[-1], '/'.join(split_path[:-1]))

    # fmt = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
    # formatter = logger.Formatter(fmt)
    # logging.get_absl_handler().setFormatter(formatter)

    # set level of verbosity
    # logging.set_verbosity(FLAGS.verbosity)
    # logging.set_stderrthreshold(FLAGS.verbosity)

    logging.info(tf.__version__)
    logging.info(tf.keras.__version__)
    logging.info(list(FLAGS))
    logging.debug('flags: \n {}'.format(FLAGS))
    logging.debug('env variables: \n{}'.format(os.environ))
    logging.debug('current dir: {}'.format(os.getcwd()))
    logging.debug('__package__: {}'.format(__package__))
    logging.debug('__name__: {}'.format(__name__))
    logging.debug('__file__: {}'.format(__file__))

    # only for HP tuning!
    if os.environ.get('CLOUD_ML_HP_METRIC_TAG') is not None:
        logging.info('this is a hyper parameters job !')

        # setup the hp flag
        FLAGS.is_hyperparameter_tuning = True
        logging.info('FLAGS.is_hyperparameter_tuning: {}'.format(
            FLAGS.is_hyperparameter_tuning))

        logging.info('os.environ[CLOUD_ML_HP_METRIC_TAG]: {}'.format(
            os.environ['CLOUD_ML_HP_METRIC_TAG']))
        logging.info('os.environ[CLOUD_ML_HP_METRIC_FILE]: {}'.format(
            os.environ['CLOUD_ML_HP_METRIC_FILE']))
        logging.info('os.environ[CLOUD_ML_TRIAL_ID]: {}'.format(
            os.environ['CLOUD_ML_TRIAL_ID']))

        # variable name for hyper parameter tuning
        metric_accuracy = os.environ['CLOUD_ML_HP_METRIC_TAG']
        logging.info('metric accuracy name: {}'.format(metric_accuracy))
    else:
        metric_accuracy = 'NotDefined'

    if os.environ.get('TF_CONFIG') is not None:
        logging.info('os.environ[TF_CONFIG]: {}'.format(
            os.environ['TF_CONFIG']))
    else:
        logging.error('os.environ[TF_CONFIG] doesn\'t exist !')

    if FLAGS.use_tpu:
        # Check or update the TensorFlow on the TPU cluster to match the one of the VM
        logging.info(
            'setting up TPU: check that TensorFlow version is the same on the VM and on the TPU cluster'
        )
        client_tpu = Client()

        # define TPU strategy before any ops
        client_tpu.configure_tpu_version(tf.__version__,
                                         restart_type='ifNeeded')
        logging.info('setting up TPU: cluster resolver')
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        )
        logging.info('setting up TPU: \n {}'.format(tpu_cluster_resolver))
        logging.info('running on TPU: \n {}'.format(
            tpu_cluster_resolver.cluster_spec().as_dict()['worker']))
        tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
        tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
        strategy = tf.distribute.experimental.TPUStrategy(tpu_cluster_resolver)
    else:
        strategy = tf.distribute.MirroredStrategy()
        print('do nothing')
    logging.info('Number of devices: {}'.format(strategy.num_replicas_in_sync))

    # choose language's model and tokenizer
    MODELS = [(TFBertModel, BertTokenizer, 'bert-base-multilingual-uncased')]
    model_index = 0  # BERT
    # model_class = MODELS[model_index][0]  # i.e TFBertModel
    # tokenizer_class = MODELS[model_index][1]  # i.e BertTokenizer
    pretrained_weights = MODELS[model_index][
        2]  # 'i.e bert-base-multilingual-uncased'

    # download   pre trained model:
    if FLAGS.pretrained_model_dir:
        # download pre trained model from a bucket
        logging.info('downloading pretrained model!')
        search = re.search('gs://(.*?)/(.*)', FLAGS.pretrained_model_dir)
        if search is not None:
            bucket_name = search.group(1)
            blob_name = search.group(2)
            local_path = '.'
            mu.download_blob(bucket_name, blob_name, local_path)
            pretrained_model_dir = local_path + '/' + blob_name
        else:
            pretrained_model_dir = FLAGS.pretrained_model_dir
    else:
        # download pre trained model from internet
        pretrained_model_dir = '.'

    # some check
    logging.info('Batch size:            {:6}/{:6}'.format(
        FLAGS.batch_size_train, FLAGS.batch_size_eval))
    logging.info('Step per epoch:        {:6}/{:6}'.format(
        FLAGS.steps_per_epoch_train, FLAGS.steps_per_epoch_eval))
    logging.info('Total number of batch: {:6}/{:6}'.format(
        FLAGS.steps_per_epoch_train * (FLAGS.epochs + 1),
        FLAGS.steps_per_epoch_eval * 1))

    # with tf.summary.create_file_writer(FLAGS.output_dir,
    #                                   filename_suffix='.oup',
    #                                   name='test').as_default():
    #    tf.summary.scalar('metric_accuracy', 1.0, step=1)
    # print('-- 00001')
    #  read TFRecords files, shuffle, map and batch size
    train_dataset = tf_bert.build_dataset(FLAGS.input_train_tfrecords,
                                          FLAGS.batch_size_train, 2048)
    valid_dataset = tf_bert.build_dataset(FLAGS.input_eval_tfrecords,
                                          FLAGS.batch_size_eval, 2048)

    # set repeat
    train_dataset = train_dataset.repeat(FLAGS.epochs + 1)
    valid_dataset = valid_dataset.repeat(2)

    # reset all variables used by Keras
    tf.keras.backend.clear_session()

    # create and compile the Keras model in the context of strategy.scope
    with strategy.scope():
        logging.debug('pretrained_model_dir={}'.format(pretrained_model_dir))
        model = tf_bert.create_model(pretrained_weights,
                                     pretrained_model_dir=pretrained_model_dir,
                                     num_labels=FLAGS.num_classes,
                                     learning_rate=FLAGS.learning_rate,
                                     epsilon=FLAGS.epsilon)
    # train the model
    tf_bert.train_and_evaluate(model,
                               num_epochs=FLAGS.epochs,
                               steps_per_epoch=FLAGS.steps_per_epoch_train,
                               train_data=train_dataset,
                               validation_steps=FLAGS.steps_per_epoch_eval,
                               eval_data=valid_dataset,
                               output_dir=FLAGS.output_dir,
                               n_steps_history=FLAGS.n_steps_history,
                               FLAGS=FLAGS,
                               decay_type=FLAGS.decay_type,
                               learning_rate=FLAGS.learning_rate,
                               s=FLAGS.decay_learning_rate,
                               n_batch_decay=FLAGS.n_batch_decay,
                               metric_accuracy=metric_accuracy)
Exemplo n.º 19
0
) -> None:
  """Log messages to file as well as stderr.

  Args:
    program_name: The name of the program.
    log_dir: The directory to log to.
    file_log_level: The minimum verbosity level to log to file to.

  Raises:
    FileNotFoundError: If the requested log_dir does not exist.
  """
  if not pathlib.Path(log_dir).is_dir():
    raise FileNotFoundError(f"Log directory not found: '{log_dir}'")
  old_verbosity = logging.get_verbosity()
  logging.set_verbosity(file_log_level)
  logging.set_stderrthreshold(old_verbosity)
  logging.get_absl_handler().start_logging_to_file(program_name, log_dir)
  # The Absl logging handler function start_logging_to_file() sets logtostderr
  # to False. Re-enable whatever value it was before the call.
  FLAGS.logtostderr = False


def StopTeeLogsToFile():
  """Stop logging messages to file as well as stderr."""
  logging.get_absl_handler().flush()
  logging.get_absl_handler().stream = sys.stderr
  FLAGS.logtostderr = True


@contextlib.contextmanager
def TeeLogsToFile(
Exemplo n.º 20
0
def do_test(create_signature,
            canonicalize=False,
            show_debug_info=False,
            use_lite=False):
    """Runs test.

  1. Performs absl and tf "main"-like initialization that must run before almost
     anything else.
  2. Converts signature_def_map to SavedModel V1
  3. Converts SavedModel V1 to MLIR
  4. Prints the textual MLIR to stdout (it is expected that the caller will have
     FileCheck checks in its file to check this output).

  This is only for use by the MLIR SavedModel importer tests.

  Args:
    create_signature: A functor that return signature_def_map, init_op and
      assets_collection. signature_def_map is a map from string key to
      signature_def. The key will be used as function name in the resulting
      MLIR.
    canonicalize: If true, canonicalizer will be run on the resulting MLIR.
    show_debug_info: If true, shows debug locations in the resulting MLIR.
    use_lite: If true, importer will not do any graph transformation such as
      lift variables.
  """

    # Make LOG(ERROR) in C++ code show up on the console.
    # All `Status` passed around in the C++ API seem to eventually go into
    # `LOG(ERROR)`, so this makes them print out by default.
    logging.set_stderrthreshold('error')

    def app_main(argv):
        """Function passed to absl.app.run."""
        if len(argv) > 1:
            raise app.UsageError('Too many command-line arguments.')
        if FLAGS.save_model_path:
            save_model_path = FLAGS.save_model_path
        else:
            save_model_path = tempfile.mktemp(suffix='.saved_model')

        signature_def_map, init_op, assets_collection = create_signature()

        sess = tf.Session()
        sess.run(tf.initializers.global_variables())
        builder = tf.saved_model.builder.SavedModelBuilder(save_model_path)
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map,
            main_op=init_op,
            assets_collection=assets_collection,
            strip_default_attrs=True)
        builder.save()

        logging.info('Saved model to: %s', save_model_path)
        # TODO(b/153507667): Set the following boolean flag once the hoisting
        #                    variables logic from SavedModel importer is removed.
        lift_variables = False
        upgrade_legacy = True
        if use_lite:
            mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
                save_model_path,
                ','.join([tf.saved_model.tag_constants.SERVING]),
                upgrade_legacy, show_debug_info)
            # We don't strictly need this, but it serves as a handy sanity check
            # for that API, which is otherwise a bit annoying to test.
            # The canonicalization shouldn't affect these tests in any way.
            mlir = pywrap_mlir.experimental_run_pass_pipeline(
                mlir, 'tf-standard-pipeline', show_debug_info)
        else:
            mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
                save_model_path,
                ','.join([tf.saved_model.tag_constants.SERVING]),
                lift_variables, upgrade_legacy, show_debug_info)

        if canonicalize:
            mlir = pywrap_mlir.experimental_run_pass_pipeline(
                mlir, 'canonicalize', show_debug_info)
        print(mlir)

    app.run(app_main)
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.ERROR)
    logging.set_verbosity(logging.ERROR)
    logging.set_stderrthreshold(logging.ERROR)
    logging.get_absl_handler().use_absl_log_file()

    # Load training and test data.
    x_train, y_train, x_test, y_test = load_cifar10()

    # Instantiate the tf.Estimator.
    classifier = tf.estimator.Estimator(model_fn=small_cnn_fn,
                                        model_dir=FLAGS.model_dir)

    # A function to construct input_fn given (data, label), to be used by the
    # membership inference training hook.
    def input_fn_constructor(x, y):
        return tf.estimator.inputs.numpy_input_fn(x={'x': x},
                                                  y=y,
                                                  shuffle=False)

    # Get hook for membership inference attack.
    mia_hook = MembershipInferenceTrainingHook(
        classifier, (x_train, y_train), (x_test, y_test),
        input_fn_constructor,
        slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
        attack_types=[
            AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
        ],
        tensorboard_dir=FLAGS.model_dir,
        tensorboard_merge_classifiers=FLAGS.tensorboard_merge_classifiers)

    # Create tf.Estimator input functions for the training and test data.
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={'x': x_train},
        y=y_train,
        batch_size=FLAGS.batch_size,
        num_epochs=FLAGS.epochs,
        shuffle=True)
    eval_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x': x_test},
                                                       y=y_test,
                                                       num_epochs=1,
                                                       shuffle=False)

    # Training loop.
    steps_per_epoch = 50000 // FLAGS.batch_size
    for epoch in range(1, FLAGS.epochs + 1):
        # Train the model, with the membership inference hook.
        classifier.train(input_fn=train_input_fn,
                         steps=steps_per_epoch,
                         hooks=[mia_hook])

        # Evaluate the model and print results
        eval_results = classifier.evaluate(input_fn=eval_input_fn)
        test_accuracy = eval_results['accuracy']
        print('Test accuracy after %d epochs is: %.3f' %
              (epoch, test_accuracy))

    print('End of training attack')
    attack_results = run_attack_on_tf_estimator_model(
        classifier, (x_train, y_train), (x_test, y_test),
        input_fn_constructor,
        slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
        attack_types=[
            AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
        ])
    att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
        attack_results)
    print('\n'.join([
        '  %s: %.4f' % (', '.join([s, t, m]), v)
        for t, s, m, v in zip(att_types, att_slices, att_metrics, att_values)
    ]))
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    logging.set_verbosity(logging.INFO)
    logging.set_stderrthreshold(logging.INFO)
    logging.get_absl_handler().use_absl_log_file()

    # Load training and test data.
    train_data, train_labels, test_data, test_labels = load_mnist()

    # Instantiate the tf.Estimator.
    mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
                                              model_dir=FLAGS.model_dir)

    # A function to construct input_fn given (data, label), to be used by the
    # membership inference training hook.
    def input_fn_constructor(x, y):
        return tf.estimator.inputs.numpy_input_fn(x={'x': x},
                                                  y=y,
                                                  shuffle=False)

    with tf.Graph().as_default():
        # Get a summary writer for the hook to write to tensorboard.
        # Can set summary_writer to None if not needed.
        if FLAGS.model_dir:
            summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
        else:
            summary_writer = None
        mia_hook = MembershipInferenceTrainingHook(mnist_classifier,
                                                   (train_data, train_labels),
                                                   (test_data, test_labels),
                                                   input_fn_constructor, [],
                                                   summary_writer)

    # Create tf.Estimator input functions for the training and test data.
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={'x': train_data},
        y=train_labels,
        batch_size=FLAGS.batch_size,
        num_epochs=FLAGS.epochs,
        shuffle=True)
    eval_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x': test_data},
                                                       y=test_labels,
                                                       num_epochs=1,
                                                       shuffle=False)

    # Training loop.
    steps_per_epoch = 60000 // FLAGS.batch_size
    for epoch in range(1, FLAGS.epochs + 1):
        # Train the model, with the membership inference hook.
        mnist_classifier.train(input_fn=train_input_fn,
                               steps=steps_per_epoch,
                               hooks=[mia_hook])

        # Evaluate the model and print results
        eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
        test_accuracy = eval_results['accuracy']
        print('Test accuracy after %d epochs is: %.3f' %
              (epoch, test_accuracy))

    print('End of training attack')
    run_attack_on_tf_estimator_model(mnist_classifier,
                                     (train_data, train_labels),
                                     (test_data, test_labels),
                                     input_fn_constructor, ['lr'])