Ejemplo n.º 1
0
  def __init__(
      self, reward_range, observation_space, action_space, frame_stack_size,
      initial_frame_chooser, batch_size, model_name, model_hparams, model_dir,
      intrinsic_reward_scale=0.0
  ):
    """Batch of environments inside the TensorFlow graph."""
    super(SimulatedBatchEnv, self).__init__(observation_space, action_space)

    self.batch_size = batch_size
    self._min_reward = reward_range[0]
    self._num_frames = frame_stack_size
    self._intrinsic_reward_scale = intrinsic_reward_scale

    model_hparams = copy.copy(model_hparams)
    problem = DummyWorldModelProblem(action_space, reward_range)
    trainer_lib.add_problem_hparams(model_hparams, problem)
    model_hparams.force_full_predict = True
    self._model = registry.model(model_name)(
        model_hparams, tf.estimator.ModeKeys.PREDICT
    )

    self.history_buffer = HistoryBuffer(
        initial_frame_chooser, self.observ_shape, self.observ_dtype,
        self._num_frames, self.batch_size
    )

    self._observ = tf.Variable(
        tf.zeros((batch_size,) + self.observ_shape, self.observ_dtype),
        trainable=False
    )

    self._reset_model = tf.get_variable(
        "reset_model", [], trainable=False, initializer=tf.zeros_initializer())

    self._model_dir = model_dir
Ejemplo n.º 2
0
    def __init__(self, environment_spec, length):
        """Batch of environments inside the TensorFlow graph."""
        super(SimulatedBatchEnv,
              self).__init__(environment_spec.observation_space,
                             environment_spec.action_space)

        self.length = length
        self._min_reward = environment_spec.reward_range[0]
        self._num_frames = environment_spec.video_num_input_frames
        self._intrinsic_reward_scale = environment_spec.intrinsic_reward_scale

        model_hparams = copy.copy(environment_spec.model_hparams)
        problem = DummyWorldModelProblem(environment_spec.action_space,
                                         environment_spec.reward_range)
        trainer_lib.add_problem_hparams(model_hparams, problem)
        model_hparams.force_full_predict = True
        self._model = registry.model(environment_spec.model_name)(
            model_hparams, tf.estimator.ModeKeys.PREDICT)

        self.history_buffer = HistoryBuffer(
            environment_spec.initial_frame_chooser, self.observ_shape,
            self.observ_dtype, self._num_frames, self.length)

        self._observ = tf.Variable(tf.zeros((len(self), ) + self.observ_shape,
                                            self.observ_dtype),
                                   trainable=False)
Ejemplo n.º 3
0
def get_policy(observations, hparams, action_space):
    """Get a policy network.

  Args:
    observations: observations
    hparams: parameters
    action_space: action space

  Returns:
    Tuple (action logits, value).
  """
    if not isinstance(action_space, gym.spaces.Discrete):
        raise ValueError("Expecting discrete action space.")

    policy_problem = DummyPolicyProblem(action_space)
    trainer_lib.add_problem_hparams(hparams, policy_problem)
    hparams.force_full_predict = True
    model = registry.model(hparams.policy_network)(hparams,
                                                   tf.estimator.ModeKeys.TRAIN)
    obs_shape = common_layers.shape_list(observations)
    features = {
        "inputs": observations,
        "input_action": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32),
        "input_reward": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32),
        "targets": tf.zeros(obs_shape[:1] + [1] + obs_shape[2:]),
        "target_action": tf.zeros(obs_shape[:1] + [1, 1], dtype=tf.int32),
        "target_reward": tf.zeros(obs_shape[:1] + [1, 1], dtype=tf.int32),
        "target_policy": tf.zeros(obs_shape[:1] + [1] + [action_space.n]),
        "target_value": tf.zeros(obs_shape[:1] + [1])
    }
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        t2t_model.create_dummy_vars()
        (targets, _) = model(features)
    return (targets["target_policy"], targets["target_value"])
Ejemplo n.º 4
0
def run():
    """
    Load Transformer model according to flags and start sampling.
    :raises:
        ValueError: if required flags are missing or invalid.
    """
    if FLAGS.model_path is None:
        raise ValueError('Required Transformer pre-trained model path.')

    if FLAGS.output_dir is None:
        raise ValueError('Required Midi output directory.')

    if FLAGS.decode_length <= 0:
        raise ValueError('Decode length must be > 0.')

    problem = utils.PianoPerformanceLanguageModelProblem()
    unconditional_encoders = problem.get_feature_encoders()
    primer_ns = music_pb2.NoteSequence()
    if FLAGS.primer_path is None:
        targets = []
    else:
        if FLAGS.max_primer_second <= 0:
            raise ValueError('Max primer second must be > 0.')

        primer_ns = utils.get_primer_ns(FLAGS.primer_path,
                                        FLAGS.max_primer_second)
        targets = unconditional_encoders['targets'].encode_note_sequence(
            primer_ns)

        # Remove the end token from the encoded primer.
        targets = targets[:-1]
        if len(targets) >= FLAGS.decode_length:
            raise ValueError(
                'Primer has more or equal events than maximum sequence length:'
                ' %d >= %d; Aborting' % (len(targets), FLAGS.decode_length))
    decode_length = FLAGS.decode_length - len(targets)

    # Set up HParams.
    hparams = trainer_lib.create_hparams(hparams_set=FLAGS.hparams_set)
    trainer_lib.add_problem_hparams(hparams, problem)
    hparams.num_hidden_layers = FLAGS.layers
    hparams.sampling_method = FLAGS.sample

    # Set up decoding HParams.
    decode_hparams = decoding.decode_hparams()
    decode_hparams.alpha = FLAGS.alpha
    decode_hparams.beam_size = FLAGS.beam_size

    # Create Estimator.
    utils.LOGGER.info('Loading model')
    run_config = trainer_lib.create_run_config(hparams)
    estimator = trainer_lib.create_estimator(FLAGS.model_name,
                                             hparams,
                                             run_config,
                                             decode_hparams=decode_hparams)

    generate(estimator, unconditional_encoders, decode_length, targets,
             primer_ns)
Ejemplo n.º 5
0
def get_policy(observations, hparams, action_space):
    """Get a policy network.

  Args:
    observations: observations
    hparams: parameters
    action_space: action space

  Returns:
    Tuple (action logits, value).
  """
    if not isinstance(action_space, gym.spaces.Discrete):
        raise ValueError("Expecting discrete action space.")

    obs_shape = common_layers.shape_list(observations)
    (frame_height, frame_width) = obs_shape[2:4]

    # TODO(afrozm): We have these dummy problems mainly for hparams, so cleanup
    # when possible and do this properly.
    if hparams.policy_problem_name == "dummy_policy_problem_ttt":
        tf.logging.info("Using DummyPolicyProblemTTT for the policy.")
        policy_problem = tic_tac_toe_env.DummyPolicyProblemTTT()
    else:
        tf.logging.info("Using DummyPolicyProblem for the policy.")
        policy_problem = DummyPolicyProblem(action_space, frame_height,
                                            frame_width)

    trainer_lib.add_problem_hparams(hparams, policy_problem)
    hparams.force_full_predict = True
    model = registry.model(hparams.policy_network)(hparams,
                                                   tf.estimator.ModeKeys.TRAIN)
    try:
        num_target_frames = hparams.video_num_target_frames
    except AttributeError:
        num_target_frames = 1
    features = {
        "inputs":
        observations,
        "input_action":
        tf.zeros(obs_shape[:2] + [1], dtype=tf.int32),
        "input_reward":
        tf.zeros(obs_shape[:2] + [1], dtype=tf.int32),
        "targets":
        tf.zeros(obs_shape[:1] + [num_target_frames] + obs_shape[2:]),
        "target_action":
        tf.zeros(obs_shape[:1] + [num_target_frames, 1], dtype=tf.int32),
        "target_reward":
        tf.zeros(obs_shape[:1] + [num_target_frames, 1], dtype=tf.int32),
        "target_policy":
        tf.zeros(obs_shape[:1] + [num_target_frames] + [action_space.n]),
        "target_value":
        tf.zeros(obs_shape[:1] + [num_target_frames])
    }
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        t2t_model.create_dummy_vars()
        (targets, _) = model(features)
    return (targets["target_policy"][:, 0, :], targets["target_value"][:, 0])
Ejemplo n.º 6
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  pruning_params = create_pruning_params()
  pruning_strategy = create_pruning_strategy(pruning_params.strategy)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem)
  input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                             hparams)
  dataset = input_fn(params, config).repeat()
  features, labels = dataset.make_one_shot_iterator().get_next()

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  spec = model_fn(
      features,
      labels,
      tf.estimator.ModeKeys.EVAL,
      params=hparams,
      config=config)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  def eval_model():
    preds = spec.predictions["predictions"]
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
Ejemplo n.º 7
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  pruning_params = create_pruning_params()
  pruning_strategy = create_pruning_strategy(pruning_params.strategy)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem)
  input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                             hparams)
  dataset = input_fn(params, config).repeat()
  features, labels = dataset.make_one_shot_iterator().get_next()

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  spec = model_fn(
      features,
      labels,
      tf.estimator.ModeKeys.EVAL,
      params=hparams,
      config=config)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  def eval_model():
    preds = spec.predictions["predictions"]
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(labels=labels, predictions=preds)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  pruning_utils.sparsify(sess, eval_model, pruning_strategy, pruning_params)
Ejemplo n.º 8
0
def main(argv):
    config = epl.Config({"cluster.colocate_split_and_replicate": True})
    epl.init(config)
    FLAGS.worker_id = epl.Env.get().cluster.worker_index
    FLAGS.worker_gpu = epl.Env.get().cluster.total_gpu_num
    epl.set_default_strategy(epl.replicate(FLAGS.worker_gpu))

    # Create HParams.
    if argv:
        set_hparams_from_args(argv[1:])
    if FLAGS.schedule != "run_std_server":
        hparams = create_hparams()

    if FLAGS.schedule == "train":
        mlperf_log.transformer_print(key=mlperf_log.RUN_START)
    else:
        raise RuntimeError(
            "Support training tasks only for now, you can define tasks in other modes."
        )
    trainer_lib.set_random_seed(FLAGS.random_seed)

    hparams.add_hparam("data_dir", FLAGS.data_dir)
    hparams.add_hparam("schedule", FLAGS.schedule)
    hparams.add_hparam("train_steps", FLAGS.train_steps)
    hparams.add_hparam("warm_start_from", None)
    trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

    # Dataset generation.
    if FLAGS.generate_data:
        generate_data()

    def model_fn_replicate(features, labels, mode):
        model_fn = t2t_model.T2TModel.make_estimator_model_fn(
            FLAGS.model, hparams)
        return model_fn(features, labels, mode)

    if is_chief():
        save_metadata(hparams)

    estimator = tf.estimator.Estimator(model_fn=model_fn_replicate,
                                       config=create_run_config())
    hooks = []
    hooks.append(
        tf.train.StepCounterHook(every_n_steps=FLAGS.log_step_count_steps))

    optimize.log_variable_sizes(verbose=True)

    problem = hparams.problem
    train_input_fn = problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.TRAIN, hparams)

    estimator.train(train_input_fn, max_steps=hparams.train_steps, hooks=hooks)
Ejemplo n.º 9
0
  def __init__(
      self, reward_range, observation_space, action_space, frame_stack_size,
      frame_height, frame_width, initial_frame_chooser, batch_size, model_name,
      model_hparams, model_dir, intrinsic_reward_scale=0.0, sim_video_dir=None
  ):
    """Batch of environments inside the TensorFlow graph."""
    super(SimulatedBatchEnv, self).__init__(observation_space, action_space)

    self._ffmpeg_works = common_video.ffmpeg_works()
    self.batch_size = batch_size
    self._min_reward = reward_range[0]
    self._num_frames = frame_stack_size
    self._intrinsic_reward_scale = intrinsic_reward_scale
    self._episode_counter = tf.get_variable(
        "episode_counter", initializer=tf.zeros((), dtype=tf.int32),
        trainable=False, dtype=tf.int32)
    if sim_video_dir:
      self._video_every_epochs = 100
      self._video_dir = sim_video_dir
      self._video_writer = None
      self._video_counter = 0
      tf.gfile.MakeDirs(self._video_dir)
      self._video_condition = tf.equal(
          self._episode_counter.read_value() % self._video_every_epochs, 0)
    else:
      self._video_condition = tf.constant(False, dtype=tf.bool, shape=())

    model_hparams = copy.copy(model_hparams)
    problem = DummyWorldModelProblem(action_space, reward_range,
                                     frame_height, frame_width)
    trainer_lib.add_problem_hparams(model_hparams, problem)
    model_hparams.force_full_predict = True
    self._model = registry.model(model_name)(
        model_hparams, tf.estimator.ModeKeys.PREDICT
    )

    self.history_buffer = HistoryBuffer(
        initial_frame_chooser, self.observ_shape, self.observ_dtype,
        self._num_frames, self.batch_size
    )

    self._observ = tf.Variable(
        tf.zeros((batch_size,) + self.observ_shape, self.observ_dtype),
        trainable=False
    )

    self._reset_model = tf.get_variable(
        "reset_model", [], trainable=False, initializer=tf.zeros_initializer())

    self._model_dir = model_dir
Ejemplo n.º 10
0
        def _load_hparams(path):
            with open(os.path.join(path, 'hparams.json'), 'rb') as json_file:
                hparams_dict = {
                    k.encode('utf-8'): v.encode('utf-8') if type(v) == unicode else v
                    for k, v in json.load(json_file).iteritems()
                }

                hparams = HParams(**hparams_dict)
                hparams.set_hparam('data_dir', path)

            trainer_lib.add_problem_hparams(hparams, 'translate_mmt')

            # Removing dropout from HParams even on TRAIN mode
            for key in hparams.values():
                if key.endswith("dropout"):
                    setattr(hparams, key, 0.0)

            return hparams
Ejemplo n.º 11
0
    def initialize(self, is_conditioned=False):
        self.model_name = 'transformer'
        self.hparams_set = 'transformer_tpu'
        self.conditioned = is_conditioned
        if self.conditioned:
            self.ckpt_path = 'models/checkpoints/melody_conditioned_model_16.ckpt'
            problem = MelodyToPianoPerformanceProblem()
        else:
            self.ckpt_path = 'models/checkpoints/unconditional_model_16.ckpt'
            problem = PianoPerformanceLanguageModelProblem()

        self.encoders = problem.get_feature_encoders()

        # Set up hyperparams
        hparams = trainer_lib.create_hparams(hparams_set=self.hparams_set)
        trainer_lib.add_problem_hparams(hparams, problem)
        hparams.num_hidden_layers = 16
        hparams.sampling_method = 'random'

        # Set up decoding hyperparams
        decode_hparams = decoding.decode_hparams()
        decode_hparams.alpha = 0.0
        decode_hparams.beam_size = 1
        if self.conditioned:
            self.inputs = []
        else:
            self.targets = []

        self.decode_length = 0
        run_config = trainer_lib.create_run_config(hparams)
        estimator = trainer_lib.create_estimator(
            self.model_name, hparams, run_config,
            decode_hparams=decode_hparams)
        fnc = self.input_generation_conditional if self.conditioned else self.input_generator_unconditional
        input_fn = decoding.make_input_fn_from_generator(fnc())
        self.samples = estimator.predict(
            input_fn, checkpoint_path=self.ckpt_path)
        _ = next(self.samples)
Ejemplo n.º 12
0
  def __init__(
      self, reward_range, observation_space, action_space, frame_stack_size,
      frame_height, frame_width, initial_frame_chooser, batch_size, model_name,
      model_hparams, model_dir, intrinsic_reward_scale=0.0
  ):
    """Batch of environments inside the TensorFlow graph."""
    super(SimulatedBatchEnv, self).__init__(observation_space, action_space)

    self.batch_size = batch_size
    self._min_reward = reward_range[0]
    self._num_frames = frame_stack_size
    self._intrinsic_reward_scale = intrinsic_reward_scale

    model_hparams = copy.copy(model_hparams)
    problem = DummyWorldModelProblem(action_space, reward_range,
                                     frame_height, frame_width)
    trainer_lib.add_problem_hparams(model_hparams, problem)
    model_hparams.force_full_predict = True
    self._model = registry.model(model_name)(
        model_hparams, tf.estimator.ModeKeys.PREDICT
    )

    self.history_buffer = HistoryBuffer(
        initial_frame_chooser, self.observ_shape, self.observ_dtype,
        self._num_frames, self.batch_size
    )

    self._observ = tf.Variable(
        tf.zeros((batch_size,) + self.observ_shape, self.observ_dtype),
        trainable=False
    )

    self._reset_model = tf.get_variable(
        "reset_model", [], trainable=False, initializer=tf.zeros_initializer())

    self._model_dir = model_dir
hparams_set = 'transformer_tpu'
ckpt_path = '../assets/checkpoints/unconditional_model_16.ckpt'


class PianoPerformanceLanguageModelProblem(score2perf.Score2PerfProblem):
    @property
    def add_eos_symbol(self):
        return True


problem = PianoPerformanceLanguageModelProblem()
unconditional_encoders = problem.get_feature_encoders()

# Set up HParams.
hparams = trainer_lib.create_hparams(hparams_set=hparams_set)
trainer_lib.add_problem_hparams(hparams, problem)
hparams.num_hidden_layers = 16
hparams.sampling_method = 'random'

# Set up decoding HParams.
decode_hparams = decoding.decode_hparams()
decode_hparams.alpha = 0.0
decode_hparams.beam_size = 1

# Create Estimator.
run_config = trainer_lib.create_run_config(hparams)
estimator = trainer_lib.create_estimator(model_name,
                                         hparams,
                                         run_config,
                                         decode_hparams=decode_hparams)
Ejemplo n.º 14
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  if FLAGS.surrogate_attack:
    tf.logging.warn("Performing surrogate model attack.")
    sur_hparams = create_surrogate_hparams()
    trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem)

  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

  attack_params = create_attack_params()
  attack_params.add_hparam(attack_params.epsilon_name, 0.0)

  if FLAGS.surrogate_attack:
    sur_config = create_surrogate_run_config(sur_hparams)
  config = t2t_trainer.create_run_config(hparams)
  params = {
      "batch_size": hparams.batch_size,
      "use_tpu": FLAGS.use_tpu,
  }

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")

  inputs, labels, features = prepare_data(problem, hparams, params, config)

  sess = tf.Session()

  if FLAGS.surrogate_attack:
    sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu)
    sur_ch_model = adv_attack_utils.T2TAttackModel(
        sur_model_fn, features, params, sur_config, scope="surrogate")
    # Dummy call to construct graph
    sur_ch_model.get_probs(inputs)

    checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir)
    tf.train.init_from_checkpoint(
        tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"})
    sess.run(tf.global_variables_initializer())

  other_vars = set(tf.global_variables())

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1, output_type=labels.dtype)
    preds = tf.reshape(preds, labels.shape)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  if FLAGS.surrogate_attack:
    attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess)
  else:
    attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  new_vars = set(tf.global_variables()) - other_vars

  # Restore weights
  saver = tf.train.Saver(new_vars)
  checkpoint_path = os.path.expanduser(FLAGS.output_dir)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, l, mask):
    """Compute model accuracy."""
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=l.dtype)

    _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask)

    if FLAGS.surrogate_attack:
      preds = sur_ch_model.get_probs(x)
      preds = tf.squeeze(preds)
      preds = tf.argmax(preds, -1, output_type=l.dtype)
      acc_update_op = tf.tuple((acc_update_op,
                                tf.metrics.accuracy(l, preds, weights=mask)[1]))

    sess.run(tf.initialize_local_variables())
    for i in range(FLAGS.eval_steps):
      tf.logging.info(
          "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps))
      acc = sess.run(acc_update_op)
    if FLAGS.surrogate_attack:
      tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1]))
    else:
      tf.logging.info("\tFinal acc: %.4f" % acc)
    return acc

  epsilon_acc_pairs = []
  for epsilon in attack_params.attack_epsilons:
    tf.logging.info("Attacking @ eps=%.4f" % epsilon)
    attack_params.set_hparam(attack_params.epsilon_name, epsilon)
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    if FLAGS.surrogate_attack:
      tf.logging.info(
          "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1]))
    else:
      tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
Ejemplo n.º 15
0
def create_experiment(
    run_config,
    hparams,
    model_name,
    problem_name,
    data_dir,
    train_steps,
    eval_steps,
    min_eval_frequency=2000,
    eval_throttle_seconds=600,
    schedule="train_and_evaluate",
    export=False,
    decode_hparams=None,
    use_tfdbg=False,
    use_dbgprofile=False,
    eval_early_stopping_steps=None,
    eval_early_stopping_metric=None,
    eval_early_stopping_metric_delta=None,
    eval_early_stopping_metric_minimize=True,
    eval_timeout_mins=240,
    use_tpu=False,
    use_tpu_estimator=False,
    use_xla=False,
    additional_train_hooks=None,
    additional_eval_hooks=None,
    warm_start_from=None,
    decode_from_file=None,
    decode_to_file=None,
    decode_reference=None,
    std_server_protocol=None):
  """Create Experiment."""
  # HParams
  hparams.add_hparam("model_dir", run_config.model_dir)
  hparams.add_hparam("data_dir", data_dir)
  hparams.add_hparam("train_steps", train_steps)
  hparams.add_hparam("eval_steps", eval_steps)
  hparams.add_hparam("schedule", schedule)
  hparams.add_hparam("warm_start_from", warm_start_from)
  hparams.add_hparam("std_server_protocol", std_server_protocol)
  hparams.add_hparam("eval_freq_in_steps", min_eval_frequency)
  hparams.add_hparam("eval_timeout_mins", eval_timeout_mins)
  if decode_hparams is not None:
    decode_hparams.add_hparam("decode_from_file", decode_from_file)
    decode_hparams.add_hparam("decode_to_file", decode_to_file)
    decode_hparams.add_hparam("decode_reference", decode_reference)
  trainer_lib.add_problem_hparams(hparams, problem_name)

  # Estimator
  estimator = trainer_lib.create_estimator(
      model_name,
      hparams,
      run_config,
      schedule=schedule,
      decode_hparams=decode_hparams,
      use_tpu=use_tpu,
      use_tpu_estimator=use_tpu_estimator,
      use_xla=use_xla)

  # Input fns from Problem
  problem = hparams.problem
  train_input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.TRAIN, hparams,
      dataset_kwargs={"max_records": FLAGS.train_data_size})
  eval_input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams)

  # Export
  exporter = None
  if export:
    def compare_fn(best_eval_result, current_eval_result):
      metric = eval_early_stopping_metric or "loss"
      return current_eval_result[metric] < best_eval_result[metric]

    exporter = tf.estimator.BestExporter(
        name="best",
        serving_input_receiver_fn=lambda: problem.serving_input_fn(hparams),
        compare_fn=compare_fn,
        assets_extra=problem.export_assets)

  # Hooks
  validation_monitor_kwargs = dict(
      input_fn=eval_input_fn,
      eval_steps=eval_steps,
      every_n_steps=min_eval_frequency,
      early_stopping_rounds=eval_early_stopping_steps,
      early_stopping_metric=eval_early_stopping_metric,
      early_stopping_metric_minimize=eval_early_stopping_metric_minimize)
  dbgprofile_kwargs = {"output_dir": run_config.model_dir}
  early_stopping_kwargs = dict(
      events_dir=os.path.join(run_config.model_dir, "eval_continuous"),
      tag=eval_early_stopping_metric,
      num_plateau_steps=eval_early_stopping_steps,
      plateau_decrease=eval_early_stopping_metric_minimize,
      plateau_delta=eval_early_stopping_metric_delta,
      every_n_steps=min_eval_frequency)

  # Eval on TPU Pods is not supported yet
  if use_tpu and run_config.tpu_config.num_shards > 8 and "eval" in schedule:
    raise ValueError("Eval is not currently supported on a TPU Pod")

  # In-process eval (and possible early stopping)
  if schedule == "continuous_train_and_eval" and min_eval_frequency:
    tf.logging.warn("ValidationMonitor only works with "
                    "--schedule=train_and_evaluate")
  use_validation_monitor = (
      schedule == "train_and_evaluate" and min_eval_frequency)
  # Distributed early stopping
  local_schedules = ["train_and_evaluate", "continuous_train_and_eval"]
  use_early_stopping = (
      schedule not in local_schedules and eval_early_stopping_steps)
  train_hooks, eval_hooks = trainer_lib.create_hooks(
      use_tfdbg=use_tfdbg,
      use_dbgprofile=use_dbgprofile,
      dbgprofile_kwargs=dbgprofile_kwargs,
      use_validation_monitor=use_validation_monitor,
      validation_monitor_kwargs=validation_monitor_kwargs,
      use_early_stopping=use_early_stopping,
      early_stopping_kwargs=early_stopping_kwargs)

  hook_context = trainer_lib.HookContext(
      estimator=estimator, problem=problem, hparams=hparams)

  train_hooks += t2t_model.T2TModel.get_train_hooks(model_name, hook_context)
  eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name, hook_context)
  if additional_train_hooks:
    train_hooks += additional_train_hooks
  if additional_eval_hooks:
    eval_hooks += additional_eval_hooks

  train_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(
      train_hooks, estimator)
  eval_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(
      eval_hooks, estimator)

  train_spec = tf.estimator.TrainSpec(
      train_input_fn, max_steps=train_steps, hooks=train_hooks)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=eval_steps,
      hooks=eval_hooks,
      start_delay_secs=0 if hparams.schedule == "evaluate" else 120,
      throttle_secs=eval_throttle_seconds,
      exporters=exporter)

  return trainer_lib.T2TExperiment(estimator, hparams, train_spec, eval_spec,
                                   use_validation_monitor, decode_hparams)

# class MelodyToPianoPerformanceProblem(score2perf.AbsoluteMelody2PerfProblem):
#     @property
#     def add_eos_symbol(self):
#         return True

uncondi_problem = PianoPerformanceLanguageModelProblem()
uncondi_encoders = uncondi_problem.get_feature_encoders()

# melody_problem = MelodyToPianoPerformanceProblem()
# melody_encoders = melody_problem.get_feature_encoders()

# Set up HParams.
hparams = trainer_lib.create_hparams(hparams_set=hparams_set)
trainer_lib.add_problem_hparams(hparams, uncondi_problem)
hparams.num_hidden_layers = 16
hparams.sampling_method = 'random'

# Set up decoding HParams.
decode_hparams = decoding.decode_hparams()
decode_hparams.alpha = 0.0
decode_hparams.beam_size = 1

# Create Estimator.
run_config = trainer_lib.create_run_config(hparams)
estimator = trainer_lib.create_estimator(model_name,
                                         hparams,
                                         run_config,
                                         decode_hparams=decode_hparams)
Ejemplo n.º 17
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])
  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
  attack_params = create_attack_params()
  attack_params.add_hparam("eps", 0.0)

  config = t2t_trainer.create_run_config(hparams)
  params = {"batch_size": hparams.batch_size}

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")
  input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams)
  dataset = input_fn(params, config).repeat()
  features, _ = dataset.make_one_shot_iterator().get_next()
  inputs, labels = features["targets"], features["inputs"]
  inputs = tf.to_float(inputs)
  labels = tf.squeeze(labels)

  sess = tf.Session()

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1)
    preds = tf.squeeze(preds)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  # Restore weights
  saver = tf.train.Saver()
  checkpoint_path = os.path.expanduser(FLAGS.output_dir or
                                       FLAGS.checkpoint_path)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, labels, mask):
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=labels.dtype)
    _, acc_update_op = tf.metrics.accuracy(
        labels=labels, predictions=preds, weights=mask)
    sess.run(tf.initialize_local_variables())
    for _ in range(FLAGS.eval_steps):
      acc = sess.run(acc_update_op)
    return acc

  acc = compute_accuracy(inputs, labels, acc_mask)
  epsilon_acc_pairs = [(0.0, acc)]
  for epsilon in attack_params.attack_epsilons:
    attack_params.eps = epsilon
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
Ejemplo n.º 18
0
def run():
    """
    Load Transformer model according to flags and start sampling.
    :raises:
        ValueError: if required flags are missing or invalid
    """

    if FLAGS.model_path is None:
        raise ValueError("Required Transformer pre-trained model path.")

    if FLAGS.output_dir is None:
        raise ValueError("Required MIDI output directory.")

    if FLAGS.decode_length <= 0:
        raise ValueError("Decode length must be > 0.")

    problem = PianoPerformanceLanguageModelProblem()
    unconditional_encoders = problem.get_feature_encoders()

    primer_note_sequence = music_pb2.NoteSequence()
    # It should be possible to supply absolutely no primer.
    if FLAGS.primer_path is None:
        targets = []
    else:
        primer_note_sequence = get_primer_ns(FLAGS.primer_path)
        targets = unconditional_encoders["targets"].encode_note_sequence(
            primer_note_sequence)

        # Remove end token from encoded primer
        targets = targets[:-1]

        if len(targets) >= FLAGS.decode_length:
            raise ValueError(
                "Primer has more or equal events than max sequence length.")

    decode_length = FLAGS.decode_length - len(targets)

    # Set up hyperparameters
    hparams = trainer_lib.create_hparams(
        hparams_set="transformer_tpu")  # Add flag
    trainer_lib.add_problem_hparams(hparams, problem)
    hparams.num_hidden_layers = NUM_HIDDEN_LAYERS
    hparams.sampling_method = SAMPLING_METHOD

    # Set up decoding HParams
    decode_hparams = decoding.decode_hparams()
    decode_hparams.alpha = ALPHA
    decode_hparams.beam_size = BEAM_SIZE

    # Create estimator
    LOGGER.info("Loading model")
    run_config = trainer_lib.create_run_config(hparams)
    estimator = trainer_lib.create_estimator(
        MODEL_NAME,
        hparams,
        run_config,
        decode_hparams=decode_hparams,
    )

    generate(
        estimator,
        unconditional_encoders,
        decode_length,
        targets,
        primer_note_sequence,
    )
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    t2t_trainer.maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        t2t_trainer.generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        t2t_trainer.set_hparams_from_args(argv[1:])
    hparams = t2t_trainer.create_hparams()
    trainer_lib.add_problem_hparams(hparams, FLAGS.problem)
    attack_params = create_attack_params()
    attack_params.add_hparam("eps", 0.0)

    config = t2t_trainer.create_run_config(hparams)
    params = {"batch_size": hparams.batch_size}

    # add "_rev" as a hack to avoid image standardization
    problem = registry.problem(FLAGS.problem + "_rev")
    input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                               hparams)
    dataset = input_fn(params, config).repeat()
    features, _ = dataset.make_one_shot_iterator().get_next()
    inputs, labels = features["targets"], features["inputs"]
    inputs = tf.to_float(inputs)
    labels = tf.squeeze(labels)

    sess = tf.Session()

    model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.model, hparams, use_tpu=FLAGS.use_tpu)
    ch_model = adv_attack_utils.T2TAttackModel(model_fn, params, config)

    acc_mask = None
    probs = ch_model.get_probs(inputs)
    if FLAGS.ignore_incorrect:
        preds = tf.argmax(probs, -1)
        preds = tf.squeeze(preds)
        acc_mask = tf.to_float(tf.equal(labels, preds))
    one_hot_labels = tf.one_hot(labels, probs.shape[-1])

    attack = create_attack(attack_params.attack)(ch_model, sess=sess)

    # Restore weights
    saver = tf.train.Saver()
    checkpoint_path = os.path.expanduser(FLAGS.output_dir
                                         or FLAGS.checkpoint_path)
    saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

    # reuse variables
    tf.get_variable_scope().reuse_variables()

    def compute_accuracy(x, labels, mask):
        preds = ch_model.get_probs(x)
        preds = tf.squeeze(preds)
        preds = tf.argmax(preds, -1, output_type=labels.dtype)
        _, acc_update_op = tf.metrics.accuracy(labels=labels,
                                               predictions=preds,
                                               weights=mask)
        sess.run(tf.initialize_local_variables())
        for _ in range(FLAGS.eval_steps):
            acc = sess.run(acc_update_op)
        return acc

    acc = compute_accuracy(inputs, labels, acc_mask)
    epsilon_acc_pairs = [(0.0, acc)]
    for epsilon in attack_params.attack_epsilons:
        attack_params.eps = epsilon
        adv_x = attack.generate(inputs,
                                y=one_hot_labels,
                                **attack_params.values())
        acc = compute_accuracy(adv_x, labels, acc_mask)
        epsilon_acc_pairs.append((epsilon, acc))

    for epsilon, acc in epsilon_acc_pairs:
        tf.logging.info("Accuracy @ eps=%f: %f" % (epsilon, acc))
Ejemplo n.º 20
0
def music_generator(primer='erik_gnossienne',
                    primer_begin_buffer=10,
                    primer_length=90,
                    output_path='.',
                    filename='./public/output'):
    SF2_PATH = './models/Yamaha-C5-Salamander-JNv5.1.sf2'
    SAMPLE_RATE = 16000

    # Upload a MIDI file and convert to NoteSequence.
    def upload_midi():
        data = list(files.upload().values())
        if len(data) > 1:
            print('Multiple files uploaded; using only one.')
        return mm.midi_to_note_sequence(data[0])

    # Decode a list of IDs.
    def decode(ids, encoder):
        ids = list(ids)
        if text_encoder.EOS_ID in ids:
            ids = ids[:ids.index(text_encoder.EOS_ID)]
        return encoder.decode(ids)

    model_name = 'transformer'
    hparams_set = 'transformer_tpu'
    ckpt_path = './models/checkpoints/unconditional_model_16.ckpt'

    class PianoPerformanceLanguageModelProblem(score2perf.Score2PerfProblem):
        @property
        def add_eos_symbol(self):
            return True

    problem = PianoPerformanceLanguageModelProblem()
    unconditional_encoders = problem.get_feature_encoders()

    # Set up HParams.
    hparams = trainer_lib.create_hparams(hparams_set=hparams_set)
    trainer_lib.add_problem_hparams(hparams, problem)
    hparams.num_hidden_layers = 16
    hparams.sampling_method = 'random'

    # Set up decoding HParams.
    decode_hparams = decoding.decode_hparams()
    decode_hparams.alpha = 0.0
    decode_hparams.beam_size = 1

    # Create Estimator.
    run_config = trainer_lib.create_run_config(hparams)
    estimator = trainer_lib.create_estimator(model_name,
                                             hparams,
                                             run_config,
                                             decode_hparams=decode_hparams)

    # These values will be changed by subsequent cells.
    targets = []
    decode_length = 0

    # Create input generator (so we can adjust priming and
    # decode length on the fly).
    def input_generator():
        global targets
        global decode_length
        while True:
            yield {
                'targets': np.array([targets], dtype=np.int32),
                'decode_length': np.array(decode_length, dtype=np.int32)
            }

    # Start the Estimator, loading from the specified checkpoint.
    input_fn = decoding.make_input_fn_from_generator(input_generator())
    unconditional_samples = estimator.predict(input_fn,
                                              checkpoint_path=ckpt_path)

    # "Burn" one.
    _ = next(unconditional_samples)

    filenames = {
        'C major arpeggio': './models/primers/c_major_arpeggio.mid',
        'C major scale': './models/primers/c_major_scale.mid',
        'Clair de Lune': './models/primers/clair_de_lune.mid',
        'Classical':
        'audio_midi/Classical_Piano_piano-midi.de_MIDIRip/bach/bach_846_format0.mid',
        'erik_gymnopedie': 'audio_midi/erik_satie/gymnopedie_1_(c)oguri.mid',
        'erik_gymnopedie_2': 'audio_midi/erik_satie/gymnopedie_2_(c)oguri.mid',
        'erik_gymnopedie_3': 'audio_midi/erik_satie/gymnopedie_3_(c)oguri.mid',
        'erik_gnossienne': 'audio_midi/erik_satie/gnossienne_1_(c)oguri.mid',
        'erik_gnossienne_2': 'audio_midi/erik_satie/gnossienne_2_(c)oguri.mid',
        'erik_gnossienne_3': 'audio_midi/erik_satie/gnossienne_3_(c)oguri.mid',
        'erik_gnossienne_dery':
        'audio_midi/erik_satie/gnossienne_1_(c)dery.mid',
        'erik_gnossienne_dery_2':
        'audio_midi/erik_satie/gnossienne_2_(c)dery.mid',
        'erik_gnossienne_dery_3':
        'audio_midi/erik_satie/gnossienne_3_(c)dery.mid',
        'erik_gnossienne_dery_5':
        'audio_midi/erik_satie/gnossienne_5_(c)dery.mid',
        'erik_gnossienne_dery_6':
        'audio_midi/erik_satie/gnossienne_6_(c)dery.mid',
        '1': 'audio_midi/erik_satie/1.mid',
        '2': 'audio_midi/erik_satie/2.mid',
        '3': 'audio_midi/erik_satie/3.mid',
        '4': 'audio_midi/erik_satie/4.mid',
        '5': 'audio_midi/erik_satie/5.mid',
        '6': 'audio_midi/erik_satie/6.mid',
        '7': 'audio_midi/erik_satie/7.mid',
        '8': 'audio_midi/erik_satie/8.mid',
        '9': 'audio_midi/erik_satie/9.mid',
        '10': 'audio_midi/erik_satie/10.mid',
    }
    # primer = 'C major scale'

    #if primer == 'Upload your own!':
    #  primer_ns = upload_midi()
    #else:
    #  # Use one of the provided primers.
    #  primer_ns = mm.midi_file_to_note_sequence(filenames[primer])
    primer_ns = mm.midi_file_to_note_sequence(filenames[primer])
    # Handle sustain pedal in the primer.
    primer_ns = mm.apply_sustain_control_changes(primer_ns)

    # Trim to desired number of seconds.
    max_primer_seconds = primer_length
    if primer_ns.total_time > max_primer_seconds:
        print('Primer is longer than %d seconds, truncating.' %
              max_primer_seconds)
        primer_ns = mm.extract_subsequence(
            primer_ns, primer_begin_buffer,
            max_primer_seconds + primer_begin_buffer)

    # Remove drums from primer if present.
    if any(note.is_drum for note in primer_ns.notes):
        print('Primer contains drums; they will be removed.')
        notes = [note for note in primer_ns.notes if not note.is_drum]
        del primer_ns.notes[:]
        primer_ns.notes.extend(notes)

    # Set primer instrument and program.
    for note in primer_ns.notes:
        note.instrument = 1
        note.program = 0

    ## Play and plot the primer.
    #mm.play_sequence(
    #    primer_ns,
    #    synth=mm.fluidsynth, sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH)
    #mm.plot_sequence(primer_ns)
    mm.sequence_proto_to_midi_file(
        primer_ns, join(output_path, 'primer_{}.mid'.format(filename)))

    targets = unconditional_encoders['targets'].encode_note_sequence(primer_ns)

    # Remove the end token from the encoded primer.
    targets = targets[:-1]

    decode_length = max(0, 10000 - len(targets))
    if len(targets) >= 4096:
        print(
            'Primer has more events than maximum sequence length; nothing will be generated.'
        )

    # Generate sample events.
    sample_ids = next(unconditional_samples)['outputs']

    # Decode to NoteSequence.
    midi_filename = decode(sample_ids,
                           encoder=unconditional_encoders['targets'])
    ns = mm.midi_file_to_note_sequence(midi_filename)
    print('Sample IDs: {}'.format(sample_ids))
    print('Sample IDs length: {}'.format(len(sample_ids)))
    print('Encoder: {}'.format(unconditional_encoders['targets']))
    print('Unconditional Samples: {}'.format(unconditional_samples))
    # print('{}'.format(ns))

    # continuation_ns = mm.concatenate_sequences([primer_ns, ns])
    continuation_ns = ns
    # mm.play_sequence(
    #     continuation_ns,
    #     synth=mm.fluidsynth, sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH)
    # mm.plot_sequence(continuation_ns)
    # try:
    audio = mm.fluidsynth(continuation_ns,
                          sample_rate=SAMPLE_RATE,
                          sf2_path=SF2_PATH)

    normalizer = float(np.iinfo(np.int16).max)
    array_of_ints = np.array(np.asarray(audio) * normalizer, dtype=np.int16)

    wavfile.write(join(output_path, filename + '.wav'), SAMPLE_RATE,
                  array_of_ints)
    print('[+] Output stored as {}'.format(filename + '.wav'))
    mm.sequence_proto_to_midi_file(
        continuation_ns,
        join(output_path, 'continuation_{}.mid'.format(filename)))
Ejemplo n.º 21
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  if FLAGS.surrogate_attack:
    tf.logging.warn("Performing surrogate model attack.")
    sur_hparams = create_surrogate_hparams()
    trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem)

  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

  attack_params = create_attack_params()
  attack_params.add_hparam(attack_params.epsilon_name, 0.0)

  if FLAGS.surrogate_attack:
    sur_config = create_surrogate_run_config(sur_hparams)
  config = t2t_trainer.create_run_config(hparams)
  params = {
      "batch_size": hparams.batch_size,
      "use_tpu": FLAGS.use_tpu,
  }

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")

  inputs, labels, features = prepare_data(problem, hparams, params, config)

  sess = tf.Session()

  if FLAGS.surrogate_attack:
    sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu)
    sur_ch_model = adv_attack_utils.T2TAttackModel(
        sur_model_fn, features, params, sur_config, scope="surrogate")
    # Dummy call to construct graph
    sur_ch_model.get_probs(inputs)

    checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir)
    tf.contrib.framework.init_from_checkpoint(
        tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"})
    sess.run(tf.global_variables_initializer())

  other_vars = set(tf.global_variables())

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1, output_type=labels.dtype)
    preds = tf.reshape(preds, labels.shape)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  if FLAGS.surrogate_attack:
    attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess)
  else:
    attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  new_vars = set(tf.global_variables()) - other_vars

  # Restore weights
  saver = tf.train.Saver(new_vars)
  checkpoint_path = os.path.expanduser(FLAGS.output_dir)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, l, mask):
    """Compute model accuracy."""
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=l.dtype)

    _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask)

    if FLAGS.surrogate_attack:
      preds = sur_ch_model.get_probs(x)
      preds = tf.squeeze(preds)
      preds = tf.argmax(preds, -1, output_type=l.dtype)
      acc_update_op = tf.tuple((acc_update_op,
                                tf.metrics.accuracy(l, preds, weights=mask)[1]))

    sess.run(tf.initialize_local_variables())
    for i in range(FLAGS.eval_steps):
      tf.logging.info(
          "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps))
      acc = sess.run(acc_update_op)
    if FLAGS.surrogate_attack:
      tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1]))
    else:
      tf.logging.info("\tFinal acc: %.4f" % acc)
    return acc

  epsilon_acc_pairs = []
  for epsilon in attack_params.attack_epsilons:
    tf.logging.info("Attacking @ eps=%.4f" % epsilon)
    attack_params.set_hparam(attack_params.epsilon_name, epsilon)
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    if FLAGS.surrogate_attack:
      tf.logging.info(
          "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1]))
    else:
      tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))