class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop(LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ('adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update( self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
def test_optimizer_serialization(self): rng = random.PRNGKey(0) module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) model = nn.Model(module, initial_params) optim_def = optim.Momentum(learning_rate=1.) optimizer = optim_def.create(model) state = serialization.to_state_dict(optimizer) expected_state = { 'target': { 'params': { 'kernel': onp.ones((1, 1)), 'bias': onp.zeros((1, )), } }, 'state': { 'step': 0, 'param_states': { 'params': { 'kernel': { 'momentum': onp.zeros((1, 1)) }, 'bias': { 'momentum': onp.zeros((1, )) }, } } }, } self.assertEqual(state, expected_state) state = jax.tree_map(lambda x: x + 1, expected_state) restored_optimizer = serialization.from_state_dict(optimizer, state) optimizer_plus1 = jax.tree_map(lambda x: x + 1, optimizer) self.assertEqual(restored_optimizer, optimizer_plus1)
def fit_model_test(): # We just check we can run the functions and that they return "something" N = 3 D = 5 C = 10 model = ModelTest(nhidden=0, nclasses=C) rng = jax.random.PRNGKey(0) X = np.random.randn(N, D) y = np.random.choice(C, size=N, p=(1 / C) * np.ones(C)) batch = {'X': X, 'y': y} params = model.init(rng, X)['params'] metrics = eval_batch(model, params, batch) make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9) optimizer = make_optimizer.create(params) optimizer, metrics = train_batch(model, optimizer, batch) #print(optimizer) num_steps = 2 def make_iter(): while True: yield batch train_iter = make_iter() test_iter = make_iter() params, history = fit_model(model, train_iter, test_iter, rng, num_steps, make_optimizer, train_batch, eval_batch, print_every=1) print('test passed')
def create_optimizer(params, learning_rate, beta, weight_decay): """Returns a momentum optimizer.""" optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta, weight_decay=weight_decay) optimizer = optimizer_def.create(params) return optimizer
def test_empty_optimizer(self): params = {} optimizer_def = optim.Momentum(learning_rate=0.1) optimizer = optimizer_def.create(params) new_optimizer = optimizer.apply_gradient({}) expected_state = optim.OptimizerState(1, {}) self.assertEqual(new_optimizer.state, expected_state)
def create_optimizer(model: flax.nn.Model, learning_rate: float, beta: float = 0.9) -> flax.optim.Optimizer: """Creates an optimizer. Learning rate will be ignored when using a learning rate schedule. Args: model: The FLAX model to optimize. learning_rate: Learning rate for the gradient descent. beta: Momentum parameter. Returns: A SGD (or RMSProp) optimizer that targets the model. """ if FLAGS.use_rmsprop: # We set beta2 and epsilon to the values used in the efficientnet paper. optimizer_def = efficientnet_optim.RMSProp( learning_rate=learning_rate, beta=beta, beta2=0.9, eps=0.001) else: optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta, nesterov=True) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(model, learning_rate, beta): optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta, nesterov=True) optimizer = optimizer_def.create(model) optimizer = jax_utils.replicate(optimizer) return optimizer
def test_init_state(self): params = onp.zeros((1,)) optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) state = optimizer_def.init_state(params) expected_hyper_params = _MomentumHyperParams(0.1, 0.2, 0, False) self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) expected_state = optim.OptimizerState( 0, _MomentumParamState(onp.zeros((1,)))) self.assertEqual(state, expected_state)
def test_create(self): params = onp.ones((1,)) optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) optimizer = optimizer_def.create(params) expected_state = optim.OptimizerState( 0, _MomentumParamState(onp.zeros((1,)))) self.assertEqual(optimizer.optimizer_def, optimizer_def) self.assertEqual(optimizer.state, expected_state) self.assertEqual(optimizer.target, params)
def test_optimizer_serialization_to_bytes(self): rng = random.PRNGKey(0) module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) model = nn.Model(module, initial_params) optim_def = optim.Momentum(learning_rate=1.) optimizer = optim_def.create(model) serialized_bytes = serialization.to_bytes(optimizer) restored_optimizer = serialization.from_bytes(optimizer, serialized_bytes) self.assertEqual(restored_optimizer, optimizer)
def test_apply_gradient(self): optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) params = np.ones((1, )) state = optim.OptimizerState(0, _MomentumParamState(np.array([1.]))) grads = np.array([3.]) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) expected_new_state = optim.OptimizerState( 1, _MomentumParamState(np.array([3.2]))) expected_new_params = np.array([1. - 0.32]) self.assertEqual(new_params, expected_new_params) self.assertEqual(new_state, expected_new_state)
def get_optimizer(hparams): """Constructs the optimizer from the given HParams. Args: hparams: Hyper parameters. Returns: A flax optimizer. """ if hparams.optimizer == 'sgd': return optimizers.GradientDescent( learning_rate=hparams.lr_hparams['initial_learning_rate']) if hparams.optimizer == 'nesterov': return optimizers.Momentum( learning_rate=hparams.lr_hparams['initial_learning_rate'], beta=hparams.opt_hparams.get('momentum', 0.9), weight_decay=hparams.opt_hparams.get('weight_decay', 0.0), nesterov=True) if hparams.optimizer == 'momentum': return optimizers.Momentum( learning_rate=hparams.lr_hparams['initial_learning_rate'], beta=hparams.opt_hparams.get('momentum', 0.9), weight_decay=hparams.opt_hparams.get('weight_decay', 0.0), nesterov=False) if hparams.optimizer == 'adam': return optimizers.Adam( learning_rate=hparams.lr_hparams['initial_learning_rate'], beta1=hparams.opt_hparams.get('beta1', 0.9), beta2=hparams.opt_hparams.get('beta2', 0.999), eps=hparams.opt_hparams.get('epsilon', 1e-8), weight_decay=hparams.opt_hparams.get('weight_decay', 0.0), ) if hparams.optimizer == 'rmsprop': return optimizers.RMSProp( learning_rate=hparams.lr_hparams.get('initial_learning_rate'), beta2=hparams.opt_hparams.get('beta2', 0.9), eps=hparams.opt_hparams.get('epsilon', 1e-8)) else: raise NotImplementedError('Optimizer {} not implemented'.format( hparams.optimizer))
def fit_model(model, rng, num_steps, train_iter, test_iter=None, train_fn=update_classifier, test_fn=eval_classifier, make_optimizer=None, preprocess_train_batch=None, preprocess_test_batch=None, print_every=1, test_every=None): batch = next(train_iter) if preprocess_train_batch is not None: batch = preprocess_train_batch(batch, rng) X = batch['X'] params = model.init(rng, X)['params'] if make_optimizer is None: make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9) optimizer = make_optimizer.create(params) history = { 'train_loss': [], 'train_accuracy': [], 'test_loss': [], 'test_accuracy': [] } if test_iter is None: test_every = 0 if test_every is None: test_every = print_every for step in range(num_steps): batch = next(train_iter) if preprocess_train_batch is not None: batch = preprocess_train_batch(batch, rng) optimizer, train_metrics = train_fn(model, optimizer, batch) if (print_every > 0) & (step % print_every == 0): print('train step: {:d}, loss: {:0.4f}, accuracy: {:0.2f}'.format( step, train_metrics['loss'], train_metrics['accuracy'])) if (test_every > 0) & (step % test_every == 0): batch = next(test_iter) if preprocess_test_batch is not None: batch = preprocess_test_batch(batch, rng) test_metrics = test_fn(model, optimizer.target, batch) history = append_history(history, train_metrics, test_metrics) params = optimizer.target return params, history
def make_optimizer(optimizer: str, learning_rate: float, weight_decay: float = 5.0e-4): if optimizer == 'SGD': return optim.Optimizer(lerning_rate=learning_rate) elif optimizer == 'Momentum': return optim.Momentum(learning_rate=learning_rate, weight_decay=weight_decay) elif optimizer == 'Adam': return optim.Adam(learning_rate=learning_rate, weight_decay=weight_decay) else: raise ValueError('Unknown optimizer spec.')
def test_compute_grad(self): params = onp.ones(()) optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) optimizer = optimizer_def.create(params) def loss_fn(x): return 2. * x loss, grad = optimizer.compute_gradient(loss_fn) self.assertEqual(loss, 2.) self.assertEqual(grad, 2.) def loss_aux_fn(x): return 3. * x, 4. loss, aux, grad = optimizer.compute_gradient(loss_aux_fn) self.assertEqual(loss, 3.) self.assertEqual(grad, 3.) self.assertEqual(aux, 4.)
def test(): # We just check we can run the functions and that they return "something" print('testing fit-flax') N = 3 D = 5 C = 10 model = ModelTest(nhidden=0, nclasses=C) rng = jax.random.PRNGKey(0) X = np.random.randn(N, D) y = np.random.choice(C, size=N, p=(1 / C) * np.ones(C)) batch = {'X': X, 'y': y} params = model.init(rng, X)['params'] # test apply logprobs = model.apply({'params': params}, batch['X']) assert logprobs.shape == (N, C) # test loss labels = batch['y'] loss = softmax_cross_entropy(logprobs, labels) assert loss.shape == () # test test_fn metrics = eval_classifier(model, params, batch) assert np.allclose(loss, metrics['loss']) # test train_fn make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9) optimizer = make_optimizer.create(params) optimizer, metrics = update_classifier(model, optimizer, batch) # test fit_model num_steps = 2 train_iter = make_iterator_from_batch(batch) test_iter = make_iterator_from_batch(batch) params_init = params params_new, history = fit_model(model, rng, num_steps, train_iter, test_iter) diff = tree_util.tree_multimap(lambda x, y: x - y, params_init, params_new) print(diff) norm = l2norm_sq(diff) assert norm > 0 # check that parameters have changed :) print(history) print('test passed')
def test_momentum_with_weight_norm(self): params = onp.ones((2, 2)) * 2. optimizer_def = optim.WeightNorm(optim.Momentum(0.1)) state = optimizer_def.init_state(params) self.assertEqual(jax.tree_map(onp.shape, state), optim.OptimizerState( step=(), param_states=_WeightNormParamState( direction_state=_MomentumParamState(momentum=(2, 2)), scale_state=_MomentumParamState(momentum=(1, 2)), mult=(1, 2) ) )) grads = onp.ones((2, 2)) new_params, new_state = optimizer_def.apply_gradient( optimizer_def.hyper_params, params, state, grads) onp.testing.assert_allclose(new_params, onp.full_like(params, 1.9)) onp.testing.assert_allclose(new_state.param_states.mult, 1.9 * 2 ** 0.5)
def create_train_state(rng, config: ml_collections.ConfigDict, model, image_size): """Create initial training state.""" dynamic_scale = None platform = jax.local_devices()[0].platform if config.half_precision and platform == 'gpu': dynamic_scale = optim.DynamicScale() else: dynamic_scale = None params, model_state = initialized(rng, image_size, model) optimizer = optim.Momentum( beta=config.momentum, nesterov=True).create(params) state = TrainState( step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) return state
def create_optimizer(model, learning_rate, beta = 0.9): """Creates an SGD (Nesterov momentum) optimizer. Learning rate will be ignored when using a learning rate schedule. Args: model: The FLAX model to optimize. learning_rate: Learning rate for the gradient descent. beta: Momentum parameter. Returns: A SGD optimizer that targets the model. """ optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta, nesterov=True) optimizer = optimizer_def.create(model) return optimizer
def create_optimizer(config, model, learning_rate, train_size, sampler_rng): """Create optimizer definition based on config flags.""" if config.optimizer == 'adam': optimizer_def = optim.Adam( learning_rate=learning_rate, beta1=config.momentum) elif config.optimizer == 'momentum': optimizer_def = optim.Momentum( learning_rate=learning_rate, beta=config.momentum) elif config.optimizer == 'sym_euler': optimizer_def = sym_euler_sgmcmc.SymEulerSGMCMC( train_size, sampler_rng, learning_rate=learning_rate, beta=config.momentum, temperature=config.base_temp, step_size_factor=1.) else: raise ValueError('Invalid value %s for config.optimizer.' % config.optimizer) if config.weight_norm == 'none': pass elif config.weight_norm == 'learned': optimizer_def = optim.WeightNorm(optimizer_def) elif config.weight_norm in ['fixed', 'ws_sqrt', 'learned_b', 'ws']: # Applied in layers directly. pass else: raise ValueError('Invalid value %s for config.weight_norm.' % config.weight_norm) optimizer = optimizer_def.create(model) if not config.debug_run: optimizer = optimizer.replicate() return optimizer
def test(): # We just check we can run the functions and that they return "something" print('testing fit-flax') N = 3; D = 5; C = 10; model = ModelTest(nhidden = 0, nclasses = C) rng = jax.random.PRNGKey(0) X = np.random.randn(N,D) y = np.random.choice(C, size=N, p=(1/C)*np.ones(C)); batch = {'X': X, 'y': y} params = model.init(rng, X)['params'] logprobs = model.apply({'params': params}, batch['X']) assert logprobs.shape==(N,C) labels = batch['y'] loss = softmax_cross_entropy(logprobs, labels) assert loss.shape==() metrics = eval_batch(model, params, batch) assert np.allclose(loss, metrics['loss']) make_optimizer = optim.Momentum(learning_rate=0.1, beta=0.9) optimizer = make_optimizer.create(params) optimizer, metrics = train_batch(model, optimizer, batch) num_steps = 2 train_iter = make_iterator_from_batch(batch); test_iter = make_iterator_from_batch(batch); params_init = params params_new, history = fit_model(model, train_iter, test_iter, rng, num_steps, make_optimizer, train_batch, eval_batch, print_every=1) diff = tree_util.tree_multimap(lambda x,y: x-y, params_init, params_new) diff_max = tree_util.tree_map(lambda x: jnp.max(x), diff) assert jnp.abs(diff_max['Dense_0']['kernel']) > 0 # has changed print('test passed')
def create_optimizer(params, learning_rate, beta): optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta) optimizer = optimizer_def.create(params) return optimizer
def experiment(model_dir='.', # pylint: disable=dangerous-default-value imagenet_subset_dir=None, dataset='cifar10', batch_size=256, eval_batch_size=1024, num_epochs=200, learning_rate=0.1, aug_imagenet_apply_colour_jitter=False, aug_imagenet_greyscale_prob=0.0, sgd_momentum=0.9, sgd_nesterov=True, lr_schedule='stepped', lr_sched_steps=[[60, 0.2], [120, 0.04], [160, 0.008]], lr_sched_halfcoslength=400.0, lr_sched_warmup=5.0, l2_reg=0.0005, weight_decay=0.0, architecture='wrn22_10', n_val=5000, n_sup=1000, teacher_alpha=0.999, anneal_teacher_alpha=False, unsupervised_regularizer='none', cons_weight=1.0, conf_thresh=0.97, conf_avg=False, cut_backg_noise=1.0, cut_prob=1.0, box_reg_scale_mode='fixed', box_reg_scale=0.25, box_reg_random_aspect_ratio=False, cow_sigma_range=(4.0, 8.0), cow_prop_range=(0.25, 1.0), mix_regularizer='none', mix_aug_separately=False, mix_logits=True, mix_weight=1.0, mix_conf_thresh=0.97, mix_conf_avg=True, mix_conf_mode='mix_prob', ict_alpha=0.1, mix_box_reg_scale_mode='fixed', mix_box_reg_scale=0.25, mix_box_reg_random_aspect_ratio=False, mix_cow_sigma_range=(4.0, 8.0), mix_cow_prop_range=(0.0, 1.0), subset_seed=12345, val_seed=131, run_seed=None, log_fn=print, checkpoints='on', on_epoch_finished_fn=None, debug=False): """Run experiment.""" if checkpoints not in {'none', 'on', 'retain'}: raise ValueError('checkpoints should be one of (none|on|retain)') if checkpoints != 'none': checkpoint_path = os.path.join(model_dir, 'checkpoint.pkl') checkpoint_new_path = os.path.join(model_dir, 'checkpoint.pkl.new') else: checkpoint_path = None checkpoint_new_path = None if dataset not in {'svhn', 'cifar10', 'cifar100', 'imagenet'}: raise ValueError('Unknown dataset \'{}\''.format(dataset)) if architecture not in {'wrn20_10', 'wrn26_10', 'wrn26_2', 'wrn20_6_shakeshake', 'wrn26_6_shakeshake', 'wrn26_2_shakeshake', 'pyramid', 'resnet50', 'resnet101', 'resnet152', 'resnet50x2', 'resnet101x2', 'resnet152x2', 'resnet50x4', 'resnet101x4', 'resnet152x4', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext152_32x4d'}: raise ValueError('Unknown architecture \'{}\''.format(architecture)) if lr_schedule not in {'constant', 'stepped', 'cosine'}: raise ValueError('Unknown LR schedule \'{}\''.format(lr_schedule)) if mix_conf_mode not in {'mix_prob', 'mix_conf'}: raise ValueError('Unknown mix_conf_mode \'{}\''.format(mix_conf_mode)) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(model_dir) else: summary_writer = None unsup_reg, augment_twice = build_pert_reg( unsupervised_regularizer, cut_backg_noise=cut_backg_noise, cut_prob=cut_prob, box_reg_scale_mode=box_reg_scale_mode, box_reg_scale=box_reg_scale, box_reg_random_aspect_ratio=box_reg_random_aspect_ratio, cow_sigma_range=cow_sigma_range, cow_prop_range=cow_prop_range) mix_reg = build_mix_reg( mix_regularizer, ict_alpha=ict_alpha, box_reg_scale_mode=mix_box_reg_scale_mode, box_reg_scale=mix_box_reg_scale, box_reg_random_aspect_ratio=mix_box_reg_random_aspect_ratio, cow_sigma_range=mix_cow_sigma_range, cow_prop_range=mix_cow_prop_range) if run_seed is None: run_seed = subset_seed << 32 | n_val train_rng = jax.random.PRNGKey(run_seed) init_rng, train_rng = jax.random.split(train_rng) if batch_size % jax.device_count() > 0: raise ValueError('Train batch size must be divisible by the number of ' 'devices') if eval_batch_size % jax.device_count() > 0: raise ValueError('Eval batch size must be divisible by the number of ' 'devices') local_batch_size = batch_size // jax.host_count() local_eval_batch_size = eval_batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() if dataset == 'svhn': image_size = 32 top5_err_required = False data_source = small_image_data_source.SVHNDataSource( n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size, eval_batch_size=local_eval_batch_size, augment_twice=augment_twice, subset_seed=subset_seed, val_seed=val_seed) elif dataset == 'cifar10': image_size = 32 top5_err_required = False data_source = small_image_data_source.CIFAR10DataSource( n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size, eval_batch_size=local_eval_batch_size, augment_twice=augment_twice, subset_seed=subset_seed, val_seed=val_seed) elif dataset == 'cifar100': image_size = 32 top5_err_required = False data_source = small_image_data_source.CIFAR100DataSource( n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size, eval_batch_size=local_eval_batch_size, augment_twice=augment_twice, subset_seed=subset_seed, val_seed=val_seed) elif dataset == 'imagenet': image_size = 224 top5_err_required = True if imagenet_subset_dir is None: raise ValueError('Please provide a directory to the imagenet_subset_dir ' 'command line arg to specify where the ImageNet ' 'subsets are stored') data_source = imagenet_data_source.ImageNetDataSource( imagenet_subset_dir, n_val, n_sup, local_batch_size, local_eval_batch_size, augment_twice, apply_colour_jitter=aug_imagenet_apply_colour_jitter, greyscale_prob=aug_imagenet_greyscale_prob, load_test_set=(n_val == 0), image_size=image_size, subset_seed=subset_seed, val_seed=val_seed) else: raise RuntimeError n_train = data_source.n_train train_ds = data_source.train_semisup_ds if n_val == 0: eval_ds = data_source.test_ds n_eval = data_source.n_test else: eval_ds = data_source.val_ds n_eval = data_source.n_val log_fn('DATA: |train|={}, |sup|={}, |eval|={}, (|val|={}, |test|={})'.format( data_source.n_train, data_source.n_sup, n_eval, data_source.n_val, data_source.n_test)) log_fn('Loaded dataset') steps_per_epoch = n_train // batch_size steps_per_eval = n_eval // eval_batch_size if n_eval % eval_batch_size > 0: steps_per_eval += 1 num_steps = steps_per_epoch * num_epochs # Create model model_stu, state_stu = create_model( init_rng, architecture, device_batch_size, image_size, data_source.n_classes) state_stu = jax_utils.replicate(state_stu) log_fn('Built model') # Create optimizer optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=sgd_momentum, nesterov=sgd_nesterov) optimizer_stu = optimizer_def.create(model_stu) optimizer_stu = optimizer_stu.replicate() del model_stu # don't keep a copy of the initial model # Create learning rate function base_learning_rate = learning_rate * batch_size / 256. if lr_schedule == 'constant': learning_rate_fn = create_constant_learning_rate_fn(base_learning_rate) elif lr_schedule == 'stepped': learning_rate_fn = create_stepped_learning_rate_fn( base_learning_rate, steps_per_epoch, lr_sched_steps=lr_sched_steps, warmup_length=lr_sched_warmup) elif lr_schedule == 'cosine': learning_rate_fn = create_cosine_learning_rate_fn( base_learning_rate, steps_per_epoch, halfcoslength_epochs=lr_sched_halfcoslength, warmup_length=lr_sched_warmup) else: raise RuntimeError if anneal_teacher_alpha: if lr_schedule == 'constant': one_minus_alpha_fn = create_constant_learning_rate_fn(1.0 - teacher_alpha) elif lr_schedule == 'stepped': one_minus_alpha_fn = create_stepped_learning_rate_fn( 1.0 - teacher_alpha, steps_per_epoch, lr_sched_steps=lr_sched_steps) elif lr_schedule == 'cosine': one_minus_alpha_fn = create_cosine_learning_rate_fn( 1.0 - teacher_alpha, steps_per_epoch, halfcoslength_epochs=lr_sched_halfcoslength) else: raise RuntimeError teacher_alpha_fn = lambda step: 1.0 - one_minus_alpha_fn(step) else: teacher_alpha_fn = lambda step: teacher_alpha log_fn('Built optimizer') # Teacher model is just the student as we duplicate it when we modify it model_tea = optimizer_stu.target # Replicate batch stats state_tea = jax.tree_map(lambda x: x, state_stu) # Set up epoch and step counter # Load existing checkpoint if available epoch = 1 step = 0 if checkpoints != 'none': if tf.io.gfile.exists(checkpoint_path): with tf.io.gfile.GFile(checkpoint_path, 'rb') as f_in: check = pickle.load(f_in) # Student optimizer and batch stats optimizer_stu = util.restore_state_list( optimizer_stu, check['optimizer_stu']) state_stu = util.restore_state_list( state_stu, check['state_stu']) # Teacher model and batch stats model_tea = util.restore_state_list( model_tea, check['model_tea']) state_tea = util.restore_state_list( state_tea, check['state_tea']) epoch = check['epoch'] step = check['step'] log_fn('Loaded checkpoint from {}'.format(checkpoint_path)) # # Training and evaluation step functions # p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=learning_rate_fn, l2_reg=l2_reg, weight_decay=weight_decay, teacher_alpha_fn=teacher_alpha_fn, unsup_reg=unsup_reg, cons_weight=cons_weight, conf_thresh=conf_thresh, conf_avg=conf_avg, mix_reg=mix_reg, mix_aug_separately=mix_aug_separately, mix_logits=mix_logits, mix_weight=mix_weight, mix_conf_thresh=mix_conf_thresh, mix_conf_avg=mix_conf_avg, mix_conf_mode=mix_conf_mode), axis_name='batch') p_eval_step = jax.pmap( functools.partial(eval_step, eval_top_5=top5_err_required), axis_name='batch') # Create dataset batch iterators train_iter = iter(train_ds) eval_iter = iter(eval_ds) # # Training loop # log_fn('Training...') epoch_metrics_stu = [] t1 = time.time() while step < num_steps: train_rng, iter_rng = jax.random.split(train_rng) batch = next(train_iter) batch = jax.tree_map(lambda x: x._numpy(), batch) # pylint: disable=protected-access batch = shard(batch, iter_rng) optimizer_stu, state_stu, metrics_stu, model_tea, state_tea = p_train_step( optimizer_stu, state_stu, model_tea, state_tea, batch) if debug: log_fn('Step {} time {}'.format(step, time.time()-t1)) epoch_metrics_stu.append(metrics_stu) if (step + 1) % steps_per_epoch == 0: epoch_metrics_stu = util.get_metrics(epoch_metrics_stu) train_epoch_metrics = jax.tree_map(lambda x: x.mean(), epoch_metrics_stu) if summary_writer is not None: for key, vals in epoch_metrics_stu.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) epoch_metrics_stu = [] eval_stu_metrics = [] eval_tea_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # TF to NumPy eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Pad short batches eval_batch = util.pad_classification_batch( eval_batch, local_eval_batch_size) # Shard across local devices eval_batch = shard(eval_batch) metrics_stu = p_eval_step(optimizer_stu.target, state_stu, eval_batch) metrics_tea = p_eval_step(model_tea, state_tea, eval_batch) eval_stu_metrics.append(metrics_stu) eval_tea_metrics.append(metrics_tea) eval_stu_metrics = util.get_metrics(eval_stu_metrics) eval_tea_metrics = util.get_metrics(eval_tea_metrics) eval_stu_epoch_metrics = jax.tree_map(lambda x: x.sum(), eval_stu_metrics) eval_tea_epoch_metrics = jax.tree_map(lambda x: x.sum(), eval_tea_metrics) eval_stu_epoch_metrics = avg_eval_metrics(eval_stu_epoch_metrics) eval_tea_epoch_metrics = avg_eval_metrics(eval_tea_epoch_metrics) t2 = time.time() if top5_err_required: log_fn('EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, ' 'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, ' 'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, ' 'top-5-err={:.3%}, TEA Eval loss={:.6f}, err={:.3%}, ' 'top-5-err={:.3%}'.format( epoch, t2 - t1, train_epoch_metrics['loss'], train_epoch_metrics['error_rate'], train_epoch_metrics['cons_loss'], train_epoch_metrics['conf_rate'], train_epoch_metrics['mix_loss'], train_epoch_metrics['mix_conf_rate'], eval_stu_epoch_metrics['loss'], eval_stu_epoch_metrics['error_rate'], eval_stu_epoch_metrics['top5_error_rate'], eval_tea_epoch_metrics['loss'], eval_tea_epoch_metrics['error_rate'], eval_tea_epoch_metrics['top5_error_rate'],)) else: log_fn('EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, ' 'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, ' 'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, ' 'TEA Eval loss={:.6f}, err={:.3%}'.format( epoch, t2 - t1, train_epoch_metrics['loss'], train_epoch_metrics['error_rate'], train_epoch_metrics['cons_loss'], train_epoch_metrics['conf_rate'], train_epoch_metrics['mix_loss'], train_epoch_metrics['mix_conf_rate'], eval_stu_epoch_metrics['loss'], eval_stu_epoch_metrics['error_rate'], eval_tea_epoch_metrics['loss'], eval_tea_epoch_metrics['error_rate'],)) if on_epoch_finished_fn is not None: if top5_err_required: on_epoch_finished_fn( epoch, eval_stu_err=eval_stu_epoch_metrics['error_rate'], eval_tea_err=eval_tea_epoch_metrics['error_rate'], eval_stu_top5_err=eval_stu_epoch_metrics['top5_error_rate'], eval_tea_top5_err=eval_tea_epoch_metrics['top5_error_rate'], ) else: on_epoch_finished_fn( epoch, eval_stu_err=eval_stu_epoch_metrics['error_rate'], eval_tea_err=eval_tea_epoch_metrics['error_rate'], ) t1 = t2 if summary_writer is not None: summary_writer.scalar( 'eval_stu_loss', eval_stu_epoch_metrics['loss'], epoch) summary_writer.scalar( 'eval_stu_error_rate', eval_stu_epoch_metrics['error_rate'], epoch) summary_writer.scalar( 'eval_tea_loss', eval_tea_epoch_metrics['loss'], epoch) summary_writer.scalar( 'eval_tea_error_rate', eval_tea_epoch_metrics['error_rate'], epoch) if top5_err_required: summary_writer.scalar( 'eval_stu_top5_error_rate', eval_stu_epoch_metrics['top5_error_rate'], epoch) summary_writer.scalar( 'eval_tea_top5_error_rate', eval_tea_epoch_metrics['top5_error_rate'], epoch) summary_writer.flush() epoch += 1 if checkpoints != 'none': if jax.host_id() == 0: # Write to new checkpoint file so that we don't immediately # overwrite the old one with tf.io.gfile.GFile(checkpoint_new_path, 'wb') as f_out: check = dict( optimizer_stu=util.to_state_list(optimizer_stu), state_stu=util.to_state_list(state_stu), model_tea=util.to_state_list(model_tea), state_tea=util.to_state_list(state_tea), epoch=epoch, step=step + 1, ) pickle.dump(check, f_out) del check # Remove old checkpoint and rename if tf.io.gfile.exists(checkpoint_path): tf.io.gfile.remove(checkpoint_path) tf.io.gfile.rename(checkpoint_new_path, checkpoint_path) step += 1 if checkpoints == 'on': if jax.host_id() == 0: if tf.io.gfile.exists(checkpoint_path): tf.io.gfile.remove(checkpoint_path)
def train(module, model_dir, batch_size, eval_batch_size, num_moco_epochs, num_clf_epochs, moco_learning_rate, clf_learning_rate, sgd_momentum=0.9, sgd_nesterov=True, make_moco_lr_fun=None, make_clf_lr_fun=None, moco_l2_reg=0.0001, clf_l2_reg=0.0, feature_size=64 * 8 * 4, moco_momentum=0.999, emb_size=128, moco_temperature=0.07, dictionary_size=65536, run_seed=0, steps_per_epoch=None, steps_per_eval=None): """Train MoCo model.""" if make_moco_lr_fun is None: def make_moco_lr_fun(base_lr, steps_per_epoch): # pylint: disable=function-redefined return lr_schedule.create_stepped_learning_rate_schedule( base_lr, steps_per_epoch, [[120, 0.1], [160, 0.01]]) if make_clf_lr_fun is None: def make_clf_lr_fun(base_lr, steps_per_epoch): # pylint: disable=function-redefined return lr_schedule.create_stepped_learning_rate_schedule( base_lr, steps_per_epoch, [[60, 0.2], [75, 0.04], [90, 0.008]]) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(model_dir) else: summary_writer = None # # # If using more than 1 host, warn the user # # if jax.host_count() > 1: logging.info( 'WARNING: the all_to_all collective used by this program is ' 'not yet supported in multi-host environments') train_rng = jax.random.PRNGKey(run_seed) (init_moco_rng, init_clf_rng, init_dictionary_rng, train_rng) = jax.random.split(train_rng, num=4) if batch_size % jax.device_count() > 0: raise ValueError('Train batch size must be divisible by the number ' 'of devices') if eval_batch_size % jax.device_count() > 0: raise ValueError('Eval batch size must be divisible by the number ' 'of devices') local_batch_size = batch_size // jax.host_count() local_eval_batch_size = eval_batch_size // jax.host_count() n_devices = jax.device_count() n_local_devices = jax.local_device_count() device_batch_size = batch_size // n_devices image_size = 224 data_source = imagenet_data_source.load_imagenet( train_batch_size=local_batch_size, eval_batch_size=local_eval_batch_size, greyscale_prob=0.1) n_train = data_source.n_train train_moco_ds = data_source.train_moco_ds train_clf_ds = data_source.train_clf_ds eval_ds = data_source.test_ds n_eval = data_source.n_test logging.info('DATA: |train|=%d, |eval|=%d', data_source.n_train, n_eval) if steps_per_epoch is None: steps_per_epoch = n_train // batch_size if steps_per_eval is None: steps_per_eval = n_eval // eval_batch_size num_moco_steps = steps_per_epoch * num_moco_epochs num_clf_steps = steps_per_epoch * num_clf_epochs logging.info('Loaded dataset') # # Create query model # model_query, state_query = create_model(init_moco_rng, device_batch_size, image_size, module) state_query = jax_utils.replicate(state_query) # Create linear classifier feat_model_clf = create_linear_classifier(init_clf_rng, device_batch_size, feature_size, data_source.n_classes) # Randomly initialise dictionary moco_dictionary = jax.random.normal(init_dictionary_rng, (dictionary_size, emb_size), dtype=jnp.float32) moco_dictionary = normalize_embeddings(moco_dictionary) logging.info('Built model') # # Create optimizer # optimizer_def = optim.Momentum(learning_rate=moco_learning_rate, beta=sgd_momentum, nesterov=sgd_nesterov) optimizer_query = optimizer_def.create(model_query) optimizer_query = optimizer_query.replicate() del model_query # don't keep a copy of the initial model feat_clf_optimizer_def = optim.Momentum(learning_rate=clf_learning_rate, beta=sgd_momentum, nesterov=sgd_nesterov) feat_clf_optimizer = feat_clf_optimizer_def.create(feat_model_clf) feat_clf_optimizer = feat_clf_optimizer.replicate() logging.info('Built optimizer') # # Learning rate schedule # base_moco_learning_rate = moco_learning_rate * batch_size / 256. base_clf_learning_rate = clf_learning_rate * batch_size / 256. moco_learning_rate_fn = make_moco_lr_fun(base_moco_learning_rate, steps_per_epoch) clf_learning_rate_fn = make_clf_lr_fun(base_clf_learning_rate, steps_per_epoch) # The key model is a replica of the query model. Since Flax models are # immutable, we can start with the query model model_key = optimizer_query.target # Replicate batch stats state_key = jax.tree_map(lambda x: x, state_query) # Set up epoch and step counter # Load existing checkpoint if available moco_epoch = 1 clf_epoch = 1 moco_step = 0 clf_step = 0 # # Training and eval functions # p_moco_key_step = jax.pmap(functools.partial(moco_key_step), axis_name='batch') p_moco_train_step = jax.pmap(functools.partial( moco_train_step, n_devices=n_devices, moco_temperature=moco_temperature, learning_rate_fn=moco_learning_rate_fn, l2_reg=moco_l2_reg, moco_momentum=moco_momentum), axis_name='batch') p_classifier_train_step = jax.pmap(functools.partial( classifier_train_step, learning_rate_fn=clf_learning_rate_fn, l2_reg=clf_l2_reg), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step), axis_name='batch') # Create MoCo dataset batch iterator train_moco_it = iter(train_moco_ds) # # Training loop # logging.info('Training MoCo...') epoch_metrics_moco = [] t1 = time.time() while moco_step < num_moco_steps: (train_rng, shuffle_rng) = jax.random.split(train_rng, num=2) batch = next(train_moco_it) # TF to NumPy batch = jax.tree_map(lambda x: x._numpy(), batch) # pylint: disable=protected-access # Compute key embeddings # We have to shuffle the batch to prevent the network from cheating using # batch stats shuffle_forward = jax.random.shuffle(shuffle_rng, jnp.arange(local_batch_size)) shuffle_backward = jnp.zeros((local_batch_size, ), dtype=int) shuffle_backward = jax.ops.index_update(shuffle_backward, shuffle_forward, jnp.arange(local_batch_size)) key_batch = dict(x_key=batch['key_image'][shuffle_forward, Ellipsis]) key_batch_sharded = common_utils.shard(key_batch) emb_key, state_key = p_moco_key_step(model_key, state_key, key_batch_sharded) emb_key = emb_key.reshape((-1, emb_size)) emb_key = emb_key[shuffle_backward, Ellipsis] # # Main MoCo training step # moco_batch = batch.copy() moco_batch['emb_key'] = emb_key sharded_moco_batch = common_utils.shard(moco_batch) # Repeat the MoCo dictionary across shards sharded_dict = jnp.repeat(moco_dictionary[None, Ellipsis], n_local_devices, axis=0) # The main train step function is applied slightly differently in # multi-host environments optimizer_query, state_query, metrics_moco, model_key, code_batch = \ p_moco_train_step(optimizer_query, state_query, model_key, sharded_moco_batch, sharded_dict) code_batch = code_batch[0].reshape((-1, emb_size)) moco_dictionary = jnp.append(code_batch, moco_dictionary, axis=0)[:dictionary_size] epoch_metrics_moco.append(metrics_moco) if (moco_step + 1) % steps_per_epoch == 0: epoch_metrics_moco = common_utils.get_metrics(epoch_metrics_moco) train_epoch_metrics = jax.tree_map(lambda x: x.mean(), epoch_metrics_moco) if summary_writer is not None: for key, vals in epoch_metrics_moco.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, moco_step - len(vals) + i + 1) epoch_metrics_moco = [] t2 = time.time() logging.info('MoCo EPOCH %d: (took %.3fs): MoCo loss=%.6f', moco_epoch, t2 - t1, train_epoch_metrics['moco_loss']) t1 = t2 if summary_writer is not None: summary_writer.flush() moco_epoch += 1 moco_step += 1 del train_moco_it # # # Unsupervised MoCo training complete # Train classifier # # logging.info('Training Linear Classifier...') train_clf_it = iter(train_clf_ds) eval_iter = iter(eval_ds) epoch_feat_metrics = [] t1 = time.time() while clf_step < num_clf_steps: batch = next(train_clf_it) # TF to NumPy batch = jax.tree_map(lambda x: x._numpy(), batch) # pylint: disable=protected-access batch = common_utils.shard(batch) feat_clf_optimizer, feat_metrics = p_classifier_train_step( feat_clf_optimizer, model_key, state_key, batch) epoch_feat_metrics.append(feat_metrics) if (clf_step + 1) % steps_per_epoch == 0: epoch_feat_metrics = common_utils.get_metrics(epoch_feat_metrics) train_epoch_feat_metrics = jax.tree_map(lambda x: x.mean(), epoch_feat_metrics) if summary_writer is not None: for key, vals in epoch_feat_metrics.items(): tag = 'train_feat_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, clf_step - len(vals) + i + 1) epoch_feat_metrics = [] eval_feat_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # TF to NumPy eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Shard across local devices eval_batch = common_utils.shard(eval_batch) feat_metrics = p_eval_step(model_key, state_key, feat_clf_optimizer.target, eval_batch) eval_feat_metrics.append(feat_metrics) eval_feat_metrics = common_utils.get_metrics(eval_feat_metrics) eval_epoch_feat_metrics = jax.tree_map(lambda x: x.mean(), eval_feat_metrics) t2 = time.time() logging.info( 'Linear classifier EPOCH %d: (took %.3fs): TRAIN FEAT loss=%.6f, ' 'err=%.3f; EVAL FEAT loss=%.6f, err=%.3f', clf_epoch, t2 - t1, train_epoch_feat_metrics['loss'], train_epoch_feat_metrics['error_rate'] * 100.0, eval_epoch_feat_metrics['loss'], eval_epoch_feat_metrics['error_rate'] * 100.0, ) t1 = t2 if summary_writer is not None: summary_writer.scalar('eval_feat_loss', eval_epoch_feat_metrics['loss'], clf_epoch) summary_writer.scalar('eval_feat_error_rate', eval_epoch_feat_metrics['error_rate'], clf_epoch) summary_writer.flush() clf_epoch += 1 clf_step += 1 return eval_epoch_feat_metrics
def create_optimizer(model, learning_rate, beta): optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta) optimizer = optimizer_def.create(model) return optimizer
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = imagenet_train_utils.create_input_iter( local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = imagenet_train_utils.create_input_iter( local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=False, cache=FLAGS.cache) # Create the hyperparameter object if FLAGS.hparams_config_dict: # In this case, there are multiple training configs defined in the config # dict, so we pull out the one this training run should use. if 'configs' in FLAGS.hparams_config_dict: hparams_config_dict = FLAGS.hparams_config_dict.configs[FLAGS.config_idx] else: hparams_config_dict = FLAGS.hparams_config_dict hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, hparams_config_dict) else: raise ValueError('Please provide a base config dict.') os_hparams_utils.write_hparams_to_file_with_host_id_check( hparams, FLAGS.model_dir) # get num_epochs from hparam instead of FLAGS num_epochs = hparams.lr_scheduler.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs # Estimate compute / memory costs if jax.host_id() == 0 and FLAGS.estimate_compute_and_memory_cost: estimate_compute_and_memory_cost( image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams) logging.info('Writing training HLO and estimating compute/memory costs.') rng = random.PRNGKey(hparams.seed) model, variables = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=hparams.model_hparams, train=True, is_teacher=hparams.is_teacher) # pylint: disable=g-long-lambda if hparams.teacher_model == 'resnet50-8bit': teacher_config = w8a8auto_paper_config() teacher_hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, teacher_config) teacher_model, _ = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=teacher_hparams.model_hparams, train=False, is_teacher=True) # teacher model does not need to be trainable # Directory where checkpoints are saved ckpt_model_dir = FLAGS.resnet508b_ckpt_path # will restore to best checkpoint state_load = checkpoints.restore_checkpoint(ckpt_model_dir, None) teacher_variables = {'params': state_load['optimizer']['target']} teacher_variables.update(state_load['model_state']) # create a dictionary for better argument passing teacher = { 'model': lambda var, img, labels: jax.nn.softmax( teacher_model.apply(var, img)), 'variables': teacher_variables, } elif hparams.teacher_model == 'labels': teacher = { 'model': lambda var, img, labels: common_utils.onehot( labels, num_classes=1000), 'variables': {}, # no need of variables in this case } else: raise ValueError('The specified teacher model is not supported.') model_state, params = variables.pop('params') if hparams.optimizer == 'sgd': optimizer = optim.Momentum( beta=hparams.momentum, nesterov=True).create(params) elif hparams.optimizer == 'adam': optimizer = optim.Adam( beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params) else: raise ValueError('Optimizer type is not supported.') state = imagenet_train_utils.TrainState( step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del params, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int(state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) base_learning_rate = hparams.base_learning_rate * batch_size / 256. learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, hparams.lr_scheduler, batch_size) p_train_step = jax.pmap( functools.partial( imagenet_train_utils.train_step, model, learning_rate_fn=learning_rate_fn, teacher=teacher), axis_name='batch', static_broadcasted_argnums=(2, 3, 4)) p_eval_step = jax.pmap( functools.partial(imagenet_train_utils.eval_step, model), axis_name='batch', static_broadcasted_argnums=(2,)) epoch_metrics = [] state_dict_summary_all = [] state_dict_keys = _get_state_dict_keys_from_flags() t_loop_start = time.time() last_log_step = 0 for step, batch in zip(range(step_offset, num_steps), train_iter): if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps * steps_per_epoch: break update_bounds = train_utils.should_update_bounds( hparams.activation_bound_update_freq, hparams.activation_bound_start_step, step) # and pass the result bool value to p_train_step # The function should take hparams.weight_quant_start_step as inputs quantize_weights = train_utils.should_quantize_weights( hparams.weight_quant_start_step, step // steps_per_epoch) state, metrics = p_train_step(state, batch, hparams, update_bounds, quantize_weights) state_dict_summary = summary_utils.get_state_dict_summary( state.model_state, state_dict_keys) state_dict_summary_all.append(state_dict_summary) epoch_metrics.append(metrics) def should_log(step): epoch_no = step // steps_per_epoch step_in_epoch = step - epoch_no * steps_per_epoch do_log = False do_log = do_log or (step + 1 == num_steps) # log at the end end_of_train = step / num_steps > 0.9 do_log = do_log or ((step_in_epoch % (steps_per_epoch // 4) == 0) and not end_of_train) do_log = do_log or ((step_in_epoch % (steps_per_epoch // 16) == 0) and end_of_train) return do_log if should_log(step): epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = (step - last_log_step) / (time.time() - t_loop_start) last_log_step = step t_loop_start = time.time() # Write to TensorBoard state_dict_summary_all = common_utils.get_metrics(state_dict_summary_all) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) if FLAGS.write_summary: summary_utils.write_state_dict_summaries_to_tb( state_dict_summary_all, summary_writer, FLAGS.state_dict_summary_freq, step) state_dict_summary_all = [] epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = imagenet_train_utils.sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch, quantize_weights) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = imagenet_train_utils.sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform dynamic_scale = None if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 dynamic_scale = optim.DynamicScale() else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = imagenet_train_utils.create_input_iter(local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = imagenet_train_utils.create_input_iter(local_batch_size, FLAGS.data_dir, image_size, input_dtype, train=False, cache=FLAGS.cache) # Create the hyperparameter object if FLAGS.hparams_config_dict: # In this case, there are multiple training configs defined in the config # dict, so we pull out the one this training run should use. if 'configs' in FLAGS.hparams_config_dict: hparams_config_dict = FLAGS.hparams_config_dict.configs[ FLAGS.config_idx] else: hparams_config_dict = FLAGS.hparams_config_dict hparams = os_hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, hparams_config_dict) else: raise ValueError('Please provide a base config dict.') os_hparams_utils.write_hparams_to_file_with_host_id_check( hparams, FLAGS.model_dir) # get num_epochs from hparam instead of FLAGS num_epochs = hparams.lr_scheduler.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs # Estimate compute / memory costs if jax.host_id() == 0: estimate_compute_and_memory_cost(image_size=image_size, model_dir=FLAGS.model_dir, hparams=hparams) logging.info( 'Writing training HLO and estimating compute/memory costs.') model, variables = imagenet_train_utils.create_model( rng, device_batch_size, image_size, model_dtype, hparams=hparams.model_hparams, train=True) model_state, params = variables.pop('params') if hparams.optimizer == 'sgd': optimizer = optim.Momentum(beta=hparams.momentum, nesterov=True).create(params) elif hparams.optimizer == 'adam': optimizer = optim.Adam(beta1=hparams.adam.beta1, beta2=hparams.adam.beta2).create(params) else: raise ValueError('Optimizer type is not supported.') state = imagenet_train_utils.TrainState(step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale) del params, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) base_learning_rate = hparams.base_learning_rate * batch_size / 256. learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, hparams.lr_scheduler) p_train_step = jax.pmap(functools.partial( imagenet_train_utils.train_step, model, learning_rate_fn=learning_rate_fn), axis_name='batch', static_broadcasted_argnums=(2, 3)) p_eval_step = jax.pmap(functools.partial(imagenet_train_utils.eval_step, model), axis_name='batch') epoch_metrics = [] state_dict_summary_all = [] state_dict_keys = _get_state_dict_keys_from_flags() t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): if hparams.early_stop_steps >= 0 and step > hparams.early_stop_steps: break update_bounds = train_utils.should_update_bounds( hparams.activation_bound_update_freq, hparams.activation_bound_start_step, step) state, metrics = p_train_step(state, batch, hparams, update_bounds) state_dict_summary = summary_utils.get_state_dict_summary( state.model_state, state_dict_keys) state_dict_summary_all.append(state_dict_summary) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() # Write to TensorBoard state_dict_summary_all = common_utils.get_metrics( state_dict_summary_all) if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) summary_utils.write_state_dict_summaries_to_tb( state_dict_summary_all, summary_writer, FLAGS.state_dict_summary_freq, step) state_dict_summary_all = [] epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = imagenet_train_utils.sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = imagenet_train_utils.sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(argv): del argv # BEGIN GOOGLE-INTERNAL xm.setup_work_unit() # END GOOGLE-INTERNAL tf.enable_v2_behavior() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir) # Write summaries in background thread to avoid blocking on device sync summary_thread = thread.ThreadPoolExecutor(1, 'summary') if FLAGS.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size is None: batch_size = min(128 * jax.device_count(), 32768) eval_batch_size = 128 * jax.device_count() local_batch_size = batch_size // jax.host_count() local_eval_batch_size = eval_batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() device_eval_batch_size = eval_batch_size // jax.device_count() device_last_eval_batch_size = (input_pipeline.EVAL_IMAGES % eval_batch_size) // jax.device_count() model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32 input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32 if FLAGS.transpose_images: train_input_shape = (224, 224, 3, device_batch_size) eval_input_shapes = [(224, 224, 3, bs) for bs in (device_eval_batch_size, device_last_eval_batch_size)] else: train_input_shape = (device_batch_size, 224, 224, 3) eval_input_shapes = [(bs, 224, 224, 3) for bs in (device_eval_batch_size, device_last_eval_batch_size)] num_epochs = FLAGS.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size logging.info('steps_per_epoch: %f', steps_per_epoch) steps_per_eval = int(np.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size)) logging.info('steps_per_eval: %d', steps_per_eval) base_learning_rate = FLAGS.learning_rate * batch_size / 256. beta = FLAGS.momentum weight_decay = FLAGS.weight_decay logging.info('creating and initializing model and optimizer') model, state = create_model(rng, device_batch_size, image_size, model_dtype) state = jax_utils.replicate(state) if FLAGS.lars: weight_opt_def = optim.LARS(base_learning_rate, beta, weight_decay=weight_decay) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=False) learning_rate_fn = polynomial_learning_rate_fn(batch_size, steps_per_epoch, num_epochs) else: weight_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=weight_decay, nesterov=True) other_opt_def = optim.Momentum(base_learning_rate, beta, weight_decay=0, nesterov=True) learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) def filter_weights(key, _): return 'bias' not in key and 'scale' not in key def filter_other(key, _): return 'bias' in key or 'scale' in key weight_traversal = optim.ModelParamTraversal(filter_weights) other_traversal = optim.ModelParamTraversal(filter_other) optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def), (other_traversal, other_opt_def)) optimizer = optimizer_def.create(model) optimizer = optimizer.replicate() del model # do not keep a copy of the initial model p_train_step = jax.pmap(partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') def device_train_loop_cond(args): _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body(args): optimizer, state, metrics, token, step, epoch = args (images, labels), token = lax.infeed( token, shape=(jax.ShapedArray(train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) batch = {'image': images, 'label': labels} optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn) step += 1 return optimizer, state, metrics, token, step, epoch def device_train_loop(optimizer, state, metrics, step, epoch): token = lax.create_token(step) optimizer, state, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, state, metrics, token, step, epoch)) return optimizer, state, metrics, step p_train_epoch = jax.pmap(device_train_loop, axis_name='batch') if FLAGS.precompile: logging.info('precompiling step/epoch functions') if FLAGS.infeed: # the device training loop condition will immediately be false p_train_epoch(optimizer, state, empty_metrics(), jax_utils.replicate(0), jax_utils.replicate(1)) else: batch = { 'image': jnp.zeros((jax.local_device_count(), ) + train_input_shape, model_dtype), 'label': jnp.zeros((jax.local_device_count(), ) + (device_batch_size, ), jnp.int32) } p_train_step(optimizer, state, batch, empty_metrics()) for dbs, eis in zip( [device_eval_batch_size, device_last_eval_batch_size], eval_input_shapes): batch = { 'image': jnp.zeros((jax.local_device_count(), ) + eis, model_dtype), 'label': jnp.zeros((jax.local_device_count(), ) + (dbs, ), jnp.int32) } p_eval_step(optimizer.target, state, batch, empty_metrics()) allreduce_metrics(empty_metrics()) pmean = functools.partial(jax.lax.pmean, axis_name='batch') jax.pmap(pmean, axis_name='batch')(state) logging.info('constructing datasets') # pylint: disable=g-complex-comprehension train_ds, eval_ds = [ input_pipeline.load_split( local_batch_size if train else local_eval_batch_size, image_size=image_size, dtype=input_dtype, train=train, transpose_images=FLAGS.transpose_images) for train in (True, False) ] # pylint: enable=g-complex-comprehension logging.info('constructing dataset iterators') train_iter = iter(train_ds) eval_iter = iter(eval_ds) logging.info('beginning training') host_step, device_step = 0, jax_utils.replicate(0) for epoch in range(num_epochs): device_epoch = jax_utils.replicate(epoch) metrics = empty_metrics() if FLAGS.infeed: optimizer, state, metrics, device_step = p_train_epoch( optimizer, state, metrics, device_step, device_epoch) while int(host_step // steps_per_epoch) == epoch: batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access if FLAGS.infeed: for i, device in enumerate(jax.local_devices()): images, labels = batch['image'][i], batch['label'][i] assert images.shape == train_input_shape and labels.dtype == jnp.int32 infeed_pool.submit( partial(device.transfer_to_infeed, (images, labels))) else: optimizer, state, metrics = p_train_step( optimizer, state, batch, metrics) host_step += 1 if FLAGS.train_metrics: metrics = allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'train', epoch + 1)) if not FLAGS.distributed_batchnorm: # otherwise it's already synced pmean = functools.partial(jax.lax.pmean, axis_name='batch') state = jax.pmap(pmean, axis_name='batch')(state) metrics = empty_metrics() for _ in range(steps_per_eval): batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter)) # pylint: disable=protected-access metrics = p_eval_step(optimizer.target, state, batch, metrics) metrics = allreduce_metrics(metrics) if jax.host_id() == 0: summary_thread.submit( partial(write_summary, summary_writer, metrics, 'eval', epoch + 1)) # TODO(deveci): do something like this from the summary thread: # if summary['accuracy'] > TARGET_ACCURACY: # break if jax.host_id() == 0: summary_thread.shutdown() # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=False, cache=FLAGS.cache) num_epochs = FLAGS.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs base_learning_rate = FLAGS.learning_rate * batch_size / 256. base_learning_rate = base_learning_rate / FLAGS.loss_scaling model, model_state = create_model(rng, device_batch_size, image_size, model_dtype) optimizer = optim.Momentum(beta=FLAGS.momentum, nesterov=True).create(model) state = TrainState(step=0, optimizer=optimizer, model_state=model_state) del model, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') epoch_metrics = [] t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
class FlaxOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.])) @parameterized.named_parameters( ('sgd', alias.sgd(LR), optim.GradientDescent(LR)), ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum( LR, beta=0.9)), # Different names. ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True), optim.Momentum(LR, beta=0.9, nesterov=True)), ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)), ('centered_rmsprop', alias.rmsprop( LR, centered=True), optim.RMSProp(LR, centered=True)), ('adam', alias.adam(LR), optim.Adam(LR)), ('adam_w', alias.adamw(LR, weight_decay=1e-4), optim.Adam(LR, weight_decay=1e-4)), # Different name. ( 'adagrad', alias.adagrad(LR, initial_accumulator_value=0.), # Different default! optim.Adagrad(LR)), ('lamb', alias.lamb(LR), optim.LAMB(LR)), ('lars', alias.lars(LR, weight_decay=.5, trust_coefficient=0.003, momentum=0.9, eps=1e-3), optim.LARS( LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9, eps=1e-3)), ('adafactor', alias.adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2), optim.Adafactor(learning_rate=LR / 10., factored=True, multiply_by_parameter_scale=True, clipping_threshold=1.0, decay_rate=0.8, min_dim_size_to_factor=2)), ) def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): # flax/optim flax_params = self.init_params flax_optimizer = flax_optimizer.create(flax_params) for _ in range(STEPS): flax_optimizer = flax_optimizer.apply_gradient( self.per_step_updates) flax_params = flax_optimizer.target # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) for _ in range(STEPS): updates, state = optax_optimizer.update(self.per_step_updates, state, optax_params) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)