Example #1
0
    def test_dump_dict_json(self):
        """Tests JSON dumping function."""
        data_dict = {
            'np_float': np.dtype('float32').type(1.0),
            'jnp_float': jnp.dtype('float32').type(1.0),
            'np_int': np.dtype('int32').type(1),
            'jnp_int': jnp.dtype('int32').type(1),
            'np_array': np.array(1.0, dtype=np.float32),
            'jnp_array': jnp.array(1.0, dtype=jnp.float32),
        }
        converted_dict = {
            key: utils._np_converter(value)
            for key, value in data_dict.items()
        }
        json_path = tempfile.NamedTemporaryFile()
        utils.dump_dict_json(data_dict, json_path.name)

        with open(json_path.name, 'r') as input_file:
            loaded_dict = json.load(input_file)
        self.assertDictEqual(loaded_dict, converted_dict)
Example #2
0
    trainer = training.Trainer(
        optimizer, initial_model, initial_state, dataset, rng)

  _, best_metrics = trainer.train(
      FLAGS.epochs,
      lr_fn=lr_fn,
      pruning_rate_fn=pruning_rate_fn,
      update_iter=FLAGS.update_iterations,
      update_epoch=FLAGS.update_epoch,
  )

  logging.info('Best metrics: %s', str(best_metrics))

  if jax.host_id() == 0:
    if FLAGS.dump_json:
      utils.dump_dict_json(best_metrics,
                           path.join(experiment_dir, 'best_metrics.json'))

    for label, value in best_metrics.items():
      summary_writer.scalar(f'best/{label}', value,
                            FLAGS.epochs * steps_per_epoch)
    summary_writer.close()


def main(argv: List[str]):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  run_training()

if __name__ == '__main__':
  app.run(main)
Example #3
0
    mask_stats = symmetry.get_mask_stats(mask)
    logging.info('Mask stats: %s', str(mask_stats))


    for label, value in mask_stats.items():
      try:
        summary_writer.scalar(f'mask/{label}', value, 0)
      # This is needed because permutations (long int) can't be cast to float32.
      except (OverflowError, ValueError):
        summary_writer.text(f'mask/{label}', str(value), 0)
        logging.error('Could not write mask/%s to tensorflow summary as float32'
                      ', writing as string instead.', label)

    if FLAGS.dump_json:
      mask_stats['permutations'] = str(mask_stats['permutations'])
      utils.dump_dict_json(
          mask_stats, path.join(experiment_dir, 'mask_stats.json'))

  mask = masked.propagate_masks(mask)

  if jax.host_id() == 0:
    mask_stats = symmetry.get_mask_stats(mask)
    logging.info('Propagated mask stats: %s', str(mask_stats))


    for label, value in mask_stats.items():
      try:
        summary_writer.scalar(f'propagated_mask/{label}', value, 0)
      # This is needed because permutations (long int) can't be cast to float32.
      except (OverflowError, ValueError):
        summary_writer.text(f'propagated_mask/{label}', str(value), 0)
        logging.error('Could not write mask/%s to tensorflow summary as float32'