Example #1
0
def load_model(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_params.parse_json(f.read())

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = sketch_rnn_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    return [model_params, eval_model_params, sample_model_params]
Example #2
0
def load_model(model_dir):
  """Loads model for inference mode, used in jupyter notebook."""
  model_params = sketch_rnn_model.get_default_hparams()
  with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
    model_params.parse_json(f.read())

  model_params.batch_size = 1  # only sample one at a time
  eval_model_params = sketch_rnn_model.copy_hparams(model_params)
  eval_model_params.use_input_dropout = 0
  eval_model_params.use_recurrent_dropout = 0
  eval_model_params.use_output_dropout = 0
  eval_model_params.is_training = 0
  sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
  sample_model_params.max_seq_len = 1  # sample one point at a time
  return [model_params, eval_model_params, sample_model_params]
Example #3
0
def load_dataset(data_dir, model_params, inference_mode=False):
    """Loads the .npz file, and splits the set into train/valid/test."""

    # normalizes the x and y columns using the training set.
    # applies same scaling factor to valid and test set.

    if isinstance(model_params.data_set, list):
        datasets = model_params.data_set
    else:
        datasets = [model_params.data_set]

    train_strokes = None
    valid_strokes = None
    test_strokes = None

    for dataset in datasets:
        if data_dir.startswith('http://') or data_dir.startswith('https://'):
            data_filepath = '/'.join([data_dir, dataset])
            tf.logging.info('Downloading %s', data_filepath)
            response = requests.get(data_filepath)
            data = np.load(six.BytesIO(response.content), encoding='latin1')
        else:
            data_filepath = os.path.join(data_dir, dataset)
            data = np.load(data_filepath, encoding='latin1', allow_pickle=True)
        tf.logging.info('Loaded {}/{}/{} from {}'.format(
            len(data['train']), len(data['valid']), len(data['test']),
            dataset))
        if train_strokes is None:
            train_strokes = data['train']
            valid_strokes = data['valid']
            test_strokes = data['test']
        else:
            train_strokes = np.concatenate((train_strokes, data['train']))
            valid_strokes = np.concatenate((valid_strokes, data['valid']))
            test_strokes = np.concatenate((test_strokes, data['test']))

    all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))
    num_points = 0
    for stroke in all_strokes:
        num_points += len(stroke)
    avg_len = num_points / len(all_strokes)
    tf.logging.info('Dataset combined: {} ({}/{}/{}), avg len {}'.format(
        len(all_strokes), len(train_strokes), len(valid_strokes),
        len(test_strokes), int(avg_len)))

    # calculate the max strokes we need.
    max_seq_len = utils.get_max_len(all_strokes)
    # overwrite the hps with this calculation.
    model_params.max_seq_len = max_seq_len

    tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len)

    eval_model_params = sketch_rnn_model.copy_hparams(model_params)

    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 1

    if inference_mode:
        eval_model_params.batch_size = 1
        eval_model_params.is_training = 0

    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    train_set = utils.DataLoader(
        train_strokes,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)

    normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
    train_set.normalize(normalizing_scale_factor)

    valid_set = utils.DataLoader(valid_strokes,
                                 eval_model_params.batch_size,
                                 max_seq_length=eval_model_params.max_seq_len,
                                 random_scale_factor=0.0,
                                 augment_stroke_prob=0.0)
    valid_set.normalize(normalizing_scale_factor)

    test_set = utils.DataLoader(test_strokes,
                                eval_model_params.batch_size,
                                max_seq_length=eval_model_params.max_seq_len,
                                random_scale_factor=0.0,
                                augment_stroke_prob=0.0)
    test_set.normalize(normalizing_scale_factor)

    tf.logging.info('normalizing_scale_factor %4.4f.',
                    normalizing_scale_factor)

    result = [
        train_set, valid_set, test_set, model_params, eval_model_params,
        sample_model_params
    ]
    return result
Example #4
0
def trainer(model_params, datasets):
    """Train a sketch-rnn model."""

    train_set = datasets[0]
    valid_set = datasets[1]
    test_set = datasets[2]
    model_params = datasets[3]
    eval_model_params = datasets[4]

    reset_graph()
    model = sketch_rnn_model.Model(model_params)
    eval_model = sketch_rnn_model.Model(eval_model_params, reuse=True)
    sample_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_params.max_seq_len = 1
    sample_model = sketch_rnn_model.Model(sample_params, reuse=True)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    train(sess, model, eval_model, train_set, valid_set, test_set)

    output_w_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/output_w:0"
    ][0].eval()
    output_b_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/output_b:0"
    ][0].eval()
    lstm_W_xh_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/LSTMCell/W_xh:0"
    ][0].eval()
    lstm_W_hh_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/LSTMCell/W_hh:0"
    ][0].eval()
    lstm_bias_ = [
        v for v in tf.trainable_variables()
        if v.name == "vector_rnn/RNN/LSTMCell/bias:0"
    ][0].eval()

    dec_output_w = output_w_
    dec_output_b = output_b_
    dec_lstm_W_xh = lstm_W_xh_
    dec_lstm_W_hh = lstm_W_hh_
    dec_lstm_bias = lstm_bias_
    dec_num_units = dec_lstm_W_hh.shape[0]
    dec_input_size = dec_lstm_W_xh.shape[0]
    dec_lstm = SketchLSTMCell(dec_num_units, dec_input_size, dec_lstm_W_xh,
                              dec_lstm_W_hh, dec_lstm_bias)

    num_of_boats = 10000
    count = 0
    train_result = []
    for j in range(num_of_boats):
        output = []
        result = generate(dec_lstm, dec_output_w, dec_output_b)
        total = len(result)
        for i in result:
            entry = []
            if i[2] != 0:
                pen = 0
            else:
                if i[3] == 1:
                    pen = i[3]
                else:
                    pen = i[4]
            if count == total - 2:
                pen = 0
            entry.extend(i[:2])
            entry.append(pen)
            output.append(entry)
            count += 1
        print("++++++++++++print_sketch+++++++++++")
        print(j)
        print(output)
        print("+++++++++++++++++++++++++++++++++++")
        train_result.append(output)
    print("==========print_train_result===========")
    print(train_result)
    print("===================================")
    return train_result
  num_points = 0
  for stroke in all_strokes:
    num_points += len(stroke)
  avg_len = num_points / len(all_strokes)
  tf.logging.info('Dataset combined: {} ({}/{}/{}), avg len {}'.format(
      len(all_strokes), len(train_strokes), len(valid_strokes),
      len(test_strokes), int(avg_len)))

  # calculate the max strokes we need.
  max_seq_len = utils.get_max_len(all_strokes)
  # overwrite the hps with this calculation.
  model_params.max_seq_len = max_seq_len

  tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len)

  eval_model_params = sketch_rnn_model.copy_hparams(model_params)

  eval_model_params.use_input_dropout = 0
  eval_model_params.use_recurrent_dropout = 0
  eval_model_params.use_output_dropout = 0
  eval_model_params.is_training = 1

  if inference_mode:
    eval_model_params.batch_size = 1
    eval_model_params.is_training = 0

  sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
  sample_model_params.batch_size = 1  # only sample one at a time
  sample_model_params.max_seq_len = 1  # sample one point at a time

  train_set = utils.DataLoader(
Example #6
0
def load_dataset(data_dir, model_params, inference_mode=False):
  """Loads the .npz file, and splits the set into train/valid/test."""

  # normalizes the x and y columns usint the training set.
  # applies same scaling factor to valid and test set.

  data_filepath = os.path.join(data_dir, model_params.data_set)
  if data_dir.startswith('http://') or data_dir.startswith('https://'):
    tf.logging.info('Downloading %s', data_filepath)
    response = requests.get(data_filepath)
    data = np.load(StringIO(response.content))
  else:
    data = np.load(data_filepath)  # load this into dictionary

  train_strokes = data['train']
  valid_strokes = data['valid']
  test_strokes = data['test']

  all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))

  # calculate the max strokes we need.
  max_seq_len = utils.get_max_len(all_strokes)
  # overwrite the hps with this calculation.
  model_params.max_seq_len = max_seq_len

  tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len)

  eval_model_params = sketch_rnn_model.copy_hparams(model_params)

  if inference_mode:
    eval_model_params.batch_size = 1

  eval_model_params.use_input_dropout = 0
  eval_model_params.use_recurrent_dropout = 0
  eval_model_params.use_output_dropout = 0
  eval_model_params.is_training = 0

  sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
  sample_model_params.batch_size = 1  # only sample one at a time
  sample_model_params.max_seq_len = 1  # sample one point at a time

  train_set = utils.DataLoader(
      train_strokes,
      model_params.batch_size,
      max_seq_length=model_params.max_seq_len,
      random_scale_factor=model_params.random_scale_factor,
      augment_stroke_prob=model_params.augment_stroke_prob)

  normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
  train_set.normalize(normalizing_scale_factor)

  valid_set = utils.DataLoader(
      valid_strokes,
      eval_model_params.batch_size,
      max_seq_length=eval_model_params.max_seq_len,
      random_scale_factor=0.0,
      augment_stroke_prob=0.0)
  valid_set.normalize(normalizing_scale_factor)

  test_set = utils.DataLoader(
      test_strokes,
      eval_model_params.batch_size,
      max_seq_length=eval_model_params.max_seq_len,
      random_scale_factor=0.0,
      augment_stroke_prob=0.0)
  test_set.normalize(normalizing_scale_factor)

  tf.logging.info('normalizing_scale_factor %4.4f.', normalizing_scale_factor)

  result = [
      train_set, valid_set, test_set, model_params, eval_model_params,
      sample_model_params
  ]
  return result
Example #7
0
def load_dataset(data_dir, model_params, inference_mode=False):
  """Loads the .npz file, and splits the set into train/valid/test."""

  # normalizes the x and y columns usint the training set.
  # applies same scaling factor to valid and test set.

  datasets = []
  if isinstance(model_params.data_set, list):
    datasets = model_params.data_set
  else:
    datasets = [model_params.data_set]

  train_strokes = None
  valid_strokes = None
  test_strokes = None

  for dataset in datasets:
    data_filepath = os.path.join(data_dir, dataset)
    if data_dir.startswith('http://') or data_dir.startswith('https://'):
      tf.logging.info('Downloading %s', data_filepath)
      response = requests.get(data_filepath)
      data = np.load(StringIO(response.content))
    else:
      if six.PY3:
        data = np.load(data_filepath, encoding='latin1')
      else:
        data = np.load(data_filepath)
    tf.logging.info('Loaded {}/{}/{} from {}'.format(
        len(data['train']), len(data['valid']), len(data['test']),
        dataset))
    if train_strokes is None:
      train_strokes = data['train']
      valid_strokes = data['valid']
      test_strokes = data['test']
    else:
      train_strokes = np.concatenate((train_strokes, data['train']))
      valid_strokes = np.concatenate((valid_strokes, data['valid']))
      test_strokes = np.concatenate((test_strokes, data['test']))

  all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))
  num_points = 0
  for stroke in all_strokes:
    num_points += len(stroke)
  avg_len = num_points / len(all_strokes)
  tf.logging.info('Dataset combined: {} ({}/{}/{}), avg len {}'.format(
      len(all_strokes), len(train_strokes), len(valid_strokes),
      len(test_strokes), int(avg_len)))

  # calculate the max strokes we need.
  max_seq_len = utils.get_max_len(all_strokes)
  # overwrite the hps with this calculation.
  model_params.max_seq_len = max_seq_len

  tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len)

  eval_model_params = sketch_rnn_model.copy_hparams(model_params)

  eval_model_params.use_input_dropout = 0
  eval_model_params.use_recurrent_dropout = 0
  eval_model_params.use_output_dropout = 0
  eval_model_params.is_training = 1

  if inference_mode:
    eval_model_params.batch_size = 1
    eval_model_params.is_training = 0

  sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
  sample_model_params.batch_size = 1  # only sample one at a time
  sample_model_params.max_seq_len = 1  # sample one point at a time

  train_set = utils.DataLoader(
      train_strokes,
      model_params.batch_size,
      max_seq_length=model_params.max_seq_len,
      random_scale_factor=model_params.random_scale_factor,
      augment_stroke_prob=model_params.augment_stroke_prob)

  normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
  train_set.normalize(normalizing_scale_factor)

  valid_set = utils.DataLoader(
      valid_strokes,
      eval_model_params.batch_size,
      max_seq_length=eval_model_params.max_seq_len,
      random_scale_factor=0.0,
      augment_stroke_prob=0.0)
  valid_set.normalize(normalizing_scale_factor)

  test_set = utils.DataLoader(
      test_strokes,
      eval_model_params.batch_size,
      max_seq_length=eval_model_params.max_seq_len,
      random_scale_factor=0.0,
      augment_stroke_prob=0.0)
  test_set.normalize(normalizing_scale_factor)

  tf.logging.info('normalizing_scale_factor %4.4f.', normalizing_scale_factor)

  result = [
      train_set, valid_set, test_set, model_params, eval_model_params,
      sample_model_params
  ]
  return result
Example #8
0
def load_dataset(data_dir, model_params, inference_mode=False):
    """Loads the .npz file, and splits the set into train/valid/test."""

    # normalizes the x and y columns usint the training set.
    # applies same scaling factor to valid and test set.

    data_filepath = os.path.join(data_dir, model_params.data_set)
    if data_dir.startswith('http://') or data_dir.startswith('https://'):
        tf.logging.info('Downloading %s', data_filepath)
        response = requests.get(data_filepath)
        data = np.load(StringIO(response.content))
    else:
        data = np.load(data_filepath)  # load this into dictionary

    train_strokes = data['train']
    valid_strokes = data['valid']
    test_strokes = data['test']

    all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))

    # calculate the max strokes we need.
    max_seq_len = utils.get_max_len(all_strokes)
    # overwrite the hps with this calculation.
    model_params.max_seq_len = max_seq_len

    tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len)

    eval_model_params = sketch_rnn_model.copy_hparams(model_params)

    if inference_mode:
        eval_model_params.batch_size = 1

    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0

    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    train_set = utils.DataLoader(
        train_strokes,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)

    normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
    train_set.normalize(normalizing_scale_factor)

    valid_set = utils.DataLoader(valid_strokes,
                                 eval_model_params.batch_size,
                                 max_seq_length=eval_model_params.max_seq_len,
                                 random_scale_factor=0.0,
                                 augment_stroke_prob=0.0)
    valid_set.normalize(normalizing_scale_factor)

    test_set = utils.DataLoader(test_strokes,
                                eval_model_params.batch_size,
                                max_seq_length=eval_model_params.max_seq_len,
                                random_scale_factor=0.0,
                                augment_stroke_prob=0.0)
    test_set.normalize(normalizing_scale_factor)

    tf.logging.info('normalizing_scale_factor %4.4f.',
                    normalizing_scale_factor)

    result = [
        train_set, valid_set, test_set, model_params, eval_model_params,
        sample_model_params
    ]
    return result