示例#1
0
def wt2_opt():
    reader_opt = ChainMap(
        {
            'sentences': False,
            # if sentences is False
            'min_seq_len': 35,
            'max_seq_len': 35,
            # fi
            'shuffle': False,
            'batch_size': 64,
            'vocab_path': '',  # auto to data
            'text_path': ''  # auto to data
        },
        reader.default_reader_opt())
    model_opt = ChainMap(
        {
            'emb_dim': 650,
            'rnn_dim': 650,
            'rnn_layers': 2,
            'rnn_variational': True,
            'rnn_input_keep_prob': 0.5,
            'rnn_layer_keep_prob': 0.7,
            'rnn_output_keep_prob': 0.5,
            'rnn_state_keep_prob': 0.7,
            'logit_weight_tying': True,
            'vocab_size': -1  # auto to data
        },
        lm.default_rnnlm_opt())
    train_opt = ChainMap(
        {
            'loss_key': 'mean_token_nll',  # or sum_token_nll
            'ngram_loss_coeff': 0.1,
            'init_learning_rate': 0.003,
            'decay_rate': 0.85,
            'staircase': True,
            'optim': 'tensorflow.train.AdamOptimizer',
            # if adam
            'optim_beta1': 0.0,
            'optim_beta2': 0.999,
            'optim_epsilon': 1e-8,
            # fi
            'clip_gradients': 5.0,
            'max_epochs': 40,
            'checkpoint_path': 'tmp',  # auto to exp_dir
            'decay_steps': -1  # if -1 auto to an epoch
        },
        lm.default_train_opt())
    return reader_opt, model_opt, train_opt
示例#2
0
        'loss_key': 'mean_l2',
        'init_learning_rate': 0.001,
        'decay_rate': 0.8,
        'staircase': True,
        'optim': 'tensorflow.train.AdamOptimizer',
        # if adam
        'optim_beta1': 0.9,
        'optim_beta2': 0.999,
        'optim_epsilon': 1e-8,
        # fi
        'clip_gradients': 5.0,
        'max_epochs': 10,
        'checkpoint_path': 'tmp',  # auto to exp_dir
        'decay_steps': -1  # if -1 auto to an epoch
    },
    lm.default_train_opt())

parser = argparse.ArgumentParser()
parser.add_argument('data_dir', type=str)
parser.add_argument('exp_dir', type=str)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--delete', action='store_true')
args = vars(parser.parse_args())

exp_path = partial(os.path.join, args['exp_dir'])
encoder_path = partial(os.path.join, os.path.join(args['exp_dir'], 'encoder'))
data_path = partial(os.path.join, args['data_dir'])
util.ensure_dir(os.path.join(args['exp_dir'], 'encoder'),
                delete=args['delete'])

with open(exp_path('reader_opt.json')) as fp: