Ejemplo n.º 1
0
  eval_model_params.use_input_dropout = 0
  eval_model_params.use_recurrent_dropout = 0
  eval_model_params.use_output_dropout = 0
  eval_model_params.is_training = 1

  if inference_mode:
    eval_model_params.batch_size = 1
    eval_model_params.is_training = 0

  sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
  sample_model_params.batch_size = 1  # only sample one at a time
  sample_model_params.max_seq_len = 1  # sample one point at a time

  train_set = utils.DataLoader(
      train_strokes,
      model_params.batch_size,
      max_seq_length=model_params.max_seq_len,
      random_scale_factor=model_params.random_scale_factor,
      augment_stroke_prob=model_params.augment_stroke_prob)

  normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
  train_set.normalize(normalizing_scale_factor)

  valid_set = utils.DataLoader(
      valid_strokes,
      eval_model_params.batch_size,
      max_seq_length=eval_model_params.max_seq_len,
      random_scale_factor=0.0,
      augment_stroke_prob=0.0)
  valid_set.normalize(normalizing_scale_factor)

  test_set = utils.DataLoader(
Ejemplo n.º 2
0
def load_dataset(data_dir, model_params, inference_mode=False):
    """Loads the .npz file, and splits the set into train/valid/test."""

    # normalizes the x and y columns using the training set.
    # applies same scaling factor to valid and test set.

    if isinstance(model_params.data_set, list):
        datasets = model_params.data_set
    else:
        datasets = [model_params.data_set]

    train_strokes = None
    valid_strokes = None
    test_strokes = None

    for dataset in datasets:
        if data_dir.startswith('http://') or data_dir.startswith('https://'):
            data_filepath = '/'.join([data_dir, dataset])
            tf.logging.info('Downloading %s', data_filepath)
            response = requests.get(data_filepath)
            data = np.load(six.BytesIO(response.content), encoding='latin1')
        else:
            data_filepath = os.path.join(data_dir, dataset)
            data = np.load(data_filepath, encoding='latin1', allow_pickle=True)
        tf.logging.info('Loaded {}/{}/{} from {}'.format(
            len(data['train']), len(data['valid']), len(data['test']),
            dataset))
        if train_strokes is None:
            train_strokes = data['train']
            valid_strokes = data['valid']
            test_strokes = data['test']
        else:
            train_strokes = np.concatenate((train_strokes, data['train']))
            valid_strokes = np.concatenate((valid_strokes, data['valid']))
            test_strokes = np.concatenate((test_strokes, data['test']))

    all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))
    num_points = 0
    for stroke in all_strokes:
        num_points += len(stroke)
    avg_len = num_points / len(all_strokes)
    tf.logging.info('Dataset combined: {} ({}/{}/{}), avg len {}'.format(
        len(all_strokes), len(train_strokes), len(valid_strokes),
        len(test_strokes), int(avg_len)))

    # calculate the max strokes we need.
    max_seq_len = utils.get_max_len(all_strokes)
    # overwrite the hps with this calculation.
    model_params.max_seq_len = max_seq_len

    tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len)

    eval_model_params = sketch_rnn_model.copy_hparams(model_params)

    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 1

    if inference_mode:
        eval_model_params.batch_size = 1
        eval_model_params.is_training = 0

    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    train_set = utils.DataLoader(
        train_strokes,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)

    normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
    train_set.normalize(normalizing_scale_factor)

    valid_set = utils.DataLoader(valid_strokes,
                                 eval_model_params.batch_size,
                                 max_seq_length=eval_model_params.max_seq_len,
                                 random_scale_factor=0.0,
                                 augment_stroke_prob=0.0)
    valid_set.normalize(normalizing_scale_factor)

    test_set = utils.DataLoader(test_strokes,
                                eval_model_params.batch_size,
                                max_seq_length=eval_model_params.max_seq_len,
                                random_scale_factor=0.0,
                                augment_stroke_prob=0.0)
    test_set.normalize(normalizing_scale_factor)

    tf.logging.info('normalizing_scale_factor %4.4f.',
                    normalizing_scale_factor)

    result = [
        train_set, valid_set, test_set, model_params, eval_model_params,
        sample_model_params
    ]
    return result
Ejemplo n.º 3
0
def load_dataset(data_dir, model_params, inference_mode=False):
    """Loads the .npz file, and splits the set into train/valid/test."""

    # normalizes the x and y columns usint the training set.
    # applies same scaling factor to valid and test set.

    data_filepath = os.path.join(data_dir, model_params.data_set)
    if data_dir.startswith('http://') or data_dir.startswith('https://'):
        tf.logging.info('Downloading %s', data_filepath)
        response = requests.get(data_filepath)
        data = np.load(StringIO(response.content))
    else:
        data = np.load(data_filepath)  # load this into dictionary

    train_strokes = data['train']
    valid_strokes = data['valid']
    test_strokes = data['test']

    all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))

    # calculate the max strokes we need.
    max_seq_len = utils.get_max_len(all_strokes)
    # overwrite the hps with this calculation.
    model_params.max_seq_len = max_seq_len

    tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len)

    eval_model_params = sketch_rnn_model.copy_hparams(model_params)

    if inference_mode:
        eval_model_params.batch_size = 1

    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0

    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    train_set = utils.DataLoader(
        train_strokes,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)

    normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
    train_set.normalize(normalizing_scale_factor)

    valid_set = utils.DataLoader(valid_strokes,
                                 eval_model_params.batch_size,
                                 max_seq_length=eval_model_params.max_seq_len,
                                 random_scale_factor=0.0,
                                 augment_stroke_prob=0.0)
    valid_set.normalize(normalizing_scale_factor)

    test_set = utils.DataLoader(test_strokes,
                                eval_model_params.batch_size,
                                max_seq_length=eval_model_params.max_seq_len,
                                random_scale_factor=0.0,
                                augment_stroke_prob=0.0)
    test_set.normalize(normalizing_scale_factor)

    tf.logging.info('normalizing_scale_factor %4.4f.',
                    normalizing_scale_factor)

    result = [
        train_set, valid_set, test_set, model_params, eval_model_params,
        sample_model_params
    ]
    return result