def __init__(self, wmodel, temperature=1, logger=None, **unused_kwargs): """Initialize a BaseSampler instance. Args: wmodel: a WrappedModel instance temperature: sampling temperature logger: Logger instance """ self.wmodel = wmodel self.temperature = temperature self.logger = logger if logger is not None else lib_logging.NoLogger() def predictor(pianorolls, masks): predictions = self.wmodel.sess.run(self.wmodel.model.predictions, { self.wmodel.model.pianorolls: pianorolls, self.wmodel.model.masks: masks }) return predictions self.predictor = lib_tfutil.RobustPredictor(predictor)
def __init__(self, wmodel, chronological): """Initialize BaseEvaluator instance. Args: wmodel: WrappedModel instance chronological: whether to evaluate in chronological order or in any order """ self.wmodel = wmodel self.chronological = chronological def predictor(pianorolls, masks): p = self.wmodel.sess.run(self.wmodel.model.predictions, feed_dict={ self.wmodel.model.pianorolls: pianorolls, self.wmodel.model.masks: masks }) return p self.predictor = lib_tfutil.RobustPredictor(predictor)