예제 #1
0
def infer_fn(params, features):
    params = copy.copy(params)
    params = util.closing_dropout(params)
    if params.enable_bert:
        util.closing_dropout(params.bert)

    with tf.variable_scope(params.model_name or "model", reuse=tf.AUTO_REUSE):
        outputs = graph(features, params)

        return outputs[1:]
예제 #2
0
def train_fn(features, params, initializer=None):
    params = copy.copy(params)
    if params.enable_bert:
        util.closing_dropout(params.bert)

    with tf.variable_scope(params.model_name or "model",
                           initializer=initializer,
                           reuse=tf.AUTO_REUSE):
        outputs, _, _ = graph(features, params)

        return {
            "loss": outputs['loss'],
        }
예제 #3
0
def infer_fn(params):
    params = copy.copy(params)
    params = util.closing_dropout(params)

    def encoding_fn(source):
        with tf.variable_scope(
                params.scope_name or "model",
                reuse=tf.AUTO_REUSE,
                dtype=tf.as_dtype(dtype.floatx()),
                custom_getter=dtype.float32_variable_storage_getter):
            state = encoder(source, params)
            state["decoder"] = {"state": state["decoder_initializer"]}
            return state

    def decoding_fn(target, state, time):
        with tf.variable_scope(
                params.scope_name or "model",
                reuse=tf.AUTO_REUSE,
                dtype=tf.as_dtype(dtype.floatx()),
                custom_getter=dtype.float32_variable_storage_getter):
            if params.search_mode == "cache":
                state['time'] = time
                step_loss, step_logits, step_state, _ = decoder(
                    target, state, params)
                del state['time']
            else:
                estate = encoder(state, params)
                estate['dev_decode'] = True
                _, step_logits, _, _ = decoder(target, estate, params)
                step_state = state

            return step_logits, step_state

    return encoding_fn, decoding_fn
예제 #4
0
def infer_fn(params):
    params = copy.copy(params)
    params = util.closing_dropout(params)

    def encoding_fn(source):
        with tf.variable_scope(params.model_name or "model",
                               reuse=tf.AUTO_REUSE):
            state = encoder(source, params)
            state["decoder"] = {"state": state["decoder_initializer"]}
            return state

    def decoding_fn(target, state, time):
        with tf.variable_scope(params.model_name or "model",
                               reuse=tf.AUTO_REUSE):
            if params.search_mode == "cache":
                step_loss, step_logits, step_state = decoder(
                    target, state, params)
            else:
                estate = encoder(state, params)
                estate['dev_decode'] = True
                _, step_logits, _ = decoder(target, estate, params)
                step_state = state

            return step_logits, step_state

    return encoding_fn, decoding_fn
예제 #5
0
def score_fn(features, params, initializer=None):
    params = copy.copy(params)
    params = util.closing_dropout(params)
    params.label_smooth = 0.0
    with tf.variable_scope(
            params.scope_name or "model",
            initializer=initializer,
            reuse=tf.AUTO_REUSE,
            dtype=tf.as_dtype(dtype.floatx()),
            custom_getter=dtype.float32_variable_storage_getter):
        state = encoder(features['source'], params)
        _, _, _, scores = decoder(features['target'], state, params)

        return {"score": scores}
예제 #6
0
def infer_fn(params):
    params = copy.copy(params)
    params = util.closing_dropout(params)

    def encoding_fn(source):
        with tf.variable_scope(params.model_name or "model",
                               reuse=tf.AUTO_REUSE):
            state = encoder(source, params)
            state["decoder"] = {
                "state": state["decoder_initializer"]
            }
            return state

    def decoding_fn(target, state, time):
        with tf.variable_scope(params.model_name or "model",
                               reuse=tf.AUTO_REUSE):
            step_loss, step_logits, step_state = decoder(
                target, state, params)
            step_state["decoder"]["state"] = util.merge_neighbor_dims(
                step_state["decoder"]["state"], axis=0
            )
            return step_logits, step_state

    return encoding_fn, decoding_fn