def test_dump_dict_json(self): """Tests JSON dumping function.""" data_dict = { 'np_float': np.dtype('float32').type(1.0), 'jnp_float': jnp.dtype('float32').type(1.0), 'np_int': np.dtype('int32').type(1), 'jnp_int': jnp.dtype('int32').type(1), 'np_array': np.array(1.0, dtype=np.float32), 'jnp_array': jnp.array(1.0, dtype=jnp.float32), } converted_dict = { key: utils._np_converter(value) for key, value in data_dict.items() } json_path = tempfile.NamedTemporaryFile() utils.dump_dict_json(data_dict, json_path.name) with open(json_path.name, 'r') as input_file: loaded_dict = json.load(input_file) self.assertDictEqual(loaded_dict, converted_dict)
trainer = training.Trainer( optimizer, initial_model, initial_state, dataset, rng) _, best_metrics = trainer.train( FLAGS.epochs, lr_fn=lr_fn, pruning_rate_fn=pruning_rate_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch, ) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: if FLAGS.dump_json: utils.dump_dict_json(best_metrics, path.join(experiment_dir, 'best_metrics.json')) for label, value in best_metrics.items(): summary_writer.scalar(f'best/{label}', value, FLAGS.epochs * steps_per_epoch) summary_writer.close() def main(argv: List[str]): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') run_training() if __name__ == '__main__': app.run(main)
mask_stats = symmetry.get_mask_stats(mask) logging.info('Mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label) if FLAGS.dump_json: mask_stats['permutations'] = str(mask_stats['permutations']) utils.dump_dict_json( mask_stats, path.join(experiment_dir, 'mask_stats.json')) mask = masked.propagate_masks(mask) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Propagated mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'propagated_mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'propagated_mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32'