def collect_params(all_params, params):
    collected = utils.HParams()

    for k in six.iterkeys(params.values()):
        collected.add_hparam(k, getattr(all_params, k))

    return collected
def default_params():
    params = utils.HParams(
        input="",
        output="",
        model="transformer",
        vocab=["", ""],
        pad="<pad>",
        bos="<eos>",
        eos="<eos>",
        unk="<unk>",
        # Dataset
        batch_size=4096,
        fixed_batch_size=False,
        min_length=1,
        max_length=256,
        buffer_size=10000,
        # Initialization
        initializer_gain=1.0,
        initializer="uniform_unit_scaling",
        # Regularization
        scale_l1=0.0,
        scale_l2=0.0,
        # Training
        script="",
        warmup_steps=4000,
        train_steps=100000,
        update_cycle=1,
        optimizer="Adam",
        adam_beta1=0.9,
        adam_beta2=0.999,
        adam_epsilon=1e-8,
        adadelta_rho=0.95,
        adadelta_epsilon=1e-6,
        clipping="global_norm",
        clip_grad_norm=5.0,
        learning_rate=1.0,
        learning_rate_schedule="linear_warmup_rsqrt_decay",
        learning_rate_boundaries=[0],
        learning_rate_values=[0.0],
        device_list=[0],
        embedding="",
        # Validation
        keep_top_k=50,
        frequency=10,
        # Checkpoint Saving
        keep_checkpoint_max=20,
        keep_top_checkpoint_max=5,
        save_summary=True,
        save_checkpoint_secs=0,
        save_checkpoint_steps=1000,
    )

    return params
예제 #3
0
def default_params():
    params = utils.HParams(
        input=None,
        output=None,
        vocabulary=None,
        embedding="",
        # vocabulary specific
        pad="<pad>",
        bos="<bos>",
        eos="<eos>",
        unk="<unk>",
        device=0,
        decode_batch_size=128)

    return params
예제 #4
0
def merge_params(params1, params2):
    params = utils.HParams()

    for (k, v) in six.iteritems(params1.values()):
        params.add_hparam(k, v)

    params_dict = params.values()

    for (k, v) in six.iteritems(params2.values()):
        if k in params_dict:
            # Override
            setattr(params, k, v)
        else:
            params.add_hparam(k, v)

    return params
    def base_params():
        params = utils.HParams(pad="<pad>",
                               bos="<eos>",
                               eos="<eos>",
                               unk="<unk>",
                               feature_size=100,
                               hidden_size=200,
                               filter_size=800,
                               num_heads=8,
                               num_hidden_layers=10,
                               attention_dropout=0.0,
                               residual_dropout=0.1,
                               relu_dropout=0.0,
                               label_smoothing=0.1,
                               clip_grad_norm=0.0)

        return params