Exemplo n.º 1
0
def load_env(data_dir, model_dir):
  """Loads environment for inference mode, used in jupyter notebook."""
  model_params = sketch_rnn_model.get_default_hparams()
  with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
    model_config = json.load(f)
    model_params.update(model_config)
  return load_dataset(data_dir, model_params, inference_mode=True)
Exemplo n.º 2
0
def load_env(data_dir, model_dir):
  """Loads environment for inference mode, used in jupyter notebook."""
  model_params = sketch_rnn_model.get_default_hparams()
  with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
    model_config = json.load(f)
    model_params.update(model_config)
  return load_dataset(data_dir, model_params, inference_mode=True)
Exemplo n.º 3
0
def load_model(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_params.parse_json(f.read())

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = sketch_rnn_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    return [model_params, eval_model_params, sample_model_params]
Exemplo n.º 4
0
def load_model(model_dir):
  """Loads model for inference mode, used in jupyter notebook."""
  model_params = sketch_rnn_model.get_default_hparams()
  with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
    model_params.parse_json(f.read())

  model_params.batch_size = 1  # only sample one at a time
  eval_model_params = sketch_rnn_model.copy_hparams(model_params)
  eval_model_params.use_input_dropout = 0
  eval_model_params.use_recurrent_dropout = 0
  eval_model_params.use_output_dropout = 0
  eval_model_params.is_training = 0
  sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
  sample_model_params.max_seq_len = 1  # sample one point at a time
  return [model_params, eval_model_params, sample_model_params]
Exemplo n.º 5
0
def main(unused_argv):
    """Load model params, save config file and start trainer."""
    model_params = sketch_rnn_model.get_default_hparams()
    if FLAGS.hparams:
        model_params.parse(FLAGS.hparams)
    trainer(model_params)
Exemplo n.º 6
0
def main(unused_argv):
    model_params = sketch_rnn_model.get_default_hparams()

    if FLAGS.hparams:
        model_params.parse(FLAGS.hparams)

    np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)

    tf.logging.info('sketch-rnn')
    #tf.logging.info('Hyperparams:')
    #print(model_params.values())
    #for key, val in six.iteritems(model_params.values()):
    # tf.logging.info('%s = %s', key, str(val))
    #tf.logging.info('Loading data files.')
    datasets = load_dataset(FLAGS.data_dir, model_params)

    #parse train, valid and
    train = datasets[0].strokes
    valid = datasets[1].strokes
    test = datasets[2].strokes
    print("\n\ntrain length = %d, valid_length = %d, test length = %d\n\n" %
          (len(train), len(valid), len(test)))
    total_data_size = len(train) + len(valid) + len(test)

    #train length = 164888, valid_length = 2500, test length = 2500
    arr = np.arange(total_data_size)
    np.random.shuffle(arr)

    #replace data
    #for i in range(a):
    #datasets[0].strokes[i] = -100;
    #for j in range(4):
    #print(datasets[0].strokes[j])

    retrain_times = 2
    for i in range(retrain_times):
        result = []
        result = trainer(model_params, datasets)
        #num_of_result = len(result)

        #  hostname = "54.82.94.146"
        #  port = 80
        #  check = 0
        # for i in result:
        #     x_array,y_array = get_sketch(i)
        #     #print(x_array)
        #     #rint(y_array)
        #     r = requests.post("http://{}:{}/data".format(hostname,port),
        #                     data = json.dumps({"data":{"x_data":x_array,"y_data":y_array,"id":i,"check":check}}))

        print("#########final_result###########")
        print(result)
        print("result number [0]")
        print(result[0])
        print("#########final_result###########")

        IP = ""

        datasets[0].strokes[i] = result[0]
        for j in range(4):
            print(datasets[0].strokes[j])
Exemplo n.º 7
0
def main(unused_argv):
  """Load model params, save config file and start trainer."""
  model_params = sketch_rnn_model.get_default_hparams()
  if FLAGS.hparams:
    model_params.parse(FLAGS.hparams)
  trainer(model_params)