def trainer(model_params): """Train a sketch-rnn model.""" np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) tf.logging.info('sketch-rnn') tf.logging.info('Hyperparams:') tf.logging.info('Loading data files.') datasets = load_dataset(FLAGS.data_dir, model_params) train_set = datasets[0] valid_set = datasets[1] test_set = datasets[2] model_params = datasets[3] eval_model_params = datasets[4] reset_graph() model = sketch_rnn_model.Model(model_params) eval_model = sketch_rnn_model.Model(eval_model_params, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) if FLAGS.resume_training: load_checkpoint(sess, FLAGS.log_root) # Write config file to json file. tf.gfile.MakeDirs(FLAGS.log_root) with tf.gfile.Open(os.path.join(FLAGS.log_root, 'model_config.json'), 'w') as f: json.dump(list(model_params.values()), f, indent=True) train(sess, model, eval_model, train_set, valid_set, test_set)
def trainer(model_params, datasets): """Train a sketch-rnn model.""" train_set = datasets[0] valid_set = datasets[1] test_set = datasets[2] model_params = datasets[3] eval_model_params = datasets[4] reset_graph() model = sketch_rnn_model.Model(model_params) eval_model = sketch_rnn_model.Model(eval_model_params, reuse=True) sample_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_params.max_seq_len = 1 sample_model = sketch_rnn_model.Model(sample_params, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) train(sess, model, eval_model, train_set, valid_set, test_set) output_w_ = [ v for v in tf.trainable_variables() if v.name == "vector_rnn/RNN/output_w:0" ][0].eval() output_b_ = [ v for v in tf.trainable_variables() if v.name == "vector_rnn/RNN/output_b:0" ][0].eval() lstm_W_xh_ = [ v for v in tf.trainable_variables() if v.name == "vector_rnn/RNN/LSTMCell/W_xh:0" ][0].eval() lstm_W_hh_ = [ v for v in tf.trainable_variables() if v.name == "vector_rnn/RNN/LSTMCell/W_hh:0" ][0].eval() lstm_bias_ = [ v for v in tf.trainable_variables() if v.name == "vector_rnn/RNN/LSTMCell/bias:0" ][0].eval() dec_output_w = output_w_ dec_output_b = output_b_ dec_lstm_W_xh = lstm_W_xh_ dec_lstm_W_hh = lstm_W_hh_ dec_lstm_bias = lstm_bias_ dec_num_units = dec_lstm_W_hh.shape[0] dec_input_size = dec_lstm_W_xh.shape[0] dec_lstm = SketchLSTMCell(dec_num_units, dec_input_size, dec_lstm_W_xh, dec_lstm_W_hh, dec_lstm_bias) num_of_boats = 10000 count = 0 train_result = [] for j in range(num_of_boats): output = [] result = generate(dec_lstm, dec_output_w, dec_output_b) total = len(result) for i in result: entry = [] if i[2] != 0: pen = 0 else: if i[3] == 1: pen = i[3] else: pen = i[4] if count == total - 2: pen = 0 entry.extend(i[:2]) entry.append(pen) output.append(entry) count += 1 print("++++++++++++print_sketch+++++++++++") print(j) print(output) print("+++++++++++++++++++++++++++++++++++") train_result.append(output) print("==========print_train_result===========") print(train_result) print("===================================") return train_result