Пример #1
0
def load_env(data_dir, model_dir):
    """Loads environment for inference mode, used in jupyter notebook."""
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_config = json.load(f)
        model_params.update(model_config)
    print("data_dir", data_dir)
    return load_dataset(data_dir, model_params, inference_mode=True)
Пример #2
0
def main(unused_argv):
    tf.logging.info('train-main-导入模型参数=======================================')
    """Load model params, save config file and start trainer."""
    model_params = sketch_rnn_model.get_default_hparams()  #得到网络参数
    if FLAGS.hparams:
        model_params.parse(FLAGS.hparams)  #将参数写成 key = value的格式
    tf.logging.info('train-main-开始trainer')
    trainer(model_params)
Пример #3
0
def main(**kwargs):
    data_base_dir = kwargs['data_base_dir']
    render_mode = kwargs['render_mode']

    npz_dir = os.path.join(data_base_dir, 'npz')
    svg_dir = os.path.join(data_base_dir, 'svg')
    png_dir = os.path.join(data_base_dir, 'png')

    model_params = sketch_rnn_model.get_default_hparams()
    for dataset_i in range(len(model_params.data_set)):
        assert model_params.data_set[dataset_i][-4:] == '.npz'
        cate_svg_dir = os.path.join(svg_dir,
                                    model_params.data_set[dataset_i][:-4])
        cate_png_dir = os.path.join(png_dir,
                                    model_params.data_set[dataset_i][:-4])

        datasets = load_dataset(data_base_dir, model_params)

        data_types = ['train', 'valid', 'test']
        for d_i, data_type in enumerate(data_types):
            split_cate_svg_dir = os.path.join(cate_svg_dir, data_type)
            split_cate_png_dir = os.path.join(
                cate_png_dir, data_type,
                str(model_params.img_H) + 'x' + str(model_params.img_W))

            os.makedirs(split_cate_svg_dir, exist_ok=True)
            os.makedirs(split_cate_png_dir, exist_ok=True)

            split_dataset = datasets[d_i]

            for ex_idx in range(len(split_dataset.strokes)):
                stroke = np.copy(split_dataset.strokes[ex_idx])
                print('example_idx', ex_idx, 'stroke.shape', stroke.shape)

                png_path = split_dataset.png_paths[ex_idx]
                assert split_cate_png_dir == png_path[:len(split_cate_png_dir)]
                actual_idx = png_path[len(split_cate_png_dir) + 1:-4]
                svg_path = os.path.join(split_cate_svg_dir,
                                        str(actual_idx) + '.svg')

                svg_size, dwg_bytestring = draw_strokes(stroke,
                                                        svg_path,
                                                        padding=10)  # (w, h)

                if render_mode == 'v1':
                    svg2png_v1(svg_path,
                               svg_size,
                               (model_params.img_W, model_params.img_H),
                               png_path,
                               padding=True)
                elif render_mode == 'v2':
                    svg2png_v2(dwg_bytestring,
                               svg_size,
                               (model_params.img_W, model_params.img_H),
                               png_path,
                               padding=True)
                else:
                    raise Exception('Error: unknown rendering mode.')
Пример #4
0
def main(unused_argv):
    """Load model params, save config file and start trainer."""
    print("USING GPU " + str(FLAGS.gpu))
    os.environ["CUDA_VISIBLE_DEVICES"] = str(
        FLAGS.gpu)  # make TF only use the first GPU.
    model_params = sketch_rnn_model.get_default_hparams()
    if FLAGS.hparams:
        model_params.parse(FLAGS.hparams)
    trainer(model_params)
Пример #5
0
def get_srm_hparams(hps):
    hps_srm = sketch_rnn.get_default_hparams()
    path = os.path.join(hps.save_dir, hps.srm_name, "model_config.json")
    with tf.gfile.Open(path, "r") as input_stream:
        hps_srm.parse_json(input_stream.read())
    hps_srm.is_training = False
    hps_srm.use_input_dropout = 0
    hps_srm.use_recurrent_dropout = 0
    hps_srm.use_output_dropout = 0
    hps_srm.is_training = 0
    return hps_srm
Пример #6
0
def load_env_compatible(data_dir, model_dir):
    """Loads environment for inference mode, used in jupyter notebook."""
    # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
    # to work with depreciated tf.HParams functionality
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        data = json.load(f)
    fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
    for fix in fix_list:
        data[fix] = (data[fix] == 1)
    model_params.parse_json(json.dumps(data))

    return load_dataset(data_dir, model_params, inference_mode=True)
Пример #7
0
def load_model(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    model_params = get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_params.parse_json(f.read())

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    return [model_params, eval_model_params, sample_model_params]
def main(unused_argv):
    """Load model params, save config file and start trainer."""
    model_params = sketch_rnn_model.get_default_hparams()
    if FLAGS.hparams:
        model_params.parse(FLAGS.hparams)
    trainer(model_params)
    # Plot graphs
    print("Plotting the graphs\n")
    t = np.arange(0, 20 * len(cost_list), 20)
    plt.subplot(211)
    plt.plot(t, cost_list, 'bo')
    plt.subplot(212)
    plt.plot(t, time_taken_list, 'bo')
    plt.savefig("cost_time.png")
Пример #9
0
def main():
    data_base_dir = 'datasets/QuickDraw'

    model_params = sketch_p2s_model.get_default_hparams()
    for dataset_i in range(len(model_params.data_set)):
        data_set = model_params.data_set[dataset_i]
        sub_data_base_dir = os.path.join(data_base_dir, data_set)
        cate_npz_dir = os.path.join(sub_data_base_dir, 'npz')
        cate_svg_dir = os.path.join(sub_data_base_dir, 'svg')
        cate_png_dir = os.path.join(sub_data_base_dir, 'png')

        datasets = load_dataset(data_base_dir, cate_png_dir, model_params)

        data_splits = ['train', 'valid', 'test']
        for d_i, data_split in enumerate(data_splits):
            if data_split == 'valid':
                continue

            split_cate_svg_dir = os.path.join(cate_svg_dir, data_split)
            split_cate_png_dir = os.path.join(
                cate_png_dir, data_split,
                str(model_params.image_size) + 'x' +
                str(model_params.image_size))
            os.makedirs(split_cate_svg_dir, exist_ok=True)
            os.makedirs(split_cate_png_dir, exist_ok=True)

            split_dataset = datasets[d_i]

            for ex_idx in range(len(split_dataset.strokes)):
                stroke = np.copy(split_dataset.strokes[ex_idx])
                print('example_idx', ex_idx, 'stroke.shape', stroke.shape)

                img_path = split_dataset.img_paths[ex_idx]
                assert split_cate_png_dir == img_path[:len(split_cate_png_dir)]
                actual_idx = img_path[len(split_cate_png_dir) + 1:-4]
                svg_path = os.path.join(split_cate_svg_dir,
                                        str(actual_idx) + '.svg')

                svg_size, dwg_bytestring = draw_strokes(
                    stroke, svg_path, padding=10, make_png=False)  # (w, h)

                # svg2png_v1(svg_path, svg_size, (model_params.image_size, model_params.image_size),
                #            png_path, padding=True)
                svg2png_v2(dwg_bytestring,
                           svg_size,
                           pngsize=(model_params.image_size,
                                    model_params.image_size),
                           png_filename=img_path,
                           padding=True)
Пример #10
0
def load_model(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_params.parse_json(f.read())

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = sketch_rnn_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    return [model_params, eval_model_params, sample_model_params]
Пример #11
0
def load_model(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_params.parse_json(f.read())

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = sketch_rnn_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    if six.PY3:
        pretrained_model_params = np.load(model_dir+'/model', encoding='latin1')
    else:
        pretrained_model_params = np.load(model_dir+'/model')
    return [model_params, eval_model_params, sample_model_params, pretrained_model_params]
Пример #12
0
def load_model_compatible(model_dir):
    """Loads model for inference mode, used in jupyter notebook."""
    # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
    # to work with depreciated tf.HParams functionality
    model_params = sketch_rnn_model.get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        data = json.load(f)
    fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
    for fix in fix_list:
        data[fix] = (data[fix] == 1)
    model_params.parse_json(json.dumps(data))

    model_params.batch_size = 1  # only sample one at a time
    eval_model_params = sketch_rnn_model.copy_hparams(model_params)
    eval_model_params.use_input_dropout = 0
    eval_model_params.use_recurrent_dropout = 0
    eval_model_params.use_output_dropout = 0
    eval_model_params.is_training = 0
    sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
    sample_model_params.max_seq_len = 1  # sample one point at a time
    return [model_params, eval_model_params, sample_model_params]
Пример #13
0
def main(unused_argv):
    """Load model params, save config file and start trainer."""
    model_params = sketch_rnn_model.get_default_hparams()
    if FLAGS.hparams:
        model_params.parse(FLAGS.hparams)
    trainer(model_params)
Пример #14
0
def main(unused_argv):
    """Load model params, save config file and start training."""
    model_params = get_default_hparams()  # load defualt params
    if FLAGS.hparams:
        model_params.parse(FLAGS.hparams)  # reload params
    trainer(model_params)
Пример #15
0
def load_env(data_dir, model_dir):
    """Loads environment for inference mode, used in jupyter notebook."""
    model_params = get_default_hparams()
    with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        model_params.parse_json(f.read())
    return load_dataset(data_dir, model_params, testing_mode=True)
Пример #16
0
def main(unused_argv):
    model_params = Model.get_default_hparams()

    trainer(model_params)