コード例 #1
0
    def _create_params_savable(self, params, model):
        """Create a RNNParamsSaveable for the weight and bias parameters.

    Args:
      params: a Variable for weight and bias parameters.
      model: a CudnnRNN model.
    """
        params_saveable = cudnn_rnn_ops.RNNParamsSaveable(
            model.params_to_canonical, model.canonical_to_params, params)
        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
コード例 #2
0
def _CreateParamsSavable(params,
                         model,
                         base_variable_scope="rnn",
                         name="params_canonical"):
    """Create a RNNParamsSaveable for the weight and bias parameters.

  Args:
    params: a Variable for weight and bias parameters.
    model: a CudnnRNN model.
    base_variable_scope: a string, prefix of names of saved variables.
    name: a string, name of the RNNParamsSaveable object.
  """
    params_saveable = cudnn_rnn_ops.RNNParamsSaveable(
        model,
        model.params_to_canonical,
        model.canonical_to_params, [params],
        base_variable_scope=base_variable_scope,
        name=name)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
コード例 #3
0
def convert(model_dir, output_dir, best_weights=False):
    print("Load model")
    md = ModelDir(model_dir)
    model = md.get_model()
    dim = model.embed_mapper.layers[1].n_units
    global_step = tf.get_variable('global_step',
                                  shape=[],
                                  dtype='int32',
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)

    print("Setting up cudnn version")
    #global_step = tf.get_variable('global_step', shape=[], dtype='int32', trainable=False)
    sess = tf.Session()
    sess.run(global_step.assign(0))
    with sess.as_default():
        model.set_input_spec(
            ParagraphAndQuestionSpec(1, None, None, 14), {"the"},
            ResourceLoader(lambda a, b: {"the": np.zeros(300, np.float32)}))

        print("Buiding graph")
        pred = model.get_prediction()

    test_questions = ParagraphAndQuestion(
        ["Harry", "Potter", "was", "written", "by", "JK"],
        ["Who", "wrote", "Harry", "Potter", "?"], None, "test_questions")

    print("Load vars")
    md.restore_checkpoint(sess)
    print("Restore finished")

    feed = model.encode([test_questions], False)
    cuddn_out = sess.run([pred.start_logits, pred.end_logits], feed_dict=feed)

    print("Done, copying files...")
    if not exists(output_dir):
        mkdir(output_dir)
    for file in listdir(model_dir):
        if isfile(file) and file != "model.npy":
            copyfile(join(model_dir, file), join(output_dir, file))

    print("Done, mapping tensors...")
    to_save = []
    to_init = []
    for x in tf.trainable_variables():
        if x.name.endswith("/gru_parameters:0"):
            key = x.name[:-len("/gru_parameters:0")]
            fw_params = x
            if "map_embed" in x.name:
                c = cudnn_rnn_ops.CudnnGRU(1, dim, 400)
            elif "chained-out" in x.name:
                c = cudnn_rnn_ops.CudnnGRU(1, dim, dim * 4)
            else:
                c = cudnn_rnn_ops.CudnnGRU(1, dim, dim * 2)
            params_saveable = cudnn_rnn_ops.RNNParamsSaveable(
                c, c.params_to_canonical, c.canonical_to_params, [fw_params],
                key)

            for spec in params_saveable.specs:
                if spec.name.endswith("bias_cudnn 0") or \
                        spec.name.endswith("bias_cudnn 1"):
                    # ??? What do these even do?
                    continue
                name = spec.name.split("/")
                name.remove("cell_0")
                if "forward" in name:
                    ix = name.index("forward")
                    name.insert(ix + 2, "fw")
                else:
                    ix = name.index("backward")
                    name.insert(ix + 2, "bw")
                del name[ix]

                ix = name.index("multi_rnn_cell")
                name[ix] = "bidirectional_rnn"
                name = "/".join(name)
                v = tf.Variable(sess.run(spec.tensor), name=name)
                to_init.append(v)
                to_save.append(v)

        else:
            to_save.append(x)

    other = [
        x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        if x not in tf.trainable_variables()
    ]
    print(other)
    sess.run(tf.initialize_variables(to_init))
    saver = tf.train.Saver(to_save + other)
    save_dir = join(output_dir, "save")
    if not exists(save_dir):
        mkdir(save_dir)

    saver.save(sess, join(save_dir, "checkpoint"), sess.run(global_step))

    sess.close()
    tf.reset_default_graph()

    print("Updating model...")
    model.embed_mapper.layers = [
        model.embed_mapper.layers[0],
        BiRecurrentMapper(CompatGruCellSpec(dim))
    ]
    model.match_encoder.layers = list(model.match_encoder.layers)
    other = model.match_encoder.layers[1].other
    other.layers = list(other.layers)
    other.layers[1] = BiRecurrentMapper(CompatGruCellSpec(dim))

    pred = model.predictor.predictor
    pred.first_layer = BiRecurrentMapper(CompatGruCellSpec(dim))
    pred.second_layer = BiRecurrentMapper(CompatGruCellSpec(dim))

    with open(join(output_dir, "model.pkl"), "wb") as f:
        pickle.dump(model, f)

    print("Testing...")
    with open(join(output_dir, "model.pkl"), "rb") as f:
        model = pickle.load(f)

    sess = tf.Session()

    model.set_input_spec(
        ParagraphAndQuestionSpec(1, None, None, 14), {"the"},
        ResourceLoader(lambda a, b: {"the": np.zeros(300, np.float32)}))
    pred = model.get_prediction()

    print("Rebuilding")
    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(save_dir))

    feed = model.encode([test_questions], False)
    cpu_out = sess.run([pred.start_logits, pred.end_logits], feed_dict=feed)

    print("These should be close:")
    print([np.allclose(a, b) for a, b in zip(cpu_out, cuddn_out)])
    print(cpu_out)
    print(cuddn_out)