def setUp(self): super().setUp() self._batch_size = 128 # Note: Tests are run on GPU/TPU. self._batch_size_test = 128 self._shuffle_buffer_size = 1024 self._rng = jax.random.PRNGKey(42) self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32) self._num_classes = 10 self._num_epochs = 1 self._learning_rate_fn = lambda _: 0.01 self._weight_decay = 0.0001 self._momentum = 0.9 self._rng = jax.random.PRNGKey(42) self._min_loss = jnp.finfo(float).eps self._max_loss = 2.0 * math.log(self._num_classes) self._dataset_name = 'MNIST' self._model_name = 'MNIST_CNN' self._summarywriter = tensorboard.SummaryWriter('/tmp/') self._dataset = dataset_factory.create_dataset( self._dataset_name, self._batch_size, self._batch_size_test, shuffle_buffer_size=self._shuffle_buffer_size) self._model, self._state = model_factory.create_model( self._model_name, self._rng, (self._input_shape, ), num_classes=self._num_classes) self._optimizer = flax.optim.Momentum( learning_rate=self._learning_rate_fn(0), beta=self._momentum, weight_decay=self._weight_decay)
if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape base_model, _ = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes) initial_model, initial_state = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes, masked_layer_indices=FLAGS.masked_layer_indices) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam( learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum( learning_rate=FLAGS.lr, beta=FLAGS.momentum,
if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape base_model, _ = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes) logging.info('Generating random mask based on model') # Re-initialize the RNG to maintain same training pattern (as in prune code). mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed) mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng, FLAGS.mask_sparsity) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Mask stats: %s', str(mask_stats))
def run_training(): """Trains a model.""" print('Logging to {}'.format(FLAGS.log_dir)) work_unit_id = uuid.uuid4() experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1, ) + dataset.shape model, initial_state = model_factory.create_model( FLAGS.model, rng, ((input_shape, np.float32), ), num_classes=dataset.num_classes) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam(learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum(learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == 'constant': lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == 'stepped': lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == 'cosine': lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError('Unknown LR schedule type {}'.format( FLAGS.lr_schedule)) if jax.host_id() == 0: trainer = training.Trainer( optimizer, model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer(optimizer, model, initial_state, dataset, rng) _, best_metrics = trainer.train(FLAGS.epochs, lr_fn=lr_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: for label, value in best_metrics.items(): summary_writer.scalar('best/{}'.format(label), value, FLAGS.epochs * steps_per_epoch) summary_writer.close()
def _create_model(self, model_name): return model_factory.create_model(model_name, self._rng, (self._input_shape, ), num_classes=self._num_classes)
rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape input_len = functools.reduce(operator.mul, dataset.shape) features = mnist_fc.feature_dim_for_param( input_len, FLAGS.param_count, FLAGS.depth) logging.info('Model Configuration: %s', str(features)) base_model, _ = model_factory.create_model( MODEL, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes, features=features) model_param_count = utils.count_param(base_model, ('kernel',)) logging.info( 'Model Config: param.: %d, depth: %d. max width: %d, min width: %d', model_param_count, len(features), max(features), min(features)) logging.info('Generating random mask based on model') # Re-initialize the RNG to maintain same training pattern (as in prune code). mask_rng = jax.random.PRNGKey(FLAGS.random_seed) mask = masked.shuffled_mask( base_model,