def __init__( self, reward_range, observation_space, action_space, frame_stack_size, initial_frame_chooser, batch_size, model_name, model_hparams, model_dir, intrinsic_reward_scale=0.0 ): """Batch of environments inside the TensorFlow graph.""" super(SimulatedBatchEnv, self).__init__(observation_space, action_space) self.batch_size = batch_size self._min_reward = reward_range[0] self._num_frames = frame_stack_size self._intrinsic_reward_scale = intrinsic_reward_scale model_hparams = copy.copy(model_hparams) problem = DummyWorldModelProblem(action_space, reward_range) trainer_lib.add_problem_hparams(model_hparams, problem) model_hparams.force_full_predict = True self._model = registry.model(model_name)( model_hparams, tf.estimator.ModeKeys.PREDICT ) self.history_buffer = HistoryBuffer( initial_frame_chooser, self.observ_shape, self.observ_dtype, self._num_frames, self.batch_size ) self._observ = tf.Variable( tf.zeros((batch_size,) + self.observ_shape, self.observ_dtype), trainable=False ) self._reset_model = tf.get_variable( "reset_model", [], trainable=False, initializer=tf.zeros_initializer()) self._model_dir = model_dir
def __init__(self, environment_spec, length): """Batch of environments inside the TensorFlow graph.""" super(SimulatedBatchEnv, self).__init__(environment_spec.observation_space, environment_spec.action_space) self.length = length self._min_reward = environment_spec.reward_range[0] self._num_frames = environment_spec.video_num_input_frames self._intrinsic_reward_scale = environment_spec.intrinsic_reward_scale model_hparams = copy.copy(environment_spec.model_hparams) problem = DummyWorldModelProblem(environment_spec.action_space, environment_spec.reward_range) trainer_lib.add_problem_hparams(model_hparams, problem) model_hparams.force_full_predict = True self._model = registry.model(environment_spec.model_name)( model_hparams, tf.estimator.ModeKeys.PREDICT) self.history_buffer = HistoryBuffer( environment_spec.initial_frame_chooser, self.observ_shape, self.observ_dtype, self._num_frames, self.length) self._observ = tf.Variable(tf.zeros((len(self), ) + self.observ_shape, self.observ_dtype), trainable=False)
def get_policy(observations, hparams, action_space): """Get a policy network. Args: observations: observations hparams: parameters action_space: action space Returns: Tuple (action logits, value). """ if not isinstance(action_space, gym.spaces.Discrete): raise ValueError("Expecting discrete action space.") policy_problem = DummyPolicyProblem(action_space) trainer_lib.add_problem_hparams(hparams, policy_problem) hparams.force_full_predict = True model = registry.model(hparams.policy_network)(hparams, tf.estimator.ModeKeys.TRAIN) obs_shape = common_layers.shape_list(observations) features = { "inputs": observations, "input_action": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32), "input_reward": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32), "targets": tf.zeros(obs_shape[:1] + [1] + obs_shape[2:]), "target_action": tf.zeros(obs_shape[:1] + [1, 1], dtype=tf.int32), "target_reward": tf.zeros(obs_shape[:1] + [1, 1], dtype=tf.int32), "target_policy": tf.zeros(obs_shape[:1] + [1] + [action_space.n]), "target_value": tf.zeros(obs_shape[:1] + [1]) } with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): t2t_model.create_dummy_vars() (targets, _) = model(features) return (targets["target_policy"], targets["target_value"])
def run(): """ Load Transformer model according to flags and start sampling. :raises: ValueError: if required flags are missing or invalid. """ if FLAGS.model_path is None: raise ValueError('Required Transformer pre-trained model path.') if FLAGS.output_dir is None: raise ValueError('Required Midi output directory.') if FLAGS.decode_length <= 0: raise ValueError('Decode length must be > 0.') problem = utils.PianoPerformanceLanguageModelProblem() unconditional_encoders = problem.get_feature_encoders() primer_ns = music_pb2.NoteSequence() if FLAGS.primer_path is None: targets = [] else: if FLAGS.max_primer_second <= 0: raise ValueError('Max primer second must be > 0.') primer_ns = utils.get_primer_ns(FLAGS.primer_path, FLAGS.max_primer_second) targets = unconditional_encoders['targets'].encode_note_sequence( primer_ns) # Remove the end token from the encoded primer. targets = targets[:-1] if len(targets) >= FLAGS.decode_length: raise ValueError( 'Primer has more or equal events than maximum sequence length:' ' %d >= %d; Aborting' % (len(targets), FLAGS.decode_length)) decode_length = FLAGS.decode_length - len(targets) # Set up HParams. hparams = trainer_lib.create_hparams(hparams_set=FLAGS.hparams_set) trainer_lib.add_problem_hparams(hparams, problem) hparams.num_hidden_layers = FLAGS.layers hparams.sampling_method = FLAGS.sample # Set up decoding HParams. decode_hparams = decoding.decode_hparams() decode_hparams.alpha = FLAGS.alpha decode_hparams.beam_size = FLAGS.beam_size # Create Estimator. utils.LOGGER.info('Loading model') run_config = trainer_lib.create_run_config(hparams) estimator = trainer_lib.create_estimator(FLAGS.model_name, hparams, run_config, decode_hparams=decode_hparams) generate(estimator, unconditional_encoders, decode_length, targets, primer_ns)
def get_policy(observations, hparams, action_space): """Get a policy network. Args: observations: observations hparams: parameters action_space: action space Returns: Tuple (action logits, value). """ if not isinstance(action_space, gym.spaces.Discrete): raise ValueError("Expecting discrete action space.") obs_shape = common_layers.shape_list(observations) (frame_height, frame_width) = obs_shape[2:4] # TODO(afrozm): We have these dummy problems mainly for hparams, so cleanup # when possible and do this properly. if hparams.policy_problem_name == "dummy_policy_problem_ttt": tf.logging.info("Using DummyPolicyProblemTTT for the policy.") policy_problem = tic_tac_toe_env.DummyPolicyProblemTTT() else: tf.logging.info("Using DummyPolicyProblem for the policy.") policy_problem = DummyPolicyProblem(action_space, frame_height, frame_width) trainer_lib.add_problem_hparams(hparams, policy_problem) hparams.force_full_predict = True model = registry.model(hparams.policy_network)(hparams, tf.estimator.ModeKeys.TRAIN) try: num_target_frames = hparams.video_num_target_frames except AttributeError: num_target_frames = 1 features = { "inputs": observations, "input_action": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32), "input_reward": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32), "targets": tf.zeros(obs_shape[:1] + [num_target_frames] + obs_shape[2:]), "target_action": tf.zeros(obs_shape[:1] + [num_target_frames, 1], dtype=tf.int32), "target_reward": tf.zeros(obs_shape[:1] + [num_target_frames, 1], dtype=tf.int32), "target_policy": tf.zeros(obs_shape[:1] + [num_target_frames] + [action_space.n]), "target_value": tf.zeros(obs_shape[:1] + [num_target_frames]) } with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): t2t_model.create_dummy_vars() (targets, _) = model(features) return (targets["target_policy"][:, 0, :], targets["target_value"][:, 0])
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.generate_data: t2t_trainer.generate_data() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) pruning_params = create_pruning_params() pruning_strategy = create_pruning_strategy(pruning_params.strategy) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem) input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, labels = dataset.make_one_shot_iterator().get_next() sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) spec = model_fn( features, labels, tf.estimator.ModeKeys.EVAL, params=hparams, config=config) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) def eval_model(): preds = spec.predictions["predictions"] preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.generate_data: t2t_trainer.generate_data() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) pruning_params = create_pruning_params() pruning_strategy = create_pruning_strategy(pruning_params.strategy) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem) input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, labels = dataset.make_one_shot_iterator().get_next() sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) spec = model_fn( features, labels, tf.estimator.ModeKeys.EVAL, params=hparams, config=config) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) def eval_model(): preds = spec.predictions["predictions"] preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
def main(argv): config = epl.Config({"cluster.colocate_split_and_replicate": True}) epl.init(config) FLAGS.worker_id = epl.Env.get().cluster.worker_index FLAGS.worker_gpu = epl.Env.get().cluster.total_gpu_num epl.set_default_strategy(epl.replicate(FLAGS.worker_gpu)) # Create HParams. if argv: set_hparams_from_args(argv[1:]) if FLAGS.schedule != "run_std_server": hparams = create_hparams() if FLAGS.schedule == "train": mlperf_log.transformer_print(key=mlperf_log.RUN_START) else: raise RuntimeError( "Support training tasks only for now, you can define tasks in other modes." ) trainer_lib.set_random_seed(FLAGS.random_seed) hparams.add_hparam("data_dir", FLAGS.data_dir) hparams.add_hparam("schedule", FLAGS.schedule) hparams.add_hparam("train_steps", FLAGS.train_steps) hparams.add_hparam("warm_start_from", None) trainer_lib.add_problem_hparams(hparams, FLAGS.problem) # Dataset generation. if FLAGS.generate_data: generate_data() def model_fn_replicate(features, labels, mode): model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) return model_fn(features, labels, mode) if is_chief(): save_metadata(hparams) estimator = tf.estimator.Estimator(model_fn=model_fn_replicate, config=create_run_config()) hooks = [] hooks.append( tf.train.StepCounterHook(every_n_steps=FLAGS.log_step_count_steps)) optimize.log_variable_sizes(verbose=True) problem = hparams.problem train_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.TRAIN, hparams) estimator.train(train_input_fn, max_steps=hparams.train_steps, hooks=hooks)
def __init__( self, reward_range, observation_space, action_space, frame_stack_size, frame_height, frame_width, initial_frame_chooser, batch_size, model_name, model_hparams, model_dir, intrinsic_reward_scale=0.0, sim_video_dir=None ): """Batch of environments inside the TensorFlow graph.""" super(SimulatedBatchEnv, self).__init__(observation_space, action_space) self._ffmpeg_works = common_video.ffmpeg_works() self.batch_size = batch_size self._min_reward = reward_range[0] self._num_frames = frame_stack_size self._intrinsic_reward_scale = intrinsic_reward_scale self._episode_counter = tf.get_variable( "episode_counter", initializer=tf.zeros((), dtype=tf.int32), trainable=False, dtype=tf.int32) if sim_video_dir: self._video_every_epochs = 100 self._video_dir = sim_video_dir self._video_writer = None self._video_counter = 0 tf.gfile.MakeDirs(self._video_dir) self._video_condition = tf.equal( self._episode_counter.read_value() % self._video_every_epochs, 0) else: self._video_condition = tf.constant(False, dtype=tf.bool, shape=()) model_hparams = copy.copy(model_hparams) problem = DummyWorldModelProblem(action_space, reward_range, frame_height, frame_width) trainer_lib.add_problem_hparams(model_hparams, problem) model_hparams.force_full_predict = True self._model = registry.model(model_name)( model_hparams, tf.estimator.ModeKeys.PREDICT ) self.history_buffer = HistoryBuffer( initial_frame_chooser, self.observ_shape, self.observ_dtype, self._num_frames, self.batch_size ) self._observ = tf.Variable( tf.zeros((batch_size,) + self.observ_shape, self.observ_dtype), trainable=False ) self._reset_model = tf.get_variable( "reset_model", [], trainable=False, initializer=tf.zeros_initializer()) self._model_dir = model_dir
def _load_hparams(path): with open(os.path.join(path, 'hparams.json'), 'rb') as json_file: hparams_dict = { k.encode('utf-8'): v.encode('utf-8') if type(v) == unicode else v for k, v in json.load(json_file).iteritems() } hparams = HParams(**hparams_dict) hparams.set_hparam('data_dir', path) trainer_lib.add_problem_hparams(hparams, 'translate_mmt') # Removing dropout from HParams even on TRAIN mode for key in hparams.values(): if key.endswith("dropout"): setattr(hparams, key, 0.0) return hparams
def initialize(self, is_conditioned=False): self.model_name = 'transformer' self.hparams_set = 'transformer_tpu' self.conditioned = is_conditioned if self.conditioned: self.ckpt_path = 'models/checkpoints/melody_conditioned_model_16.ckpt' problem = MelodyToPianoPerformanceProblem() else: self.ckpt_path = 'models/checkpoints/unconditional_model_16.ckpt' problem = PianoPerformanceLanguageModelProblem() self.encoders = problem.get_feature_encoders() # Set up hyperparams hparams = trainer_lib.create_hparams(hparams_set=self.hparams_set) trainer_lib.add_problem_hparams(hparams, problem) hparams.num_hidden_layers = 16 hparams.sampling_method = 'random' # Set up decoding hyperparams decode_hparams = decoding.decode_hparams() decode_hparams.alpha = 0.0 decode_hparams.beam_size = 1 if self.conditioned: self.inputs = [] else: self.targets = [] self.decode_length = 0 run_config = trainer_lib.create_run_config(hparams) estimator = trainer_lib.create_estimator( self.model_name, hparams, run_config, decode_hparams=decode_hparams) fnc = self.input_generation_conditional if self.conditioned else self.input_generator_unconditional input_fn = decoding.make_input_fn_from_generator(fnc()) self.samples = estimator.predict( input_fn, checkpoint_path=self.ckpt_path) _ = next(self.samples)
def __init__( self, reward_range, observation_space, action_space, frame_stack_size, frame_height, frame_width, initial_frame_chooser, batch_size, model_name, model_hparams, model_dir, intrinsic_reward_scale=0.0 ): """Batch of environments inside the TensorFlow graph.""" super(SimulatedBatchEnv, self).__init__(observation_space, action_space) self.batch_size = batch_size self._min_reward = reward_range[0] self._num_frames = frame_stack_size self._intrinsic_reward_scale = intrinsic_reward_scale model_hparams = copy.copy(model_hparams) problem = DummyWorldModelProblem(action_space, reward_range, frame_height, frame_width) trainer_lib.add_problem_hparams(model_hparams, problem) model_hparams.force_full_predict = True self._model = registry.model(model_name)( model_hparams, tf.estimator.ModeKeys.PREDICT ) self.history_buffer = HistoryBuffer( initial_frame_chooser, self.observ_shape, self.observ_dtype, self._num_frames, self.batch_size ) self._observ = tf.Variable( tf.zeros((batch_size,) + self.observ_shape, self.observ_dtype), trainable=False ) self._reset_model = tf.get_variable( "reset_model", [], trainable=False, initializer=tf.zeros_initializer()) self._model_dir = model_dir
hparams_set = 'transformer_tpu' ckpt_path = '../assets/checkpoints/unconditional_model_16.ckpt' class PianoPerformanceLanguageModelProblem(score2perf.Score2PerfProblem): @property def add_eos_symbol(self): return True problem = PianoPerformanceLanguageModelProblem() unconditional_encoders = problem.get_feature_encoders() # Set up HParams. hparams = trainer_lib.create_hparams(hparams_set=hparams_set) trainer_lib.add_problem_hparams(hparams, problem) hparams.num_hidden_layers = 16 hparams.sampling_method = 'random' # Set up decoding HParams. decode_hparams = decoding.decode_hparams() decode_hparams.alpha = 0.0 decode_hparams.beam_size = 1 # Create Estimator. run_config = trainer_lib.create_run_config(hparams) estimator = trainer_lib.create_estimator(model_name, hparams, run_config, decode_hparams=decode_hparams)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) if FLAGS.surrogate_attack: tf.logging.warn("Performing surrogate model attack.") sur_hparams = create_surrogate_hparams() trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam(attack_params.epsilon_name, 0.0) if FLAGS.surrogate_attack: sur_config = create_surrogate_run_config(sur_hparams) config = t2t_trainer.create_run_config(hparams) params = { "batch_size": hparams.batch_size, "use_tpu": FLAGS.use_tpu, } # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") inputs, labels, features = prepare_data(problem, hparams, params, config) sess = tf.Session() if FLAGS.surrogate_attack: sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu) sur_ch_model = adv_attack_utils.T2TAttackModel( sur_model_fn, features, params, sur_config, scope="surrogate") # Dummy call to construct graph sur_ch_model.get_probs(inputs) checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir) tf.train.init_from_checkpoint( tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"}) sess.run(tf.global_variables_initializer()) other_vars = set(tf.global_variables()) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1, output_type=labels.dtype) preds = tf.reshape(preds, labels.shape) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) if FLAGS.surrogate_attack: attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess) else: attack = create_attack(attack_params.attack)(ch_model, sess=sess) new_vars = set(tf.global_variables()) - other_vars # Restore weights saver = tf.train.Saver(new_vars) checkpoint_path = os.path.expanduser(FLAGS.output_dir) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, l, mask): """Compute model accuracy.""" preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask) if FLAGS.surrogate_attack: preds = sur_ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) acc_update_op = tf.tuple((acc_update_op, tf.metrics.accuracy(l, preds, weights=mask)[1])) sess.run(tf.initialize_local_variables()) for i in range(FLAGS.eval_steps): tf.logging.info( "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps)) acc = sess.run(acc_update_op) if FLAGS.surrogate_attack: tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1])) else: tf.logging.info("\tFinal acc: %.4f" % acc) return acc epsilon_acc_pairs = [] for epsilon in attack_params.attack_epsilons: tf.logging.info("Attacking @ eps=%.4f" % epsilon) attack_params.set_hparam(attack_params.epsilon_name, epsilon) adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: if FLAGS.surrogate_attack: tf.logging.info( "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1])) else: tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
def create_experiment( run_config, hparams, model_name, problem_name, data_dir, train_steps, eval_steps, min_eval_frequency=2000, eval_throttle_seconds=600, schedule="train_and_evaluate", export=False, decode_hparams=None, use_tfdbg=False, use_dbgprofile=False, eval_early_stopping_steps=None, eval_early_stopping_metric=None, eval_early_stopping_metric_delta=None, eval_early_stopping_metric_minimize=True, eval_timeout_mins=240, use_tpu=False, use_tpu_estimator=False, use_xla=False, additional_train_hooks=None, additional_eval_hooks=None, warm_start_from=None, decode_from_file=None, decode_to_file=None, decode_reference=None, std_server_protocol=None): """Create Experiment.""" # HParams hparams.add_hparam("model_dir", run_config.model_dir) hparams.add_hparam("data_dir", data_dir) hparams.add_hparam("train_steps", train_steps) hparams.add_hparam("eval_steps", eval_steps) hparams.add_hparam("schedule", schedule) hparams.add_hparam("warm_start_from", warm_start_from) hparams.add_hparam("std_server_protocol", std_server_protocol) hparams.add_hparam("eval_freq_in_steps", min_eval_frequency) hparams.add_hparam("eval_timeout_mins", eval_timeout_mins) if decode_hparams is not None: decode_hparams.add_hparam("decode_from_file", decode_from_file) decode_hparams.add_hparam("decode_to_file", decode_to_file) decode_hparams.add_hparam("decode_reference", decode_reference) trainer_lib.add_problem_hparams(hparams, problem_name) # Estimator estimator = trainer_lib.create_estimator( model_name, hparams, run_config, schedule=schedule, decode_hparams=decode_hparams, use_tpu=use_tpu, use_tpu_estimator=use_tpu_estimator, use_xla=use_xla) # Input fns from Problem problem = hparams.problem train_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.TRAIN, hparams, dataset_kwargs={"max_records": FLAGS.train_data_size}) eval_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams) # Export exporter = None if export: def compare_fn(best_eval_result, current_eval_result): metric = eval_early_stopping_metric or "loss" return current_eval_result[metric] < best_eval_result[metric] exporter = tf.estimator.BestExporter( name="best", serving_input_receiver_fn=lambda: problem.serving_input_fn(hparams), compare_fn=compare_fn, assets_extra=problem.export_assets) # Hooks validation_monitor_kwargs = dict( input_fn=eval_input_fn, eval_steps=eval_steps, every_n_steps=min_eval_frequency, early_stopping_rounds=eval_early_stopping_steps, early_stopping_metric=eval_early_stopping_metric, early_stopping_metric_minimize=eval_early_stopping_metric_minimize) dbgprofile_kwargs = {"output_dir": run_config.model_dir} early_stopping_kwargs = dict( events_dir=os.path.join(run_config.model_dir, "eval_continuous"), tag=eval_early_stopping_metric, num_plateau_steps=eval_early_stopping_steps, plateau_decrease=eval_early_stopping_metric_minimize, plateau_delta=eval_early_stopping_metric_delta, every_n_steps=min_eval_frequency) # Eval on TPU Pods is not supported yet if use_tpu and run_config.tpu_config.num_shards > 8 and "eval" in schedule: raise ValueError("Eval is not currently supported on a TPU Pod") # In-process eval (and possible early stopping) if schedule == "continuous_train_and_eval" and min_eval_frequency: tf.logging.warn("ValidationMonitor only works with " "--schedule=train_and_evaluate") use_validation_monitor = ( schedule == "train_and_evaluate" and min_eval_frequency) # Distributed early stopping local_schedules = ["train_and_evaluate", "continuous_train_and_eval"] use_early_stopping = ( schedule not in local_schedules and eval_early_stopping_steps) train_hooks, eval_hooks = trainer_lib.create_hooks( use_tfdbg=use_tfdbg, use_dbgprofile=use_dbgprofile, dbgprofile_kwargs=dbgprofile_kwargs, use_validation_monitor=use_validation_monitor, validation_monitor_kwargs=validation_monitor_kwargs, use_early_stopping=use_early_stopping, early_stopping_kwargs=early_stopping_kwargs) hook_context = trainer_lib.HookContext( estimator=estimator, problem=problem, hparams=hparams) train_hooks += t2t_model.T2TModel.get_train_hooks(model_name, hook_context) eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name, hook_context) if additional_train_hooks: train_hooks += additional_train_hooks if additional_eval_hooks: eval_hooks += additional_eval_hooks train_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks( train_hooks, estimator) eval_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks( eval_hooks, estimator) train_spec = tf.estimator.TrainSpec( train_input_fn, max_steps=train_steps, hooks=train_hooks) eval_spec = tf.estimator.EvalSpec( eval_input_fn, steps=eval_steps, hooks=eval_hooks, start_delay_secs=0 if hparams.schedule == "evaluate" else 120, throttle_secs=eval_throttle_seconds, exporters=exporter) return trainer_lib.T2TExperiment(estimator, hparams, train_spec, eval_spec, use_validation_monitor, decode_hparams)
# class MelodyToPianoPerformanceProblem(score2perf.AbsoluteMelody2PerfProblem): # @property # def add_eos_symbol(self): # return True uncondi_problem = PianoPerformanceLanguageModelProblem() uncondi_encoders = uncondi_problem.get_feature_encoders() # melody_problem = MelodyToPianoPerformanceProblem() # melody_encoders = melody_problem.get_feature_encoders() # Set up HParams. hparams = trainer_lib.create_hparams(hparams_set=hparams_set) trainer_lib.add_problem_hparams(hparams, uncondi_problem) hparams.num_hidden_layers = 16 hparams.sampling_method = 'random' # Set up decoding HParams. decode_hparams = decoding.decode_hparams() decode_hparams.alpha = 0.0 decode_hparams.beam_size = 1 # Create Estimator. run_config = trainer_lib.create_run_config(hparams) estimator = trainer_lib.create_estimator(model_name, hparams, run_config, decode_hparams=decode_hparams)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam("eps", 0.0) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, _ = dataset.make_one_shot_iterator().get_next() inputs, labels = features["targets"], features["inputs"] inputs = tf.to_float(inputs) labels = tf.squeeze(labels) sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1) preds = tf.squeeze(preds) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) attack = create_attack(attack_params.attack)(ch_model, sess=sess) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, labels, mask): preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy( labels=labels, predictions=preds, weights=mask) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc acc = compute_accuracy(inputs, labels, acc_mask) epsilon_acc_pairs = [(0.0, acc)] for epsilon in attack_params.attack_epsilons: attack_params.eps = epsilon adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
def run(): """ Load Transformer model according to flags and start sampling. :raises: ValueError: if required flags are missing or invalid """ if FLAGS.model_path is None: raise ValueError("Required Transformer pre-trained model path.") if FLAGS.output_dir is None: raise ValueError("Required MIDI output directory.") if FLAGS.decode_length <= 0: raise ValueError("Decode length must be > 0.") problem = PianoPerformanceLanguageModelProblem() unconditional_encoders = problem.get_feature_encoders() primer_note_sequence = music_pb2.NoteSequence() # It should be possible to supply absolutely no primer. if FLAGS.primer_path is None: targets = [] else: primer_note_sequence = get_primer_ns(FLAGS.primer_path) targets = unconditional_encoders["targets"].encode_note_sequence( primer_note_sequence) # Remove end token from encoded primer targets = targets[:-1] if len(targets) >= FLAGS.decode_length: raise ValueError( "Primer has more or equal events than max sequence length.") decode_length = FLAGS.decode_length - len(targets) # Set up hyperparameters hparams = trainer_lib.create_hparams( hparams_set="transformer_tpu") # Add flag trainer_lib.add_problem_hparams(hparams, problem) hparams.num_hidden_layers = NUM_HIDDEN_LAYERS hparams.sampling_method = SAMPLING_METHOD # Set up decoding HParams decode_hparams = decoding.decode_hparams() decode_hparams.alpha = ALPHA decode_hparams.beam_size = BEAM_SIZE # Create estimator LOGGER.info("Loading model") run_config = trainer_lib.create_run_config(hparams) estimator = trainer_lib.create_estimator( MODEL_NAME, hparams, run_config, decode_hparams=decode_hparams, ) generate( estimator, unconditional_encoders, decode_length, targets, primer_note_sequence, )
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam("eps", 0.0) config = t2t_trainer.create_run_config(hparams) params = {"batch_size": hparams.batch_size} # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) dataset = input_fn(params, config).repeat() features, _ = dataset.make_one_shot_iterator().get_next() inputs, labels = features["targets"], features["inputs"] inputs = tf.to_float(inputs) labels = tf.squeeze(labels) sess = tf.Session() model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams, use_tpu=FLAGS.use_tpu) ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1) preds = tf.squeeze(preds) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) attack = create_attack(attack_params.attack)(ch_model, sess=sess) # Restore weights saver = tf.train.Saver() checkpoint_path = os.path.expanduser(FLAGS.output_dir or FLAGS.checkpoint_path) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, labels, mask): preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=labels.dtype) _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds, weights=mask) sess.run(tf.initialize_local_variables()) for _ in range(FLAGS.eval_steps): acc = sess.run(acc_update_op) return acc acc = compute_accuracy(inputs, labels, acc_mask) epsilon_acc_pairs = [(0.0, acc)] for epsilon in attack_params.attack_epsilons: attack_params.eps = epsilon adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
def music_generator(primer='erik_gnossienne', primer_begin_buffer=10, primer_length=90, output_path='.', filename='./public/output'): SF2_PATH = './models/Yamaha-C5-Salamander-JNv5.1.sf2' SAMPLE_RATE = 16000 # Upload a MIDI file and convert to NoteSequence. def upload_midi(): data = list(files.upload().values()) if len(data) > 1: print('Multiple files uploaded; using only one.') return mm.midi_to_note_sequence(data[0]) # Decode a list of IDs. def decode(ids, encoder): ids = list(ids) if text_encoder.EOS_ID in ids: ids = ids[:ids.index(text_encoder.EOS_ID)] return encoder.decode(ids) model_name = 'transformer' hparams_set = 'transformer_tpu' ckpt_path = './models/checkpoints/unconditional_model_16.ckpt' class PianoPerformanceLanguageModelProblem(score2perf.Score2PerfProblem): @property def add_eos_symbol(self): return True problem = PianoPerformanceLanguageModelProblem() unconditional_encoders = problem.get_feature_encoders() # Set up HParams. hparams = trainer_lib.create_hparams(hparams_set=hparams_set) trainer_lib.add_problem_hparams(hparams, problem) hparams.num_hidden_layers = 16 hparams.sampling_method = 'random' # Set up decoding HParams. decode_hparams = decoding.decode_hparams() decode_hparams.alpha = 0.0 decode_hparams.beam_size = 1 # Create Estimator. run_config = trainer_lib.create_run_config(hparams) estimator = trainer_lib.create_estimator(model_name, hparams, run_config, decode_hparams=decode_hparams) # These values will be changed by subsequent cells. targets = [] decode_length = 0 # Create input generator (so we can adjust priming and # decode length on the fly). def input_generator(): global targets global decode_length while True: yield { 'targets': np.array([targets], dtype=np.int32), 'decode_length': np.array(decode_length, dtype=np.int32) } # Start the Estimator, loading from the specified checkpoint. input_fn = decoding.make_input_fn_from_generator(input_generator()) unconditional_samples = estimator.predict(input_fn, checkpoint_path=ckpt_path) # "Burn" one. _ = next(unconditional_samples) filenames = { 'C major arpeggio': './models/primers/c_major_arpeggio.mid', 'C major scale': './models/primers/c_major_scale.mid', 'Clair de Lune': './models/primers/clair_de_lune.mid', 'Classical': 'audio_midi/Classical_Piano_piano-midi.de_MIDIRip/bach/bach_846_format0.mid', 'erik_gymnopedie': 'audio_midi/erik_satie/gymnopedie_1_(c)oguri.mid', 'erik_gymnopedie_2': 'audio_midi/erik_satie/gymnopedie_2_(c)oguri.mid', 'erik_gymnopedie_3': 'audio_midi/erik_satie/gymnopedie_3_(c)oguri.mid', 'erik_gnossienne': 'audio_midi/erik_satie/gnossienne_1_(c)oguri.mid', 'erik_gnossienne_2': 'audio_midi/erik_satie/gnossienne_2_(c)oguri.mid', 'erik_gnossienne_3': 'audio_midi/erik_satie/gnossienne_3_(c)oguri.mid', 'erik_gnossienne_dery': 'audio_midi/erik_satie/gnossienne_1_(c)dery.mid', 'erik_gnossienne_dery_2': 'audio_midi/erik_satie/gnossienne_2_(c)dery.mid', 'erik_gnossienne_dery_3': 'audio_midi/erik_satie/gnossienne_3_(c)dery.mid', 'erik_gnossienne_dery_5': 'audio_midi/erik_satie/gnossienne_5_(c)dery.mid', 'erik_gnossienne_dery_6': 'audio_midi/erik_satie/gnossienne_6_(c)dery.mid', '1': 'audio_midi/erik_satie/1.mid', '2': 'audio_midi/erik_satie/2.mid', '3': 'audio_midi/erik_satie/3.mid', '4': 'audio_midi/erik_satie/4.mid', '5': 'audio_midi/erik_satie/5.mid', '6': 'audio_midi/erik_satie/6.mid', '7': 'audio_midi/erik_satie/7.mid', '8': 'audio_midi/erik_satie/8.mid', '9': 'audio_midi/erik_satie/9.mid', '10': 'audio_midi/erik_satie/10.mid', } # primer = 'C major scale' #if primer == 'Upload your own!': # primer_ns = upload_midi() #else: # # Use one of the provided primers. # primer_ns = mm.midi_file_to_note_sequence(filenames[primer]) primer_ns = mm.midi_file_to_note_sequence(filenames[primer]) # Handle sustain pedal in the primer. primer_ns = mm.apply_sustain_control_changes(primer_ns) # Trim to desired number of seconds. max_primer_seconds = primer_length if primer_ns.total_time > max_primer_seconds: print('Primer is longer than %d seconds, truncating.' % max_primer_seconds) primer_ns = mm.extract_subsequence( primer_ns, primer_begin_buffer, max_primer_seconds + primer_begin_buffer) # Remove drums from primer if present. if any(note.is_drum for note in primer_ns.notes): print('Primer contains drums; they will be removed.') notes = [note for note in primer_ns.notes if not note.is_drum] del primer_ns.notes[:] primer_ns.notes.extend(notes) # Set primer instrument and program. for note in primer_ns.notes: note.instrument = 1 note.program = 0 ## Play and plot the primer. #mm.play_sequence( # primer_ns, # synth=mm.fluidsynth, sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH) #mm.plot_sequence(primer_ns) mm.sequence_proto_to_midi_file( primer_ns, join(output_path, 'primer_{}.mid'.format(filename))) targets = unconditional_encoders['targets'].encode_note_sequence(primer_ns) # Remove the end token from the encoded primer. targets = targets[:-1] decode_length = max(0, 10000 - len(targets)) if len(targets) >= 4096: print( 'Primer has more events than maximum sequence length; nothing will be generated.' ) # Generate sample events. sample_ids = next(unconditional_samples)['outputs'] # Decode to NoteSequence. midi_filename = decode(sample_ids, encoder=unconditional_encoders['targets']) ns = mm.midi_file_to_note_sequence(midi_filename) print('Sample IDs: {}'.format(sample_ids)) print('Sample IDs length: {}'.format(len(sample_ids))) print('Encoder: {}'.format(unconditional_encoders['targets'])) print('Unconditional Samples: {}'.format(unconditional_samples)) # print('{}'.format(ns)) # continuation_ns = mm.concatenate_sequences([primer_ns, ns]) continuation_ns = ns # mm.play_sequence( # continuation_ns, # synth=mm.fluidsynth, sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH) # mm.plot_sequence(continuation_ns) # try: audio = mm.fluidsynth(continuation_ns, sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH) normalizer = float(np.iinfo(np.int16).max) array_of_ints = np.array(np.asarray(audio) * normalizer, dtype=np.int16) wavfile.write(join(output_path, filename + '.wav'), SAMPLE_RATE, array_of_ints) print('[+] Output stored as {}'.format(filename + '.wav')) mm.sequence_proto_to_midi_file( continuation_ns, join(output_path, 'continuation_{}.mid'.format(filename)))
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.maybe_log_registry_and_exit() if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: t2t_trainer.generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() if argv: t2t_trainer.set_hparams_from_args(argv[1:]) if FLAGS.surrogate_attack: tf.logging.warn("Performing surrogate model attack.") sur_hparams = create_surrogate_hparams() trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem) hparams = t2t_trainer.create_hparams() trainer_lib.add_problem_hparams(hparams, FLAGS.problem) attack_params = create_attack_params() attack_params.add_hparam(attack_params.epsilon_name, 0.0) if FLAGS.surrogate_attack: sur_config = create_surrogate_run_config(sur_hparams) config = t2t_trainer.create_run_config(hparams) params = { "batch_size": hparams.batch_size, "use_tpu": FLAGS.use_tpu, } # add "_rev" as a hack to avoid image standardization problem = registry.problem(FLAGS.problem + "_rev") inputs, labels, features = prepare_data(problem, hparams, params, config) sess = tf.Session() if FLAGS.surrogate_attack: sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu) sur_ch_model = adv_attack_utils.T2TAttackModel( sur_model_fn, features, params, sur_config, scope="surrogate") # Dummy call to construct graph sur_ch_model.get_probs(inputs) checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir) tf.contrib.framework.init_from_checkpoint( tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"}) sess.run(tf.global_variables_initializer()) other_vars = set(tf.global_variables()) model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config) acc_mask = None probs = ch_model.get_probs(inputs) if FLAGS.ignore_incorrect: preds = tf.argmax(probs, -1, output_type=labels.dtype) preds = tf.reshape(preds, labels.shape) acc_mask = tf.to_float(tf.equal(labels, preds)) one_hot_labels = tf.one_hot(labels, probs.shape[-1]) if FLAGS.surrogate_attack: attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess) else: attack = create_attack(attack_params.attack)(ch_model, sess=sess) new_vars = set(tf.global_variables()) - other_vars # Restore weights saver = tf.train.Saver(new_vars) checkpoint_path = os.path.expanduser(FLAGS.output_dir) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path)) # reuse variables tf.get_variable_scope().reuse_variables() def compute_accuracy(x, l, mask): """Compute model accuracy.""" preds = ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask) if FLAGS.surrogate_attack: preds = sur_ch_model.get_probs(x) preds = tf.squeeze(preds) preds = tf.argmax(preds, -1, output_type=l.dtype) acc_update_op = tf.tuple((acc_update_op, tf.metrics.accuracy(l, preds, weights=mask)[1])) sess.run(tf.initialize_local_variables()) for i in range(FLAGS.eval_steps): tf.logging.info( "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps)) acc = sess.run(acc_update_op) if FLAGS.surrogate_attack: tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1])) else: tf.logging.info("\tFinal acc: %.4f" % acc) return acc epsilon_acc_pairs = [] for epsilon in attack_params.attack_epsilons: tf.logging.info("Attacking @ eps=%.4f" % epsilon) attack_params.set_hparam(attack_params.epsilon_name, epsilon) adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values()) acc = compute_accuracy(adv_x, labels, acc_mask) epsilon_acc_pairs.append((epsilon, acc)) for epsilon, acc in epsilon_acc_pairs: if FLAGS.surrogate_attack: tf.logging.info( "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1])) else: tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))