Example #1
0
def main(argv):  # pylint: disable=missing-docstring
    """Main entry point."""
    if len(argv) > 1:
        raise app.UsageError("Unknown arguments: '{}'.".format(' '.join(
            argv[1:])))
    sys.exit(pytest.main([__file__, '-vv']))
Example #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Unrecognized command line flags.')
    sys.exit(pytest.main([__file__, '-vv']))
Example #3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")
    _generate_data()
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    aiplatform.constants.API_BASE_PATH = FLAGS.api_uri
    m = re.match(
        "projects/(.*)/locations/(.*)/tensorboards/.*", FLAGS.tensorboard_resource_name
    )
    project_id = m[1]
    region = m[2]
    api_client = aiplatform.initializer.global_config.create_client(
        client_class=TensorboardClientWithOverride, location_override=region,
    )

    try:
        tensorboard = api_client.get_tensorboard(name=FLAGS.tensorboard_resource_name)
    except grpc.RpcError as rpc_error:
        if rpc_error.code() == grpc.StatusCode.NOT_FOUND:
            raise app.UsageError(
                "Tensorboard resource %s not found" % FLAGS.tensorboard_resource_name,
                exitcode=0,
            )
        raise

    if tensorboard.blob_storage_path_prefix:
        path_prefix = tensorboard.blob_storage_path_prefix + "/"
        first_slash_index = path_prefix.find("/")
        bucket_name = path_prefix[:first_slash_index]
        blob_storage_bucket = storage.Client(project=project_id).bucket(bucket_name)
        blob_storage_folder = path_prefix[first_slash_index + 1 :]
    else:
        raise app.UsageError(
            "Tensorboard resource {} is obsolete. Please create a new one.".format(
                FLAGS.tensorboard_resource_name
            ),
            exitcode=0,
        )

    tb_uploader = uploader.TensorBoardUploader(
        experiment_name=FLAGS.experiment_name,
        experiment_display_name=FLAGS.experiment_display_name,
        tensorboard_resource_name=tensorboard.name,
        blob_storage_bucket=blob_storage_bucket,
        blob_storage_folder=blob_storage_folder,
        allowed_plugins=FLAGS.allowed_plugins,
        writer_client=api_client,
        logdir=FLAGS.logdir,
        one_shot=FLAGS.one_shot,
        event_file_inactive_secs=FLAGS.event_file_inactive_secs,
        run_name_prefix=FLAGS.run_name_prefix,
    )

    tb_uploader.create_experiment()

    print(
        "View your Tensorboard at https://{}.{}/experiment/{}".format(
            region,
            FLAGS.web_server_uri,
            tb_uploader.get_experiment_resource_name().replace("/", "+"),
        )
    )
    if FLAGS.one_shot:
        tb_uploader._upload_once()  # pylint: disable=protected-access
    else:
        tb_uploader.start_uploading()
Example #5
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Read the full spec file, used for everything
  with open(FLAGS.spec_file, 'r') as spec_file:
    tag_spec = yaml.safe_load(spec_file)

  # Get existing partial contents
  partials = gather_existing_partials(FLAGS.partial_dir)

  # Abort if spec.yaml is invalid
  schema = yaml.safe_load(SCHEMA_TEXT)
  v = TfDockerTagValidator(schema, partials=partials)
  if not v.validate(tag_spec):
    eprint('> Error: {} is an invalid spec! The errors are:'.format(
        FLAGS.spec_file))
    eprint(yaml.dump(v.errors, indent=2))
    exit(1)
  tag_spec = v.normalized(tag_spec)

  # Assemble tags and images used to build them
  all_tags = assemble_tags(tag_spec, FLAGS.arg, FLAGS.release, partials)

  # Empty Dockerfile directory if building new Dockerfiles
  if FLAGS.construct_dockerfiles:
    eprint('> Emptying Dockerfile dir "{}"'.format(FLAGS.dockerfile_dir))
    shutil.rmtree(FLAGS.dockerfile_dir, ignore_errors=True)
    mkdir_p(FLAGS.dockerfile_dir)

  # Set up Docker helper
  dock = docker.from_env()

  # Login to Docker if uploading images
  if FLAGS.upload_to_hub:
    if not FLAGS.hub_username:
      eprint('> Error: please set --hub_username when uploading to Dockerhub.')
      exit(1)
    if not FLAGS.hub_repository:
      eprint(
          '> Error: please set --hub_repository when uploading to Dockerhub.')
      exit(1)
    if not FLAGS.hub_password:
      eprint('> Error: please set --hub_password when uploading to Dockerhub.')
      exit(1)
    dock.login(
        username=FLAGS.hub_username,
        password=FLAGS.hub_password,
    )

  # Each tag has a name ('tag') and a definition consisting of the contents
  # of its Dockerfile, its build arg list, etc.
  failed_tags = []
  succeeded_tags = []
  for tag, tag_defs in all_tags.items():
    for tag_def in tag_defs:
      eprint('> Working on {}'.format(tag))

      if FLAGS.exclude_tags_matching and re.match(FLAGS.exclude_tags_matching,
                                                  tag):
        eprint('>> Excluded due to match against "{}".'.format(
            FLAGS.exclude_tags_matching))
        continue

      if FLAGS.only_tags_matching and not re.match(FLAGS.only_tags_matching,
                                                   tag):
        eprint('>> Excluded due to failure to match against "{}".'.format(
            FLAGS.only_tags_matching))
        continue

      # Write releases marked "is_dockerfiles" into the Dockerfile directory
      if FLAGS.construct_dockerfiles and tag_def['is_dockerfiles']:
        path = os.path.join(FLAGS.dockerfile_dir,
                            tag_def['dockerfile_subdirectory'],
                            tag + '.Dockerfile')
        eprint('>> Writing {}...'.format(path))
        if not FLAGS.dry_run:
          mkdir_p(os.path.dirname(path))
          with open(path, 'w') as f:
            f.write(tag_def['dockerfile_contents'])

      # Don't build any images for dockerfile-only releases
      if not FLAGS.build_images:
        continue

      # Only build images for host architecture
      proc_arch = platform.processor()
      is_x86 = proc_arch.startswith('x86')
      if (is_x86 and any(arch in tag for arch in ['ppc64le']) or
          not is_x86 and proc_arch not in tag):
        continue

      # Generate a temporary Dockerfile to use to build, since docker-py
      # needs a filepath relative to the build context (i.e. the current
      # directory)
      dockerfile = os.path.join(FLAGS.dockerfile_dir, tag + '.temp.Dockerfile')
      if not FLAGS.dry_run:
        with open(dockerfile, 'w') as f:
          f.write(tag_def['dockerfile_contents'])
      eprint('>> (Temporary) writing {}...'.format(dockerfile))

      repo_tag = '{}:{}'.format(FLAGS.repository, tag)
      eprint('>> Building {} using build args:'.format(repo_tag))
      for arg, value in tag_def['cli_args'].items():
        eprint('>>> {}={}'.format(arg, value))

      # Note that we are NOT using cache_from, which appears to limit
      # available cache layers to those from explicitly specified layers. Many
      # of our layers are similar between local builds, so we want to use the
      # implied local build cache.
      tag_failed = False
      image, logs = None, []
      if not FLAGS.dry_run:
        try:
          # Use low level APIClient in order to stream log output
          resp = dock.api.build(
              timeout=FLAGS.hub_timeout,
              path='.',
              nocache=FLAGS.nocache,
              dockerfile=dockerfile,
              buildargs=tag_def['cli_args'],
              tag=repo_tag)
          last_event = None
          image_id = None
          # Manually process log output extracting build success and image id
          # in order to get built image
          while True:
            try:
              output = next(resp).decode('utf-8')
              json_output = json.loads(output.strip('\r\n'))
              if 'stream' in json_output:
                eprint(json_output['stream'], end='')
                match = re.search(r'(^Successfully built |sha256:)([0-9a-f]+)$',
                                  json_output['stream'])
                if match:
                  image_id = match.group(2)
                last_event = json_output['stream']
                # collect all log lines into the logs object
                logs.append(json_output)
            except StopIteration:
              eprint('Docker image build complete.')
              break
            except ValueError:
              eprint('Error parsing from docker image build: {}'.format(output))
          # If Image ID is not set, the image failed to built properly. Raise
          # an error in this case with the last log line and all logs
          if image_id:
            image = dock.images.get(image_id)
          else:
            raise docker.errors.BuildError(last_event or 'Unknown', logs)

          # Run tests if requested, and dump output
          # Could be improved by backgrounding, but would need better
          # multiprocessing support to track failures properly.
          if FLAGS.run_tests_path:
            if not tag_def['tests']:
              eprint('>>> No tests to run.')
            for test in tag_def['tests']:
              eprint('>> Testing {}...'.format(test))
              container, = dock.containers.run(
                  image,
                  '/tests/' + test,
                  working_dir='/',
                  log_config={'type': 'journald'},
                  detach=True,
                  stderr=True,
                  stdout=True,
                  volumes={
                      FLAGS.run_tests_path: {
                          'bind': '/tests',
                          'mode': 'ro'
                      }
                  },
                  runtime=tag_def['test_runtime']),
              ret = container.wait()
              code = ret['StatusCode']
              out = container.logs(stdout=True, stderr=False)
              err = container.logs(stdout=False, stderr=True)
              container.remove()
              if out:
                eprint('>>> Output stdout:')
                eprint(out.decode('utf-8'))
              else:
                eprint('>>> No test standard out.')
              if err:
                eprint('>>> Output stderr:')
                eprint(err.decode('utf-8'))
              else:
                eprint('>>> No test standard err.')
              if code != 0:
                eprint('>> {} failed tests with status: "{}"'.format(
                    repo_tag, code))
                failed_tags.append(tag)
                tag_failed = True
                if FLAGS.stop_on_failure:
                  eprint('>> ABORTING due to --stop_on_failure!')
                  exit(1)
              else:
                eprint('>> Tests look good!')

        except docker.errors.BuildError as e:
          eprint('>> {} failed to build with message: "{}"'.format(
              repo_tag, e.msg))
          eprint('>> Build logs follow:')
          log_lines = [l.get('stream', '') for l in e.build_log]
          eprint(''.join(log_lines))
          failed_tags.append(tag)
          tag_failed = True
          if FLAGS.stop_on_failure:
            eprint('>> ABORTING due to --stop_on_failure!')
            exit(1)

        # Clean temporary dockerfiles if they were created earlier
        if not FLAGS.keep_temp_dockerfiles:
          os.remove(dockerfile)

      # Upload new images to DockerHub as long as they built + passed tests
      if FLAGS.upload_to_hub:
        if not tag_def['upload_images']:
          continue
        if tag_failed:
          continue

        eprint('>> Uploading to {}:{}'.format(FLAGS.hub_repository, tag))
        if not FLAGS.dry_run:
          p = multiprocessing.Process(
              target=upload_in_background,
              args=(FLAGS.hub_repository, dock, image, tag))
          p.start()

      if not tag_failed:
        succeeded_tags.append(tag)

  if failed_tags:
    eprint(
        '> Some tags failed to build or failed testing, check scrollback for '
        'errors: {}'.format(','.join(failed_tags)))
    exit(1)

  eprint('> Writing built{} tags to standard out.'.format(
      ' and tested' if FLAGS.run_tests_path else ''))
  for tag in succeeded_tags:
    print('{}:{}'.format(FLAGS.repository, tag))
Example #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    batch_size = FLAGS.batch_size
    learning_rate = FLAGS.learning_rate
    num_train_steps = FLAGS.num_train_steps
    eval_freq = FLAGS.eval_frequency
    random_seed = FLAGS.random_seed

    if not FLAGS.dev:
        raise app.UsageError('Please provide path to dev set.')
    if not FLAGS.train:
        raise app.UsageError('Please provide path to training set.')
    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    device_batch_size = batch_size // jax.device_count()

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval'))

    # create the training and development dataset
    vocabs = input_pipeline.create_vocabs(FLAGS.train)
    config = models.TransformerConfig(vocab_size=len(vocabs['forms']),
                                      output_vocab_size=len(vocabs['xpos']),
                                      max_len=FLAGS.max_length)

    attributes_input = [input_pipeline.CoNLLAttributes.FORM]
    attributes_target = [input_pipeline.CoNLLAttributes.XPOS]
    train_ds = input_pipeline.sentence_dataset_dict(FLAGS.train,
                                                    vocabs,
                                                    attributes_input,
                                                    attributes_target,
                                                    batch_size=batch_size,
                                                    bucket_size=config.max_len)
    train_iter = iter(train_ds)

    eval_ds = input_pipeline.sentence_dataset_dict(FLAGS.dev,
                                                   vocabs,
                                                   attributes_input,
                                                   attributes_target,
                                                   batch_size=batch_size,
                                                   bucket_size=config.max_len,
                                                   repeat=1)

    model = models.Transformer(config)

    rng = random.PRNGKey(random_seed)
    rng, init_rng = random.split(rng)

    # call a jitted initialization function to get the initial parameter tree
    @jax.jit
    def initialize_variables(init_rng):
        init_batch = jnp.ones((config.max_len, 1), jnp.float32)
        init_variables = model.init(init_rng, inputs=init_batch, train=False)
        return init_variables

    init_variables = initialize_variables(init_rng)

    optimizer_def = optim.Adam(learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=1e-1)
    optimizer = optimizer_def.create(init_variables['params'])
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=learning_rate)

    p_train_step = jax.pmap(functools.partial(
        train_step, model=model, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')

    def eval_step(params, batch):
        """Calculate evaluation metrics on a batch."""
        inputs, targets = batch['inputs'], batch['targets']
        weights = jnp.where(targets > 0, 1.0, 0.0)
        logits = model.apply({'params': params}, inputs=inputs, train=False)
        return compute_metrics(logits, targets, weights)

    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())
    metrics_all = []
    tick = time.time()
    best_dev_score = 0
    for step, batch in zip(range(num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access

        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        if (step + 1) % eval_freq == 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            logging.info('train in step: %d, loss: %.4f', step,
                         summary['loss'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                train_summary_writer.scalar('steps per second', steps_per_sec,
                                            step)
                for key, val in summary.items():
                    train_summary_writer.scalar(key, val, step)
                train_summary_writer.flush()

            metrics_all = [
            ]  # reset metric accumulation for next evaluation cycle.

            eval_metrics = []
            eval_iter = iter(eval_ds)

            for eval_batch in eval_iter:
                eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = eval_batch['inputs'].shape[0]
                if cur_pred_batch_size != batch_size:
                    # pad up to batch size
                    eval_batch = jax.tree_map(
                        lambda x: pad_examples(x, batch_size), eval_batch)
                eval_batch = common_utils.shard(eval_batch)

                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)

            logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])

            if best_dev_score < eval_summary['accuracy']:
                best_dev_score = eval_summary['accuracy']
                # TODO: save model.
            eval_summary['best_dev_score'] = best_dev_score
            logging.info('best development model score %.4f', best_dev_score)
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(key, val, step)
                eval_summary_writer.flush()
Example #7
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))
  tff.backends.native.set_local_execution_context(max_fanout=10)

  model_builder = functools.partial(
      stackoverflow_models.create_recurrent_model,
      vocab_size=FLAGS.vocab_size,
      embedding_size=FLAGS.embedding_size,
      latent_size=FLAGS.latent_size,
      num_layers=FLAGS.num_layers,
      shared_embedding=FLAGS.shared_embedding)

  loss_builder = functools.partial(
      tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

  special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size)
  pad_token = special_tokens.pad
  oov_tokens = special_tokens.oov
  eos_token = special_tokens.eos

  def metrics_builder():
    return [
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_with_oov', masked_tokens=[pad_token]),
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens),
        # Notice BOS never appears in ground truth.
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_no_oov_or_eos',
            masked_tokens=[pad_token, eos_token] + oov_tokens),
        keras_metrics.NumBatchesCounter(),
        keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
    ]

  datasets = stackoverflow_dataset.construct_word_level_datasets(
      FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round,
      FLAGS.sequence_length, FLAGS.max_elements_per_user,
      FLAGS.num_validation_examples)
  train_dataset, validation_dataset, test_dataset = datasets

  if FLAGS.uniform_weighting:
    def client_weight_fn(local_outputs):
      del local_outputs
      return 1.0
  else:
    def client_weight_fn(local_outputs):
      return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

  def model_fn():
    return tff.learning.from_keras_model(
        model_builder(),
        loss_builder(),
        input_spec=validation_dataset.element_spec,
        metrics=metrics_builder())

  if FLAGS.noise_multiplier is not None:
    if not FLAGS.uniform_weighting:
      raise ValueError(
          'Differential privacy is only implemented for uniform weighting.')

    dp_query = tff.utils.build_dp_query(
        clip=FLAGS.clip,
        noise_multiplier=FLAGS.noise_multiplier,
        expected_total_weight=FLAGS.clients_per_round,
        adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
        target_unclipped_quantile=FLAGS.target_unclipped_quantile,
        clipped_count_budget_allocation=FLAGS.clipped_count_budget_allocation,
        expected_num_clients=FLAGS.clients_per_round,
        per_vector_clipping=FLAGS.per_vector_clipping,
        model=model_fn())

    weights_type = tff.learning.framework.weights_type_from_model(model_fn)
    aggregation_process = tff.utils.build_dp_aggregate_process(
        weights_type.trainable, dp_query)
  else:
    aggregation_process = None

  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')
  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')

  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_weight_fn=client_weight_fn,
      client_optimizer_fn=client_optimizer_fn,
      aggregation_process=aggregation_process)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      train_dataset, FLAGS.clients_per_round)

  evaluate_fn = training_utils.build_evaluate_fn(
      model_builder=model_builder,
      eval_dataset=validation_dataset,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder,
      assign_weights_to_keras_model=dp_utils.assign_weights_to_keras_model)

  test_fn = training_utils.build_evaluate_fn(
      model_builder=model_builder,
      # Use both val and test for symmetry with other experiments, which
      # evaluate on the entire test set.
      eval_dataset=validation_dataset.concatenate(test_dataset),
      loss_builder=loss_builder,
      metrics_builder=metrics_builder,
      assign_weights_to_keras_model=dp_utils.assign_weights_to_keras_model)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
  training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

  training_loop.run(
      iterative_process=iterative_process,
      client_datasets_fn=client_datasets_fn,
      validation_fn=evaluate_fn,
      test_fn=test_fn,
      hparam_dict=hparam_dict,
      **training_loop_dict)
Example #8
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")
  _generate_trainval_archive()
  _generate_test_archive()
Example #9
0
def main(argv):
    global BLEU_THRESHOLD_REACHED
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    init_mllogger()
    mllogger.event('cache_clear')
    mllogger.start('init_start')
    mllogger.event('submission_org', 'Google')
    mllogger.event('submission_platform',
                   'TPUv3-{}'.format(jax.device_count()))
    mllogger.event('submission_division', 'closed')
    mllogger.event('submission_status', 'research')
    mllogger.event('submission_benchmark', 'transformer')
    mllogger.event('train_samples', input_pipeline.N_TRAIN)
    mllogger.event('eval_samples', input_pipeline.N_EVAL)

    tf.enable_v2_behavior()

    # Use hardware RNG for bernoulli randoms in dropout mask creation.
    if FLAGS.hardware_rng:
        models.set_hardware_bernoulli()

    num_partitions = FLAGS.num_partitions
    batch_size = FLAGS.batch_size
    if batch_size is None:
        batch_size = min(16 * jax.device_count() // num_partitions, 2048)
    mllogger.event('global_batch_size', batch_size)

    num_eval_steps = FLAGS.num_eval_steps
    max_target_length = FLAGS.max_target_length
    max_eval_target_length = FLAGS.max_eval_target_length
    max_length = max(max_target_length, max_eval_target_length)
    mllogger.event('max_sequence_length',
                   max_length,
                   metadata={'method': 'discard'})
    if FLAGS.random_seed is not None:
        seed = FLAGS.random_seed
    else:
        seed = np.int32(time.time() if jax.host_id() == 0 else 0)
        seed = per_host_sum_pmap(seed)
    mllogger.event('seed', int(seed))
    steps_per_epoch = int(math.ceil(input_pipeline.N_TRAIN / batch_size))
    logging.info('steps per epoch: %d', steps_per_epoch)
    num_replicas = jax.local_device_count() // num_partitions
    device_train_input_shape = (batch_size //
                                (num_replicas * jax.host_count()),
                                max_target_length)
    # This is per-host; in principle 64/replica or more should fit
    eval_batch_size = min(
        32 * num_replicas,
        int(
            math.ceil(input_pipeline.N_EVAL /
                      (num_replicas * jax.host_count()))) * num_replicas)
    logging.info('eval batch size: %d', eval_batch_size)
    pred_batches = int(
        math.ceil(input_pipeline.N_EVAL /
                  (jax.host_count() * eval_batch_size)))
    logging.info('pred batches: %d', pred_batches)
    broadcast = functools.partial(_broadcast,
                                  num_replicas=num_replicas,
                                  num_partitions=num_partitions)

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))
    else:
        train_summary_writer = None
        eval_summary_writer = None
    # Write summaries in background thread to avoid blocking on device sync
    summary_thread = thread.ThreadPoolExecutor(1, 'summary')
    if FLAGS.infeed:
        # Infeed is currently synchronous, so do it in a background thread too
        infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(),
                                                'infeed')

    def maybe_start_xprof(seconds):
        if jax.host_id() == 0 and FLAGS.xprof:
            xprof = xprof_session.XprofSession()
            xprof.start_session('REDACTED', True, 2)

            def sleep_and_end_xprof():
                time.sleep(seconds)
                logging.info(
                    'Xprof URL: %s',
                    xprof.end_session_and_get_url(
                        tag=
                        'flax transformer, {} devices, {}-way, batch {} per replica'
                        .format(jax.device_count(), num_partitions,
                                device_train_input_shape[0])))

            thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof)

    # MLPerf 2020 WMT en-de dataset uses a custom T2T dataset:
    #   Shared 32K subword tokenization
    #   256-length packed training examples from WMT17
    #   97-length unpacked evaluation examples from WMT14
    train_keys = [
        'inputs', 'targets', 'inputs_position', 'targets_position',
        'inputs_segmentation', 'targets_segmentation'
    ]
    encoder = mlperf_encoder.SubwordTextEncoder(filename=FLAGS.vocab_path)
    input_encoder = encoder
    target_encoder = encoder
    vocab_size = input_encoder.vocab_size
    output_vocab_size = target_encoder.vocab_size

    input_shape = (batch_size, max_target_length)
    target_shape = (batch_size, max_target_length)

    transformer_kwargs = {
        'vocab_size': vocab_size,
        'output_vocab_size': output_vocab_size,
        'emb_dim': 1024,
        'num_heads': 16,
        'num_layers': 6,
        'qkv_dim': 1024,
        'mlp_dim': 4096,
        'max_len': max_length,
        'share_embeddings': FLAGS.share_embeddings,
        'logits_via_embedding': FLAGS.logits_via_embedding,
        'num_partitions': num_partitions,
    }

    rng = random.PRNGKey(seed)
    rng, init_rng = random.split(rng)
    model, cache_def = create_model(init_rng, tuple(input_shape),
                                    tuple(target_shape), transformer_kwargs)
    mllogger.event('opt_name', 'adam')
    if batch_size < 1024:
        learning_rate = 4.0  # 0.0625
        warmup_steps = 1000
        beta1 = 0.9
        beta2 = 0.98
    if batch_size < 2048:
        learning_rate = 2.0
        warmup_steps = 500  # ??
        beta1 = 0.9  # ??
        beta2 = 0.98  # ??
    else:
        learning_rate = 3.3092157691415953
        warmup_steps = 664
        beta1 = 0.9086575725261137
        beta2 = 0.9198719118104947
    epsilon = 1e-9
    if FLAGS.learning_rate is not None:
        learning_rate = FLAGS.learning_rate
    mllogger.event('opt_adam_beta_1', beta1)
    mllogger.event('opt_adam_beta_2', beta2)
    mllogger.event('opt_adam_epsilon', epsilon)
    optimizer_def = optim.Adam(learning_rate,
                               beta1=beta1,
                               beta2=beta2,
                               eps=epsilon,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(model)
    del model  # don't keep a copy of the initial model

    # Build parameter partition annotations for preserving partitions from train
    # to eval.
    partition_rules = [
        (('encoder', 'posembed_input'), partitions.empty_dict),
        (('decoder', 'posembed_targets'), partitions.empty_dict),
        (('embedding', ), partitions.spec(num_partitions, 1)),
        ((r'LayerNorm_\d+', '(bias|scale)'), None),
        ((r'encoder(decoder)?_norm', '(bias|scale)'), None),
        ((r'MultiHeadDotProductAttention_\d+', '(query|key|value)', 'kernel'),
         partitions.spec(1, num_partitions, 1)),
        ((r'MultiHeadDotProductAttention_\d+', 'out', 'kernel'),
         partitions.spec(num_partitions, 1, 1)),
        ((r'MlpBlock_\d+', r'Dense_\d+', 'bias'), None),
        ((r'MlpBlock_\d+', 'Dense_0', 'kernel'),
         partitions.spec(1, num_partitions)),
        ((r'MlpBlock_\d+', 'Dense_1', 'kernel'),
         partitions.spec(num_partitions, 1)),
        (('state', 'step'), None),
    ]
    optimizer_partitions = optimizer.restore_state(
        partitions.set_partitions(partition_rules, optimizer.state_dict()))

    optimizer = broadcast(optimizer)
    empty_metrics = broadcast({'loss': 0.0, 'accuracy': 0, 'denominator': 0})

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=learning_rate,
        warmup_steps=warmup_steps,
        hidden_size=transformer_kwargs['qkv_dim'])

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch',
                            in_axes=(None, 0, 0, 0))
    if num_partitions > 1:
        sharded_predict_step = sharded_jit(
            predict_step,
            in_parts=(None, optimizer_partitions.target, None),
            out_parts=None)
    else:
        sharded_predict_step = predict_step
    if FLAGS.extra_eval_metrics:
        p_eval_step = jax.pmap(eval_step, axis_name='batch', in_axes=(None, 0))
    p_pred_step = jax.pmap(sharded_predict_step,
                           axis_name='batch',
                           in_axes=(0, None, None))
    p_allreduce_metrics = jax.pmap(functools.partial(lax.psum,
                                                     axis_name='batch'),
                                   axis_name='batch')

    def device_train_loop_cond(args):
        _, _, _, _, step, epoch = args
        return step // steps_per_epoch == epoch

    def device_train_loop_body(args):
        optimizer, dropout_rngs, metrics, token, step, epoch = args
        input_data, token = lax.infeed(token,
                                       shape=tuple([
                                           jax.ShapedArray(
                                               device_train_input_shape,
                                               jnp.int32) for _ in train_keys
                                       ]))
        batch = {k: v for k, v in zip(train_keys, input_data)}
        optimizer, metrics, dropout_rngs = train_step(optimizer,
                                                      batch,
                                                      metrics,
                                                      learning_rate_fn,
                                                      dropout_rng=dropout_rngs)
        step += 1
        return optimizer, dropout_rngs, metrics, token, step, epoch

    def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch):
        token = lax.create_token(step)
        optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, dropout_rngs, metrics, token, step, epoch))
        return optimizer, dropout_rngs, metrics, step

    if num_partitions > 1:
        device_train_loop = sharded_jit(device_train_loop,
                                        in_parts=(optimizer_partitions, None,
                                                  None, None, None),
                                        out_parts=(optimizer_partitions, None,
                                                   None, None))
    p_train_epoch = jax.pmap(device_train_loop,
                             axis_name='batch',
                             in_axes=(None, 0, 0, None, None))

    p_allreduce_metrics_train = functools.partial(lax.psum, axis_name='batch')
    if num_partitions > 1:
        p_allreduce_metrics_train = sharded_jit(p_allreduce_metrics_train,
                                                in_parts=None,
                                                out_parts=None,
                                                num_partitions=num_partitions)
    p_allreduce_metrics_train = jax.pmap(p_allreduce_metrics_train,
                                         axis_name='batch')

    # Precompile all needed computations with fake data so as not to include
    # compilation time in MLPerf metrics.
    if FLAGS.precompile:
        logging.info('precompiling step/epoch functions')
        if FLAGS.infeed:
            # the device training loop condition will immediately be false, but
            # the optimizer tree will be resharded here
            optimizer, *_ = p_train_epoch(unbroadcast(optimizer),
                                          random.split(rng, num_replicas),
                                          empty_metrics,
                                          jnp.array(0, dtype=jnp.int32), 1)
        else:
            metrics = empty_metrics
            train_input_shape = (num_replicas, batch_size // num_replicas,
                                 input_pipeline.MAX_TRAIN_LEN)
            fake_batch = {
                k: jnp.ones(train_input_shape, jnp.int32)
                for k in train_keys
            }
            p_train_step(unbroadcast(optimizer),
                         fake_batch,
                         metrics,
                         dropout_rng=random.split(rng, num_replicas))
        eval_input_shape = (num_replicas, eval_batch_size // num_replicas,
                            input_pipeline.MAX_EVAL_LEN)
        fake_eval_batch = {
            'inputs': jnp.ones(eval_input_shape, jnp.int32),
            'targets': jnp.ones(eval_input_shape, jnp.int32),
        }
        if FLAGS.extra_eval_metrics:
            p_eval_step(unbroadcast(optimizer.target), fake_eval_batch)
        fake_cache = cache_def.initialize_cache(
            (eval_input_shape[1], FLAGS.max_predict_length))
        maybe_start_xprof(20)
        p_pred_step(fake_eval_batch['inputs'], unbroadcast(optimizer.target),
                    fake_cache)
        time.sleep(20)
        sync_devices()
        fake_bleu_1 = np.zeros((4, ), dtype=np.int32)
        fake_bleu_2 = np.zeros((), dtype=np.int32)
        per_host_sum_pmap((fake_bleu_1, fake_bleu_1, fake_bleu_2, fake_bleu_2))
        sync_devices()
        p_allreduce_metrics_train(empty_metrics)
        sync_devices()
        logging.info('finished precompiling step/epoch functions')

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, num_replicas)

    # Record time-0 metrics for proper tensorboard plot x-axis scaling.
    if jax.host_id() == 0:
        if FLAGS.compute_train_metrics:
            train_summary_writer.scalar('loss', 9.999, 0)
            train_summary_writer.scalar('accuracy', 0.0, 0)
            train_summary_writer.flush()
        eval_summary_writer.scalar('bleu', 0.0, 0)
        eval_summary_writer.flush()

    train_ds = input_pipeline.get_wmt_dataset(batch_size=batch_size //
                                              jax.host_count(),
                                              train=True)
    eval_ds = input_pipeline.get_wmt_dataset(batch_size=eval_batch_size,
                                             train=False)
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)
    local_devices = jax.local_devices()
    maybe_start_xprof(max(30, 60 / (jax.device_count() / 2048)))
    host_step, device_step = 0, broadcast(0)
    gc.disable()
    mllogger.end('init_stop')
    if jax.host_id() == 0:
        mllogger.start('run_start')
    for epoch in range(FLAGS.num_epochs):
        if jax.host_id() == 0 and not BLEU_THRESHOLD_REACHED:
            mllogger.start('block_start',
                           metadata={
                               'first_epoch_num': epoch + 1,
                               'epoch_count': 1
                           })
        metrics = empty_metrics
        if FLAGS.infeed:
            optimizer, dropout_rngs, metrics, device_step = p_train_epoch(
                unbroadcast(optimizer), dropout_rngs, metrics,
                unbroadcast(device_step), epoch)
        while int(host_step // steps_per_epoch) == epoch:
            # pylint: disable=protected-access
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))
            # Shard data to devices and do a training step.
            batch = jax.tree_map(
                lambda x: x.reshape((num_replicas, -1) + x.shape[1:]), batch)
            if FLAGS.infeed:
                for i, device in enumerate(local_devices):
                    replica_id = i // num_partitions
                    input_tuple = tuple(
                        [batch[k][replica_id] for k in train_keys])
                    assert input_tuple[0].shape == device_train_input_shape, (
                        'infeed shape error %s != %s' %
                        (input_tuple[0].shape, device_train_input_shape))
                    assert input_tuple[0].dtype == jnp.int32, (
                        'infeed dtype error %s != %s' %
                        (input_tuple[0].dtype, jnp.int32))
                    infeed_pool.submit(
                        functools.partial(device.transfer_to_infeed,
                                          input_tuple))
            else:
                optimizer, metrics, dropout_rngs = p_train_step(
                    unbroadcast(optimizer),
                    batch,
                    metrics,
                    dropout_rng=dropout_rngs)
            host_step += 1

        if FLAGS.compute_train_metrics:
            metrics = p_allreduce_metrics_train(metrics)
            # Schedule training metric handling.
            summary_thread.submit(
                functools.partial(write_train_summary, metrics,
                                  train_summary_writer, host_step))

        # Optional, extra evaluation metrics.
        if FLAGS.extra_eval_metrics:
            eval_metrics = []
            eval_iter = iter(eval_ds)
            for _, eval_batch in zip(range(num_eval_steps), eval_iter):
                eval_batch = common_utils.shard(eval_batch)
                metrics = p_eval_step(unbroadcast(optimizer.target),
                                      eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = p_allreduce_metrics(eval_metrics)
            # Schedule metric summarization/logging.
            summary_thread.submit(
                functools.partial(write_eval_summary, eval_metrics,
                                  eval_summary_writer, host_step))

        # Translation and BLEU Score.
        all_predicted, all_targets, all_bs = [], [], []
        for i in range(pred_batches):
            # pylint: disable=protected-access
            pred_batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter))
            logging.info('Predicting on input of shape %s.',
                         str(pred_batch['inputs'].shape))
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size != eval_batch_size:
                logging.info('Translation: uneven batch size %d.',
                             cur_pred_batch_size)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, eval_batch_size), pred_batch)
            pred_batch = jax.tree_map(
                lambda x: x.reshape((num_replicas, -1) + x.shape[1:]),
                pred_batch)
            per_device_batchsize = pred_batch['inputs'].shape[1]
            cache = cache_def.initialize_cache(
                (per_device_batchsize, FLAGS.max_predict_length))
            all_predicted.append(
                p_pred_step(pred_batch['inputs'],
                            unbroadcast(optimizer.target), cache))
            all_targets.append(pred_batch['targets'])
            all_bs.append(cur_pred_batch_size)
        # Schedule BLEU calculation and summarization/logging.
        # We use the ICI as part of BLEU score computation, so we call this from the
        # main thread so the BLEU pmap runs before the next train epoch pmap
        write_predict_summary(all_predicted, all_targets, all_bs,
                              target_encoder, eval_summary_writer, epoch,
                              host_step, summary_thread)

    # Wait until computations are done before exiting
    sync_devices()
    if jax.host_id() == 0:
        summary_thread.shutdown()
        if not BLEU_THRESHOLD_REACHED:
            mllogger.end('run_stop', metadata={'status': 'aborted'})
Example #10
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  if FLAGS.logging_config:
    print('Setting logging configuration: ', FLAGS.logging_config)
    config.fileConfig(FLAGS.logging_config)

  # Set random seed.
  np.random.seed(FLAGS.seed)
  tf.set_random_seed(FLAGS.seed)

  ############################################################################
  #                               DATA                                       #
  ############################################################################
  # Load data.
  data_class = GCNDataset if FLAGS.model_cls == 'gcn' else PlanetoidDataset
  data = load_data_planetoid(
      name=FLAGS.dataset_name,
      path=FLAGS.data_path,
      row_normalize=FLAGS.row_normalize,
      data_container_class=data_class)

  # Potentially add noisy edges. This can be used to asses the robustness of
  # GAM to noisy edges. See `Robustness` section of our paper.
  if FLAGS.target_ratio_correct:
    data = add_noisy_edges(data, FLAGS.target_ratio_correct)

  ############################################################################
  #                            PREPARE OUTPUTS                               #
  ############################################################################
  # Put together parameters to create a model name.
  model_name = FLAGS.model_cls
  model_name += ('_' + FLAGS.hidden_cls) if FLAGS.model_cls == 'mlp' else ''
  model_name += '-' + FLAGS.model_agr
  model_name += ('_' + FLAGS.hidden_agr) if FLAGS.model_agr == 'mlp' else ''
  model_name += '-aggr_' + FLAGS.aggregation_agr_inputs
  model_name += ('_' + FLAGS.hidden_aggreg) if FLAGS.hidden_aggreg else ''
  model_name += (
      '-add_%d-conf_%.2f-iterCls_%d-iterAgr_%d-batchCls_%d' %
      (FLAGS.num_samples_to_label, FLAGS.min_confidence_new_label,
       FLAGS.max_num_iter_cls, FLAGS.max_num_iter_agr, FLAGS.batch_size_cls))
  model_name += (('-wdecayCls_%.4f' %
                  FLAGS.weight_decay_cls) if FLAGS.weight_decay_cls else '')
  model_name += (('-wdecayAgr_%.4f' %
                  FLAGS.weight_decay_agr) if FLAGS.weight_decay_agr else '')
  model_name += '-LL_%s_LU_%s_UU_%s' % (str(
      FLAGS.reg_weight_ll), str(FLAGS.reg_weight_lu), str(FLAGS.reg_weight_uu))
  model_name += '-perfAgr' if FLAGS.use_perfect_agreement else ''
  model_name += '-perfCls' if FLAGS.use_perfect_classifier else ''
  model_name += '-keepProp' if FLAGS.keep_label_proportions else ''
  model_name += '-PenNegAgr' if FLAGS.penalize_neg_agr else ''
  model_name += '-VAT' if FLAGS.reg_weight_vat > 0 else ''
  model_name += 'ENT' if FLAGS.reg_weight_vat > 0 and FLAGS.use_ent_min else ''
  model_name += '-transd' if not FLAGS.inductive else ''
  model_name += '-L2' if FLAGS.use_l2_cls else '-CE'
  model_name += '-graph' if FLAGS.use_graph else '-noGraph'
  model_name += '-rowNorm' if FLAGS.row_normalize else ''
  model_name += '-seed_' + str(FLAGS.seed)
  model_name += FLAGS.experiment_suffix
  logging.info('Model name: %s', model_name)

  # Create directories for model checkpoints, summaries, and
  # self-labeled data backup.
  summary_dir = os.path.join(FLAGS.output_dir, 'summaries', FLAGS.dataset_name,
                             model_name)
  checkpoints_dir = os.path.join(FLAGS.output_dir, 'checkpoints',
                                 FLAGS.dataset_name, model_name)
  data_dir = os.path.join(FLAGS.data_output_dir, 'data_checkpoints',
                          FLAGS.dataset_name, model_name)
  if not os.path.exists(checkpoints_dir):
    os.makedirs(checkpoints_dir)
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  ############################################################################
  #                            MODEL SETUP                                   #
  ############################################################################
  # Create classification model.
  model_cls = get_model_cls(
      model_name=FLAGS.model_cls,
      data=data,
      dataset_name=FLAGS.dataset_name,
      hidden=FLAGS.hidden_cls)

  # Create agreement model.
  model_agr = get_model_agr(
      model_name=FLAGS.model_agr,
      dataset_name=FLAGS.dataset_name,
      hidden_aggreg=FLAGS.hidden_aggreg,
      aggregation_agr_inputs=FLAGS.aggregation_agr_inputs,
      hidden=FLAGS.hidden_agr)

  # Train.
  trainer = TrainerCotraining(
      model_cls=model_cls,
      model_agr=model_agr,
      max_num_iter_cotrain=FLAGS.max_num_iter_cotrain,
      min_num_iter_cls=FLAGS.min_num_iter_cls,
      max_num_iter_cls=FLAGS.max_num_iter_cls,
      num_iter_after_best_val_cls=FLAGS.num_iter_after_best_val_cls,
      min_num_iter_agr=FLAGS.min_num_iter_agr,
      max_num_iter_agr=FLAGS.max_num_iter_agr,
      num_iter_after_best_val_agr=FLAGS.num_iter_after_best_val_agr,
      num_samples_to_label=FLAGS.num_samples_to_label,
      min_confidence_new_label=FLAGS.min_confidence_new_label,
      keep_label_proportions=FLAGS.keep_label_proportions,
      num_warm_up_iter_agr=FLAGS.num_warm_up_iter_agr,
      optimizer=tf.train.AdamOptimizer,
      gradient_clip=FLAGS.gradient_clip,
      batch_size_agr=FLAGS.batch_size_agr,
      batch_size_cls=FLAGS.batch_size_cls,
      learning_rate_cls=FLAGS.learning_rate_cls,
      learning_rate_agr=FLAGS.learning_rate_agr,
      enable_summaries=True,
      enable_summaries_per_model=True,
      summary_dir=summary_dir,
      summary_step_cls=FLAGS.summary_step_cls,
      summary_step_agr=FLAGS.summary_step_agr,
      logging_step_cls=FLAGS.logging_step_cls,
      logging_step_agr=FLAGS.logging_step_agr,
      eval_step_cls=FLAGS.eval_step_cls,
      eval_step_agr=FLAGS.eval_step_agr,
      checkpoints_dir=checkpoints_dir,
      checkpoints_step=1,
      data_dir=data_dir,
      abs_loss_chg_tol=1e-10,
      rel_loss_chg_tol=1e-7,
      loss_chg_iter_below_tol=30,
      use_perfect_agr=FLAGS.use_perfect_agreement,
      use_perfect_cls=FLAGS.use_perfect_classifier,
      warm_start_cls=FLAGS.warm_start_cls,
      warm_start_agr=FLAGS.warm_start_agr,
      ratio_valid_agr=FLAGS.ratio_valid_agr,
      max_samples_valid_agr=FLAGS.max_samples_valid_agr,
      weight_decay_cls=FLAGS.weight_decay_cls,
      weight_decay_schedule_cls=FLAGS.weight_decay_schedule_cls,
      weight_decay_schedule_agr=FLAGS.weight_decay_schedule_agr,
      weight_decay_agr=FLAGS.weight_decay_agr,
      reg_weight_ll=FLAGS.reg_weight_ll,
      reg_weight_lu=FLAGS.reg_weight_lu,
      reg_weight_uu=FLAGS.reg_weight_uu,
      num_pairs_reg=FLAGS.num_pairs_reg,
      reg_weight_vat=FLAGS.reg_weight_vat,
      use_ent_min=FLAGS.use_ent_min,
      penalize_neg_agr=FLAGS.penalize_neg_agr,
      use_l2_cls=FLAGS.use_l2_cls,
      first_iter_original=FLAGS.first_iter_original,
      inductive=FLAGS.inductive,
      seed=FLAGS.seed,
      eval_acc_pred_by_agr=FLAGS.eval_acc_pred_by_agr,
      num_neighbors_pred_by_agr=FLAGS.num_neighbors_pred_by_agr,
      lr_decay_rate_cls=FLAGS.lr_decay_rate_cls,
      lr_decay_steps_cls=FLAGS.lr_decay_steps_cls,
      lr_decay_rate_agr=FLAGS.lr_decay_rate_agr,
      lr_decay_steps_agr=FLAGS.lr_decay_steps_agr,
      load_from_checkpoint=FLAGS.load_from_checkpoint,
      use_graph=FLAGS.use_graph,
      always_agree=FLAGS.always_agree,
      add_negative_edges_agr=FLAGS.add_negative_edges_agr)

  ############################################################################
  #                            TRAIN                                         #
  ############################################################################
  trainer.train(data)
Example #11
0
def main(argv):
  #######################################################################
  # Initial Setup. Logging, Flags, Random seeds.
  #######################################################################
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")
  absl_logging.use_python_logging()
  flags_dict = {
      flag.name: flag.value
      for flag in FLAGS.flags_by_module_dict()[argv[0]]
  }

  if FLAGS.use_subset:
    message = (f"{colorama.Back.RED}{colorama.Fore.WHITE}"
               f"{colorama.Style.BRIGHT}USING A SUBSET OF THE DATASET"
               f"{colorama.Style.RESET_ALL}")
    LOGGER.warning(
        message
    )

  utils.log_module_args(LOGGER, argv[0])
  if not FLAGS.output_dir.startswith("gs://"):
    utils.check_exists(FLAG_OUTPUT_DIR.value)
    if not tf.io.gfile.isdir(FLAG_OUTPUT_DIR.value):
      raise RuntimeError("Output dir needs to be a directory.")

  tf.random.set_seed(FLAG_RANDOM_SEED.value)
  np.random.seed(FLAG_RANDOM_SEED.value)

  # Prepare the instance output directory path and save the config there
  folder_name = time.strftime(
      f"{FLAG_RUN_NAME.value}_{FLAG_APPROACH_TYPE.value}_%Y%m%d-%H%M%S"
  )
  instance_output_dir = os.path.join(FLAG_OUTPUT_DIR.value, folder_name).strip()
  if not instance_output_dir.endswith("/"):
    instance_output_dir += "/"
  json_target = os.path.join(instance_output_dir, "training_params.json")
  if not json_target.strip().startswith("gs://"):
    subprocess.check_call(["mkdir", "-p", instance_output_dir])
  utils.to_json_file(json_target, instance_output_dir)

  ##############################################################################
  # Initialization and Configuration of the Devices.
  ##############################################################################
  tpu_setup = None
  # current_acelerator_type is always "CPU" in the beginning with TPUs
  if tf_utils.current_accelerator_type() == "CPU":
    tpu_setup = tf_utils.init_tpus()

  LOGGER.debug("Devices we are computing on:\n%s",
               utils.wrap_iterable(map(str, tf_utils.devices_to_use())))
  LOGGER.debug("All devices:")
  LOGGER.debug(tf_utils.device_mapping())

  if tf_utils.current_accelerator_type() == "GPU":
    tf.config.set_soft_device_placement(True)

  if tf_utils.current_accelerator_type() != "TPU":
    tf.debugging.set_log_device_placement(True)

  if FLAG_DISTRIBUTE_MODE.value in constants.PURE_DATA_PARALLEL_STRATEGIES:
    actual_num_replicas = len(tf_utils.devices_to_use())
  elif FLAG_DISTRIBUTE_MODE.value in constants.DATA_PARALLEL_DMC:
    actual_num_replicas = FLAG_NUM_REPLICAS.value
  else:
    actual_num_replicas = 1

  ##############################################################################
  # We load the retriever model if it is needed.
  ##############################################################################
  # Not currently used.

  retriever = None
  # if (FLAG_APPROACH_TYPE.value ==
  #     constants.ApproachTypeChoices.lm_and_realm):
  #   raise NotImplementedError("This part needs to be tested anew.")
    # config_path = FLAG_RETRIEVER_CONFIG_PATH.value
    # realm_save = tf_utils.REALMSave(**utils.from_json_file(config_path))
    #
    # # Approx 15 min when not in dev mode, on CPU
    # with utils.log_duration(LOGGER, "main",
    #                         "whole of BERTScaNNRetriever.__init__",
    #                         logging.INFO):
    #   scann_config = retrievers.ScannConfig(
    #       **utils.from_json_file(FLAG_SCANN_CONFIG_PATH.value))
    #   retriever = retrievers.BERTScaNNRetriever(
    #       retriever_module_path=realm_save.query_embedder_path,
    #       block_records_path=realm_save.text_records,
    #       num_block_records=realm_save.num_block_records,
    #       mode=tf.estimator.ModeKeys.EVAL,
    #       scann_config=scann_config)

  # elif (FLAG_APPROACH_TYPE.value ==
  #       constants.ApproachTypeChoices.cached_realm):
  #   raise NotImplementedError("This part needs to be tested anew.")
    # config_path = FLAG_RETRIEVER_CONFIG_PATH.value
    # realm_save = tf_utils.REALMSave(**utils.from_json_file(config_path))
    #
    # # Approx 15 min when not in dev mode, on CPU
    # with utils.log_duration(LOGGER, "main",
    #                         "whole of FullyCachedRetriever.__init__",
    #                         logging.INFO):
    #
    #   retriever = retrievers.FullyCachedRetriever(
    #       db_path=FLAG_FULLYCACHED_H5_PATH.value,
    #       block_records_path=realm_save.text_records,
    #       num_block_records=realm_save.num_block_records,
    #       )

  ##############################################################################
  # Distributed training task
  ##############################################################################
  if FLAG_TASK.value == constants.TaskChoices.train:
    with utils.log_duration(LOGGER, "main", "Load model"):
      utils.print_mem("before loading model", LOGGER)
      model_specific = task_specific.load_model(FLAG_MODEL_LOAD_PATH.value,
                                                FLAG_MODEL_KEY.value,
                                                FLAG_DISTRIBUTE_MODE.value,
                                                tpu_setup,
                                                FLAG_NUM_REPLICAS.value)
      utils.print_mem("after loading model", LOGGER)
      model_or_replicas = model_specific.model
      if isinstance(model_or_replicas, list):
        model_or_replicas: List[transformers.TFGPT2LMHeadModel]
      else:
        model_or_replicas: transformers.TFGPT2LMHeadModel

      tokenizer = model_specific.tokenizer

      def make_optimizer():
        return tensor2tensor.utils.adafactor.AdafactorOptimizer(
            learning_rate=FLAG_LEARNING_RATE.value)

      if model_specific.strategy:
        with model_specific.strategy.scope():
          optimizer = make_optimizer()
      else:
        optimizer = make_optimizer()

    ############################################################################
    # Prepare the dataset functions
    ############################################################################
    rg = np.random.default_rng(FLAG_RANDOM_SEED.value)

    def call_lm_preproc(
        repeat,
        split,
        random_seed
    ):
      """Using functools.partial prevents the linter from doing its job."""
      if FLAG_DATASET_NAME.value == constants.DatasetNameChoices.kilt_eli5:
        return task_specific.create_lm_ds_kilt_eli5(
            tokenizer=tokenizer,
            context_window_size=(
                model_or_replicas[0].config.n_positions
                if isinstance(model_or_replicas, list)
                else model_or_replicas.config.n_positions
            ),
            dataset_name=FLAG_DATASET_NAME.value,
            # Batches are split over the replicas:
            batch_size=FLAG_BATCH_SIZE.value * actual_num_replicas,
            db_path=FLAG_DB_PATH.value,
            random_seed=random_seed,
            use_subset=FLAG_USE_SUBSET.value,
            subset_size=FLAG_SUBSET_SIZE.value,
            use_helper_words=FLAG_USE_HELPER_WORDS.value,
            approach_type=FLAG_APPROACH_TYPE.value,
            num_retrievals=FLAG_NUM_RETRIEVALS.value,
            retrieval_temperature=FLAG_RETRIEVAL_TEMPERATURE.value,
            retriever=retriever,
            repeat=repeat,
            split=split,
            enable_debug_checks=FLAG_DATASET_DEBUG.value,
            retrieval_bank_size=FLAG_RETRIEVAL_BANK_SIZE.value,
            dataset_type=FLAG_DATASET_TYPE.value,
            qty_shuffle=FLAG_QTY_SHUFFLE.value,
            tfr_prefix=FLAG_TFR_PREFIX.value,
            max_length_generation=FLAG_MAX_LENGTH_GENERATION.value,
        )
      else:
        raise NotImplementedError(
            f"FLAG_DATASET_NAME.value unsupported: `{FLAG_DATASET_NAME.value}`"
        )

    make_training_dataset: Callable[Ellipsis, tf.data.Dataset] = functools.partial(
        call_lm_preproc,
        split="train",
        repeat=False,
    )
    make_eval_dataset: Callable[Ellipsis, tf.data.Dataset] = functools.partial(
        call_lm_preproc,
        split="eval",
        repeat=True,
    )

    ############################################################################
    # Prepare the step functions
    ############################################################################
    utils.check_contained(
        FLAG_DISTRIBUTE_MODE.value, constants.DistributeModeChoices.choices()
    )
    tf_function_flags = dict(
        experimental_compile=FLAG_EXPERIMENTAL_COMPILE.value,
        reduce_retracing=not FLAG_INPUT_FIXED_SIZE.value
    )

    if (FLAG_DISTRIBUTE_MODE.value ==
        constants.DistributeModeChoices.split_and_data_parallel):
      if not isinstance(model_or_replicas, list):
        raise RuntimeError(type(model_or_replicas))
      training_step = build_manual_data_parallel_training_step(
          model_or_replicas, optimizer, tf_function_flags
      )

    else:
      training_step = build_regular_training_step(
          model_or_replicas,
          optimizer,
          strategy=model_specific.strategy,
          tf_function_kwargs=tf_function_flags
      )

    evaluation_step = build_evaluation_step(
        model_or_replicas, tf_function_flags
    )

    secs_since_last_ckpt = time.time()
    # Model checkpoints are saved to the tmp_directory and then rsynced to GCS
    ##########################################################################
    # Prepare the different logging facilities
    ##########################################################################
    train_log_dir = os.path.join(instance_output_dir, "tensorboard", "train")
    eval_log_dir = os.path.join(instance_output_dir, "tensorboard", "eval")
    flags_log_dir = os.path.join(instance_output_dir, "tensorboard", "params")
    writers = dict(
        train=tf.summary.create_file_writer(train_log_dir),
        eval=tf.summary.create_file_writer(eval_log_dir),
        flags=tf.summary.create_file_writer(flags_log_dir)
    )
    with writers["flags"].as_default():
      tf.summary.text(
          "Flags",
          # Tensorboard takes Markdown:
          json.dumps(flags_dict, indent=4).replace("\n", "\n\n"),
          step=0
          )

    ma_loss = dict(
        train=utils.MovingAverage(0.9),
        eval=utils.MovingAverage(0.9)
        )
    step_counters = dict(train=0, eval=0)
    batch_counters = dict(train=0, eval=0)
    prev_batch_end = time.time()

    # The eval ds has no real concept of epoch, repeats forever, shuffling
    # each time it reaches its end
    with utils.log_duration(LOGGER, "main", "All of make_eval_dataset"):
      eval_ds_instance = make_eval_dataset(
          random_seed=rg.integers(-2**63, 2**63 - 1),
      )
    LOGGER.debug("Distributing the eval dataset to the replicas.")
    if FLAG_DATASET_TYPE.value == "tfr":
      eval_ds_instance = (
          model_specific.strategy.experimental_distribute_dataset(
              eval_ds_instance
          )
      )

    LOGGER.debug("Done distributing the eval dataset to the replcias.")
    eval_ds_instance = iter(eval_ds_instance)

    ##########################################################################
    # Training Loop
    ##########################################################################
    for epoch in itertools.count():
      ####################################################################
      # Epoch Setup
      ####################################################################
      LOGGER.debug("EPOCH %d START", epoch)
      # Shuffle differently every epoch
      with utils.log_duration(
          LOGGER, "main", "All of make_training_dataset"
      ):
        train_ds_instance = make_training_dataset(
            random_seed=rg.integers(-2**63, 2**63 - 1),
        )
      LOGGER.debug(
          "Attempting to distribute the training dataset to the replicas."
      )
      if FLAG_DATASET_TYPE.value == "tfr":
        train_ds_instance = (
            model_specific.strategy.experimental_distribute_dataset(
                train_ds_instance
            )
        )

      LOGGER.debug(
          "Done distributing the training dataset to the replicas."
      )
      train_ds_instance = iter(train_ds_instance)

      # This allows us to see if we reached the end of the training iterator,
      # in which case "did_at_least_one_training_batch == False".
      # We could also test that it did all the batches, to similar results.
      did_at_least_one_training_batch = True
      split = "eval"
      while did_at_least_one_training_batch:
        # Invert split
        if split == "train":
          split = "eval"
        else:
          split = "train"

        # Prepare to test if we did at least one training batch
        if split == "train":
          did_at_least_one_training_batch = False

        if split == "train":
          dataset_iterator = itertools.islice(
              train_ds_instance, FLAG_BATCHES_BETWEEN_EVALS.value
          )
        else:
          # The evaluation DS is tiny, so we reshuffle and take a random
          dataset_iterator = itertools.islice(
              eval_ds_instance, FLAG_NUMBER_EVAL_BATCHES.value
          )

        LOGGER.debug("Batching")
        for batch in dataset_iterator:
          # LOGGER.debug("Input sentence:\n\"%s\"",
          #              tokenizer.decode([x for x in batch["input_ids"][0]
          #                                if x != tokenizer.eos_token_id]))
          # LOGGER.debug("Label:\n\"%s\"",
          #              tokenizer.decode([(x if x != -100 else 0)
          #                                for x in batch["label_ids"][0]]))

          if FLAG_DATASET_TYPE.value != "tfr":
            batch = (
                model_specific.strategy
                .experimental_distribute_values_from_function(
                    tf_utils.make_dict_distribute_fn(batch)
                ))

          # We only care about training epochs as, obviously, we don't train
          # over eval samples; the number of  eval samples seen only
          # contributes to lowering the variance in the evaluation of when to
          # do early stopping.
          if split == "train":
            did_at_least_one_training_batch = True

          input_ids = batch["input_ids"]
          label_ids = batch["label_ids"]

          ####################################################################
          # Training Step
          ####################################################################
          step_counters[split] += (
              FLAG_BATCH_SIZE.value * actual_num_replicas
          )

          if split == "train":
            batch_counters[split] += 1
            training_kwargs = dict(
                input_ids=input_ids,
                label_ids=label_ids,
            )

            if model_specific.strategy:
              utils.print_mem("before running", LOGGER)

              LOGGER.debug("Training, Calling strategy.run")
              loss = model_specific.strategy.run(
                  training_step,
                  kwargs=training_kwargs
              )
              LOGGER.debug("Training, Done with strategy.run")
              utils.print_mem("after running", LOGGER)

            else:
              loss = training_step(**training_kwargs)  # pytype: disable=wrong-arg-count
              # If we are in the strategy-free data parallel mode, we need
              # to change the weights of all replicas to those of the model at
              # index 0
              if (
                  FLAG_DISTRIBUTE_MODE.value ==
                  constants.DistributeModeChoices.split_and_data_parallel
              ):
                for replica in model_or_replicas[1:]:
                  replica.set_weights(model_or_replicas[0].get_weights())

          ####################################################################
          # Evaluation Step
          ####################################################################
          elif split == "eval":
            evaluation_kwargs = dict(
                input_ids=input_ids,
                label_ids=label_ids,
            )

            if model_specific.strategy:
              loss = model_specific.strategy.run(
                  evaluation_step,
                  kwargs=evaluation_kwargs
              )
            else:
              loss = evaluation_step(**evaluation_kwargs)
          else:
            raise ValueError(f"Unexpected value for split: {split}")

          ####################################################################
          # Logging
          ####################################################################
          if (FLAG_DISTRIBUTE_MODE.value in
              constants.PURE_DATA_PARALLEL_STRATEGIES):
            utils.check_equal(len(loss.values), actual_num_replicas)
            LOGGER.debug("Split: %s", split)
            LOGGER.debug("Real num replicas: %s", actual_num_replicas)
            LOGGER.debug("Loss: %s", loss)
            LOGGER.debug("Loss values: %s", loss.values)

            average_loss = float(tf.math.reduce_mean(loss.values).numpy())
          else:
            average_loss = float(loss.numpy())

          # tf.debugging.check_numerics(loss)
          now = time.time()
          batch_duration = now - prev_batch_end
          prev_batch_end = now
          ma_loss[split].update(average_loss)

          # Actual logging
          LOGGER.info("Epoch: # %d", epoch)
          LOGGER.info("Tensorboard_dir: %s", instance_output_dir)
          LOGGER.info("Batch: %s # %d", split, batch_counters[split])
          LOGGER.info("Step: %s # %d", split, step_counters[split])
          if FLAG_USE_SUBSET.value:
            LOGGER.warning(">> USING A SUBSET OF THE DATASET <<")
          LOGGER.info(
              "%(split)s Batch loss:           %(metric)f",
              dict(split=split, metric=average_loss)
          )
          LOGGER.info(
              "%(split)s Moving average loss:  %(metric)f",
              dict(split=split, metric=ma_loss[split].average)
          )
          LOGGER.info(
              "%(split)s Moving average ppl:   %(metric)f",
              dict(split=split, metric=np.exp(ma_loss[split].average))
          )
          LOGGER.info(
              "%(split)s Batch duration:       %(duration)s",
              dict(
                  split=split,
                  duration=utils.TimeStamp.from_seconds(
                      batch_duration).format()
              )
          )
          if FLAG_DISTRIBUTE_MODE.value in constants.DATA_PARALLEL_DMC:
            LOGGER.info(
                "%(split)s Duration per sample:  %(duration)s",
                dict(
                    split=split,
                    duration=utils.TimeStamp.from_seconds(
                        batch_duration / (
                            FLAG_BATCH_SIZE.value * actual_num_replicas
                        )
                    )
                )
            )

          # Write to Tensorboard
          with writers[split].as_default():
            tf.summary.scalar(
                f"Loss/{split}", average_loss, step_counters[split]
            )
            tf.summary.scalar(
                f"PPL/{split}", np.exp(average_loss), step_counters[split]
            )
          writers[split].flush()

          # Save every 5 min
          if (time.time() - secs_since_last_ckpt) / (60 * 20) >= 1:
            secs_since_last_ckpt = time.time()
            save_model(
                train_steps=step_counters["train"],
                model_or_replicas=model_or_replicas,
                instance_output_dir=instance_output_dir
            )

        secs_since_last_ckpt = time.time()
        save_model(
            train_steps=step_counters["train"],
            model_or_replicas=model_or_replicas,
            instance_output_dir=instance_output_dir
        )
    #############################################################
    # Post Training Cleanup
    #######################################################################
    for writer in writers.values():
      writer.close()
Example #12
0
def _run_benchmarks(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")
    return _benchmark.RunSpecifiedBenchmarks()
Example #13
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    FLAGS.save_path = update_flag_value(FLAGS.save_path)
    FLAGS.eval_results_path = update_flag_value(FLAGS.eval_results_path)
    FLAGS.questions_path = update_flag_value(FLAGS.questions_path)
    FLAGS.golden_answers_path = update_flag_value(FLAGS.golden_answers_path)
    FLAGS.inferred_answers_path = update_flag_value(
        FLAGS.inferred_answers_path)

    if os.path.exists(os.path.join(FLAGS.save_path, 'vocab.cfq.tokens')):
        print_status('Skipping preprocessing')
    else:
        print_status('Running preprocessing')
        dataset = preprocessor.get_dataset_from_tfds(FLAGS.dataset,
                                                     FLAGS.split)
        preprocessor.write_dataset(dataset, FLAGS.save_path)
        token_vocab = preprocessor.get_token_vocab(FLAGS.save_path)
        preprocessor.write_token_vocab(token_vocab, FLAGS.save_path)

    t2t_usr_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                               'cfq')
    output_dir = os.path.join(FLAGS.save_path, 'output')

    print_status('Running t2t-datagen')
    # NOTE: This one skips automatically if the files exist.
    # TODO(danielfurrer): Sometimes one of the steps here will crash with a
    #                     CUBLAS_STATUS_NOT_INITIALIZED error. I suspect this is
    #                     related to the subprocess calls here. Rerunning seems to
    #                     solve the problem (perhaps because t2t-datagen returns
    #                     quickly if it was completed before).
    subprocess.run([
        't2t-datagen',
        '--t2t_usr_dir=' + t2t_usr_dir,
        '--data_dir=' + FLAGS.save_path,
        '--problem=' + T2T_PROBLEM,
        '--tmp_dir=/tmp/cfq_tmp',
    ],
                   check=True)

    print_status('Running t2t-trainer')
    subprocess.run([
        't2t-trainer',
        '--t2t_usr_dir=' + t2t_usr_dir,
        '--data_dir=' + FLAGS.save_path,
        '--problem=' + T2T_PROBLEM,
        '--model=' + FLAGS.model,
        '--hparams_set=' + FLAGS.hparams_set,
        '--output_dir=' + output_dir,
        '--train_steps=%s' % FLAGS.train_steps,
    ],
                   check=True)

    print_status('Running t2t-decoder')
    checkpoint_path = os.path.join(output_dir,
                                   'model.ckpt-%s' % FLAGS.train_steps)
    subprocess.run([
        't2t-decoder',
        '--t2t_usr_dir=' + t2t_usr_dir,
        '--data_dir=' + FLAGS.save_path,
        '--problem=' + T2T_PROBLEM,
        '--model=' + FLAGS.model,
        '--hparams_set=' + FLAGS.hparams_set,
        '--checkpoint_path=' + checkpoint_path,
        '--decode_from_file=' + FLAGS.questions_path,
        '--decode_to_file=' + FLAGS.inferred_answers_path,
        '--output_dir=' + output_dir,
    ],
                   check=True)

    print_status('Calculating accuracy')
    accuracy_result = evaluator.get_accuracy_result(
        FLAGS.questions_path, FLAGS.golden_answers_path,
        FLAGS.inferred_answers_path)
    evaluator.write_accuracy_result(accuracy_result,
                                    FLAGS.eval_results_path,
                                    print_output=True)
Example #14
0
def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Set up experiment params and load the configs from file/files.
    experiment_params = params.EdgeTPUBERTCustomParams()
    experiment_params = utils.config_override(experiment_params, FLAGS)

    # change the input mask type to tf.float32 to avoid additional casting op.
    experiment_params.student_model.encoder.mobilebert.input_mask_dtype = 'float32'

    # Experiments indicate using -120 as the mask value for Softmax is good enough
    # for both int8 and bfloat. So we set quantization_friendly to True for both
    # quant and float model.
    pretrainer_model = model_builder.build_bert_pretrainer(
        experiment_params.student_model,
        name='pretrainer',
        quantization_friendly=True)

    encoder_network = pretrainer_model.encoder_network
    model = models.BertSpanLabeler(
        network=encoder_network,
        initializer=tf.keras.initializers.TruncatedNormal(stddev=0.01))

    # Load model weights.
    if FLAGS.model_checkpoint is not None:
        checkpoint_dict = {'model': model}
        checkpoint = tf.train.Checkpoint(**checkpoint_dict)
        checkpoint.restore(
            FLAGS.model_checkpoint).assert_existing_objects_matched()

    model_for_serving = build_model_for_serving(model, FLAGS.sequence_length,
                                                FLAGS.batch_size)
    model_for_serving.summary()

    # TODO(b/194449109): Need to save the model to file and then convert tflite
    # with 'tf.lite.TFLiteConverter.from_saved_model()' to get the expected
    # accuracy
    tmp_dir = tempfile.TemporaryDirectory().name
    model_for_serving.save(tmp_dir)

    def _representative_dataset():
        dataset_params = question_answering_dataloader.QADataConfig()
        dataset_params.input_path = SQUAD_TRAIN_SPLIT
        dataset_params.drop_remainder = False
        dataset_params.global_batch_size = 1
        dataset_params.is_training = True

        dataset = orbit.utils.make_distributed_dataset(
            tf.distribute.get_strategy(), build_inputs, dataset_params)
        for example in dataset.take(100):
            inputs = example[0]
            input_word_ids = inputs['input_word_ids']
            input_mask = inputs['input_mask']
            input_type_ids = inputs['input_type_ids']
            yield [input_word_ids, input_mask, input_type_ids]

    converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
    if FLAGS.quantization_method in ['full-integer', 'hybrid']:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
    if FLAGS.quantization_method in ['full-integer']:
        converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS_INT8
        ]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.float32
        converter.representative_dataset = _representative_dataset

    tflite_quant_model = converter.convert()
    export_model_path = os.path.join(FLAGS.export_path, 'model.tflite')
    with tf.io.gfile.GFile(export_model_path, 'wb') as f:
        f.write(tflite_quant_model)
    logging.info('Successfully save the tflite to %s', FLAGS.export_path)
Example #15
0
def main(argv):  # pylint: disable=too-many-locals,too-many-branches,too-many-statements
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Must be called before KubernetesApiManager or GcpApiManager init.
    xds_flags.set_socket_default_timeout_from_flag()

    command = _CMD.value
    security_mode = _SECURITY.value

    project: str = xds_flags.PROJECT.value
    network: str = xds_flags.NETWORK.value

    # Resource names.
    resource_prefix: str = xds_flags.RESOURCE_PREFIX.value
    resource_suffix: str = xds_flags.RESOURCE_SUFFIX.value

    # Test server
    server_name = xds_flags.SERVER_NAME.value
    server_port = xds_flags.SERVER_PORT.value
    server_maintenance_port = xds_flags.SERVER_MAINTENANCE_PORT.value
    server_xds_host = xds_flags.SERVER_XDS_HOST.value
    server_xds_port = xds_flags.SERVER_XDS_PORT.value
    server_namespace = _KubernetesServerRunner.make_namespace_name(
        resource_prefix, resource_suffix)

    gcp_api_manager = gcp.api.GcpApiManager()

    if security_mode is None:
        td = traffic_director.TrafficDirectorManager(
            gcp_api_manager,
            project=project,
            network=network,
            resource_prefix=resource_prefix,
            resource_suffix=resource_suffix)
    else:
        td = traffic_director.TrafficDirectorSecureManager(
            gcp_api_manager,
            project=project,
            network=network,
            resource_prefix=resource_prefix,
            resource_suffix=resource_suffix)
        if server_maintenance_port is None:
            server_maintenance_port = \
                _KubernetesServerRunner.DEFAULT_SECURE_MODE_MAINTENANCE_PORT

    try:
        if command in ('create', 'cycle'):
            logger.info('Create mode')
            if security_mode is None:
                logger.info('No security')
                td.setup_for_grpc(server_xds_host,
                                  server_xds_port,
                                  health_check_port=server_maintenance_port)

            elif security_mode == 'mtls':
                logger.info('Setting up mtls')
                td.setup_for_grpc(server_xds_host,
                                  server_xds_port,
                                  health_check_port=server_maintenance_port)
                td.setup_server_security(server_namespace=server_namespace,
                                         server_name=server_name,
                                         server_port=server_port,
                                         tls=True,
                                         mtls=True)
                td.setup_client_security(server_namespace=server_namespace,
                                         server_name=server_name,
                                         tls=True,
                                         mtls=True)

            elif security_mode == 'tls':
                logger.info('Setting up tls')
                td.setup_for_grpc(server_xds_host,
                                  server_xds_port,
                                  health_check_port=server_maintenance_port)
                td.setup_server_security(server_namespace=server_namespace,
                                         server_name=server_name,
                                         server_port=server_port,
                                         tls=True,
                                         mtls=False)
                td.setup_client_security(server_namespace=server_namespace,
                                         server_name=server_name,
                                         tls=True,
                                         mtls=False)

            elif security_mode == 'plaintext':
                logger.info('Setting up plaintext')
                td.setup_for_grpc(server_xds_host,
                                  server_xds_port,
                                  health_check_port=server_maintenance_port)
                td.setup_server_security(server_namespace=server_namespace,
                                         server_name=server_name,
                                         server_port=server_port,
                                         tls=False,
                                         mtls=False)
                td.setup_client_security(server_namespace=server_namespace,
                                         server_name=server_name,
                                         tls=False,
                                         mtls=False)

            elif security_mode == 'mtls_error':
                # Error case: server expects client mTLS cert,
                # but client configured only for TLS
                logger.info('Setting up mtls_error')
                td.setup_for_grpc(server_xds_host,
                                  server_xds_port,
                                  health_check_port=server_maintenance_port)
                td.setup_server_security(server_namespace=server_namespace,
                                         server_name=server_name,
                                         server_port=server_port,
                                         tls=True,
                                         mtls=True)
                td.setup_client_security(server_namespace=server_namespace,
                                         server_name=server_name,
                                         tls=True,
                                         mtls=False)

            elif security_mode == 'server_authz_error':
                # Error case: client does not authorize server
                # because of mismatched SAN name.
                logger.info('Setting up mtls_error')
                td.setup_for_grpc(server_xds_host,
                                  server_xds_port,
                                  health_check_port=server_maintenance_port)
                # Regular TLS setup, but with client policy configured using
                # intentionality incorrect server_namespace.
                td.setup_server_security(server_namespace=server_namespace,
                                         server_name=server_name,
                                         server_port=server_port,
                                         tls=True,
                                         mtls=False)
                td.setup_client_security(
                    server_namespace=f'incorrect-namespace-{rand.rand_string()}',
                    server_name=server_name,
                    tls=True,
                    mtls=False)

            logger.info('Works!')
    except Exception:  # noqa pylint: disable=broad-except
        logger.exception('Got error during creation')

    if command in ('cleanup', 'cycle'):
        logger.info('Cleaning up')
        td.cleanup(force=True)

    if command == 'backends-add':
        logger.info('Adding backends')
        k8s_api_manager = k8s.KubernetesApiManager(
            xds_k8s_flags.KUBE_CONTEXT.value)
        k8s_namespace = k8s.KubernetesNamespace(k8s_api_manager,
                                                server_namespace)

        neg_name, neg_zones = k8s_namespace.get_service_neg(
            server_name, server_port)

        td.load_backend_service()
        td.backend_service_add_neg_backends(neg_name, neg_zones)
        td.wait_for_backends_healthy_status()
    elif command == 'backends-cleanup':
        td.load_backend_service()
        td.backend_service_remove_all_backends()
    elif command == 'unused-xds-port':
        try:
            unused_xds_port = td.find_unused_forwarding_rule_port()
            logger.info('Found unused forwarding rule port: %s',
                        unused_xds_port)
        except Exception:  # noqa pylint: disable=broad-except
            logger.exception("Couldn't find unused forwarding rule port")
Example #16
0
def main(argv: typing.List[str]):
    """Main entry point."""
    if len(argv) > 1:
        raise app.UsageError("Unknown arguments: '{}'.".format(' '.join(
            argv[1:])))
    sys.exit(pytest.main([__file__, '-vv']))
Example #17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    hypebot = HypeBot(FLAGS.params)
    hypebot.interface.Loop()
Example #18
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")
    run()
Example #19
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Enable training summary.
    if FLAGS.train_summary_steps > 0:
        tf.config.set_soft_device_placement(True)

    builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
    builder.download_and_prepare()
    num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
    num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
    num_classes = builder.info.features['label'].num_classes

    train_steps = model_util.get_train_steps(num_train_examples)
    eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
    epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

    resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
    model = resnet.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                             width_multiplier=FLAGS.width_multiplier,
                             cifar_stem=FLAGS.image_size <= 32)

    checkpoint_steps = (FLAGS.checkpoint_steps
                        or (FLAGS.checkpoint_epochs * epoch_steps))

    cluster = None
    if FLAGS.use_tpu and FLAGS.master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(cluster)
            tf.tpu.experimental.initialize_tpu_system(cluster)

    default_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
    sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
    run_config = tf.estimator.tpu.RunConfig(
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=checkpoint_steps,
            eval_training_input_configuration=sliced_eval_mode
            if FLAGS.use_tpu else default_eval_mode),
        model_dir=FLAGS.model_dir,
        save_summary_steps=checkpoint_steps,
        save_checkpoints_steps=checkpoint_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        master=FLAGS.master,
        cluster=cluster)
    estimator = tf.estimator.tpu.TPUEstimator(
        model_lib.build_model_fn(model, num_classes, num_train_examples),
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        use_tpu=FLAGS.use_tpu)

    if FLAGS.mode == 'eval':
        for ckpt in tf.train.checkpoints_iterator(run_config.model_dir,
                                                  min_interval_secs=15):
            try:
                result = perform_evaluation(estimator=estimator,
                                            input_fn=data_lib.build_input_fn(
                                                builder, False),
                                            eval_steps=eval_steps,
                                            model=model,
                                            num_classes=num_classes,
                                            checkpoint_path=ckpt)
            except tf.errors.NotFoundError:
                continue
            if result['global_step'] >= train_steps:
                return
    else:
        estimator.train(data_lib.build_input_fn(builder, True),
                        max_steps=train_steps)
        if FLAGS.mode == 'train_then_eval':
            perform_evaluation(estimator=estimator,
                               input_fn=data_lib.build_input_fn(
                                   builder, False),
                               eval_steps=eval_steps,
                               model=model,
                               num_classes=num_classes)
Example #20
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    train()
Example #21
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    plot_csv_data(FLAGS.fname)
Example #22
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  create_vggish_frozen_graph()
Example #23
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)

    shared_args = utils_impl.lookup_flag_values(shared_flags)
    shared_args['iterative_process_builder'] = iterative_process_builder
    task_args = _get_task_args()
    hparam_dict = _get_hparam_flags()

    if FLAGS.task == 'cifar100':
        run_federated_fn = federated_cifar100.run_federated
    elif FLAGS.task == 'emnist_cr':
        run_federated_fn = federated_emnist.run_federated
    elif FLAGS.task == 'emnist_ae':
        run_federated_fn = federated_emnist_ae.run_federated
    elif FLAGS.task == 'shakespeare':
        run_federated_fn = federated_shakespeare.run_federated
    elif FLAGS.task == 'stackoverflow_nwp':
        run_federated_fn = federated_stackoverflow.run_federated
    elif FLAGS.task == 'stackoverflow_lr':
        run_federated_fn = federated_stackoverflow_lr.run_federated
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    run_federated_fn(**shared_args, **task_args, hparam_dict=hparam_dict)
Example #24
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model]
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """
        if FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp':

            def client_weight_fn(local_outputs):
                return tf.cast(tf.squeeze(local_outputs['num_tokens']),
                               tf.float32)
        else:
            client_weight_fn = None

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
Example #25
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Create datasets for training global and personalized models.
  federated_train_data, federated_p13n_data = _get_emnist_datasets()

  def model_fn():
    """Build a `tff.learning.Model` for training EMNIST."""
    keras_model = emnist_models.create_conv_dropout_model(only_digits=False)
    return tff.learning.from_keras_model(
        keras_model=keras_model,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        input_spec=federated_train_data[0].element_spec,
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  # Build a standard federated averaging process for training the global model.
  client_opt = lambda: tf.keras.optimizers.SGD(learning_rate=0.02)
  server_opt = lambda: tf.keras.optimizers.SGD(learning_rate=1.0, momentum=0.9)
  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=model_fn,
      client_optimizer_fn=client_opt,
      server_optimizer_fn=server_opt)

  # Initialize the server state of the FedAvg process.
  server_state = iterative_process.initialize()

  # Create a dictionary of two personalization strategies: one uses SGD while
  # the other one uses Adam optimizer to train a personalized model.
  #
  # Here `personalize_fn_dict` is an `OrderedDict` that maps a strategy name to
  # a no-argument function which, when being called, returns a `tf.function`
  # that represents a personalization strategy. Customers can define arbitrary
  # personalization strategis (i.e., `tf.function`s) that take a
  # `tff.learning.Model`, an unbatched `tf.data.Dataset` for train, an unbatched
  # `tf.data.Dataset` for test, (and an extra `context` object), and returns the
  # personalization metrics (see `build_personalize_fn` for an example).
  personalize_fn_dict = collections.OrderedDict()
  sgd_opt = lambda: tf.keras.optimizers.SGD(learning_rate=0.02)
  personalize_fn_dict['sgd'] = functools.partial(
      build_personalize_fn,
      optimizer_fn=sgd_opt,
      train_batch_size=20,
      max_num_epochs=10,
      num_epochs_per_eval=1,
      test_batch_size=20)
  adam_opt = lambda: tf.keras.optimizers.Adam(learning_rate=0.02)
  personalize_fn_dict['adam'] = functools.partial(
      build_personalize_fn,
      optimizer_fn=adam_opt,
      train_batch_size=20,
      max_num_epochs=10,
      num_epochs_per_eval=1,
      test_batch_size=20)

  # Build the `tff.Computation` for evaluating the personalization strategies.
  # Here `p13n_eval` is a `tff.Computation` with the following type signature:
  # <model_weights@SERVER, datasets@CLIENTS> -> personalization_metrics@SERVER.
  p13n_eval = tff.learning.build_personalization_eval(
      model_fn=model_fn,
      personalize_fn_dict=personalize_fn_dict,
      baseline_evaluate_fn=functools.partial(evaluate_fn, batch_size=10),
      max_num_samples=900)  # Metrics from all p13n clients will be returned.

  # Start the training loop.
  for round_idx in range(1, FLAGS.num_total_rounds + 1):
    sampled_train_data = list(
        np.random.choice(
            federated_train_data, FLAGS.num_clients_per_round, replace=False))
    server_state, _ = iterative_process.next(server_state, sampled_train_data)

    if round_idx % FLAGS.num_rounds_per_p13n_eval == 0:
      # Invoke the constructed `tff.Computation`. Below we run `p13n_eval` for
      # 18 rounds with 50 clients per round. This will take some time to finish.
      # The returned `p13n_metrics` is a nested dictionary that stores all the
      # personalization metrics from the clients.
      p13n_metrics = p13n_eval(server_state.model, federated_p13n_data[:50])
      p13n_metrics = p13n_metrics._asdict(recursive=True)  # Convert to a dict.
      for i in range(1, 18):
        current_p13n_metrics = p13n_eval(
            server_state.model,
            federated_p13n_data[i * 50:(i + 1) * 50])._asdict(recursive=True)

        p13n_metrics = tf.nest.map_structure(
            lambda a, b: tf.concat([a, b], axis=0), p13n_metrics,
            current_p13n_metrics)
      # Specifically, `p13n_metrics` is an `OrderedDict` that maps
      # key 'baseline_metrics' to the evaluation metrics of the initial global
      # model (computed by `baseline_evaluate_fn` argument in `p13n_eval`), and
      # maps keys (strategy names) in `personalize_fn_dict` to the evaluation
      # metrics of the corresponding personalization strategies.
      #
      # Only metrics from at most `max_num_samples` participating clients are
      # collected (clients are sampled without replacement). Each metric is
      # mapped to a list of scalars (each scalar comes from one client). Metric
      # values at the same position, e.g., metric_1[i], metric_2[i]..., come
      # from the same client.
      #
      # Users can save `p13n_metrics` to file for further analysis. For
      # simplcity, we extract and print two values here:
      # 1. mean accuracy of the initial global model;
      # 2. mean accuracy of the personalized models obtained at Epoch 1.
      print('Current Round {}'.format(round_idx))

      global_model_accuracies = np.array(
          p13n_metrics['baseline_metrics']['sparse_categorical_accuracy'])
      print('Mean accuracy of the global model: {}'.format(
          np.mean(global_model_accuracies).item()))

      personalized_models_accuracies = np.array(
          p13n_metrics['sgd']['epoch_1']['sparse_categorical_accuracy'])
      print('Mean accuracy of the personalized models at Epoch 1: {}'.format(
          np.mean(personalized_models_accuracies).item()))
Example #26
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    for split in ['train', 'val', 'test']:
        _generate_data(split)
Example #27
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # This seems to be necessary even when importing TF2?
    tf.enable_v2_behavior()

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = random.split(rng)
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    # call a jitted initialization function to get the initial parameter tree
    @jax.jit
    def initialize_variables(rng):
        return models.Transformer(eval_config).init(
            rng, jnp.ones(input_shape, jnp.float32),
            jnp.ones(target_shape, jnp.float32))

    initial_variables = initialize_variables(init_rng)

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['param'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing),
                            axis_name='batch',
                            donate_argnums=(0, ))
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=FLAGS.label_smoothing),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=FLAGS.max_predict_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info('Starting training loop.')
    metrics_all = []
    t_loop_start = time.time()
    for step, batch in zip(range(start_step, FLAGS.num_train_steps),
                           train_iter):
        # Shard data to devices and do a training step.
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a checkpoint on one host after every checkpoint_freq steps.
        if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0
                and step > 0 and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir,
                                        jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if step % FLAGS.eval_frequency != 0 and step > 0:
            continue

        # Training Metrics
        logging.info('Gathering training metrics.')
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        steps_per_eval = FLAGS.eval_frequency if step != 0 else 1
        steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
        t_loop_start = time.time()
        if jax.host_id() == 0:
            train_summary_writer.scalar('steps per second', steps_per_sec,
                                        step)
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        metrics_all = []
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])

        # Eval Metrics
        logging.info('Gathering evaluation metrics.')
        t_eval_start = time.time()
        eval_metrics = []
        eval_iter = iter(eval_ds)
        for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            eval_batch = common_utils.shard(eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        if jax.host_id() == 0:
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
        logging.info('eval in step: %d, loss: %.4f', step,
                     eval_summary['loss'])
        logging.info('eval time: %.4f s step %d',
                     time.time() - t_eval_start, step)

        # Translation and BLEU Score.
        logging.info('Translating evaluation dataset.')
        t_inference_start = time.time()
        predict_iter = iter(predict_ds)
        sources, references, predictions = [], [], []
        for _, pred_batch in enumerate(predict_iter):
            pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
            pred_batch = common_utils.shard(pred_batch)
            cache = p_init_cache(pred_batch['inputs'])
            predicted = p_pred_step(pred_batch['inputs'], optimizer.target,
                                    cache, eos_id, FLAGS.max_predict_length)
            predicted = tohost(predicted)
            inputs = tohost(pred_batch['inputs'])
            targets = tohost(pred_batch['targets'])
            # Iterate through non-padding examples of batch.
            for i, s in enumerate(predicted[:cur_pred_batch_size]):
                sources.append(decode_tokens(inputs[i]))
                references.append(decode_tokens(targets[i]))
                predictions.append(decode_tokens(s))
        logging.info('Translation: %d predictions %d references %d sources.',
                     len(predictions), len(references), len(sources))
        logging.info('Translation time: %.4f s step %d.',
                     time.time() - t_inference_start, step)

        # Calculate BLEU score for translated eval corpus against reference.
        bleu_matches = bleu.bleu_partial(references, predictions)
        all_bleu_matches = per_host_sum_pmap(bleu_matches)
        bleu_score = bleu.complete_bleu(*all_bleu_matches)
        # Save translation samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
            exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n'
        if jax.host_id() == 0:
            eval_summary_writer.scalar('bleu', bleu_score, step)
            eval_summary_writer.text('samples', exemplars, step)
            eval_summary_writer.flush()
        logging.info('Translation BLEU Score %.4f', bleu_score)
Example #28
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_devices = tf.config.list_logical_devices('GPU')
    server_device = tf.config.list_logical_devices('CPU')[0]
    tff.backends.native.set_local_execution_context(
        max_fanout=2 * FLAGS.clients_per_round,
        server_tf_device=server_device,
        client_tf_devices=client_devices,
        clients_per_thread=FLAGS.clients_per_thread)

    logging.info('Show FLAGS for debugging:')
    for f in HPARAM_FLAGS:
        logging.info('%s=%s', f, FLAGS[f].value)

    train_data, test_data = _get_emnist_dataset(
        FLAGS.only_digits,
        FLAGS.client_epochs_per_round,
        FLAGS.client_batch_size,
    )

    def tff_model_fn():
        keras_model = _create_original_fedavg_cnn_model(FLAGS.only_digits)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
        return dp_fedavg.KerasModelWrapper(keras_model, test_data.element_spec,
                                           loss)

    noise_std = FLAGS.clip_norm * FLAGS.noise_multiplier / float(
        FLAGS.clients_per_round)
    server_optimizer_fn = functools.partial(_server_optimizer_fn,
                                            name=FLAGS.server_optimizer,
                                            learning_rate=FLAGS.server_lr,
                                            noise_std=noise_std)
    client_optimizer_fn = functools.partial(_client_optimizer_fn,
                                            name=FLAGS.client_optimizer,
                                            learning_rate=FLAGS.client_lr)
    iterative_process = dp_fedavg.build_federated_averaging_process(
        tff_model_fn,
        dp_clip_norm=FLAGS.clip_norm,
        server_optimizer_fn=server_optimizer_fn,
        client_optimizer_fn=client_optimizer_fn)

    keras_metics = [tf.keras.metrics.SparseCategoricalAccuracy()]
    model = tff_model_fn()

    def evaluate_fn(model_weights, dataset):
        model.from_weights(model_weights)
        metrics = dp_fedavg.keras_evaluate(model.keras_model, dataset,
                                           keras_metics)
        return collections.OrderedDict(
            (metric.name, metric.result().numpy()) for metric in metrics)

    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in HPARAM_FLAGS])
    total_epochs = 0 if FLAGS.total_epochs is None else FLAGS.total_epochs
    training_loop.run(iterative_process,
                      client_datasets_fn=_get_client_datasets_fn(train_data),
                      validation_fn=functools.partial(evaluate_fn,
                                                      dataset=test_data),
                      total_rounds=FLAGS.total_rounds,
                      total_epochs=total_epochs,
                      experiment_name=FLAGS.experiment_name,
                      train_eval_fn=None,
                      test_fn=functools.partial(evaluate_fn,
                                                dataset=test_data),
                      root_output_dir=FLAGS.root_output_dir,
                      hparam_dict=hparam_dict,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_train_eval=2000)
Example #29
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")
  else:
    for n_string in FLAGS.N:
      n_bit_mul_new_file(int(n_string))
Example #30
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    config = FLAGS.config

    input_files = sum([glob.glob(pattern) for pattern in config.input_files],
                      [])
    assert input_files, "No input files!"
    print(f"Training with {len(input_files)} input files, including:")
    print(f" - {input_files[0]}")

    model = modeling.BertForPreTraining(config=config.model)
    initial_params = get_initial_params(model,
                                        init_checkpoint=config.init_checkpoint)
    optimizer = create_optimizer(config, initial_params)
    del initial_params  # the optimizer takes ownership of all params

    output_dir = get_output_dir(config)
    gfile.makedirs(output_dir)

    # Restore from a local checkpoint, if one exists.
    optimizer = checkpoints.restore_checkpoint(output_dir, optimizer)
    if isinstance(optimizer.state, (list, tuple)):
        start_step = int(optimizer.state[0].step)
    else:
        start_step = int(optimizer.state.step)

    optimizer = optimizer.replicate()
    optimizer = training.harmonize_across_hosts(optimizer)

    data_pipeline = data.PretrainingDataPipeline(
        sum([glob.glob(pattern) for pattern in config.input_files], []),
        config.tokenizer,
        max_seq_length=config.max_seq_length,
        max_predictions_per_seq=config.max_predictions_per_seq,
    )

    learning_rate_fn = training.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        steps_per_cycle=config.num_train_steps - config.num_warmup_steps,
    )

    train_history = training.TrainStateHistory(learning_rate_fn)
    train_state = train_history.initial_state()

    if config.do_train:
        train_batch_size = config.train_batch_size
        if jax.host_count() > 1:
            assert (train_batch_size % jax.host_count() == 0
                    ), "train_batch_size must be divisible by number of hosts"
            train_batch_size = train_batch_size // jax.host_count()
        train_iter = data_pipeline.get_inputs(batch_size=train_batch_size,
                                              training=True)
        train_step_fn = training.create_train_step(
            model,
            compute_pretraining_loss_and_metrics,
            max_grad_norm=config.max_grad_norm,
        )

        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            optimizer, train_state = train_step_fn(optimizer, batch,
                                                   train_state)
            if jax.host_id() == 0 and (step % config.save_checkpoints_steps
                                       == 0
                                       or step == config.num_train_steps - 1):
                checkpoints.save_checkpoint(output_dir,
                                            optimizer.unreplicate(), step)
                config_path = os.path.join(output_dir, "config.json")
                if not os.path.exists(config_path):
                    with open(config_path, "w") as f:
                        json.dump({"model_type": "bert", **config.model}, f)
                tokenizer_path = os.path.join(output_dir,
                                              "sentencepiece.model")
                if not os.path.exists(tokenizer_path):
                    shutil.copy(config.tokenizer, tokenizer_path)

        # With the current Rust data pipeline code, running more than one pipeline
        # at a time will lead to a hang. A simple workaround is to fully delete the
        # training pipeline before potentially starting another for evaluation.
        del train_iter

    if config.do_eval:
        eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size)
        eval_iter = itertools.islice(eval_iter, config.max_eval_steps)
        eval_fn = training.create_eval_fn(model,
                                          compute_pretraining_stats,
                                          sample_feature_name="input_ids")
        eval_stats = eval_fn(optimizer, eval_iter)

        eval_metrics = {
            "loss":
            jnp.mean(eval_stats["loss"]),
            "masked_lm_loss":
            jnp.mean(eval_stats["masked_lm_loss"]),
            "next_sentence_loss":
            jnp.mean(eval_stats["next_sentence_loss"]),
            "masked_lm_accuracy":
            jnp.sum(eval_stats["masked_lm_correct"]) /
            jnp.sum(eval_stats["masked_lm_total"]),
            "next_sentence_accuracy":
            jnp.sum(eval_stats["next_sentence_correct"]) /
            jnp.sum(eval_stats["next_sentence_total"]),
        }

        eval_results = []
        for name, val in sorted(eval_metrics.items()):
            line = f"{name} = {val:.06f}"
            print(line, flush=True)
            eval_results.append(line)

        eval_results_path = os.path.join(output_dir, "eval_results.txt")
        with gfile.GFile(eval_results_path, "w") as f:
            for line in eval_results:
                f.write(line + "\n")