Exemplo n.º 1
0
def partitioned_additive_weight_decay(
        conv_weight_decay: float,  # wd1 from original repo
        linear_weight_decay: float,  # wd2 from original repo
) -> optax.GradientTransformation:
    def predicate(layer_name, param_name, value):
        del param_name, value
        return layer_name.split("/")[-1].startswith("linear")

    return partition(
        predicate,
        optax.additive_weight_decay(linear_weight_decay),
        optax.additive_weight_decay(conv_weight_decay),
    )
Exemplo n.º 2
0
def partitioned_additive_weight_decay(weight_decay: float):
    def predicate(layer_name, param_name, value):
        del layer_name, value
        return param_name == "w"

    return optax_utils.partition(predicate,
                                 optax.additive_weight_decay(weight_decay))
Exemplo n.º 3
0
def create_train_state(rng, config: ml_collections.ConfigDict, model):
    """Create initial training state."""
    params = get_initial_params(rng, model)
    tx = optax.chain(
        optax.sgd(learning_rate=config.learning_rate,
                  momentum=config.momentum),
        optax.additive_weight_decay(weight_decay=config.weight_decay))
    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    return state
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay,
                       max_norm):

    tx = optax.chain(optax.clip_by_global_norm(max_norm),
                     optax.scale_by_adam(),
                     optax.additive_weight_decay(weight_decay),
                     optax.scale_by_schedule(lr_schedule_fn))

    params = model.init(rng,
                        jax.numpy.ones((1, img_size, img_size, 3)),
                        is_training=False)

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )
    return train_state
    assert cores_per_replica <= 8

    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]
    n_vocab = params["n_vocab"]
    seq = params["seq"]
    norm = params["norm"]

    params["sampler"] = nucleaus_sample
    opt = optax.chain(optax.scale(1 / gradient_accumulation_steps),
                      clip_by_global_norm(1), optax.scale_by_adam(),
                      optax.additive_weight_decay(0), optax.scale(-1),
                      optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0)))

    params["optimizer"] = opt

    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
        meta = json.load(f)

    ckpt_step = meta["checkpoints"][-1]
Exemplo n.º 6
0
    bucket = params["bucket"]
    model_dir = params["model_dir"]
    layers = params["layers"]
    d_model = params["d_model"]
    n_heads = params["n_heads"]
    n_vocab = params["n_vocab"]
    seq = params["seq"]
    norm = params["norm"]

    params["sampler"] = nucleaus_sample
    opt = optax.chain(
        optax.scale(1 / gradient_accumulation_steps),
        clip_by_global_norm(1),
        optax.scale_by_adam(),
        optax.additive_weight_decay(0),
        optax.scale(-1),
        optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
    )

    params["optimizer"] = opt

    start = time.time()
    print(f"jax devices: {jax.device_count()}")
    print(f"jax runtime initialized in {time.time() - start:.06}s")

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
    devices = np.array(jax.devices()).reshape(mesh_shape)

    with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
        meta = json.load(f)
Exemplo n.º 7
0
def selective_additive_weight_decay(predicate, weight_decay: float):
    return partition(
        predicate,
        optax.additive_weight_decay(weight_decay),
    )