def load_dataset(data_dir, model_params, inference_mode=False): tf.logging.info('loaddataset-开始数据处理=================================================') source = model_params.source #source.npy target = model_params.target #target.npy source_data = np.load(os.path.join(data_dir, source)) target_data = np.load(os.path.join(data_dir, target)) tf.logging.info('打印原始输入长度 %i.',len(source_data)) tf.logging.info('打印目标输入长度 %i.',len(target_data)) num_points = 68 model_params.max_seq_len = num_points tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len) #并打印出来 eval_model_params = sketch_rnn_model.copy_hparams(model_params) #讲model的参数复制给评价模型 eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 #并修改一些参数 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 1 if inference_mode: # = fales eval_model_params.batch_size = 1 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) #将参数复制给sample模型 sample_model_params.batch_size = 1 # only sample one at a time sample_model_params.max_seq_len = 1 # sample one point at a time #随机打乱后依次取出一个batch的训练集、测试集、验证集数据 #数据的x,y做过normalize处理,且不足Nmax的补充为(0,0,0,0,1) tf.logging.info('正式处理数据,输入网络') indices = np.random.permutation(range(0, len(source_data)))[0:model_params.batch_size] #为保证source和target同顺序 source_set = utils.DataLoader( source_data, indices, model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor) target_set = utils.DataLoader( target_data, indices, model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor) factor_source = source_set.calculate_normalizing_scale_factor() source_set.normalize(factor_source)#再对数据做normalize factor_target = target_set.calculate_normalizing_scale_factor() target_set.normalize(factor_target)#再对数据做normalize tf.logging.info('source normalizing_scale_factor is %4.4f.',factor_source) tf.logging.info('target normalizing_scale_factor is %4.4f.',factor_target) result = [source_set, target_set, model_params, eval_model_params, sample_model_params] return result
def load_model(model_dir): """Loads model for inference mode, used in jupyter notebook.""" model_params = get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: model_params.parse_json(f.read()) model_params.batch_size = 1 # only sample one at a time eval_model_params = copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params]
def load_model(model_dir): """Loads model for inference mode, used in jupyter notebook.""" model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: model_params.parse_json(f.read()) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params]
def load_model(model_dir): """Loads model for inference mode, used in jupyter notebook.""" model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: model_params.parse_json(f.read()) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time if six.PY3: pretrained_model_params = np.load(model_dir+'/model', encoding='latin1') else: pretrained_model_params = np.load(model_dir+'/model') return [model_params, eval_model_params, sample_model_params, pretrained_model_params]
def load_model_compatible(model_dir): """Loads model for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params]
def load_dataset(data_dir, model_params, inference_mode=False): """Loads the .npz file, and splits the set into train/valid/test.""" # normalizes the x and y columns usint the training set. # applies same scaling factor to valid and test set. datasets = [] if isinstance(model_params.data_set, list): datasets = model_params.data_set else: datasets = [model_params.data_set] train_strokes = None valid_strokes = None test_strokes = None for dataset in datasets: data_filepath = os.path.join(data_dir, dataset) if data_dir.startswith('http://') or data_dir.startswith('https://'): tf.logging.info('Downloading %s', data_filepath) response = requests.get(data_filepath) data = np.load(StringIO(response.content)) else: data = np.load(data_filepath) # load this into dictionary tf.logging.info('Loaded {}/{}/{} from {}'.format( len(data['train']), len(data['valid']), len(data['test']), dataset)) if train_strokes is None: train_strokes = data['train'] valid_strokes = data['valid'] test_strokes = data['test'] else: train_strokes = np.concatenate((train_strokes, data['train'])) valid_strokes = np.concatenate((valid_strokes, data['valid'])) test_strokes = np.concatenate((test_strokes, data['test'])) all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes)) num_points = 0 for stroke in all_strokes: num_points += len(stroke) avg_len = num_points / len(all_strokes) tf.logging.info('Dataset combined: {} ({}/{}/{}), avg len {}'.format( len(all_strokes), len(train_strokes), len(valid_strokes), len(test_strokes), int(avg_len))) # calculate the max strokes we need. max_seq_len = utils.get_max_len(all_strokes) # overwrite the hps with this calculation. model_params.max_seq_len = max_seq_len tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len) eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 1 if inference_mode: eval_model_params.batch_size = 1 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.batch_size = 1 # only sample one at a time sample_model_params.max_seq_len = 1 # sample one point at a time train_set = utils.DataLoader( train_strokes, model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor, augment_stroke_prob=model_params.augment_stroke_prob) normalizing_scale_factor = train_set.calculate_normalizing_scale_factor() train_set.normalize(normalizing_scale_factor) print('Length original', len(train_strokes), len(valid_strokes), len(test_strokes)) valid_set = utils.DataLoader(valid_strokes, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) valid_set.normalize(normalizing_scale_factor) test_set = utils.DataLoader(test_strokes, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) test_set.normalize(normalizing_scale_factor) tf.logging.info('normalizing_scale_factor %4.4f.', normalizing_scale_factor) result = [ train_set, valid_set, test_set, model_params, eval_model_params, sample_model_params ] return result
def load_dataset(data_dir, model_params, testing_mode=False): """Loads the .npz file, and splits the set into train/valid/test.""" # normalizes the x and y columns using scale_factor. dataset = model_params.data_set data_filepath = os.path.join(data_dir, dataset) data = np.load(data_filepath, allow_pickle=True, encoding='latin1') # target data train_strokes = data['train'] valid_strokes = data['valid'] test_strokes = data['test'] all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes)) # standard data (reference data in paper) std_train_strokes = data['std_train'] std_valid_strokes = data['std_valid'] std_test_strokes = data['std_test'] all_std_trokes = np.concatenate( (std_train_strokes, std_valid_strokes, std_test_strokes)) print('Dataset combined: %d (train=%d/validate=%d/test=%d)' % (len(all_strokes), len(train_strokes), len(valid_strokes), len(test_strokes))) # calculate the max strokes we need. max_seq_len = utils.get_max_len(all_strokes) max_std_seq_len = utils.get_max_len(all_std_trokes) # overwrite the hps with this calculation. model_params.max_seq_len = max(max_seq_len, max_std_seq_len) print('model_params.max_seq_len set to %d.' % model_params.max_seq_len) eval_model_params = copy_hparams(model_params) eval_model_params.rnn_dropout_keep_prob = 1.0 eval_model_params.is_training = True if testing_mode: # for testing eval_model_params.batch_size = 1 eval_model_params.is_training = False # sample mode train_set = utils.DataLoader( train_strokes, model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor, augment_stroke_prob=model_params.augment_stroke_prob) normalizing_scale_factor = model_params.scale_factor train_set.normalize(normalizing_scale_factor) valid_set = utils.DataLoader(valid_strokes, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) valid_set.normalize(normalizing_scale_factor) test_set = utils.DataLoader(test_strokes, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) test_set.normalize(normalizing_scale_factor) # process the reference dataset std_train_set = utils.DataLoader( std_train_strokes, model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor, augment_stroke_prob=model_params.augment_stroke_prob) std_train_set.normalize(normalizing_scale_factor) std_valid_set = utils.DataLoader( std_valid_strokes, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) std_valid_set.normalize(normalizing_scale_factor) std_test_set = utils.DataLoader( std_test_strokes, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) std_test_set.normalize(normalizing_scale_factor) result = [ train_set, valid_set, test_set, std_train_set, std_valid_set, std_test_set, model_params, eval_model_params ] return result
def load_datasets(data_dir, model_params, inference_mode=False): """Load and preprocess data""" data = utils.load_dataset(data_dir) train_strokes = data['train'] valid_strokes = data['valid'] test_strokes = data['test'] all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes)) num_points = 0 for stroke in all_strokes: num_points += len(stroke) avg_len = num_points / len(all_strokes) tf.logging.info('{} Shapes / {} Total points'.format(len(all_strokes), num_points)) tf.logging.info('Dataset combined: {} ({}/{}/{}), avg len {}'.format( len(all_strokes), len(train_strokes), len(valid_strokes), len(test_strokes), int(avg_len))) # calculate the max strokes we need. max_seq_len = utils.get_max_len(all_strokes) # overwrite the hps with this calculation. model_params.max_seq_len = max_seq_len tf.logging.info('model_params.max_seq_len %i.', model_params.max_seq_len) eval_model_params = derender_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 1 if inference_mode: eval_model_params.batch_size = 1 eval_model_params.is_training = 0 sample_model_params = derender_model.copy_hparams(eval_model_params) sample_model_params.batch_size = 1 # only sample one at a time sample_model_params.max_seq_len = 1 # sample one point at a time train_set = utils.DataLoader( train_strokes, model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor, augment_stroke_prob=model_params.augment_stroke_prob) normalizing_scale_factor = train_set.calculate_normalizing_scale_factor() train_set.normalize(normalizing_scale_factor) valid_set = utils.DataLoader( valid_strokes, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) valid_set.normalize(normalizing_scale_factor) test_set = utils.DataLoader( test_strokes, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) test_set.normalize(normalizing_scale_factor) tf.logging.info('normalizing_scale_factor %4.4f.', normalizing_scale_factor) result = [ train_set, valid_set, test_set, model_params, eval_model_params, sample_model_params ] return result
def trainer(model_params): # model_params is the default parameters from model.py if FLAGS.test != False: model_params.is_training = False model_params.batch_size = 1 np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) # set the output format ''' print the parameters ''' tf.logging.info('Hyperparams:') for key, val in six.iteritems(model_params.values()): tf.logging.info('%s = %s', key, str(val)) tf.logging.info('Loading data files.') ''' get train dataset, valid dataset, test dataset ready data format: input: [ data_num * game_round(should be 20) * dimension(should be 7)] output:[ data_num * game_round(should be 20) * dimension(should be 1)] data usage: train_set: used to train the network valid_set: used to check the network, if network of current parameters performs better than all previous epochs, we consider the current parameters to be valid ones and save them. test_set: used to evaluate the final performance of our trained network ''' _ = np.load(model_params.train_data_set, encoding="latin1") train_set = {} train_set[0] = _["input"].copy() train_set[1] = _["output"].copy() _ = np.load(model_params.valid_data_set, encoding="latin1") valid_set = {} valid_set[0] = _["input"].copy() valid_set[1] = _["output"].copy() _ = np.load(model_params.test_data_set, encoding="latin1") test_set = {} test_set[0] = _["input"].copy() test_set[1] = _["output"].copy() '''eval_model_params is for evaluation, and it will have the same parameters as the training set(although some of its parameters will be changed very soon''' eval_model_params = Model.copy_hparams(model_params) # reset_graph() model = Model.Model(model_params) eval_model = Model.Model(eval_model_params, reuse=True) '''set and initialize the graph, for the difference between tf.InteractiveSession() and tf.Session() refer to this https://stackoverflow.com/questions/47721792/tensorflow-tf-get-default-session-after-sess-tf-session-is-none ''' sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) if FLAGS.resume_training or FLAGS.test: load_checkpoint(sess, FLAGS.log_root) # Write config file to json file. tf.gfile.MakeDirs(FLAGS.log_root) with tf.gfile.Open(os.path.join(FLAGS.log_root, 'model_config.json'), 'w') as f: json.dump(model_params.values(), f, indent=True) train(sess, model, eval_model, train_set, valid_set, test_set)
def load_dataset(data_dir, model_params, inference_mode=False): """Loads the .npz file, and splits the set into train/valid/test.""" # normalizes the x and y columns using the training set. # applies same scaling factor to valid and test set. if isinstance(model_params.data_set, list): datasets = model_params.data_set else: datasets = [model_params.data_set] train_strokes = None valid_strokes = None test_strokes = None png_paths_map = {'train': [], 'valid': [], 'test': []} for dataset in datasets: if data_dir.startswith('http://') or data_dir.startswith('https://'): data_filepath = '/'.join([data_dir, dataset]) print('Downloading %s' % data_filepath) response = requests.get(data_filepath) data = np.load(six.BytesIO(response.content), encoding='latin') else: data_filepath = os.path.join(data_dir, 'npz', dataset) if six.PY3: data = np.load(data_filepath, encoding='latin1') else: data = np.load(data_filepath) print('Loaded {}/{}/{} from {}'.format(len(data['train']), len(data['valid']), len(data['test']), dataset)) if train_strokes is None: train_strokes = data[ 'train'] # [N (#sketches),], each with [S (#points), 3] valid_strokes = data['valid'] test_strokes = data['test'] else: train_strokes = np.concatenate((train_strokes, data['train'])) valid_strokes = np.concatenate((valid_strokes, data['valid'])) test_strokes = np.concatenate((test_strokes, data['test'])) splits = ['train', 'valid', 'test'] for split in splits: for im_idx in range(len(data[split])): png_path = os.path.join( data_dir, 'png', dataset[:-4], split, str(model_params.img_H) + 'x' + str(model_params.img_W), str(im_idx) + '.png') png_paths_map[split].append(png_path) all_strokes = np.concatenate((train_strokes, valid_strokes, test_strokes)) num_points = 0 for stroke in all_strokes: num_points += len(stroke) avg_len = num_points / len(all_strokes) print('Dataset combined: {} ({}/{}/{}), avg len {}'.format( len(all_strokes), len(train_strokes), len(valid_strokes), len(test_strokes), int(avg_len))) assert len(train_strokes) == len(png_paths_map['train']) assert len(valid_strokes) == len(png_paths_map['valid']) assert len(test_strokes) == len(png_paths_map['test']) # calculate the max strokes we need. max_seq_len = utils.get_max_len(all_strokes) # overwrite the hps with this calculation. model_params.max_seq_len = max_seq_len print('model_params.max_seq_len %i.' % model_params.max_seq_len) eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 1 if inference_mode: eval_model_params.batch_size = 1 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.batch_size = 1 # only sample one at a time sample_model_params.max_seq_len = 1 # sample one point at a time train_set = utils.DataLoader( train_strokes, png_paths_map['train'], model_params.img_H, model_params.img_W, model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor, augment_stroke_prob=model_params.augment_stroke_prob) normalizing_scale_factor = train_set.calculate_normalizing_scale_factor() train_set.normalize(normalizing_scale_factor) valid_set = utils.DataLoader(valid_strokes, png_paths_map['valid'], eval_model_params.img_H, eval_model_params.img_W, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) valid_set.normalize(normalizing_scale_factor) test_set = utils.DataLoader(test_strokes, png_paths_map['test'], eval_model_params.img_H, eval_model_params.img_W, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) test_set.normalize(normalizing_scale_factor) print('normalizing_scale_factor %4.4f.' % normalizing_scale_factor) result = [ train_set, valid_set, test_set, model_params, eval_model_params, sample_model_params ] return result
def __init__(self, hps, hps_srm=None, hps_cls=None, input_k=None): """ Build the model. """ self.hps = hps if "auto" in hps.comparison_classes: hps.comparison_classes = hps_srm.data_set if hps_srm is not None: hps_srms = sketch_rnn.copy_hparams(hps_srm) if not self.hps.is_training: hps_srms.max_seq_len = 1 # slow sampling self.srm_s = sketch_rnn.Model(hps_srms, fast_mode=self.hps.is_training) if hps_cls is not None: # load in classifier self.clsm = classifier.Classifier(hps_cls) self.hps.label_width = self.clsm.hps.label_width self.hps.labels = self.clsm.hps.labels self.hps.data_set = self.clsm.hps.data_set self.label_width = self.hps.label_width assert isinstance(self.label_width, int) with tf.variable_scope("amortizer"): self.global_step = tf.get_variable( "global_step", dtype=tf.int32, initializer=0, trainable=False ) ###### # Inputs ### if input_k is None: self.input_k = tf.placeholder_with_default( input=tf.random_uniform( [self.hps.batch_size, hps_srm.z_size], minval=-3, maxval=3 ), name="input_k", shape=[self.hps.batch_size, hps_srm.z_size], ) else: self.input_k = input_k ###### # Amortizer, fully connected layers ### hl = self.input_k for _ in range(self.hps.critic_layers): hl = leaky_relu(tf.layers.dense(hl, self.hps.critic_size)) self.logits = tf.layers.dense(hl, self.label_width) ###### # Losses ### if self.hps.is_training: self.input_logits = tf.placeholder_with_default( input=self.clsm.logits_c, name="input_logits", shape=[self.hps.batch_size, self.label_width], ) # cross entropy loss labels = tf.nn.sigmoid(self.input_logits) entropy = tf.nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=self.logits ) self.loss = tf.reduce_mean(entropy) # dynamic learning rate self.lr = (hps.learning_rate - hps.min_learning_rate) * ( hps.decay_rate ) ** tf.cast(self.global_step, tf.float32) + hps.min_learning_rate # l2 regularization keys = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope="amortizer" ) self.loss_l2 = tf.add_n([tf.nn.l2_loss(t) for t in keys]) * 0.0001 # final loss function loss = self.loss + self.loss_l2 # underlying code uses Adam optimizer self.train_op = self.build_optimizer( loss, keys, self.lr, "train_op", global_step=self.global_step )
def load_dataset(sketch_data_dir, photo_data_dir, model_params, inference_mode=False): """Loads the .npz file, and splits the set into train/test.""" # normalizes the x and y columns using the training set. # applies same scaling factor to test set. if isinstance(model_params.data_set, list): datasets = model_params.data_set else: datasets = [model_params.data_set] train_strokes = None test_strokes = None train_image_paths = [] test_image_paths = [] for dataset in datasets: if model_params.data_type == 'QMUL': train_data_filepath = os.path.join(sketch_data_dir, dataset, 'train_svg_sim_spa_png.h5') test_data_filepath = os.path.join(sketch_data_dir, dataset, 'test_svg_sim_spa_png.h5') train_data_dict = utils.load_hdf5(train_data_filepath) test_data_dict = utils.load_hdf5(test_data_filepath) train_sketch_data = utils.reassemble_data( train_data_dict['image_data'], train_data_dict['data_offset'] ) # list of [N_sketches], each [N_points, 4] train_photo_names = train_data_dict[ 'image_base_name'] # [N_sketches, 1], byte train_photo_paths = [ os.path.join(photo_data_dir, train_photo_names[i, 0].decode() + '.png') for i in range(train_photo_names.shape[0]) ] # [N_sketches], str test_sketch_data = utils.reassemble_data( test_data_dict['image_data'], test_data_dict['data_offset'] ) # list of [N_sketches], each [N_points, 4] test_photo_names = test_data_dict[ 'image_base_name'] # [N_sketches, 1], byte test_photo_paths = [ os.path.join(photo_data_dir, test_photo_names[i, 0].decode() + '.png') for i in range(test_photo_names.shape[0]) ] # [N_sketches], str # transfer stroke-4 to stroke-3 train_sketch_data = utils.to_normal_strokes_4to3(train_sketch_data) test_sketch_data = utils.to_normal_strokes_4to3( test_sketch_data) # [N_sketches,], each with [N_points, 3] if train_strokes is None: train_strokes = train_sketch_data test_strokes = test_sketch_data else: train_strokes = np.concatenate( (train_strokes, train_sketch_data)) test_strokes = np.concatenate((test_strokes, test_sketch_data)) elif model_params.data_type == 'QuickDraw': data_filepath = os.path.join(sketch_data_dir, dataset, 'npz', 'sketchrnn_' + dataset + '.npz') if six.PY3: data = np.load(data_filepath, encoding='latin1') else: data = np.load(data_filepath) if train_strokes is None: train_strokes = data[ 'train'] # [N_sketches,], each with [N_points, 3] test_strokes = data['test'] else: train_strokes = np.concatenate((train_strokes, data['train'])) test_strokes = np.concatenate((test_strokes, data['test'])) train_photo_paths = [ os.path.join( sketch_data_dir, dataset, 'png', 'train', str(model_params.image_size) + 'x' + str(model_params.image_size), str(im_idx) + '.png') for im_idx in range(len(data['train'])) ] test_photo_paths = [ os.path.join( sketch_data_dir, dataset, 'png', 'test', str(model_params.image_size) + 'x' + str(model_params.image_size), str(im_idx) + '.png') for im_idx in range(len(data['test'])) ] else: raise Exception('Unknown data type:', model_params.data_type) print('Loaded {}/{} from {} {}'.format(len(train_photo_paths), len(test_photo_paths), model_params.data_type, dataset)) train_image_paths += train_photo_paths test_image_paths += test_photo_paths all_strokes = np.concatenate((train_strokes, test_strokes)) num_points = 0 for stroke in all_strokes: num_points += len(stroke) avg_len = num_points / len(all_strokes) print('Dataset combined: {} ({}/{}), avg len {}'.format( len(all_strokes), len(train_strokes), len(test_strokes), int(avg_len))) assert len(train_image_paths) == len(train_strokes) assert len(test_image_paths) == len(test_strokes) # calculate the max strokes we need. max_seq_len = utils.get_max_len(all_strokes) # overwrite the hps with this calculation. model_params.max_seq_len = max_seq_len print('model_params.max_seq_len %i.' % model_params.max_seq_len) eval_model_params = sketch_p2s_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 1 if inference_mode: eval_model_params.batch_size = 1 eval_model_params.is_training = 0 sample_model_params = sketch_p2s_model.copy_hparams(eval_model_params) sample_model_params.batch_size = 1 # only sample one at a time sample_model_params.max_seq_len = 1 # sample one point at a time train_set = utils.DataLoader( train_strokes, train_image_paths, model_params.image_size, model_params.image_size, model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor, augment_stroke_prob=model_params.augment_stroke_prob) normalizing_scale_factor = train_set.calculate_normalizing_scale_factor() train_set.normalize(normalizing_scale_factor) # valid_set = utils.DataLoader( # valid_strokes, # eval_model_params.batch_size, # max_seq_length=eval_model_params.max_seq_len, # random_scale_factor=0.0, # augment_stroke_prob=0.0) # valid_set.normalize(normalizing_scale_factor) test_set = utils.DataLoader(test_strokes, test_image_paths, model_params.image_size, model_params.image_size, eval_model_params.batch_size, max_seq_length=eval_model_params.max_seq_len, random_scale_factor=0.0, augment_stroke_prob=0.0) test_set.normalize(normalizing_scale_factor) print('normalizing_scale_factor %4.4f.' % normalizing_scale_factor) result = [ train_set, None, test_set, model_params, eval_model_params, sample_model_params ] return result
def load_dataset(root_dir, dataset, model_params, inference_mode=False): """Loads the .npz file, and splits the set into train/valid/test.""" # normalizes the x and y columns usint the training set. # applies same scaling factor to valid and test set. if dataset in ['shoesv2', 'chairsv2', 'shoesv2f_sup', 'shoesv2f_train']: data_dir = '%s/%s/' % (root_dir, dataset.split('v')[0]) elif 'quickdraw' in str(dataset).lower(): data_dir = os.path.join(root_dir, 'quickdraw') else: raise Exception('Dataset error') print(data_dir) sketch_data, sketch_png_data, image_data, sbir_data, skh_img_id, img_skh_id = {}, {}, {}, {}, {}, {} subset = 'train' if not inference_mode else 'test' rgb_str = '' view_type = 'photo' if dataset in ['shoesv2', 'chairsv2']: rgb_str = '_rgb' print('Prepared data for %s set' % subset) photo_png_dir = os.path.join(data_dir, '%s/%s%s.h5' % (view_type, subset, rgb_str)) photo_png_dir_train = os.path.join(data_dir, '%s/train%s.h5' % (view_type, rgb_str)) photo_png_dir_test = os.path.join(data_dir, '%s/test%s.h5' % (view_type, rgb_str)) if 'quickdraw' in str(dataset).lower(): sketch_data_dir = os.path.join(data_dir, 'npz_data/%s.npz' % category) sketch_png_dir_train = os.path.join(data_dir, 'hdf5_data/%s_train.h5' % category) sketch_png_dir_test = os.path.join(data_dir, 'hdf5_data/%s_test.h5' % category) # load data w/o label into dictionary # sketch_data[subset] = np.copy(np.load(sketch_data_dir)[subset]) sketch_data['train'] = np.copy(np.load(sketch_data_dir)['train']) sketch_data['valid'] = np.copy(np.load(sketch_data_dir)['valid']) sketch_data['test'] = np.copy(np.load(sketch_data_dir)['test']) sketch_png_data['train'] = load_hdf5( sketch_png_dir_train)['image_data'] sketch_png_data['test'] = load_hdf5(sketch_png_dir_test)['image_data'] # image_data[subset] = None image_data['train'], image_data['test'] = None, None skh_img_id['train'], img_skh_id['train'] = None, None skh_img_id['test'], img_skh_id['test'] = None, None elif dataset in ['shoesv2', 'chairsv2']: sketch_data_dir_train = os.path.join( data_dir, 'svg_trimmed/train_svg_sim_spa_png.h5') sketch_data_dir_test = os.path.join( data_dir, 'svg_trimmed/test_svg_sim_spa_png.h5') # load data w/o label into dictionary # sketch_data[subset] = get_sketch_data(sketch_data_dir) sketch_data['train'] = get_sketch_data(sketch_data_dir_train) sketch_data['test'] = get_sketch_data(sketch_data_dir_test) sketch_data_info_train = load_hdf5(sketch_data_dir_train) sketch_data_info_test = load_hdf5(sketch_data_dir_test) sketch_png_data['train'] = sketch_data_info_train['png_data'] sketch_png_data['test'] = sketch_data_info_test['png_data'] # image_data[subset] = load_hdf5(photo_png_dir)['image_data'] image_data['train'] = load_hdf5(photo_png_dir_train)['image_data'] image_data['test'] = load_hdf5(photo_png_dir_test)['image_data'] # skh_img_id[subset], img_skh_id[subset] = get_skh_img_ids(sketch_data_info['image_id'], sketch_data_info['instance_id']) skh_img_id['train'], img_skh_id['train'] = get_skh_img_ids( sketch_data_info_train['image_id'], sketch_data_info_train['instance_id']) skh_img_id['test'], img_skh_id['test'] = get_skh_img_ids( sketch_data_info_test['image_id'], sketch_data_info_test['instance_id']) if model_params.enc_type != 'cnn': # free the memory if not need # sketch_png_data[subset], image_data[subset] = None, None sketch_png_data['train'], image_data['train'] = None, None sketch_png_data['test'], image_data['test'] = None, None if sketch_data[subset] is not None: sketch_size = len(sketch_data[subset]) else: sketch_size = 0 if image_data[subset] is not None: photo_size = len(image_data[subset]) else: photo_size = sketch_size print('Loaded {} set, {} sketches and {} images'.format( subset, sketch_size, photo_size)) if 'quickdraw' in str(dataset).lower(): all_strokes = np.concatenate( (sketch_data['train'], sketch_data['valid'], sketch_data['test'])) else: all_strokes = np.concatenate( (sketch_data['train'], sketch_data['test'])) num_points = 0 for stroke in all_strokes: num_points += len(stroke) avg_len = num_points / len(all_strokes) print('Dataset train/valid/test, avg len {}'.format(int(avg_len))) # calculate the max strokes we need. max_seq_len = get_max_len(all_strokes) # overwrite the hps with this calculation. model_params.max_seq_len = max_seq_len model_params.rnn_enc_max_seq_len = max_seq_len print('model_params.max_seq_len %d' % model_params.max_seq_len) eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 # eval_model_params.is_training = 1 eval_model_params.is_training = 0 if inference_mode: eval_model_params.batch_size = 1 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.batch_size = 1 # only sample one at a time sample_model_params.max_seq_len = 1 # sample one point at a time gen_model_params = sketch_rnn_model.copy_hparams(eval_model_params) gen_model_params.batch_size = 1 # only sample one at a time print("Create DataLoader for %s subset" % subset) sbir_data['train_set'] = DataLoader( sketch_data['train'], sketch_png_data['train'], image_data['train'], skh_img_ids=skh_img_id['train'], img_skh_ids=img_skh_id['train'], dataset=dataset, enc_type=model_params.enc_type, vae_type=model_params.vae_type, batch_size=model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor, augment_stroke_prob=model_params.augment_stroke_prob, augment_flipr_flag=model_params.flip_aug) if inference_mode: sbir_data['test_set'] = DataLoader( sketch_data['test'], sketch_png_data['test'], image_data['test'], skh_img_ids=skh_img_id['test'], img_skh_ids=img_skh_id['test'], dataset=dataset, enc_type=model_params.enc_type, vae_type=model_params.vae_type, batch_size=model_params.batch_size, max_seq_length=model_params.max_seq_len, random_scale_factor=model_params.random_scale_factor, augment_stroke_prob=model_params.augment_stroke_prob, augment_flipr_flag=model_params.flip_aug) normalizing_scale_factor = sbir_data[ 'train_set'].calculate_normalizing_scale_factor() sbir_data['%s_set' % subset].normalize(normalizing_scale_factor) print('normalizing_scale_factor %4.4f.' % normalizing_scale_factor) # return_model_params = eval_model_params if inference_mode else model_params # return sbir_data['%s_set' % subset], return_model_params if not inference_mode: return sbir_data['%s_set' % subset], model_params else: return sbir_data['%s_set' % subset], sample_model_params, gen_model_params