コード例 #1
0
def trainer(model_params):
    """Train a sketch-rnn model."""
    np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)

    tf.logging.info('sketch-rnn')
    tf.logging.info('Hyperparams:')
    tf.logging.info('Loading data files.')
    datasets = load_dataset(FLAGS.data_dir, model_params)

    train_set = datasets[0]
    valid_set = datasets[1]
    test_set = datasets[2]
    model_params = datasets[3]
    eval_model_params = datasets[4]

    reset_graph()
    model = sketch_rnn_model.Model(model_params)
    eval_model = sketch_rnn_model.Model(eval_model_params, reuse=True)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    if FLAGS.resume_training:
        load_checkpoint(sess, FLAGS.log_root)

    # Write config file to json file.
    tf.gfile.MakeDirs(FLAGS.log_root)
    with tf.gfile.Open(os.path.join(FLAGS.log_root, 'model_config.json'),
                       'w') as f:
        json.dump(list(model_params.values()), f, indent=True)

    train(sess, model, eval_model, train_set, valid_set, test_set)
コード例 #2
0
def trainer(model_params, datasets):
    """Train a sketch-rnn model."""

    train_set = datasets[0]
    valid_set = datasets[1]
    test_set = datasets[2]
    model_params = datasets[3]
    eval_model_params = datasets[4]

    reset_graph()
    model = sketch_rnn_model.Model(model_params)
    eval_model = sketch_rnn_model.Model(eval_model_params, reuse=True)
    sample_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_params.max_seq_len = 1
    sample_model = sketch_rnn_model.Model(sample_params, reuse=True)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    train(sess, model, eval_model, train_set, valid_set, test_set)

    output_w_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/output_w:0"
    ][0].eval()
    output_b_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/output_b:0"
    ][0].eval()
    lstm_W_xh_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/LSTMCell/W_xh:0"
    ][0].eval()
    lstm_W_hh_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/LSTMCell/W_hh:0"
    ][0].eval()
    lstm_bias_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/LSTMCell/bias:0"
    ][0].eval()

    dec_output_w = output_w_
    dec_output_b = output_b_
    dec_lstm_W_xh = lstm_W_xh_
    dec_lstm_W_hh = lstm_W_hh_
    dec_lstm_bias = lstm_bias_
    dec_num_units = dec_lstm_W_hh.shape[0]
    dec_input_size = dec_lstm_W_xh.shape[0]
    dec_lstm = SketchLSTMCell(dec_num_units, dec_input_size, dec_lstm_W_xh,
                              dec_lstm_W_hh, dec_lstm_bias)

    num_of_boats = 10000
    count = 0
    train_result = []
    for j in range(num_of_boats):
        output = []
        result = generate(dec_lstm, dec_output_w, dec_output_b)
        total = len(result)
        for i in result:
            entry = []
            if i[2] != 0:
                pen = 0
            else:
                if i[3] == 1:
                    pen = i[3]
                else:
                    pen = i[4]
            if count == total - 2:
                pen = 0
            entry.extend(i[:2])
            entry.append(pen)
            output.append(entry)
            count += 1
        print("++++++++++++print_sketch+++++++++++")
        print(j)
        print(output)
        print("+++++++++++++++++++++++++++++++++++")
        train_result.append(output)
    print("==========print_train_result===========")
    print(train_result)
    print("===================================")
    return train_result