Beispiel #1
0
    def select_action(self, state, training=True):
        sample = random.random()
        if training:
            self._eps -= (self._INITIAL_EPSILON -
                          self._FINAL_EPSILON) / self._EPS_DECAY
            self._eps = max(self._eps, self._FINAL_EPSILON)
        action_mean = None
        action_var = None
        if sample > self._eps:
            with torch.no_grad():
                q_vals = torch.zeros(
                    (self._policy_net.get_num_ensembles(), self._n_actions))
                state = state.to(self._device)
                for i in range(self._policy_net.get_num_ensembles()):
                    q_vals[i, :] = self._policy_net(
                        state, ens_num=i).to('cpu').squeeze(0)
                action_mean = torch.mean(q_vals, 0)
                action_var = torch.var(q_vals, 0)
                top_idx = torch.argmax(action_mean)
                score = torch.zeros((self._n_actions))
                for i in range(self._n_actions):
                    normal_val = (action_mean[top_idx] - action_mean[i]) / \
                        (action_var[top_idx] + action_var[i])
                    score[i] = 1. - self._dist.cdf(normal_val)
                action_dist = Multinomial(1, score)
                a = action_dist.sample().argmax().item()
        else:
            a = torch.tensor([[random.randrange(self._n_actions)]],
                             device='cpu',
                             dtype=torch.long).numpy()[0, 0].item()

        return a, self._eps, action_mean, action_var
 def choosePathByAlphas(self):
     alphas = self.alphas()[0]
     # draw partition from multinomial distribution
     dist = Multinomial(total_count=self.nLayers(), logits=alphas)
     partition = dist.sample().type(int32)
     # set partition as model path
     self._setPartitionPath(partition)
Beispiel #3
0
def generate_text(model, input_string, num_char=50, top_k=1):
    num_string = np.array(
        [dataset_test.to_int(x) for x in input_string.lower()])
    tensor_string = torch.from_numpy(num_string).unsqueeze(0)
    model.eval()
    generated_text = []

    model.reset_hidden(1)

    for i in range(num_char):
        probs = model.predict(tensor_string)
        probs = probs.squeeze(0)

        mult_distr = Multinomial(total_count=50, probs=probs)
        mult_distr_sample = mult_distr.sample()

        indices = torch.topk(
            mult_distr_sample,
            k=top_k).indices  #selects top_k classes with highest probability
        rd = np.random.randint(
            low=0, high=top_k,
            size=1)  #select random class from top probabilities
        idx = indices[rd].unsqueeze(1)

        tensor_string = torch.cat((tensor_string, idx), dim=1)
        generated_text.append(dataset_test.to_char(indices[rd].numpy()[0]))

    return (input_string + ''.join(generated_text))
Beispiel #4
0
    def compute_loss_for_batch(self,
                               data,
                               model,
                               K=K,
                               testing_mode=False,
                               alpha=alpha):
        # data = (B, 1, H, W)
        B, _, H, W = data.shape
        data_k_vec = data.repeat((1, K, 1, 1)).view(-1, H * W)
        mu, logstd = model.encode(data_k_vec)
        # (B*K, #latents)
        z = model.reparameterize(mu, logstd)

        # summing over latents due to independence assumption
        # (B*K)
        log_q = compute_log_probabitility_gaussian(z, mu, logstd)

        log_p_z = compute_log_probabitility_gaussian(
            z, torch.zeros_like(z, requires_grad=False),
            torch.zeros_like(z, requires_grad=False))
        decoded = model.decode(z)
        if discrete_data:
            log_p = compute_log_probabitility_bernoulli(decoded, data_k_vec)
        else:
            # Gaussian where sigma = 0, not letting sigma be predicted atm
            log_p = compute_log_probabitility_gaussian(
                decoded, data_k_vec, torch.zeros_like(decoded))
        # hopefully this reshape operation magically works like always
        if model_type == 'iwae' or testing_mode:
            log_w_matrix = (log_p_z + log_p - log_q).view(B, K)
        elif model_type == 'vae':
            # treat each sample for a given data point as you would treat all samples in the minibatch
            # 1/K value because loss values seemed off otherwise
            log_w_matrix = (log_p_z + log_p - log_q).view(B * K, 1) * 1 / K
        elif model_type == 'general_alpha' or model_type == 'vralpha':
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K) * (1 - alpha)
        elif model_type == 'vrmax':
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K).max(
                axis=1, keepdim=True).values

        log_w_minus_max = log_w_matrix - torch.max(
            log_w_matrix, 1, keepdim=True)[0]
        ws_matrix = torch.exp(log_w_minus_max)
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if model_type == 'vralpha' and not testing_mode:
            sample_dist = Multinomial(1, ws_norm)
            ws_sum_per_datapoint = log_w_matrix.gather(
                1,
                sample_dist.sample().argmax(1, keepdim=True))
        else:
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if model_type in ["general_alpha", "vralpha"] and not testing_mode:
            ws_sum_per_datapoint /= (1 - alpha)

        loss = -torch.sum(ws_sum_per_datapoint)

        return decoded, mu, logstd, loss
def sampled_kwinners2d(x, k, temperature, relu=False, inplace=False):
    """
    A stochastic K-winner take all function for Conv2d layers with sparse output.
    Keeps only k units which are sampled according to a softmax distribution over
    the activations.

    :param x:
      Current activity of each unit, optionally batched along the 0th dimension.

    :param k:
      The activity of the top k units will be allowed to remain, the rest are
      set to zero.

    :param temperature:
      Temperature to use when computing the softmax distribution over activations.
      Higher temperatures increases the entropy, lower temperatures decrease entropy.

    :param relu:
      Whether to simulate the effect of applying ReLU before KWinners

    :param inplace:
      Whether to modify x in place

    :return:
      A tensor representing the activity of x after k-winner take all.
    """
    if k == 0:
        return torch.zeros_like(x)
    shape2 = (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
    logits = x.view(shape2)
    probs = softmax(logits / temperature, dim=-1)

    # Default to 2dkwinners when sampling k is not possible. A low enough temperature
    # has only one non-zero in the softmax distribution. However, Multinomial only
    # can sample the non-zero probabilites.
    if ((probs != 0).sum(dim=1) < k).any():
        # Use kwinners as default; at low temperatures, this is equivalent
        return kwinners2d(
            x,
            duty_cycles=None,
            k=k,
            boost_strength=0.0,  # no boosting
            local=False,
            break_ties=False,
            relu=relu,
            inplace=inplace)

    dist = Multinomial(total_count=k, probs=probs)
    on_mask = dist.sample().bool()
    if relu:
        on_mask |= (logits <= 0)
    off_mask = ~on_mask
    off_mask = off_mask.view(x.shape)
    if inplace:
        return x.masked_fill_(off_mask, 0)
    else:
        return x.masked_fill(off_mask, 0)
    def __init__(self, q_init_xa_conds, x_support, a_domain):
        self.q_init_xa_conds = q_init_xa_conds
        self.x_support = x_support  # Tuple of lower and upper bound pairs
        self.a_domain = a_domain  # Enumeration of possible a's
        self.models = []
        self.thetas = []
        self.logzs = []

        self.a_pdf = Multinomial(probs=torch.ones(len(a_domain)))
class KDE(object):
    def __init__(self, landmarks, coeff=1.0):
        self.landmarks = landmarks
        with torch.no_grad():
            n = landmarks.shape[0]
            self.num = n
            s = torch.std(landmarks, dim=0, keepdim=True)
            # Silverman's rule of thumb
            self.dim = landmarks.shape[1]
            self.h = np.power(4.0 / (self.dim + 2.0), 1.0 / (self.dim + 4))
            self.h *= np.power(self.landmarks.shape[0], -1.0 / (self.dim + 4))
            self.h = self.h * s * coeff
            self.h = self.h.to(DEVICE)
        self.landmarks = self.landmarks.to(DEVICE)
        num_landmarks = self.landmarks.shape[0]
        p = torch.ones(num_landmarks, dtype=t_float) / float(num_landmarks)
        self.idx_sampler = Multinomial(probs=p)

    def log_pdf(self, x):
        return mix_gauss_pdf(x, self.landmarks, self.h)

    def get_samples(self, num_samples):
        idx = self.idx_sampler.sample(sample_shape=[num_samples]).to(DEVICE)
        centers = torch.matmul(idx, self.landmarks)

        z = torch_randn2d(num_samples, self.dim) * self.h + centers
        return z
 def __init__(self, landmarks, coeff=1.0):
     self.landmarks = landmarks
     with torch.no_grad():
         n = landmarks.shape[0]
         self.num = n
         s = torch.std(landmarks, dim=0, keepdim=True)
         # Silverman's rule of thumb
         self.dim = landmarks.shape[1]
         self.h = np.power(4.0 / (self.dim + 2.0), 1.0 / (self.dim + 4))
         self.h *= np.power(self.landmarks.shape[0], -1.0 / (self.dim + 4))
         self.h = self.h * s * coeff
         self.h = self.h.to(DEVICE)
     self.landmarks = self.landmarks.to(DEVICE)
     num_landmarks = self.landmarks.shape[0]
     p = torch.ones(num_landmarks, dtype=t_float) / float(num_landmarks)
     self.idx_sampler = Multinomial(probs=p)
Beispiel #9
0
 def get_action(obs, epsilon=0.01):
     # epsilon-greedy
     with torch.no_grad():
         action_values = q_learner_net(obs)
         if len(obs.shape) == 1:
             # for single action policy rollout
             action = torch.argmax(action_values)
             act_probs = torch.ones((n_acts)) * epsilon / (n_acts - 1)
             act_probs[action] = 1 - epsilon
             act_distro = Multinomial(1, probs=act_probs)
             action = torch.argmax(act_distro.sample())
             action = action.numpy()
         else:
             # for batch actions
             action = torch.argmax(action_values, dim=1, keepdim=True)
     return action
Beispiel #10
0
    def forward(self, inputs):
        """
        If self.training or not self.is_valid, just return inputs.
        If self.is_valid apply SAP to inputs and return the result tensor.
        Parameters
        ----------
        inputs torch.Tensor : input tensor whose shape is [b, c, h, w].
        Returns
        -------
        outputs torch.Tensor : just return inputs or stochastically pruned inputs.
        """
        # print("SAP: ", self.is_valid)
        # if self.training or not self.is_valid:
        if not self.is_valid:
            return inputs
        else:
            b, c, h, w = inputs.shape
            inputs_1d = inputs.reshape([b, c * h * w])  # [b, c * h * w]
            # print(inputs_1d)
            outputs = torch.zeros_like(
                inputs_1d)  # outputs with 0 initilization

            inputs_1d_sum = torch.sum(torch.abs(inputs_1d),
                                      dim=-1,
                                      keepdim=True)
            inputs_1d_prob = torch.abs(inputs_1d) / inputs_1d_sum

            # r: num_nodes
            num_sample = int(c * h * w * self.ratio)

            # multinomial(total_count:int, probs:tensor, logits:tensor)
            idx = Multinomial(num_sample, inputs_1d_prob).sample()

            # if nonzero, keep; else, drop, let be zeroes
            outputs[idx.nonzero(as_tuple=True)] = inputs_1d[idx.nonzero(
                as_tuple=True)]

            # pdb.set_trace()
            # scale up
            outputs = outputs / (1 - (1 - inputs_1d_prob)**num_sample + 1e-12)
            outputs = outputs.reshape([b, c, h, w])  # [b, c, h, w]
            # print("OUT: ", outputs)
        return outputs
    def sample_intention(self, h, schedule_decisions=None) -> torch.Tensor:
        if h == 0:
            Ps = F.softmax(self.Q_table[0] / self.temperature)
        elif h == 1:
            Ps = F.softmax(
                (self.Q_table[1][schedule_decisions[0]] / self.temperature))
        else:
            raise ValueError("Invalid number of tasks per episode.")

        # return int(Multinomial(probs=Ps).sample().item())
        return torch.where(Multinomial(probs=Ps).sample() == 1)[0]
    def evaluate(self, possible_boards):

        # possible_boards -> neural network -> sigmoid -> last_layer_sigmoid
        last_layer_outputs = self.run_through_neural_network(possible_boards)
        # last_layer_sigmoid = list(map(lambda x: x.sigmoid(), last_layer_outputs))

        # Decide move and save log_prob for backward
        # We make sure not to affect the value fn with .detach()

        probs = self.pg_plugin.softmax(last_layer_outputs)
        distribution = Multinomial(1, probs)
        move = distribution.sample()
        self.saved_log_probabilities.append(distribution.log_prob(move))

        _, move = move.max(0)
        # calculate the value estimation and save for backward
        value_estimate = self.pg_plugin.value_function(
            last_layer_outputs[move])
        self.saved_value_estimations.append(value_estimate)
        return move
Beispiel #13
0
    def prune(self, size_pruned):
        """Prune across neurons within the layer to the desired size."""
        probs = self.probability
        num_filters = probs.shape[0]

        if size_pruned > 0:
            num_to_sample = adapt_sample_size(probs.view(-1), size_pruned,
                                              self.uniform)
            num_to_sample = int(num_to_sample)
            num_samples = Multinomial(num_to_sample, probs).sample().int()
        else:
            num_samples = torch.zeros(num_filters).int().to(size_pruned.device)

        return num_samples
Beispiel #14
0
            cluster_knn_sent_dist, cluster_knn_sent_idx = \
                    torch.topk(distances, k=args.truncate_cluster_nn, dim=0, largest=False)

            # gathering all possible sentences to be sampled or ranked
            # [k x H x nSamples]
            sampled_neighbourhoods = cluster_knn_sent_idx.gather(
                2, sampled_clusters)

            if args.sample_sentences:
                # [k x H x nSamples]
                sampled_neighbourhood_dist = \
                        cluster_knn_sent_dist.gather(2, sampled_clusters)

                # multinomial over sentences
                multi = Multinomial(
                    total_count=args.num_sent_samples,
                    logits=-sampled_neighbourhood_dist.transpose(0, 2) /
                    args.temp)

                # [k x H x C]
                samples = multi.sample().transpose(0, 2)

                # [k x H x nSent]
                scattered_sampled_sentences_counts = \
                        torch.zeros(args.truncate_cluster_nn, used_nheads, nsent).to(device)
                scattered_sampled_sentences_counts.scatter_add_(
                    2, sampled_neighbourhoods, samples)

                # [nSent]
                sampled_sentences_counts = scattered_sampled_sentences_counts.sum(
                    dim=[0, 1])
            else:
Beispiel #15
0
    def compute_loss_for_batch(self,
                               data,
                               model,
                               K=K,
                               test=False,
                               alpha=alpha):
        # data = (N,560)
        if model_type == 'vae':
            alpha = 1
        elif model_type in ('iwae', 'vrmax'):
            alpha = 0
        else:
            # use whatever alpha is defined in hyperparameters
            if abs(alpha - 1) <= 1e-3:
                alpha = 1

        data_k_vec = data.repeat_interleave(K, 0)

        mu, logstd = model.encode(data_k_vec)
        # (B*K, #latents)
        z = model.reparameterize(mu, logstd)

        # summing over latents due to independence assumption
        # (B*K)
        log_q = compute_log_probabitility_gaussian(z, mu, logstd)

        log_p_z = torch.sum(
            -0.5 * z**2, 1) - .5 * z.shape[1] * T.log(torch.tensor(2 * np.pi))
        decoded = model.decode(z)  # decoded = (pmu, plog_sigma)
        log_p = compute_log_probabitility_bernoulli(decoded, data_k_vec)
        # hopefully this reshape operation magically works like always
        if model_type == 'iwae' or test == True:
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K)
        elif model_type == 'vae':
            # treat each sample for a given data point as you would treat all samples in the minibatch
            # 1/K value because loss values seemed off otherwise
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, 1) * 1 / K
        elif model_type == 'general_alpha' or model_type == 'vralpha':
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K) * (1 - alpha)
        elif model_type == 'vrmax':
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K).max(
                axis=1, keepdim=True).values

        log_w_minus_max = log_w_matrix - torch.max(
            log_w_matrix, 1, keepdim=True)[0]
        ws_matrix = torch.exp(log_w_minus_max)
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if model_type == 'vralpha' and not test:
            sample_dist = Multinomial(1, ws_norm)
            ws_sum_per_datapoint = log_w_matrix.gather(
                1,
                sample_dist.sample().argmax(1, keepdim=True))
        else:
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if model_type in ["general_alpha", "vralpha"] and not test:
            ws_sum_per_datapoint /= (1 - alpha)

        loss = -torch.sum(ws_sum_per_datapoint)

        return decoded, mu, logstd, loss
Beispiel #16
0
 def choosePathByAlphas(self):
     dist = Multinomial(total_count=self.nFilters(), logits=self.alphas)
     partition = dist.sample().type(int32)
     self.setFiltersPartition(partition)
Beispiel #17
0
def random_sample(cosine, sampling_num, temprature=0.5):
    batch_size, channels, h, w = cosine.shape
    prob = torch.exp(cosine / temprature)
    return Multinomial(sampling_num, prob.view(batch_size, channels,
                                               -1)).sample().view(cosine.shape)
class FairDensity(Distribution):

    def __init__(self, q_init_xa_conds, x_support, a_domain):
        self.q_init_xa_conds = q_init_xa_conds
        self.x_support = x_support  # Tuple of lower and upper bound pairs
        self.a_domain = a_domain  # Enumeration of possible a's
        self.models = []
        self.thetas = []
        self.logzs = []

        self.a_pdf = Multinomial(probs=torch.ones(len(a_domain)))


    def normalise(self, abstol=1e-3):
        z = 0
        for a in self.a_domain:
            z_a = 0
            for x in self.x_support:
                z_a += math.exp(self._unnorm_log_prob(torch.Tensor(x), a))

            z += z_a


        return math.log(z)

    def log_prob(self, x, a):
        norm = 0
        if len(self.logzs) > 0:
            norm = self.logzs[-1]

        return self._unnorm_log_prob(x, a) - norm


    def _unnorm_log_prob(self, x, a):
        ai = self.a_domain.index(a)
        dens = self.q_init_xa_conds[ai].log_prob(x) - math.log(len(self.a_domain))

        for (m, theta) in zip(self.models, self.thetas):
            dens += theta * m(x)

        return dens


    def append(self, m, theta):
        self.models.append(m)
        self.thetas.append(theta)
        self.logzs.append(self.normalise())


    def gen_loss_function(self):
        def loss(classifier, samples):
            p_samples = [tuple([s[0], s[1]]) for s in samples if s[2] == 1]
            q_samples = [tuple([s[0], s[1]]) for s in samples if s[2] == 0]

            p_x_samples, _ = zip(*p_samples)
            q_x_samples, _ = zip(*q_samples)

            p_x_samples = torch.stack(p_x_samples)
            q_x_samples = torch.stack(q_x_samples)


            if len(self.logzs) > 0:
                weights = math.exp(-self.logzs[-1])
            else:
                weights = 1
            for (m, theta) in zip(self.models, self.thetas):
                weights *= torch.exp(theta * m(q_x_samples))

            p_expectation = torch.mean(torch.log(torch.sigmoid(classifier(p_x_samples))))
            q_expectation = torch.mean(torch.log(1 - torch.sigmoid(classifier(q_x_samples))) * weights)

            return -(p_expectation + q_expectation)

        return loss


    def representation_rate(self):
        rr_list = []

        a_probs = {}
        for a in self.a_domain:
            a_prob = 0
            for x in self.x_support:
                a_prob += math.exp(self.log_prob(torch.Tensor(x), a))
            a_probs[a] = a_prob

        for (ai, aj) in itertools.product(self.a_domain, repeat=2):
            rr_list.append(a_probs[ai] / a_probs[aj])

        return min(rr_list)


    def sample_q_init(self, n):

        a_samples = self.a_pdf.sample([int(n)])
        a_counts = torch.sum(a_samples, axis=0)

        sample_list = []
        for i in range(len(a_counts)):

            x_samples = self.q_init_xa_conds[i].sample([int(a_counts[i])])
            xa_samples = [(x, self.a_domain[i]) for x in x_samples]

            sample_list += xa_samples

        return sample_list


    def get_prob_array(self):
        probs = []
        for x, a in itertools.product(self.x_support, self.a_domain):
            probs.append(math.exp(self.log_prob(torch.Tensor(x), a)))

        return probs


    def rsample(self, sample_shape=torch.Size([])):
        a_sampler = D.OneHotCategorical(probs=torch.ones(len(self.a_domain)))

        probs = {}
        x_a_dists = {}
        for a in self.a_domain:
            probs[a] = {}
            for x in self.x_support:
                probs[a][x] = math.exp(self.log_prob(x, a))

            normalise = sum(probs[a].values())
            for x in self.x_support:
                probs[a][x] = probs[a][x] / normalise

            x_a_dists[a] = EmpiricalDistribution(None, domain=self.x_support, probs=probs[a])
            self.test = EmpiricalDistribution(None, domain=self.x_support, probs=probs[a])

        a_vals = a_sampler.sample_n(sample_shape.numel())

        a_counts = torch.sum(a_vals, axis=0)

        x_samples = []
        a_samples = []
        for a_c, a_vals in zip(a_counts, self.a_domain):
            a_c = int(a_c)
            x_samples.append(x_a_dists[a_vals].sample_n(a_c))
            a_samples.append(torch.Tensor([a_vals] * a_c))

        x_samples = torch.cat(x_samples).view(*sample_shape, -1)
        a_samples = torch.cat(a_samples).view(*sample_shape, -1)

        return x_samples, a_samples
Beispiel #19
0
    def compute_loss_for_batch(self, data, model, K=K, test=False):
        # data = (B, 1, H, W)
        B, _, H, W = data.shape

        # Generate K copies of each observation. Each will get sampled once according to the generated distribution to generate a total of K observation samples
        data_k_vec = data.repeat((1, K, 1, 1)).view(-1, H * W)

        # Retrieve the estimated mean and log(standard deviation) estimates from the posterior approximator
        mu, logstd = model.encode(data_k_vec)

        # Use the reparametrization trick to generate (mean)+(epsilon)*(standard deviation) for each sample of each observation
        z = model.reparameterize(mu, logstd)

        # Calculate log q(z|x) - how likely are the importance samples given the distribution that generated them?
        log_q = compute_log_probabitility_gaussian(z, mu, logstd)

        # Calculate log p(z) - how likely are the importance samples under the prior N(0,1) assumption?
        log_p_z = compute_log_probabitility_gaussian(
            z, torch.zeros_like(z, requires_grad=False),
            torch.zeros_like(z, requires_grad=False))

        # Hand the samples to the decoder network and get a reconstruction of each sample.
        decoded = model.decode(z)

        # Calculate log p(x|z) with a bernoulli distribution - how likely are the recreations given the latents that generated them?
        log_p = compute_log_probabitility_bernoulli(decoded, data_k_vec)

        # Begin calculating L_alpha depending on the (a) model type, and (b) optimization method
        # log_p_z + log_p - log_q = log(p(z_i)p(x|z_i)/q(z_i|x)) = log(p(x,z_i)/q(z_i|x)) = L_VI
        #   (for each importance sample i out of K for each observation)
        if model_type == 'iwae' or test:
            # Re-order the entries so that each row holds the K importance samples for each observation
            log_w_matrix = (log_p_z + log_p - log_q).view(B, K)

        elif model_type == 'vae':
            # Don't reorder, and divide by K in anticipation of taking a batch sum of (1/K)*SUM(log(p(x,z)/q(z|x)))
            log_w_matrix = (log_p_z + log_p - log_q).view(B * K, 1) * 1 / K

        elif model_type == 'general_alpha' or model_type == 'vralpha':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Multiply by (1-alpha) because (1-alpha)* log(p(x,z_i)/q(z_i|x)) =  log([p(x,z_i)/q(z_i|x)]^(1-alpha))
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K) * (1 - alpha)

        elif model_type == 'vrmax':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Take the max in each row, representing the maximum-weighted sample
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K).max(
                axis=1, keepdim=True).values

            # immediately return loss = -sum(L_alpha) over each observation
            return -torch.sum(log_w_matrix)

        # Begin using the "max trick". Subtract the maximum log(*) sample value for each observation.
        # log_w_minus_max = log([p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]))
        log_w_minus_max = log_w_matrix - torch.max(
            log_w_matrix, 1, keepdim=True)[0]

        # Exponentiate so that each term is [p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]) (no log)
        ws_matrix = torch.exp(log_w_minus_max)

        # Calculate normalized weights in each row. Max denominators cancel out!
        # ws_norm = [p(z_i,x)/q(z_i|x)]/SUM([p(z_k,x)/q(z_k|x)])
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if model_type == 'vralpha' and not test:
            # If we're specifically using a VR-alpha model, we want to choose a sample to backprop according to the values in ws_norm above
            # So we make a distribution in each row
            sample_dist = Multinomial(1, ws_norm)

            # Then we choose a sample in each row acccording to this distribution
            ws_sum_per_datapoint = log_w_matrix.gather(
                1,
                sample_dist.sample().argmax(1, keepdim=True))
        else:
            # For any other model, we're taking the full sum at this point
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if model_type in ["general_alpha", "vralpha"] and not test:
            # For both VR-alpha and directly estimating L_alpha with a sum, we have to renormalize the sum with 1-alpha
            ws_sum_per_datapoint /= (1 - alpha)

        # Return a value of loss = -L_alpha as the batch sum.
        loss = -torch.sum(ws_sum_per_datapoint)

        return loss
Beispiel #20
0
    def compute_loss_for_batch(self, data, model, K=K, test=False):
        B, _, H, W = data.shape

        # First repeat the observations K times, representing the data as a flat (M*K, # of pixels)
        data_k_vec = data.repeat((1, K, 1, 1)).view(-1, H * W)

        # Encode the model and retrieve estimated distribution parameters mu and log(standard deviation) for each sample of each observation
        # z1 holds the latent samples generated at the first stochastic layer.
        mu, log_std, [x, z1] = self.encode(data_k_vec)

        # Sample from each observation's approximated latent distribution in each row (i.e. once for each of K importance samples, represented by rows)
        # (this uses the reparametrization trick!)
        z = model.reparameterize(mu, log_std)

        # Calculate Log p(z) (prior) - how likely are these values given the prior assumption N(0,1)?
        log_p_z = torch.sum(
            -0.5 * z**2, 1) - .5 * z.shape[1] * T.log(torch.tensor(2 * np.pi))

        # Calculate q (z | h1) - how likely are the generated output latent samples given the distributions they came from?
        log_qz_h1 = compute_log_probabitility_gaussian(z, mu, log_std)

        # Re-Generate the mu and log_std that generated the first-layer latents z1
        h1 = torch.tanh(self.fc1(x))
        h2 = torch.tanh(self.fc2(h1))
        mu, log_std = self.fc31(h2), self.fc32(h2)

        # Calculate log q(h1|x) - how likely are the first-stochastic-layer latents given the distributions they come from?
        log_qh1_x = compute_log_probabitility_gaussian(z1, mu, log_std)

        # Calculate the distribution parameters that generated the first-layer latents upon decoding
        h5 = torch.tanh(self.fc7(z))
        h6 = torch.tanh(self.fc8(h5))
        mu, log_std = self.fc81(h6), self.fc82(h6)

        # Calculate log p(h1|z) - how likely are the latents z1 under the parameters of the distribution here?
        #   (This directly encourages the decoder to learn the inverse of the map h1->z)
        log_ph1_z = compute_log_probabitility_gaussian(z1, mu, log_std)

        # Finally calculate the reconstructed image
        h7 = torch.tanh(self.fc9(z1))
        h8 = torch.tanh(self.fc10(h7))
        decoded = torch.sigmoid(self.fc11(h8))

        # calculate log p(x | h1) - how likely is the reconstruction given the latent samples that generated it?
        log_px_h1 = compute_log_probabitility_bernoulli(decoded, x)

        # Begin calculating L_alpha depending on the (a) model type, and (b) optimization method
        # log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x =
        #           log([p(z0_i)p(x|z1_i)p(z1_i|z0_i)]/[q(z0_i|z1_i)q(z1_i|x)]) = log(p(x,z0_i,z1_i)/q(z0_i,z1_i|x)) = L_VI
        #   (for each importance sample i out of K for each observation)
        # Note that if test==True then we're always using the IWAE objective!
        if model_type == 'iwae' or test == True:
            # Re-order the entries so that each row holds the K importance samples for each observation
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 -
                            log_qh1_x).view(-1, K)

        elif model_type == 'vae':
            # Don't reorder, and divide by K in anticipation of taking a batch sum of (1/K)*SUM(log(p(x,z)/q(z|x)))
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 -
                            log_qh1_x).view(-1, 1) * 1 / K
            return -torch.sum(log_w_matrix)

        elif model_type == 'general_alpha' or model_type == 'vralpha':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Multiply by (1-alpha) because (1-alpha)* log(p(x,z_i)/q(z_i|x)) =  log([p(x,z_i)/q(z_i|x)]^(1-alpha))
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 -
                            log_qh1_x).view(-1, K) * (1 - self.alpha)

        elif model_type == 'vrmax':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Take the max in each row, representing the maximum-weighted sample, then immediately return batch sum loss -L_alpha
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 -
                            log_qh1_x).view(-1, K).max(axis=1,
                                                       keepdim=True).values
            return -torch.sum(log_w_matrix)

        # Begin using the "max trick". Subtract the maximum log(*) sample value for each observation.
        # log_w_minus_max = log([p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]))
        log_w_minus_max = log_w_matrix - torch.max(
            log_w_matrix, 1, keepdim=True)[0]

        # Exponentiate so that each term is [p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]) (no log)
        ws_matrix = torch.exp(log_w_minus_max)

        # Calculate normalized weights in each row. Max denominators cancel out!
        # ws_norm = [p(z_i,x)/q(z_i|x)]/SUM([p(z_k,x)/q(z_k|x)])
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if model_type == 'vralpha' and not test:
            # If we're specifically using a VR-alpha model, we want to choose a sample to backprop according to the values in ws_norm above
            # So we make a distribution in each row
            sample_dist = Multinomial(1, ws_norm)

            # Then we choose a sample in each row acccording to this distribution
            ws_sum_per_datapoint = log_w_matrix.gather(
                1,
                sample_dist.sample().argmax(1, keepdim=True))
        else:
            # For any other model, we're taking the full sum at this point
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if model_type in ["general_alpha", "vralpha"] and not test:
            # For both VR-alpha and directly estimating L_alpha with a sum, we have to renormalize the sum with 1-alpha
            ws_sum_per_datapoint /= (1 - alpha)

        loss = -torch.sum(ws_sum_per_datapoint)

        return loss
Beispiel #21
0
    def forward(self,
                inputs,
                padding_mask=None,
                commitment_cost=None,
                temp=None):
        device = inputs.device

        if commitment_cost is None:
            commitment_cost = self.commitment_cost

        if temp is None:
            temp = self.temp

        # inputs can be:
        # 2-dimensional [B x E]         (already flattened)
        # 3-dimensional [B x T x E]     (e.g., batch of sentences)
        # 4-dimensional [B x S x T x E] (e.g., batch of documents)
        input_shape = inputs.size()
        input_dims = inputs.dim()

        # Flatten input
        flat_input = inputs.reshape(-1, self.d_model)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) +
                     torch.sum(self.codebook.weight**2, dim=1) -
                     2 * torch.matmul(flat_input, self.codebook.weight.t()))

        # TODO: generalize this to any padding_idx
        if padding_mask is not None:
            # no normal token gets mapped to discrete value 0
            distances[:, 0][~padding_mask.reshape(-1)] = np.inf
            # all pad tokens get mapped to discrete value 0
            distances[:, 1:][padding_mask.reshape(-1, 1).expand(
                -1, self.codebook_size - 1)] = np.inf

        # Define multinomial distribution and sample from it
        multi = Multinomial(total_count=self.num_samples,
                            logits=-distances / temp)
        samples = multi.sample().to(device)

        # Soft-quantize and unflatten
        quantized = torch.matmul(
            samples, self.codebook.weight).view(input_shape) / self.num_samples

        # Loss
        if padding_mask is not None:
            num_nonpad_elements = torch.sum(~padding_mask) * self.d_model
            e_latent_loss = torch.sum(
                (quantized.detach() - inputs)**2) / num_nonpad_elements
        else:
            e_latent_loss = torch.mean((quantized.detach() - inputs)**2)
        loss = commitment_cost * e_latent_loss

        # Use EMA to update the embedding vectors
        if self.training:
            if self.discard_ema_cluster_sizes:
                self._ema_cluster_size = torch.sum(samples,
                                                   0) / self.num_samples
                self.discard_ema_cluster_sizes = False
            else:
                self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                         (1 - self._decay) * \
                                         (torch.sum(samples, 0) / self.num_samples)

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon) /
                (n + self.codebook_size * self._epsilon) * n)

            dw = torch.matmul(samples.t(), flat_input) / self.num_samples
            self._ema_w = nn.Parameter(self._ema_w * self._decay +
                                       (1 - self._decay) * dw)

            normalized_ema_w = self._ema_w / self._ema_cluster_size.unsqueeze(
                1)
            if self.padding_idx is not None:
                normalized_ema_w[self.padding_idx] = 0
            self.codebook.weight = nn.Parameter(normalized_ema_w)

        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(samples, dim=0) / self.num_samples
        perplexity = torch.exp(-torch.sum(avg_probs *
                                          torch.log(avg_probs + 1e-10)))

        samples = samples.reshape(
            list(input_shape[:input_dims - 1]) + [self.codebook_size])

        return quantized, samples, loss, perplexity
Beispiel #22
0
    """
    arms = 3
    trials = 1
    pulls = 1

    mu = torch.empty(arms, dtype=torch.float).uniform_(1, 3)
    sigma = torch.empty(arms, dtype=torch.float).uniform_(0, 1)
    distrib = Normal(mu, sigma)

    total_actions = torch.zeros(arms)
    total_rewards = torch.zeros(arms)

    for t in range(10):
        outcomes = distrib.sample(torch.Size((trials, )))

        uniform_probs = torch.softmax(torch.ones(arms), dim=0)
        uniform_random_action_distribution = Multinomial(pulls,
                                                         probs=uniform_probs)
        actions = uniform_random_action_distribution.sample(
            torch.Size((trials, )))

        reward = actions * outcomes

        total_rewards = total_rewards + reward
        total_actions = total_actions + actions
        greed_score = total_rewards / total_actions
        greed_score[torch.isnan(greed_score)] = 0
        greed_probs = torch.softmax(greed_score, dim=1)
        pass
    #print('reward for action %i was %f' % (action, reward.item()))