Beispiel #1
0
def load_dataset(data_dir, model_params, inference_mode=False): 
    tf.logging.info('loaddataset-开始数据处理=================================================')
    source = model_params.source #source.npy
    target = model_params.target #target.npy
    source_data = np.load(os.path.join(data_dir, source))
    target_data = np.load(os.path.join(data_dir, target)) 
    tf.logging.info('打印原始输入长度 %i.',len(source_data))
    tf.logging.info('打印目标输入长度 %i.',len(target_data))
    num_points = 68
    model_params.max_seq_len = num_points 
    tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len) #并打印出来
    
    eval_model_params = sketch_rnn_model.copy_hparams(model_params) #讲model的参数复制给评价模型
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0 #并修改一些参数
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 1

    if inference_mode: # = fales
        eval_model_params.batch_size = 1
        eval_model_params.is_training = 0

    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) #将参数复制给sample模型
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    #随机打乱后依次取出一个batch的训练集、测试集、验证集数据
    #数据的x,y做过normalize处理,且不足Nmax的补充为(0,0,0,0,1)
    tf.logging.info('正式处理数据,输入网络')
    indices = np.random.permutation(range(0, len(source_data)))[0:model_params.batch_size] 
    #为保证source和target同顺序
    source_set = utils.DataLoader( source_data, indices,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor)
    target_set = utils.DataLoader( target_data, indices,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor)
    
    factor_source = source_set.calculate_normalizing_scale_factor() 
    source_set.normalize(factor_source)#再对数据做normalize
    factor_target = target_set.calculate_normalizing_scale_factor() 
    target_set.normalize(factor_target)#再对数据做normalize

    tf.logging.info('source normalizing_scale_factor is %4.4f.',factor_source) 
    tf.logging.info('target normalizing_scale_factor is %4.4f.',factor_target)

    result = [source_set, target_set, model_params, eval_model_params, sample_model_params]
    return result
Beispiel #2
0
def load_model(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    model_params = get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_params.parse_json(f.read())

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    return [model_params, eval_model_params, sample_model_params]
def load_model(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_params.parse_json(f.read())

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = sketch_rnn_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    return [model_params, eval_model_params, sample_model_params]
Beispiel #4
0
def load_model(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_params.parse_json(f.read())

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = sketch_rnn_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    if six.PY3:
        pretrained_model_params = np.load(model_dir+'/model', encoding='latin1')
    else:
        pretrained_model_params = np.load(model_dir+'/model')
    return [model_params, eval_model_params, sample_model_params, pretrained_model_params]
Beispiel #5
0
def load_model_compatible(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
    # to work with depreciated tf.HParams functionality
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        data = json.load(f)
    fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
    for fix in fix_list:
        data[fix] = (data[fix] == 1)
    model_params.parse_json(json.dumps(data))

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = sketch_rnn_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    return [model_params, eval_model_params, sample_model_params]
def load_dataset(data_dir, model_params, inference_mode=False):
    """Loads the .npz file, and splits the set into train/valid/test."""

    # normalizes the x and y columns usint the training set.
    # applies same scaling factor to valid and test set.

    datasets = []
    if isinstance(model_params.data_set, list):
        datasets = model_params.data_set
    else:
        datasets = [model_params.data_set]

    train_strokes = None
    valid_strokes = None
    test_strokes = None

    for dataset in datasets:
        data_filepath = os.path.join(data_dir, dataset)
        if data_dir.startswith('http://') or data_dir.startswith('https://'):
            tf.logging.info('Downloading %s', data_filepath)
            response = requests.get(data_filepath)
            data = np.load(StringIO(response.content))
        else:
            data = np.load(data_filepath)  # load this into dictionary
        tf.logging.info('Loaded {}/{}/{} from {}'.format(
            len(data['train']), len(data['valid']), len(data['test']),
            dataset))
        if train_strokes is None:
            train_strokes = data['train']
            valid_strokes = data['valid']
            test_strokes = data['test']
        else:
            train_strokes = np.concatenate((train_strokes, data['train']))
            valid_strokes = np.concatenate((valid_strokes, data['valid']))
            test_strokes = np.concatenate((test_strokes, data['test']))

    all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))
    num_points = 0
    for stroke in all_strokes:
        num_points += len(stroke)
    avg_len = num_points / len(all_strokes)
    tf.logging.info('Dataset combined: {} ({}/{}/{}), avg len {}'.format(
        len(all_strokes), len(train_strokes), len(valid_strokes),
        len(test_strokes), int(avg_len)))

    # calculate the max strokes we need.
    max_seq_len = utils.get_max_len(all_strokes)
    # overwrite the hps with this calculation.
    model_params.max_seq_len = max_seq_len

    tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len)

    eval_model_params = sketch_rnn_model.copy_hparams(model_params)

    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 1

    if inference_mode:
        eval_model_params.batch_size = 1
        eval_model_params.is_training = 0

    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    train_set = utils.DataLoader(
        train_strokes,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)

    normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
    train_set.normalize(normalizing_scale_factor)
    print('Length original', len(train_strokes), len(valid_strokes),
          len(test_strokes))
    valid_set = utils.DataLoader(valid_strokes,
                                 eval_model_params.batch_size,
                                 max_seq_length=eval_model_params.max_seq_len,
                                 random_scale_factor=0.0,
                                 augment_stroke_prob=0.0)
    valid_set.normalize(normalizing_scale_factor)

    test_set = utils.DataLoader(test_strokes,
                                eval_model_params.batch_size,
                                max_seq_length=eval_model_params.max_seq_len,
                                random_scale_factor=0.0,
                                augment_stroke_prob=0.0)
    test_set.normalize(normalizing_scale_factor)

    tf.logging.info('normalizing_scale_factor %4.4f.',
                    normalizing_scale_factor)

    result = [
        train_set, valid_set, test_set, model_params, eval_model_params,
        sample_model_params
    ]
    return result
Beispiel #7
0
def load_dataset(data_dir, model_params, testing_mode=False):
    """Loads the .npz file, and splits the set into train/valid/test."""
    # normalizes the x and y columns using scale_factor.

    dataset = model_params.data_set
    data_filepath = os.path.join(data_dir, dataset)
    data = np.load(data_filepath, allow_pickle=True, encoding='latin1')

    # target data
    train_strokes = data['train']
    valid_strokes = data['valid']
    test_strokes = data['test']
    all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))

    # standard data (reference data in paper)
    std_train_strokes = data['std_train']
    std_valid_strokes = data['std_valid']
    std_test_strokes = data['std_test']
    all_std_trokes = np.concatenate(
        (std_train_strokes, std_valid_strokes, std_test_strokes))

    print('Dataset combined: %d (train=%d/validate=%d/test=%d)' %
          (len(all_strokes), len(train_strokes), len(valid_strokes),
           len(test_strokes)))

    # calculate the max strokes we need.
    max_seq_len = utils.get_max_len(all_strokes)
    max_std_seq_len = utils.get_max_len(all_std_trokes)
    # overwrite the hps with this calculation.
    model_params.max_seq_len = max(max_seq_len, max_std_seq_len)
    print('model_params.max_seq_len set to %d.' % model_params.max_seq_len)

    eval_model_params = copy_hparams(model_params)
    eval_model_params.rnn_dropout_keep_prob = 1.0
    eval_model_params.is_training = True

    if testing_mode:  # for testing
        eval_model_params.batch_size = 1
        eval_model_params.is_training = False  # sample mode

    train_set = utils.DataLoader(
        train_strokes,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)
    normalizing_scale_factor = model_params.scale_factor
    train_set.normalize(normalizing_scale_factor)

    valid_set = utils.DataLoader(valid_strokes,
                                 eval_model_params.batch_size,
                                 max_seq_length=eval_model_params.max_seq_len,
                                 random_scale_factor=0.0,
                                 augment_stroke_prob=0.0)
    valid_set.normalize(normalizing_scale_factor)

    test_set = utils.DataLoader(test_strokes,
                                eval_model_params.batch_size,
                                max_seq_length=eval_model_params.max_seq_len,
                                random_scale_factor=0.0,
                                augment_stroke_prob=0.0)
    test_set.normalize(normalizing_scale_factor)

    # process the reference dataset
    std_train_set = utils.DataLoader(
        std_train_strokes,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)
    std_train_set.normalize(normalizing_scale_factor)

    std_valid_set = utils.DataLoader(
        std_valid_strokes,
        eval_model_params.batch_size,
        max_seq_length=eval_model_params.max_seq_len,
        random_scale_factor=0.0,
        augment_stroke_prob=0.0)
    std_valid_set.normalize(normalizing_scale_factor)

    std_test_set = utils.DataLoader(
        std_test_strokes,
        eval_model_params.batch_size,
        max_seq_length=eval_model_params.max_seq_len,
        random_scale_factor=0.0,
        augment_stroke_prob=0.0)
    std_test_set.normalize(normalizing_scale_factor)

    result = [
        train_set, valid_set, test_set, std_train_set, std_valid_set,
        std_test_set, model_params, eval_model_params
    ]
    return result
Beispiel #8
0
def load_datasets(data_dir, model_params, inference_mode=False):
    """Load and preprocess data"""
    data = utils.load_dataset(data_dir)
    train_strokes = data['train']
    valid_strokes = data['valid']
    test_strokes = data['test']

    all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))
    num_points = 0
    for stroke in all_strokes:
        num_points += len(stroke)
    avg_len = num_points / len(all_strokes)
    tf.logging.info('{} Shapes / {} Total points'.format(len(all_strokes), num_points))
    tf.logging.info('Dataset combined: {} ({}/{}/{}), avg len {}'.format(
        len(all_strokes), len(train_strokes), len(valid_strokes),
        len(test_strokes), int(avg_len)))

    # calculate the max strokes we need.
    max_seq_len = utils.get_max_len(all_strokes)
    # overwrite the hps with this calculation.
    model_params.max_seq_len = max_seq_len

    tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len)

    eval_model_params = derender_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 1

    if inference_mode:
        eval_model_params.batch_size = 1
        eval_model_params.is_training = 0

    sample_model_params = derender_model.copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    train_set = utils.DataLoader(
        train_strokes,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)

    normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
    train_set.normalize(normalizing_scale_factor)

    valid_set = utils.DataLoader(
        valid_strokes,
        eval_model_params.batch_size,
        max_seq_length=eval_model_params.max_seq_len,
        random_scale_factor=0.0,
        augment_stroke_prob=0.0)
    valid_set.normalize(normalizing_scale_factor)

    test_set = utils.DataLoader(
        test_strokes,
        eval_model_params.batch_size,
        max_seq_length=eval_model_params.max_seq_len,
        random_scale_factor=0.0,
        augment_stroke_prob=0.0)
    test_set.normalize(normalizing_scale_factor)

    tf.logging.info('normalizing_scale_factor %4.4f.', normalizing_scale_factor)

    result = [
        train_set, valid_set, test_set, model_params, eval_model_params,
        sample_model_params
    ]

    return result
def trainer(model_params):
    # model_params is the default parameters from model.py
    if FLAGS.test != False:
        model_params.is_training = False
        model_params.batch_size = 1
    np.set_printoptions(precision=8, edgeitems=6, linewidth=200,
                        suppress=True)  # set the output format
    ''' print the parameters '''
    tf.logging.info('Hyperparams:')
    for key, val in six.iteritems(model_params.values()):
        tf.logging.info('%s = %s', key, str(val))
    tf.logging.info('Loading data files.')
    ''' get train dataset, valid dataset, test dataset ready
      data format:
        input: [ data_num * game_round(should be 20) * dimension(should be 7)]
        output:[ data_num * game_round(should be 20) * dimension(should be 1)]
      
      data usage:
        train_set: used to train the network
        valid_set: used to check the network, if network of current parameters performs better than all previous epochs, we consider the current parameters to be valid ones and save them.
        test_set:  used to evaluate the final performance of our trained network
  
  
  
  '''
    _ = np.load(model_params.train_data_set, encoding="latin1")
    train_set = {}
    train_set[0] = _["input"].copy()
    train_set[1] = _["output"].copy()

    _ = np.load(model_params.valid_data_set, encoding="latin1")
    valid_set = {}
    valid_set[0] = _["input"].copy()
    valid_set[1] = _["output"].copy()

    _ = np.load(model_params.test_data_set, encoding="latin1")
    test_set = {}
    test_set[0] = _["input"].copy()
    test_set[1] = _["output"].copy()
    '''eval_model_params is for evaluation, and it will have the same parameters as the training set(although some of its parameters will be changed very soon'''
    eval_model_params = Model.copy_hparams(model_params)

    # reset_graph()
    model = Model.Model(model_params)
    eval_model = Model.Model(eval_model_params, reuse=True)
    '''set and initialize the graph, for the difference between tf.InteractiveSession() and tf.Session()
            refer to this
                https://stackoverflow.com/questions/47721792/tensorflow-tf-get-default-session-after-sess-tf-session-is-none
  '''

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    if FLAGS.resume_training or FLAGS.test:
        load_checkpoint(sess, FLAGS.log_root)

    # Write config file to json file.
    tf.gfile.MakeDirs(FLAGS.log_root)
    with tf.gfile.Open(os.path.join(FLAGS.log_root, 'model_config.json'),
                       'w') as f:
        json.dump(model_params.values(), f, indent=True)

    train(sess, model, eval_model, train_set, valid_set, test_set)
def load_dataset(data_dir, model_params, inference_mode=False):
    """Loads the .npz file, and splits the set into train/valid/test."""

    # normalizes the x and y columns using the training set.
    # applies same scaling factor to valid and test set.

    if isinstance(model_params.data_set, list):
        datasets = model_params.data_set
    else:
        datasets = [model_params.data_set]

    train_strokes = None
    valid_strokes = None
    test_strokes = None

    png_paths_map = {'train': [], 'valid': [], 'test': []}

    for dataset in datasets:
        if data_dir.startswith('http://') or data_dir.startswith('https://'):
            data_filepath = '/'.join([data_dir, dataset])
            print('Downloading %s' % data_filepath)
            response = requests.get(data_filepath)
            data = np.load(six.BytesIO(response.content), encoding='latin')
        else:
            data_filepath = os.path.join(data_dir, 'npz', dataset)
            if six.PY3:
                data = np.load(data_filepath, encoding='latin1')
            else:
                data = np.load(data_filepath)
        print('Loaded {}/{}/{} from {}'.format(len(data['train']),
                                               len(data['valid']),
                                               len(data['test']), dataset))
        if train_strokes is None:
            train_strokes = data[
                'train']  # [N (#sketches),], each with [S (#points), 3]
            valid_strokes = data['valid']
            test_strokes = data['test']
        else:
            train_strokes = np.concatenate((train_strokes, data['train']))
            valid_strokes = np.concatenate((valid_strokes, data['valid']))
            test_strokes = np.concatenate((test_strokes, data['test']))

        splits = ['train', 'valid', 'test']
        for split in splits:
            for im_idx in range(len(data[split])):
                png_path = os.path.join(
                    data_dir, 'png', dataset[:-4], split,
                    str(model_params.img_H) + 'x' + str(model_params.img_W),
                    str(im_idx) + '.png')
                png_paths_map[split].append(png_path)

    all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes))
    num_points = 0
    for stroke in all_strokes:
        num_points += len(stroke)
    avg_len = num_points / len(all_strokes)
    print('Dataset combined: {} ({}/{}/{}), avg len {}'.format(
        len(all_strokes), len(train_strokes), len(valid_strokes),
        len(test_strokes), int(avg_len)))
    assert len(train_strokes) == len(png_paths_map['train'])
    assert len(valid_strokes) == len(png_paths_map['valid'])
    assert len(test_strokes) == len(png_paths_map['test'])

    # calculate the max strokes we need.
    max_seq_len = utils.get_max_len(all_strokes)

    # overwrite the hps with this calculation.
    model_params.max_seq_len = max_seq_len
    print('model_params.max_seq_len %i.' % model_params.max_seq_len)

    eval_model_params = sketch_rnn_model.copy_hparams(model_params)

    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 1

    if inference_mode:
        eval_model_params.batch_size = 1
        eval_model_params.is_training = 0

    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    train_set = utils.DataLoader(
        train_strokes,
        png_paths_map['train'],
        model_params.img_H,
        model_params.img_W,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)

    normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
    train_set.normalize(normalizing_scale_factor)

    valid_set = utils.DataLoader(valid_strokes,
                                 png_paths_map['valid'],
                                 eval_model_params.img_H,
                                 eval_model_params.img_W,
                                 eval_model_params.batch_size,
                                 max_seq_length=eval_model_params.max_seq_len,
                                 random_scale_factor=0.0,
                                 augment_stroke_prob=0.0)
    valid_set.normalize(normalizing_scale_factor)

    test_set = utils.DataLoader(test_strokes,
                                png_paths_map['test'],
                                eval_model_params.img_H,
                                eval_model_params.img_W,
                                eval_model_params.batch_size,
                                max_seq_length=eval_model_params.max_seq_len,
                                random_scale_factor=0.0,
                                augment_stroke_prob=0.0)
    test_set.normalize(normalizing_scale_factor)

    print('normalizing_scale_factor %4.4f.' % normalizing_scale_factor)

    result = [
        train_set, valid_set, test_set, model_params, eval_model_params,
        sample_model_params
    ]
    return result
Beispiel #11
0
    def __init__(self, hps, hps_srm=None, hps_cls=None, input_k=None):
        """
        Build the model.
        """
        self.hps = hps

        if "auto" in hps.comparison_classes:
            hps.comparison_classes = hps_srm.data_set

        if hps_srm is not None:
            hps_srms = sketch_rnn.copy_hparams(hps_srm)
            if not self.hps.is_training:
                hps_srms.max_seq_len = 1  # slow sampling
            self.srm_s = sketch_rnn.Model(hps_srms, fast_mode=self.hps.is_training)

        if hps_cls is not None:
            # load in classifier
            self.clsm = classifier.Classifier(hps_cls)

            self.hps.label_width = self.clsm.hps.label_width
            self.hps.labels = self.clsm.hps.labels
            self.hps.data_set = self.clsm.hps.data_set

        self.label_width = self.hps.label_width
        assert isinstance(self.label_width, int)

        with tf.variable_scope("amortizer"):
            self.global_step = tf.get_variable(
                "global_step", dtype=tf.int32, initializer=0, trainable=False
            )

            ######
            # Inputs
            ###

            if input_k is None:
                self.input_k = tf.placeholder_with_default(
                    input=tf.random_uniform(
                        [self.hps.batch_size, hps_srm.z_size], minval=-3, maxval=3
                    ),
                    name="input_k",
                    shape=[self.hps.batch_size, hps_srm.z_size],
                )
            else:
                self.input_k = input_k

            ######
            # Amortizer, fully connected layers
            ###
            hl = self.input_k
            for _ in range(self.hps.critic_layers):
                hl = leaky_relu(tf.layers.dense(hl, self.hps.critic_size))
            self.logits = tf.layers.dense(hl, self.label_width)

            ######
            # Losses
            ###
            if self.hps.is_training:
                self.input_logits = tf.placeholder_with_default(
                    input=self.clsm.logits_c,
                    name="input_logits",
                    shape=[self.hps.batch_size, self.label_width],
                )

                # cross entropy loss
                labels = tf.nn.sigmoid(self.input_logits)
                entropy = tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=labels, logits=self.logits
                )
                self.loss = tf.reduce_mean(entropy)

                # dynamic learning rate
                self.lr = (hps.learning_rate - hps.min_learning_rate) * (
                    hps.decay_rate
                ) ** tf.cast(self.global_step, tf.float32) + hps.min_learning_rate

                # l2 regularization
                keys = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope="amortizer"
                )
                self.loss_l2 = tf.add_n([tf.nn.l2_loss(t) for t in keys]) * 0.0001

                # final loss function
                loss = self.loss + self.loss_l2

                # underlying code uses Adam optimizer
                self.train_op = self.build_optimizer(
                    loss, keys, self.lr, "train_op", global_step=self.global_step
                )
Beispiel #12
0
def load_dataset(sketch_data_dir,
                 photo_data_dir,
                 model_params,
                 inference_mode=False):
    """Loads the .npz file, and splits the set into train/test."""

    # normalizes the x and y columns using the training set.
    # applies same scaling factor to test set.

    if isinstance(model_params.data_set, list):
        datasets = model_params.data_set
    else:
        datasets = [model_params.data_set]

    train_strokes = None
    test_strokes = None
    train_image_paths = []
    test_image_paths = []

    for dataset in datasets:
        if model_params.data_type == 'QMUL':
            train_data_filepath = os.path.join(sketch_data_dir, dataset,
                                               'train_svg_sim_spa_png.h5')
            test_data_filepath = os.path.join(sketch_data_dir, dataset,
                                              'test_svg_sim_spa_png.h5')

            train_data_dict = utils.load_hdf5(train_data_filepath)
            test_data_dict = utils.load_hdf5(test_data_filepath)

            train_sketch_data = utils.reassemble_data(
                train_data_dict['image_data'], train_data_dict['data_offset']
            )  # list of [N_sketches], each [N_points, 4]
            train_photo_names = train_data_dict[
                'image_base_name']  # [N_sketches, 1], byte
            train_photo_paths = [
                os.path.join(photo_data_dir,
                             train_photo_names[i, 0].decode() + '.png')
                for i in range(train_photo_names.shape[0])
            ]  # [N_sketches], str
            test_sketch_data = utils.reassemble_data(
                test_data_dict['image_data'], test_data_dict['data_offset']
            )  # list of [N_sketches], each [N_points, 4]
            test_photo_names = test_data_dict[
                'image_base_name']  # [N_sketches, 1], byte
            test_photo_paths = [
                os.path.join(photo_data_dir,
                             test_photo_names[i, 0].decode() + '.png')
                for i in range(test_photo_names.shape[0])
            ]  # [N_sketches], str

            # transfer stroke-4 to stroke-3
            train_sketch_data = utils.to_normal_strokes_4to3(train_sketch_data)
            test_sketch_data = utils.to_normal_strokes_4to3(
                test_sketch_data)  # [N_sketches,], each with [N_points, 3]

            if train_strokes is None:
                train_strokes = train_sketch_data
                test_strokes = test_sketch_data
            else:
                train_strokes = np.concatenate(
                    (train_strokes, train_sketch_data))
                test_strokes = np.concatenate((test_strokes, test_sketch_data))

        elif model_params.data_type == 'QuickDraw':
            data_filepath = os.path.join(sketch_data_dir, dataset, 'npz',
                                         'sketchrnn_' + dataset + '.npz')
            if six.PY3:
                data = np.load(data_filepath, encoding='latin1')
            else:
                data = np.load(data_filepath)

            if train_strokes is None:
                train_strokes = data[
                    'train']  # [N_sketches,], each with [N_points, 3]
                test_strokes = data['test']
            else:
                train_strokes = np.concatenate((train_strokes, data['train']))
                test_strokes = np.concatenate((test_strokes, data['test']))

            train_photo_paths = [
                os.path.join(
                    sketch_data_dir, dataset, 'png', 'train',
                    str(model_params.image_size) + 'x' +
                    str(model_params.image_size),
                    str(im_idx) + '.png')
                for im_idx in range(len(data['train']))
            ]
            test_photo_paths = [
                os.path.join(
                    sketch_data_dir, dataset, 'png', 'test',
                    str(model_params.image_size) + 'x' +
                    str(model_params.image_size),
                    str(im_idx) + '.png')
                for im_idx in range(len(data['test']))
            ]
        else:
            raise Exception('Unknown data type:', model_params.data_type)

        print('Loaded {}/{} from {} {}'.format(len(train_photo_paths),
                                               len(test_photo_paths),
                                               model_params.data_type,
                                               dataset))
        train_image_paths += train_photo_paths
        test_image_paths += test_photo_paths

    all_strokes = np.concatenate((train_strokes, test_strokes))
    num_points = 0
    for stroke in all_strokes:
        num_points += len(stroke)
    avg_len = num_points / len(all_strokes)
    print('Dataset combined: {} ({}/{}), avg len {}'.format(
        len(all_strokes), len(train_strokes), len(test_strokes), int(avg_len)))
    assert len(train_image_paths) == len(train_strokes)
    assert len(test_image_paths) == len(test_strokes)

    # calculate the max strokes we need.
    max_seq_len = utils.get_max_len(all_strokes)

    # overwrite the hps with this calculation.
    model_params.max_seq_len = max_seq_len
    print('model_params.max_seq_len %i.' % model_params.max_seq_len)

    eval_model_params = sketch_p2s_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 1

    if inference_mode:
        eval_model_params.batch_size = 1
        eval_model_params.is_training = 0

    sample_model_params = sketch_p2s_model.copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    train_set = utils.DataLoader(
        train_strokes,
        train_image_paths,
        model_params.image_size,
        model_params.image_size,
        model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob)

    normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
    train_set.normalize(normalizing_scale_factor)

    # valid_set = utils.DataLoader(
    #     valid_strokes,
    #     eval_model_params.batch_size,
    #     max_seq_length=eval_model_params.max_seq_len,
    #     random_scale_factor=0.0,
    #     augment_stroke_prob=0.0)
    # valid_set.normalize(normalizing_scale_factor)

    test_set = utils.DataLoader(test_strokes,
                                test_image_paths,
                                model_params.image_size,
                                model_params.image_size,
                                eval_model_params.batch_size,
                                max_seq_length=eval_model_params.max_seq_len,
                                random_scale_factor=0.0,
                                augment_stroke_prob=0.0)
    test_set.normalize(normalizing_scale_factor)

    print('normalizing_scale_factor %4.4f.' % normalizing_scale_factor)

    result = [
        train_set, None, test_set, model_params, eval_model_params,
        sample_model_params
    ]
    return result
Beispiel #13
0
def load_dataset(root_dir, dataset, model_params, inference_mode=False):
    """Loads the .npz file, and splits the set into train/valid/test."""

    # normalizes the x and y columns usint the training set.
    # applies same scaling factor to valid and test set.

    if dataset in ['shoesv2', 'chairsv2', 'shoesv2f_sup', 'shoesv2f_train']:
        data_dir = '%s/%s/' % (root_dir, dataset.split('v')[0])
    elif 'quickdraw' in str(dataset).lower():
        data_dir = os.path.join(root_dir, 'quickdraw')
    else:
        raise Exception('Dataset error')

    print(data_dir)

    sketch_data, sketch_png_data, image_data, sbir_data, skh_img_id, img_skh_id = {}, {}, {}, {}, {}, {}
    subset = 'train' if not inference_mode else 'test'

    rgb_str = ''

    view_type = 'photo'
    if dataset in ['shoesv2', 'chairsv2']:
        rgb_str = '_rgb'

    print('Prepared data for %s set' % subset)

    photo_png_dir = os.path.join(data_dir,
                                 '%s/%s%s.h5' % (view_type, subset, rgb_str))
    photo_png_dir_train = os.path.join(data_dir,
                                       '%s/train%s.h5' % (view_type, rgb_str))
    photo_png_dir_test = os.path.join(data_dir,
                                      '%s/test%s.h5' % (view_type, rgb_str))

    if 'quickdraw' in str(dataset).lower():
        sketch_data_dir = os.path.join(data_dir, 'npz_data/%s.npz' % category)
        sketch_png_dir_train = os.path.join(data_dir,
                                            'hdf5_data/%s_train.h5' % category)
        sketch_png_dir_test = os.path.join(data_dir,
                                           'hdf5_data/%s_test.h5' % category)

        # load data w/o label into dictionary
        # sketch_data[subset] = np.copy(np.load(sketch_data_dir)[subset])
        sketch_data['train'] = np.copy(np.load(sketch_data_dir)['train'])
        sketch_data['valid'] = np.copy(np.load(sketch_data_dir)['valid'])
        sketch_data['test'] = np.copy(np.load(sketch_data_dir)['test'])
        sketch_png_data['train'] = load_hdf5(
            sketch_png_dir_train)['image_data']
        sketch_png_data['test'] = load_hdf5(sketch_png_dir_test)['image_data']
        # image_data[subset] = None
        image_data['train'], image_data['test'] = None, None
        skh_img_id['train'], img_skh_id['train'] = None, None
        skh_img_id['test'], img_skh_id['test'] = None, None

    elif dataset in ['shoesv2', 'chairsv2']:

        sketch_data_dir_train = os.path.join(
            data_dir, 'svg_trimmed/train_svg_sim_spa_png.h5')
        sketch_data_dir_test = os.path.join(
            data_dir, 'svg_trimmed/test_svg_sim_spa_png.h5')

        # load data w/o label into dictionary
        # sketch_data[subset] = get_sketch_data(sketch_data_dir)
        sketch_data['train'] = get_sketch_data(sketch_data_dir_train)
        sketch_data['test'] = get_sketch_data(sketch_data_dir_test)
        sketch_data_info_train = load_hdf5(sketch_data_dir_train)
        sketch_data_info_test = load_hdf5(sketch_data_dir_test)
        sketch_png_data['train'] = sketch_data_info_train['png_data']
        sketch_png_data['test'] = sketch_data_info_test['png_data']
        # image_data[subset] = load_hdf5(photo_png_dir)['image_data']
        image_data['train'] = load_hdf5(photo_png_dir_train)['image_data']
        image_data['test'] = load_hdf5(photo_png_dir_test)['image_data']
        # skh_img_id[subset], img_skh_id[subset] = get_skh_img_ids(sketch_data_info['image_id'], sketch_data_info['instance_id'])
        skh_img_id['train'], img_skh_id['train'] = get_skh_img_ids(
            sketch_data_info_train['image_id'],
            sketch_data_info_train['instance_id'])
        skh_img_id['test'], img_skh_id['test'] = get_skh_img_ids(
            sketch_data_info_test['image_id'],
            sketch_data_info_test['instance_id'])

    if model_params.enc_type != 'cnn':
        # free the memory if not need
        # sketch_png_data[subset], image_data[subset] = None, None
        sketch_png_data['train'], image_data['train'] = None, None
        sketch_png_data['test'], image_data['test'] = None, None

    if sketch_data[subset] is not None:
        sketch_size = len(sketch_data[subset])
    else:
        sketch_size = 0
    if image_data[subset] is not None:
        photo_size = len(image_data[subset])
    else:
        photo_size = sketch_size

    print('Loaded {} set, {} sketches and {} images'.format(
        subset, sketch_size, photo_size))

    if 'quickdraw' in str(dataset).lower():
        all_strokes = np.concatenate(
            (sketch_data['train'], sketch_data['valid'], sketch_data['test']))
    else:
        all_strokes = np.concatenate(
            (sketch_data['train'], sketch_data['test']))
    num_points = 0
    for stroke in all_strokes:
        num_points += len(stroke)
    avg_len = num_points / len(all_strokes)
    print('Dataset train/valid/test, avg len {}'.format(int(avg_len)))

    # calculate the max strokes we need.
    max_seq_len = get_max_len(all_strokes)

    # overwrite the hps with this calculation.
    model_params.max_seq_len = max_seq_len
    model_params.rnn_enc_max_seq_len = max_seq_len

    print('model_params.max_seq_len %d' % model_params.max_seq_len)

    eval_model_params = sketch_rnn_model.copy_hparams(model_params)

    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    # eval_model_params.is_training = 1
    eval_model_params.is_training = 0

    if inference_mode:
        eval_model_params.batch_size = 1
        eval_model_params.is_training = 0

    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.batch_size = 1  # only sample one at a time
    sample_model_params.max_seq_len = 1  # sample one point at a time

    gen_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    gen_model_params.batch_size = 1  # only sample one at a time

    print("Create DataLoader for %s subset" % subset)
    sbir_data['train_set'] = DataLoader(
        sketch_data['train'],
        sketch_png_data['train'],
        image_data['train'],
        skh_img_ids=skh_img_id['train'],
        img_skh_ids=img_skh_id['train'],
        dataset=dataset,
        enc_type=model_params.enc_type,
        vae_type=model_params.vae_type,
        batch_size=model_params.batch_size,
        max_seq_length=model_params.max_seq_len,
        random_scale_factor=model_params.random_scale_factor,
        augment_stroke_prob=model_params.augment_stroke_prob,
        augment_flipr_flag=model_params.flip_aug)

    if inference_mode:
        sbir_data['test_set'] = DataLoader(
            sketch_data['test'],
            sketch_png_data['test'],
            image_data['test'],
            skh_img_ids=skh_img_id['test'],
            img_skh_ids=img_skh_id['test'],
            dataset=dataset,
            enc_type=model_params.enc_type,
            vae_type=model_params.vae_type,
            batch_size=model_params.batch_size,
            max_seq_length=model_params.max_seq_len,
            random_scale_factor=model_params.random_scale_factor,
            augment_stroke_prob=model_params.augment_stroke_prob,
            augment_flipr_flag=model_params.flip_aug)

    normalizing_scale_factor = sbir_data[
        'train_set'].calculate_normalizing_scale_factor()

    sbir_data['%s_set' % subset].normalize(normalizing_scale_factor)

    print('normalizing_scale_factor %4.4f.' % normalizing_scale_factor)

    # return_model_params = eval_model_params if inference_mode else model_params

    # return sbir_data['%s_set' % subset], return_model_params
    if not inference_mode:
        return sbir_data['%s_set' % subset], model_params
    else:
        return sbir_data['%s_set' %
                         subset], sample_model_params, gen_model_params