예제 #1
0
def load_vaes(H, logprint):
    rng = random.PRNGKey(H.seed_init)
    init_rng, init_eval_rng = random.split(rng)
    init_batch = jnp.zeros((1, H.image_size, H.image_size, H.image_channels))
    ema = params = VAE(H).init({'params': init_rng}, init_batch, init_batch,
                               init_eval_rng)['params']
    optimizer = Adam(weight_decay=H.wd, beta1=H.adam_beta1,
                     beta2=H.adam_beta2).create(params)
    epoch = 0
    if H.restore_path:
        logprint(f'Restoring vae from {H.restore_path}')
        optimizer, epoch = checkpoints.restore_checkpoint(
            H.restore_path, (optimizer, epoch))
        ema = checkpoints.restore_checkpoint(H.restore_path + '_ema', ema)
    total_params = 0
    for p in jax.tree_flatten(optimizer.target)[0]:
        total_params += np.prod(p.shape)
    logprint(total_params=total_params, readable=f'{total_params:,}')
    optimizer, ema = jax_utils.replicate((optimizer, ema))
    return optimizer, ema, epoch
예제 #2
0
def regression_oracle(net: Module, x, y, key, hyperparams: ServerHyperParams):
    x_init = jnp.zeros((1, hyperparams.image_size, hyperparams.image_size,
                        hyperparams.num_channels))
    num_vmap_batching = x.shape[0]
    assert num_vmap_batching == y.shape[0]
    keys = random.split(key, num_vmap_batching + 1)
    batch_params_init = vmap(net.init, in_axes=[0, None])(keys[1:], x_init)
    key = keys[0]
    opt_def = Adam(learning_rate=hyperparams.oracle_lr)
    # opt_def = Momentum(learning_rate=1e-3, weight_decay=1e-4, nesterov=True)
    opts = vmap(opt_def.create)(batch_params_init)
    num_loop_unrolling = 10
    for step in range(hyperparams.oracle_num_steps // num_loop_unrolling):
        keys = random.split(key, num_vmap_batching * num_loop_unrolling + 1)
        key = keys[num_vmap_batching * num_loop_unrolling]
        keys = jnp.reshape(keys[:num_vmap_batching * num_loop_unrolling],
                           (num_vmap_batching, num_loop_unrolling, 2))
        v, opts = jv_train_op_n(net, opts, x, y, keys, hyperparams,
                                num_loop_unrolling)
        if step % 1000 == 0:
            print("step %5d oracle error %.2f" % (step, v))

    return opts.target
    )

    # Enable tensorboard only on the master node
    if has_tensorboard and jax.host_id() == 0:
        summary_writer = SummaryWriter(
            log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = FlaxDataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)

    # Setup optimizer
    optimizer = Adam(
        learning_rate=training_args.learning_rate,
        weight_decay=training_args.weight_decay,
        beta1=training_args.adam_beta1,
        beta2=training_args.adam_beta2,
    ).create(model.params)

    # Create learning rate scheduler
    lr_scheduler_fn = create_learning_rate_scheduler(
        base_learning_rate=training_args.learning_rate,
        warmup_steps=max(training_args.warmup_steps, 1))

    # Create parallel version of the training and evaluation steps
    p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0, ))
    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the optimizer on each device
    optimizer = jax_utils.replicate(optimizer)
예제 #4
0
파일: main.py 프로젝트: Cli212/flax-BERT
 parser.add_argument("--seed", type=int, default=0)
 parser.add_argument("--model", type=str, required=True)
 parser.add_argument("--dataset", type=str, default="wikitext")
 parser.add_argument("--dataset_config", type=str, default="wikitext-2-v1")
 parser.add_argument("--train", default=False, type=ast.literal_eval)
 parser.add_argument("--batch_size", type=int, default=32)
 parser.add_argument("--warmup_steps", type=int, default=5)
 args = parser.parse_args()
 logger = logger_config("log.txt", logging_name="log")
 #
 tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
 rng = random.PRNGKey(args.seed)
 data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer,
                                                     mlm_probability=0.15)
 model = FlaxBertForPretrained.from_pretrained(args.model)
 optimizer = Adam().create(model.params)
 lr_scheduler_fn = create_learning_rate_scheduler(base_learning_rate=1e-3,
                                                  warmup_steps=min(
                                                      args.warmup_steps, 1))
 logger.info("Load dataset")
 datasets = load_dataset(args.dataset, args.dataset_config)
 column_names = datasets['train'].column_names if args.train else datasets[
     'validation'].column_names
 ## This place needs a customized setting
 text_column_name = "review_body" if "review_body" in column_names else column_names[
     0]
 tokenized_datasets = datasets.map(tokenized_function,
                                   input_columns=[text_column_name],
                                   batched=True,
                                   remove_columns=column_names)
 p_training_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))
예제 #5
0
파일: SAC.py 프로젝트: ethanabrooks/jax-rl
 def reset_optimizer(adam: Adam, model: nn.Model) -> Optimizer:
     return jax.device_put(adam.create(model))