def load_env(data_dir, model_dir): """Loads environment for inference mode, used in jupyter notebook.""" model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: model_config = json.load(f) model_params.update(model_config) print("data_dir", data_dir) return load_dataset(data_dir, model_params, inference_mode=True)
def main(unused_argv): tf.logging.info('train-main-导入模型参数=======================================') """Load model params, save config file and start trainer.""" model_params = sketch_rnn_model.get_default_hparams() #得到网络参数 if FLAGS.hparams: model_params.parse(FLAGS.hparams) #将参数写成 key = value的格式 tf.logging.info('train-main-开始trainer') trainer(model_params)
def main(**kwargs): data_base_dir = kwargs['data_base_dir'] render_mode = kwargs['render_mode'] npz_dir = os.path.join(data_base_dir, 'npz') svg_dir = os.path.join(data_base_dir, 'svg') png_dir = os.path.join(data_base_dir, 'png') model_params = sketch_rnn_model.get_default_hparams() for dataset_i in range(len(model_params.data_set)): assert model_params.data_set[dataset_i][-4:] == '.npz' cate_svg_dir = os.path.join(svg_dir, model_params.data_set[dataset_i][:-4]) cate_png_dir = os.path.join(png_dir, model_params.data_set[dataset_i][:-4]) datasets = load_dataset(data_base_dir, model_params) data_types = ['train', 'valid', 'test'] for d_i, data_type in enumerate(data_types): split_cate_svg_dir = os.path.join(cate_svg_dir, data_type) split_cate_png_dir = os.path.join( cate_png_dir, data_type, str(model_params.img_H) + 'x' + str(model_params.img_W)) os.makedirs(split_cate_svg_dir, exist_ok=True) os.makedirs(split_cate_png_dir, exist_ok=True) split_dataset = datasets[d_i] for ex_idx in range(len(split_dataset.strokes)): stroke = np.copy(split_dataset.strokes[ex_idx]) print('example_idx', ex_idx, 'stroke.shape', stroke.shape) png_path = split_dataset.png_paths[ex_idx] assert split_cate_png_dir == png_path[:len(split_cate_png_dir)] actual_idx = png_path[len(split_cate_png_dir) + 1:-4] svg_path = os.path.join(split_cate_svg_dir, str(actual_idx) + '.svg') svg_size, dwg_bytestring = draw_strokes(stroke, svg_path, padding=10) # (w, h) if render_mode == 'v1': svg2png_v1(svg_path, svg_size, (model_params.img_W, model_params.img_H), png_path, padding=True) elif render_mode == 'v2': svg2png_v2(dwg_bytestring, svg_size, (model_params.img_W, model_params.img_H), png_path, padding=True) else: raise Exception('Error: unknown rendering mode.')
def main(unused_argv): """Load model params, save config file and start trainer.""" print("USING GPU " + str(FLAGS.gpu)) os.environ["CUDA_VISIBLE_DEVICES"] = str( FLAGS.gpu) # make TF only use the first GPU. model_params = sketch_rnn_model.get_default_hparams() if FLAGS.hparams: model_params.parse(FLAGS.hparams) trainer(model_params)
def get_srm_hparams(hps): hps_srm = sketch_rnn.get_default_hparams() path = os.path.join(hps.save_dir, hps.srm_name, "model_config.json") with tf.gfile.Open(path, "r") as input_stream: hps_srm.parse_json(input_stream.read()) hps_srm.is_training = False hps_srm.use_input_dropout = 0 hps_srm.use_recurrent_dropout = 0 hps_srm.use_output_dropout = 0 hps_srm.is_training = 0 return hps_srm
def load_env_compatible(data_dir, model_dir): """Loads environment for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) return load_dataset(data_dir, model_params, inference_mode=True)
def load_model(model_dir): """Loads model for inference mode, used in jupyter notebook.""" model_params = get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: model_params.parse_json(f.read()) model_params.batch_size = 1 # only sample one at a time eval_model_params = copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params]
def main(unused_argv): """Load model params, save config file and start trainer.""" model_params = sketch_rnn_model.get_default_hparams() if FLAGS.hparams: model_params.parse(FLAGS.hparams) trainer(model_params) # Plot graphs print("Plotting the graphs\n") t = np.arange(0, 20 * len(cost_list), 20) plt.subplot(211) plt.plot(t, cost_list, 'bo') plt.subplot(212) plt.plot(t, time_taken_list, 'bo') plt.savefig("cost_time.png")
def main(): data_base_dir = 'datasets/QuickDraw' model_params = sketch_p2s_model.get_default_hparams() for dataset_i in range(len(model_params.data_set)): data_set = model_params.data_set[dataset_i] sub_data_base_dir = os.path.join(data_base_dir, data_set) cate_npz_dir = os.path.join(sub_data_base_dir, 'npz') cate_svg_dir = os.path.join(sub_data_base_dir, 'svg') cate_png_dir = os.path.join(sub_data_base_dir, 'png') datasets = load_dataset(data_base_dir, cate_png_dir, model_params) data_splits = ['train', 'valid', 'test'] for d_i, data_split in enumerate(data_splits): if data_split == 'valid': continue split_cate_svg_dir = os.path.join(cate_svg_dir, data_split) split_cate_png_dir = os.path.join( cate_png_dir, data_split, str(model_params.image_size) + 'x' + str(model_params.image_size)) os.makedirs(split_cate_svg_dir, exist_ok=True) os.makedirs(split_cate_png_dir, exist_ok=True) split_dataset = datasets[d_i] for ex_idx in range(len(split_dataset.strokes)): stroke = np.copy(split_dataset.strokes[ex_idx]) print('example_idx', ex_idx, 'stroke.shape', stroke.shape) img_path = split_dataset.img_paths[ex_idx] assert split_cate_png_dir == img_path[:len(split_cate_png_dir)] actual_idx = img_path[len(split_cate_png_dir) + 1:-4] svg_path = os.path.join(split_cate_svg_dir, str(actual_idx) + '.svg') svg_size, dwg_bytestring = draw_strokes( stroke, svg_path, padding=10, make_png=False) # (w, h) # svg2png_v1(svg_path, svg_size, (model_params.image_size, model_params.image_size), # png_path, padding=True) svg2png_v2(dwg_bytestring, svg_size, pngsize=(model_params.image_size, model_params.image_size), png_filename=img_path, padding=True)
def load_model(model_dir): """Loads model for inference mode, used in jupyter notebook.""" model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: model_params.parse_json(f.read()) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params]
def load_model(model_dir): """Loads model for inference mode, used in jupyter notebook.""" model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: model_params.parse_json(f.read()) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time if six.PY3: pretrained_model_params = np.load(model_dir+'/model', encoding='latin1') else: pretrained_model_params = np.load(model_dir+'/model') return [model_params, eval_model_params, sample_model_params, pretrained_model_params]
def load_model_compatible(model_dir): """Loads model for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params]
def main(unused_argv): """Load model params, save config file and start trainer.""" model_params = sketch_rnn_model.get_default_hparams() if FLAGS.hparams: model_params.parse(FLAGS.hparams) trainer(model_params)
def main(unused_argv): """Load model params, save config file and start training.""" model_params = get_default_hparams() # load defualt params if FLAGS.hparams: model_params.parse(FLAGS.hparams) # reload params trainer(model_params)
def load_env(data_dir, model_dir): """Loads environment for inference mode, used in jupyter notebook.""" model_params = get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: model_params.parse_json(f.read()) return load_dataset(data_dir, model_params, testing_mode=True)
def main(unused_argv): model_params = Model.get_default_hparams() trainer(model_params)