def partitioned_additive_weight_decay( conv_weight_decay: float, # wd1 from original repo linear_weight_decay: float, # wd2 from original repo ) -> optax.GradientTransformation: def predicate(layer_name, param_name, value): del param_name, value return layer_name.split("/")[-1].startswith("linear") return partition( predicate, optax.additive_weight_decay(linear_weight_decay), optax.additive_weight_decay(conv_weight_decay), )
def partitioned_additive_weight_decay(weight_decay: float): def predicate(layer_name, param_name, value): del layer_name, value return param_name == "w" return optax_utils.partition(predicate, optax.additive_weight_decay(weight_decay))
def create_train_state(rng, config: ml_collections.ConfigDict, model): """Create initial training state.""" params = get_initial_params(rng, model) tx = optax.chain( optax.sgd(learning_rate=config.learning_rate, momentum=config.momentum), optax.additive_weight_decay(weight_decay=config.weight_decay)) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay, max_norm): tx = optax.chain(optax.clip_by_global_norm(max_norm), optax.scale_by_adam(), optax.additive_weight_decay(weight_decay), optax.scale_by_schedule(lr_schedule_fn)) params = model.init(rng, jax.numpy.ones((1, img_size, img_size, 3)), is_training=False) train_state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, ) return train_state
assert cores_per_replica <= 8 bucket = params["bucket"] model_dir = params["model_dir"] layers = params["layers"] d_model = params["d_model"] n_heads = params["n_heads"] n_vocab = params["n_vocab"] seq = params["seq"] norm = params["norm"] params["sampler"] = nucleaus_sample opt = optax.chain(optax.scale(1 / gradient_accumulation_steps), clip_by_global_norm(1), optax.scale_by_adam(), optax.additive_weight_decay(0), optax.scale(-1), optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))) params["optimizer"] = opt start = time.time() print(f"jax devices: {jax.device_count()}") print(f"jax runtime initialized in {time.time() - start:.06}s") mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) devices = np.array(jax.devices()).reshape(mesh_shape) with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f: meta = json.load(f) ckpt_step = meta["checkpoints"][-1]
bucket = params["bucket"] model_dir = params["model_dir"] layers = params["layers"] d_model = params["d_model"] n_heads = params["n_heads"] n_vocab = params["n_vocab"] seq = params["seq"] norm = params["norm"] params["sampler"] = nucleaus_sample opt = optax.chain( optax.scale(1 / gradient_accumulation_steps), clip_by_global_norm(1), optax.scale_by_adam(), optax.additive_weight_decay(0), optax.scale(-1), optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)) ) params["optimizer"] = opt start = time.time() print(f"jax devices: {jax.device_count()}") print(f"jax runtime initialized in {time.time() - start:.06}s") mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) devices = np.array(jax.devices()).reshape(mesh_shape) with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f: meta = json.load(f)
def selective_additive_weight_decay(predicate, weight_decay: float): return partition( predicate, optax.additive_weight_decay(weight_decay), )