コード例 #1
0
ファイル: coconet_sample.py プロジェクト: vanton/magenta
    def run(self, tuple_in):
        shape, midi_in = tuple_in
        pianorolls = self.decoder.encode_midi_to_pianoroll(midi_in, shape)
        # fill in the silences
        masks = lib_sampling.CompletionMasker()(pianorolls)
        gibbs = self.make_sampler("gibbs",
                                  masker=lib_sampling.BernoulliMasker(),
                                  sampler=self.make_sampler(
                                      "independent",
                                      temperature=FLAGS.temperature),
                                  schedule=lib_sampling.YaoSchedule())

        with self.logger.section("context"):
            context = np.array([
                lib_mask.apply_mask(pianoroll, mask)
                for pianoroll, mask in zip(pianorolls, masks)
            ])
            self.logger.log(pianorolls=context,
                            masks=masks,
                            predictions=context)
        pianorolls = gibbs(pianorolls, masks)
        with self.logger.section("result"):
            self.logger.log(pianorolls=pianorolls,
                            masks=masks,
                            predictions=pianorolls)
        return pianorolls
コード例 #2
0
ファイル: lib_tfsampling.py プロジェクト: yati989/magenta
    def run(self,
            pianorolls,
            masks=None,
            sample_steps=0,
            current_step=0,
            total_gibbs_steps=0,
            temperature=0.99):
        """Given input pianorolls, runs Gibbs sampling to fill in the rest.

    When total_gibbs_steps is 0, total_gibbs_steps is set to
    time * instruments.  If faster sampling is desired on the expanse of sample
    quality, total_gibbs_steps can be explicitly set to a lower number,
    possibly to the value of sample_steps if do not plan on stopping sample
    early to obtain intermediate results.

    This function can be used to return intermediate results by setting the
    sample_steps to when results should be returned and leaving
    total_gibbs_steps to be 0.

    To continue sampling from intermediate results, set current_step to the
    number of steps taken, and feed in the intermediate pianorolls.  Again
    leaving total_gibbs_steps as 0.

    Builds the graph and restores checkpoint if necessary.

    Args:
      pianorolls: a 4D numpy array of shape (batch, time, pitch, instrument)
      masks: a 4D numpy array of the same shape as pianorolls, with 1s
          indicating mask out.  If is None, then the masks will be where have 1s
          where there are no notes, indicating to the model they should be
          filled in.
      sample_steps: an integer indicating the number of steps to sample in this
          call.  If set to 0, then it defaults to total_gibbs_steps.
      current_step: an integer indicating how many steps might have already
          sampled before.
      total_gibbs_steps: an integer indicating the total number of steps that
          a complete sampling procedure would take.
      temperature: a float indicating the temperature for sampling from softmax.

    Returns:
      A dictionary, consisting of "pianorolls" which is a 4D numpy array of
      the sampled results and "time_taken" which is the time taken in sampling.
    """
        if self.sess is None:
            # Build graph and restore checkpoint.
            self.instantiate_sess_and_restore_checkpoint()

        outer_masks = masks
        if outer_masks is None:
            outer_masks = lib_sampling.CompletionMasker()(pianorolls)

        start_time = time.time()
        new_piece = self.sess.run(
            self.samples,
            feed_dict={
                self.placeholders["pianorolls"]: pianorolls,
                self.placeholders["outer_masks"]: outer_masks,
                self.placeholders["sample_steps"]: sample_steps,
                self.placeholders["total_gibbs_steps"]: total_gibbs_steps,
                self.placeholders["current_step"]: current_step,
                self.placeholders["temperature"]: temperature
            })

        label = "independent blocked gibbs"
        time_taken = (time.time() - start_time) / 60.0
        tf.logging.info("exit  %s (%.3fmin)" % (label, time_taken))
        return dict(pianorolls=new_piece, time_taken=time_taken)