Ejemplo n.º 1
0
def create_model(session, forward_only):
    """建立模型"""
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = s2s_model.S2SModel(data_utils.dim, data_utils.dim, buckets,
                               FLAGS.size, FLAGS.dropout, FLAGS.num_layers,
                               FLAGS.max_gradient_norm, FLAGS.batch_size,
                               FLAGS.learning_rate, FLAGS.num_samples,
                               forward_only, dtype)
    return model
Ejemplo n.º 2
0
def create_model(session, forward_only):
    dtype = tf.float32
    model = s2s_model.S2SModel(
        data_utils.dim,
        data_utils.dim,
        buckets,
        512,
        1.0,
        1,
        5.0,
        64,
        0.0003,
        512,
        forward_only,
        dtype
    )
    return model
Ejemplo n.º 3
0
def create_model(session, forward_only):
    #forward_only即代表是否是训练还是预测,因为预测时只做向前传播训练时双向传播
    """建立模型"""
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = s2s_model.S2SModel(
        data_utils.dim,  #字典中字的总数
        data_utils.dim,
        buckets,  #即那四个问答字数的桶,[(5,15),(10,20),(15,25),(20,30)]
        FLAGS.size,
        FLAGS.dropout,
        FLAGS.num_layers,  #纵向上2个lstm,横向上是不同时间状态的变化。
        FLAGS.max_gradient_norm,  #最大的梯度截断
        FLAGS.batch_size,  #64
        FLAGS.learning_rate,  #0.01
        FLAGS.num_samples,  #512
        forward_only,  #True则为训练,False则为测试
        dtype)
    return model
Ejemplo n.º 4
0
def create_model(session, forward_only):
    """建立模型"""
    # 是否使用16位浮点数(默认32位)
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = s2s_model.S2SModel(
        data_utils.dim, # dictionary.json 字典总长度6865
        data_utils.dim,
        buckets, # [(5, 15), (10, 20), (15, 25), (20, 30)]
        FLAGS.size, # 512 LSTM每层神经元数量
        FLAGS.dropout, # 1.0 每层输出DROPOUT的大小
        FLAGS.num_layers, # 2 LSTM的层数
        FLAGS.max_gradient_norm, # 5.0 梯度最大阈值
        FLAGS.batch_size, # 64 批量梯度下降的批量大小
        FLAGS.learning_rate, # 0.0003 学习率
        FLAGS.num_samples, # 512 分批softmax的样本量
        forward_only, # 是否仅仅前向传播 test时 true, train时false
        dtype
    )
    return model