def _hparams_from_flags(): keys = (""" dataset quantization_level num_instruments separate_instruments crop_piece_len architecture num_layers num_filters use_residual batch_size maskout_method mask_indicates_context optimize_mask_only rescale_loss patience corrupt_ratio eval_freq run_id """.split()) hparams = lib_hparams.Hyperparameters(**dict( (key, getattr(FLAGS, key)) for key in keys)) return hparams
def _hparams_from_flags(): """Instantiate hparams based on flags set in FLAGS.""" keys = (""" dataset quantization_level num_instruments separate_instruments crop_piece_len architecture use_sep_conv num_initial_regular_conv_layers sep_conv_depth_multiplier num_dilation_blocks dilate_time_only num_layers num_filters use_residual batch_size maskout_method mask_indicates_context optimize_mask_only rescale_loss patience corrupt_ratio eval_freq run_id """.split()) hparams = lib_hparams.Hyperparameters(**dict( (key, getattr(FLAGS, key)) for key in keys)) return hparams
def save_checkpoint(self): logdir = tempfile.mkdtemp() save_path = os.path.join(logdir, 'model.ckpt') hparams = lib_hparams.Hyperparameters(**{}) tf.gfile.MakeDirs(logdir) config_fpath = os.path.join(logdir, 'config') with tf.gfile.Open(config_fpath, 'w') as p: hparams.dump(p) with tf.Graph().as_default(): lib_graph.build_graph(is_training=True, hparams=hparams) sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.save(sess, save_path) return logdir