def __call__(self, predictions, masks, separate_instruments=True): bb, tt, pp, ii = masks.shape if separate_instruments: # select one variable to sample. sample according to normalized mask; # is uniform as all masked out variables have equal positive weight. selection = masks.max(axis=2).reshape([bb, tt * ii]) selection = lib_util.sample(selection, axis=1, onehot=True) selection = selection.reshape([bb, tt, 1, ii]) else: selection = masks.reshape([bb, tt * pp]) selection = lib_util.sample(selection, axis=1, onehot=True) selection = selection.reshape([bb, tt, pp, ii]) # Intersect with mask to avoid selecting outside of the mask, e.g. in case # some masks[b] is zero everywhere. # This can happen inside blocked Gibbs, where different examples have # different block sizes. return selection * masks
def sample_predictions(self, predictions, temperature=None): """Sample from model outputs.""" temperature = self.temperature if temperature is None else temperature if self.separate_instruments: return lib_util.sample( predictions, axis=2, onehot=True, temperature=temperature) else: return lib_util.sample_bernoulli( 0.5 * predictions, temperature=temperature)