예제 #1
0
    def forward(self, dependency_tree, input_seq):

        embedded_seq = []

        for elt in input_seq:
            embedded_seq.append(
                self.embedding(Variable(torch.LongTensor([elt]))).unsqueeze(0))

        current_level_c = []
        current_level_h = []

        first = True
        for level in dependency_tree:
            next_level_c = []
            next_level_h = []

            for node in level:

                if len(node[1]) == 1 and first == True:
                    x_j = embedded_seq[node[0]]
                    h_tilde_j = Variable(torch.zeros(self.hidden_size))
                    sum_term_cj = Variable(torch.zeros(self.hidden_size))

                else:
                    x_j = embedded_seq[node[0]]
                    h_tilde_j = Variable(torch.zeros(self.hidden_size))
                    sum_term_cj = Variable(torch.zeros(self.hidden_size))

                    for child in node[1]:
                        h_k = current_level_h[child]
                        c_k = current_level_c[child]
                        f_jk = torch.sigmoid(self.w_f(x_j).add(self.u_f(h_k)))
                        h_tilde_j = h_tilde_j.add(h_k)
                        sum_term_cj = sum_term_cj.add(torch.mul(f_jk, c_k))

                sum_term_cj = sum_term_cj.reshape(-1)
                h_tilde_j = h_tilde_j.reshape(-1)
                i_j = torch.sigmoid(self.w_i(x_j).add(
                    self.u_i(h_tilde_j))).reshape(-1)
                o_j = torch.sigmoid(self.w_o(x_j).add(
                    self.u_o(h_tilde_j))).reshape(-1)
                u_j = torch.tanh(self.w_u(x_j).add(
                    self.u_u(h_tilde_j))).reshape(-1)

                c_j = torch.mul(i_j, u_j).add(sum_term_cj)
                h_j = torch.mul(o_j, torch.tanh(c_j))

                next_level_c.append(c_j)
                next_level_h.append(h_j)

            first = False

            current_level_c = next_level_c
            current_level_h = next_level_h

        return current_level_h, current_level_c
예제 #2
0
  def _pad(self, sentences, pad_id, volatile=False, 
           raml=False, raml_tau=1., vocab_size=None):
    """Pad all instances in [data] to the longest length.

    Args:
      sentences: list of [batch_size] lists.

    Returns:
      padded_sentences: Variable of size [batch_size, max_len], the sentences.
      mask: Variable of size [batch_size, max_len]. 1 means to ignore.
      pos_emb_indices: Variable of size [batch_size, max_len]. indices to use
        when computing positional embedding.
      sum_len: total words
    """

    lengths = [len(sentence) for sentence in sentences]
    sum_len = sum(lengths)
    max_len = max(lengths)
    padded_sentences = [
      sentence + ([pad_id] * (max_len - len(sentence)))
      for sentence in sentences]
    mask = [
      ([0] * len(sentence)) + ([1] * (max_len - len(sentence)))
      for sentence in sentences]

    padded_sentences = Variable(torch.LongTensor(padded_sentences))
    mask = torch.ByteTensor(mask)
    #l = Variable(torch.FloatTensor(lengths))

    if self.hparams.cuda:
      padded_sentences = padded_sentences.cuda()
      mask = mask.cuda()

    if not raml:
      return padded_sentences, mask, lengths, sum_len

    assert vocab_size is not None
    logits = torch.arange(max_len)
    if self.hparams.cuda:
      logits = logits.cuda()
    probs = self.softmax(logits.mul_(raml_tau))
    num_words = torch.distributions.Categorical(probs).sample()

    lengths = torch.FloatTensor(lengths)
    if self.hparams.cuda:
      lengths = lengths.cuda()
    corrupt_pos = num_words.data.float().div_(lenngths).unsqueeze(1).expand_as(padded_sentences).contiguous().masked_fill_(mask, -self.hparams.inf)
    corrupt_pos = torch.bernoulli(corrupt_pos, out=corrupt_pos).byte()
    total_words = int(corrupt_pos.sum())
    
    corrupt_val = torch.LongTensor(total_words)
    corrupts = torch.zeros(batch_size, max_len).long()

    if self.hparams.cuda:
      corrupt_val = corrupt_val.long().cuda()
      corrupts = corrupts.cuda()
      corrupt_pos = corrupt_pos.cuda()
    corrupts = corrupts.masked_scatter_(corrupt_pos, corrupt_val)
    sample_sentences = padded_sentences.add(Variable(corrupts)).remainder_(vocab_size).masked_fill_(Variable(mask), pad_id)
    return sample_sentences, mask, lengths, sum_len
예제 #3
0
    def forward(self, h, Q, u):
        # h is mean, Q is var, u is action
        batch_size = h.size()[0]
        # torch chunk method: split into 2 chunks
        v, r = self.trans(h).chunk(2, dim=1)
        v1 = v.unsqueeze(2)
        rT = r.unsqueeze(1)
        I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1))
        if rT.data.is_cuda:
            I.dada.cuda()
        A = I.add(v1.bmm(rT))
        B = self.fc_B(h).view(-1, self.dim_z, self.dim_u)
        o = self.fc_o(h)

        # need to compute the parameters for distributions
        # as well as for the samples
        u = u.unsqueeze(2)
        # Q.mu.unsqueeze(2) (z^bar): torch.Size([128, 100, 1])
        d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(
            o.unsqueeze(2)).squeeze(2)
        # print("d", d.size()) # ([128, 100])
        sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(
            o.unsqueeze(2)).squeeze(2)

        return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r)
예제 #4
0
    def forward(self, h, Q, u):
        batch_size = h.size()[0]

        # computes the new basis vector for the embedded dynamics
        v, r = self.trans(h).chunk(2, dim=1)
        v1 = v.unsqueeze(2)
        rT = r.unsqueeze(1)
        I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1))
        if rT.data.is_cuda:
            I = I.to(rT.device)
        # A is batch_size X z_size X z_size
        A = I.add(v1.bmm(rT))
        # B is batch_size X z_size X input_size
        if self.dim_u is not 0:
            B = self.fc_B(h).view(-1, self.dim_z, self.dim_u)
            u = u.unsqueeze(2)
        # o (constant terms) is batch_size X z_size
        o = self.fc_o(h).unsqueeze(2)
        # need to compute the parameters for distributions
        # as well as for the samples

        if self.dim_u is not 0:
            d = A.bmm(Q.mean.float().unsqueeze(2)).add(
                B.bmm(u)).add(o).squeeze(2)
            sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
        else:
            d = A.bmm(Q.mean.float().unsqueeze(2)).add(o).squeeze(2)
            sample = A.bmm(h.unsqueeze(2)).add(o).squeeze(2)

        Qz_next_cov = A.double().bmm(Q.covariance_matrix.double()).bmm(
            A.double().transpose(1, 2))
        return sample, distributions.MultivariateNormal(
            d.double(), Qz_next_cov)
    def eval(self, values, grad=False, grad_loop=False):
        ''' Takes a map of variable names, to variable values '''
        assert(isinstance(values, Variable))
        ################## Start FOPPL input ##########
        values = Variable(values.data, requires_grad=True)
        a = VariableCast(0.0)
        b = VariableCast(1.0)

        normal_obj = dis.Normal(a,b)
        c = VariableCast(3.0)
        i = 0
        logp_x  = normal_obj.logpdf(values)
        grad1 = self.calc_grad(logp_x, values)
        while(i<10):
            normal_obj_while = dis.Normal(values, c )
            values = Variable(normal_obj_while.sample().data, requires_grad = True)
            i = i + 1
        logp_x_g_x = normal_obj_while.logpdf(values)
        grad2 = self.calc_grad(logp_x_g_x, values)
        gradients = grad1 + grad2

        logjoint = Variable.add(logp_x, logp_x_g_x)
        if grad:
            return gradients
        elif grad_loop:
            return logjoint, gradients
        else:
            return logjoint, values
    def eval(self, values, grad=False, grad_loop=False):
        ''' Takes a map of variable names, to variable values '''
        assert (isinstance(values, Variable))
        ################## Start FOPPL input ##########
        values = Variable(values.data, requires_grad=True)
        a = VariableCast(0.0)
        b = VariableCast(2)
        normal_obj1 = dis.Normal(a, b)
        # log of prior p(x)
        logp_x = normal_obj1.logpdf(values)
        # else:
        #     x = normal_object.sample()
        #     x = Variable(x.data, requires_grad = True)
        if torch.gt(values.data, torch.zeros(values.size()))[0][0]:
            y = VariableCast(5)
            normal_obj2 = dis.Normal(values + b, b)
            logp_y_x = normal_obj2.logpdf(y)
        else:
            y = VariableCast(-5)
            normal_obj3 = dis.Normal(values-b, b)
            logp_y_x = normal_obj3.logpdf(y)

        logjoint = Variable.add(logp_x, logp_y_x)
        if grad:
            gradients = self.calc_grad(logjoint, values)
            return gradients
        elif grad_loop:
            gradients = self.calc_grad(logjoint, values)
            return logjoint, gradients
        else:
            return logjoint, values
예제 #7
0
    def eval(self, values, grad=False, grad2=False):
        ''' Takes a map of variable names, to variable values '''
        a = VariableCast(0.0)
        b = VariableCast(1)
        c1 = VariableCast(-1)
        normal_obj1 = dis.Normal(a, b)
        values = Variable(values.data, requires_grad=True)
        logp_x = normal_obj1.logpdf(values)
        # else:
        #     x = normal_object.sample()
        #     x = Variable(x.data, requires_grad = True)
        if torch.gt(values.data, torch.zeros(values.size()))[0][0]:
            y = VariableCast(1)
            normal_obj2 = dis.Normal(b, b)
            logp_y_x = normal_obj2.logpdf(y)
        else:
            y = VariableCast(1)
            normal_obj3 = dis.Normal(c1, b)
            logp_y_x = normal_obj3.logpdf(y)

        logjoint = Variable.add(logp_x, logp_y_x)
        if grad:
            gradients = self.calc_grad(logjoint, values)
            return VariableCast(gradients)
        elif grad2:
            gradients = self.calc_grad(logjoint, values)
            return logjoint, VariableCast(gradients)
        else:
            return logjoint, values
예제 #8
0
    def universal(self, args):
        self.set_mode('eval')

        init = False

        correct = 0
        cost = 0
        total = 0

        data_loader = self.data_loader['test']
        for e in range(100000):
            for batch_idx, (images, labels) in enumerate(data_loader):

                x = Variable(cuda(images, self.cuda))
                y = Variable(cuda(labels, self.cuda))

                if not init:
                    sz = x.size()[1:]
                    r = torch.zeros(sz)
                    r = Variable(cuda(r, self.cuda), requires_grad=True)
                    init = True

                logit = self.net(x+r)
                p_ygx = F.softmax(logit, dim=1)
                H_ygx = (-p_ygx*torch.log(self.eps+p_ygx)).sum(1).mean(0)
                prediction_cost = H_ygx
                #prediction_cost = F.cross_entropy(logit,y)
                #perceptual_cost = -F.l1_loss(x+r,x)
                #perceptual_cost = -F.mse_loss(x+r,x)
                #perceptual_cost = -F.mse_loss(x+r,x) -r.norm()
                perceptual_cost = -F.mse_loss(x+r, x) -F.relu(r.norm()-5)
                #perceptual_cost = -F.relu(r.norm()-5.)
                #if perceptual_cost.data[0] < 10: perceptual_cost.data.fill_(0)
                cost = prediction_cost + perceptual_cost
                #cost = prediction_cost

                self.net.zero_grad()
                if r.grad:
                    r.grad.fill_(0)
                cost.backward()

                #r = r + args.eps*r.grad.sign()
                r = r + r.grad*1e-1
                r = Variable(cuda(r.data, self.cuda), requires_grad=True)



                prediction = logit.max(1)[1]
                correct = torch.eq(prediction, y).float().mean().data[0]
                if batch_idx % 100 == 0:
                    if self.visdom:
                        self.vf.imshow_multi(x.add(r).data)
                        #self.vf.imshow_multi(r.unsqueeze(0).data,factor=4)
                    print(correct*100, prediction_cost.data[0], perceptual_cost.data[0],\
                            r.norm().data[0])

        self.set_mode('train')
예제 #9
0
    def forward_single_graph(self, graph):
        '''
        Args:
            graph: Graph object
            vtx_features: dict of dicts
            IE: vtx_features[l][v] gives a torch.Tensor of the vertex v's representation at level l

            Compute the vertex representations at each level for each vertice in the graph and
            return the graph representation of this graph
        '''

        vtx_features = self.init_base_features(graph)
        rfields = compute_receptive_fields(graph, self.lvls+1)

        for lvl in range(1, self.lvls+1):
            for v in graph.vertices:
                v_rfield = rfields[lvl][v] # receptive field of vertex v at level lvl
                v_nbrs = graph.neighborhood(v, 1)
                k = len(v_rfield)
                n = len(v_nbrs)
                if n == 0:
                    # isolated vertex doesnt contribute to the graph representation
                    continue
                in_channels = self.w_sizes[lvl]['in']
                out_channels = self.w_sizes[lvl]['out']
                aggregate = Variable(torch.zeros((n, in_channels, k, k)), requires_grad=False)
                reduced_adj_mat = Variable(torch.Tensor(graph.sub_adj(v_rfield)), requires_grad=False)

                for index, w in enumerate(v_nbrs):
                    w_rfield_prev = rfields[lvl-1][w] # receptive field of vertex w
                    chi = Variable(chi_matrix(v_rfield, w_rfield_prev), requires_grad=False)
                    nbr_feat = vtx_features[lvl-1][w] # should be a Variable already
                    aggregate[index] = aggregate[index].add(chi.matmul(nbr_feat).matmul(chi.t()))

                try:
                    aggregate = aggregate.sum(dim=0) # collapse on the neighbors
                except:
                    pdb.set_trace()
                aggregate = aggregate.add(self.adj_param(lvl) * reduced_adj_mat)
                # Before reshaping aggregate has shape (k, k, in_channels), where k is the size
                # of the receptive field of vertex v.

                # After mixing channels via the w matrix, we get a tensor
                # of shape (out_channels, k*k)
                new_features = self.nonlinearity(self.linear_transform(lvl, aggregate.view(k*k, -1)))

                # new features will be of size k*k, new_channels
                # Reshape it to be of size (k, k, out_channels)
                vtx_features[lvl][v] = new_features.view(out_channels, k, k)

        graph_repr = self.collapse_vtx_features(vtx_features)
        return graph_repr
예제 #10
0
    def forward(self, q, sf0, sf1, sf2, sf3, sf4, sf5, sf6, sf7):

        q = self.do(self.embedding(q).permute(1, 0, 2))

        sf0 = self.do(self.embedding(sf0).permute(1, 0, 2))
        sf1 = self.do(self.embedding(sf1).permute(1, 0, 2))
        sf2 = self.do(self.embedding(sf2).permute(1, 0, 2))
        sf3 = self.do(self.embedding(sf3).permute(1, 0, 2))
        sf4 = self.do(self.embedding(sf4).permute(1, 0, 2))
        sf5 = self.do(self.embedding(sf5).permute(1, 0, 2))
        sf6 = self.do(self.embedding(sf6).permute(1, 0, 2))
        sf7 = self.do(self.embedding(sf7).permute(1, 0, 2))

        _, (emb_q, _) = self.q_rnn(q)

        _, (emb_sf0, _) = self.sf_rnn(sf0)
        _, (emb_sf1, _) = self.sf_rnn(sf1)
        _, (emb_sf2, _) = self.sf_rnn(sf2)
        _, (emb_sf3, _) = self.sf_rnn(sf3)
        _, (emb_sf4, _) = self.sf_rnn(sf4)
        _, (emb_sf5, _) = self.sf_rnn(sf5)
        _, (emb_sf6, _) = self.sf_rnn(sf6)
        _, (emb_sf7, _) = self.sf_rnn(sf7)

        emb_sfs = [
            emb_sf0, emb_sf1, emb_sf2, emb_sf3, emb_sf4, emb_sf5, emb_sf6,
            emb_sf7
        ]

        g_o = Variable(torch.zeros(32, 256))

        if torch.cuda.is_available():
            g_o = g_o.cuda()

        emb_q = emb_q.squeeze(0)

        for i, _ in enumerate(emb_sfs):
            o_i = emb_sfs[i].squeeze(0)
            pos = Variable(torch.FloatTensor(o_i.shape[0], 1).fill_(i))
            if torch.cuda.is_available():
                pos = pos.cuda()
            x = self.do(F.relu(self.g1(torch.cat((o_i, emb_q, pos), dim=1))))
            x = self.do(F.relu(self.g2(x)))
            x = self.do(F.relu(self.g3(x)))
            x = self.do(F.relu(self.g4(x)))
            g_o = g_o.add(x)

        x = self.do(F.relu(self.f1(g_o)))
        x = self.do(F.relu(self.f2(x)))
        x = self.f3(x)

        return x
예제 #11
0
    def collapse_vtx_features(self, vtx_features):
        '''
        Args:
            vtx_features: a dict of dicts
            IE: vtx_features[l][v] gives a torch.Tensor of the vertex v's representation at level l
        Collapse the final level vertex features to nchannels(num channelsof final layer)
        Sum these from all vertices in the graph
        '''
        graph_repr = Variable(torch.zeros(self.w_sizes[self.lvls]['out']),
                              requires_grad=True)

        for v, vtx_repr in vtx_features[self.lvls].items():
            graph_repr = graph_repr.add(vtx_repr.sum(-1).sum(-1))

        return graph_repr
예제 #12
0
 def eval(self, values):
     ''' Takes a map of variable names, to variable values '''
     a = VariableCast(1.0)
     b = VariableCast(1.41)
     normal_object = Normal(a, b)
     if values['x'] is not None:
         x = Variable(values['x'], requires_grad=True)
     # else:
     #     x = normal_object.sample()
     #     x = Variable(x.data, requires_grad = True)
     logp_x = normal_object.logpdf(x)
     std = VariableCast(1.73)
     p_y_g_x = Normal(x, std)
     obs2 = VariableCast(7.0)
     logp_y_g_x = p_y_g_x.logpdf(obs2)
     logp_x_y = Variable.add(logp_x, logp_y_g_x)
     return logp_x_y, {'x': x.data}
예제 #13
0
    def transition(self, h, Q, u):
        batch_size = h.size()[0]
        v, r = self.trans(h).chunk(2, dim=1)
        v1 = v.unsqueeze(2)
        rT = r.unsqueeze(1)
        I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1))
        if rT.data.is_cuda:
            I.dada.cuda()
        A = I.add(v1.bmm(rT))

        B = self.fc_B(h).view(-1, self.dim_z, self.dim_u)
        o = self.fc_o(h).reshape((-1, self.dim_z, 1))

        u = u.unsqueeze(2)

        d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
        sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
        return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r)
예제 #14
0
def loss_function(recon_xs, x, mu, logvar, y_class, categorical, epoch):
    BCE = 0
    for recon_x in recon_xs:
        # BCE += reconstruction_function(recon_x, x) * max(math.sqrt(model.z_size / float(epoch ** 2)), 1.)
        BCE += reconstruction_function(recon_x, x) * math.sqrt(model.z_size)
    # BCE /= len(recon_xs)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    # 0.5 * sum ((mu - mu_clusters) * 1/sigma_clusters * (mu - mu-clusters) + 1/sigma_clusters * logVar.exp() + logVar - log det sigma_clusters)
    y_class = torch.from_numpy(y_class).type(torch.FloatTensor)
    mu2 = Variable(y_class.matmul(mu_clusters))
    sigma2 = Variable(y_class.matmul(sigma_clusters))
    theta2 = Variable(y_class.matmul(theta_clusters))

    KLD_element = mu.sub(mu2).pow(2).div(sigma2)
    # print(logvar)
    # print(torch.sum(theta2.mul(theta2.log()),dim=1).view((args.batch_size,-1)).expand_as(logvar))
    KLD_element = KLD_element.add_(
        logvar.exp().div(sigma2)).mul_(-1).add_(1).add_(logvar).mul(
            (-torch.sum(theta2.mul(theta2.log()), dim=1).view(
                (logvar.size()[0], -1)).expand_as(logvar)
             ))  #.sub_(sigma2.log()) #/ (2*model.z_size)
    KLD = torch.sum(KLD_element).mul_(-0.5) * model.z_size / 2.
    # KLD *= torch.sum(theta2.mul(theta2.log()),dim=1)

    # print(KLD)
    # KLD_element2 = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    # KLD2 = torch.sum(KLD_element2).mul_(-0.5) * model.z_size
    # print(KLD2)

    c = F.softmax(categorical)

    # KLD += kl_loss(categorical, theta2)
    KLD += torch.sum(c.mul(c.add(1e-9).log() - (theta2.add(1e-9)).log()))
    # KLD += kl_loss(theta2, categorical)

    # do something here

    return BCE + KLD
예제 #15
0
 def forward(self, input):
     batch_size = input.size(0)
     b = Variable(torch.zeros(batch_size, self.input_features,
                              self.output_features),
                  requires_grad=False)
     if input.is_cuda:
         b = b.cuda()
     input = input.unsqueeze(dim=2).expand(
         -1, -1, self.output_features,
         -1).contiguous().view(batch_size,
                               self.input_features * self.output_features,
                               self.input_feature_length, 1)
     hat_u = torch.matmul(self.weight,
                          input).view(batch_size, self.input_features,
                                      self.output_features,
                                      self.output_feature_length)
     for r in range(self.routing_iterators):
         c = F.softmax(b.view(batch_size, self.input_features,
                              self.output_features),
                       dim=-1).view(batch_size, self.input_features,
                                    self.output_features)
         hat_u_ = torch.mul(
             c.view(-1, 1).expand(-1, self.output_feature_length),
             hat_u.view(-1, self.output_feature_length)).view(
                 batch_size, self.input_features, self.output_features,
                 self.output_feature_length)
         s = torch.sum(hat_u_, dim=1)
         s = s.view(batch_size * self.output_features,
                    self.output_feature_length)
         s_norm = torch.norm(s, p=2, dim=1)
         s_norm_ = torch.div(s_norm, s_norm.pow(2).add(1))
         v = torch.mul(
             s_norm_.view(-1, 1).expand(-1, self.output_feature_length), s)
         v = v.view(batch_size, self.output_features,
                    self.output_feature_length)
         if r == self.routing_iterators - 1:
             return v
         v = v.expand(self.input_features, batch_size, self.output_features,
                      self.output_feature_length).transpose(0, 1)
         b = b.add(torch.mul(hat_u, v).sum(-1))
     return v
예제 #16
0
    def forward(self, h, Q, u):
        batch_size = h.size()[0]
        v, r = self.trans(h).chunk(2, dim=1)
        v1 = v.unsqueeze(2)
        rT = r.unsqueeze(1)
        I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1))
        if rT.data.is_cuda:
            I.dada.cuda()
        A = I.add(v1.bmm(rT))

        B = self.fc_B(h).view(-1, self.dim_z, self.dim_u)
        o = self.fc_o(h)

        # need to compute the parameters for distributions
        # as well as for the samples
        u = u.unsqueeze(2)

        d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
        sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)

        return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r)
예제 #17
0
파일: validate.py 프로젝트: NaiveXu/Master
def validate(model, epoch, optimizer, test_loader, args, writer, reinforcement_learner, request_dict, accuracy_dict, episode):

    # Initialize training:
    model.eval()

    # Collect a random batch:
    image_batch, label_batch = test_loader.__iter__().__next__()

    # Episode Statistics:
    episode_correct = 0.0
    episode_predict = 0.0
    episode_request = 0.0
    episode_reward = 0.0
    episode_loss = 0.0

    # Create initial state:
    state = []
    label_dict = []
    for i in range(args.batch_size):
        state.append([0 for i in range(args.class_vector_size)])
        label_dict.append({})

    # Initialize model between each episode:
    hidden = model.reset_hidden(args.batch_size)

    # Statistics again:    
    for v in request_dict.values():
        v.append([])
    for v in accuracy_dict.values():
        v.append([])

    # Placeholder for loss Variable:
    if (args.cuda):
        loss = Variable(torch.zeros(1).type(torch.Tensor)).cuda()
    else:
        loss = Variable(torch.zeros(1).type(torch.Tensor))

    # EPISODE LOOP:
    for i_e in range(len(label_batch)):
        episode_labels = label_batch[i_e]
        episode_images = image_batch[i_e]

        # Tensoring the state:
        state = torch.FloatTensor(state)
        
        # Need to add image to the state vector:
        flat_images = episode_images.squeeze().view(args.batch_size, -1)

        # Concatenating possible labels/zero vector with image, to create the environment state:
        state = torch.cat((state, flat_images), 1)
        
        one_hot_labels = []
        for i in range(args.batch_size):
            true_label = episode_labels[i]

            # Creating one hot labels:
            one_hot_labels.append([1 if j == true_label else 0 for j in range(args.class_vector_size)])

            # Logging statistics:
            if (true_label not in label_dict[i]):
                label_dict[i][true_label] = 1
            else:
                label_dict[i][true_label] += 1

        # Selecting an action to perform (Epsilon Greedy):
        if (args.cuda):
            q_values, hidden = model(Variable(state, volatile=True).type(torch.FloatTensor).cuda(), hidden)
        else:
            q_values, hidden = model(Variable(state, volatile=True).type(torch.FloatTensor), hidden)

        # Choosing the largest Q-values:
        model_actions = q_values.data.max(1)[1].view(args.batch_size)

        # Performing Epsilon Greedy Exploration:
        agent_actions = model_actions
        
        # Collect rewards:
        rewards = reinforcement_learner.collect_reward_batch(agent_actions, one_hot_labels, args.batch_size)

        # Collecting average reward at time t:
        episode_reward += float(sum(rewards)/args.batch_size)

        # Just some statistics logging:
        for i in range(args.batch_size):

            true_label = episode_labels[i]

            # Statistics:
            reward = rewards[i]
            if (reward == reinforcement_learner.request_reward):
                episode_request += 1
                episode_predict += 1
                if (label_dict[i][true_label] in request_dict):
                    request_dict[label_dict[i][true_label]][-1].append(1)
                if (label_dict[i][true_label] in accuracy_dict):
                    accuracy_dict[label_dict[i][true_label]][-1].append(0)
            elif (reward == reinforcement_learner.prediction_reward):
                episode_correct += 1.0
                episode_predict += 1.0
                if (label_dict[i][true_label] in request_dict):
                    request_dict[label_dict[i][true_label]][-1].append(0)
                if (label_dict[i][true_label] in accuracy_dict):
                    accuracy_dict[label_dict[i][true_label]][-1].append(1)
            else:
                episode_predict += 1.0
                if (label_dict[i][true_label] in request_dict):
                    request_dict[label_dict[i][true_label]][-1].append(0)
                if (label_dict[i][true_label] in accuracy_dict):
                    accuracy_dict[label_dict[i][true_label]][-1].append(0)

        
        # Observe next state and images:
        next_state_start = reinforcement_learner.next_state_batch(agent_actions, one_hot_labels, args.batch_size)

        # Tensoring the reward:
        rewards = Variable(torch.Tensor([rewards]))

        # Need to collect the representative Q-values:
        agent_actions = Variable(torch.LongTensor(agent_actions)).unsqueeze(1)
        current_q_values = q_values.gather(1, agent_actions)

        # Non-final state:
        if (i_e < args.episode_size - 1):
            # Collect next image:
            next_flat_images = image_batch[i_e + 1].squeeze().view(args.batch_size, -1)

            # Create next state:
            next_state = torch.cat((torch.FloatTensor(next_state_start), next_flat_images), 1)

            # Get target value for next state:
            target_value = model(Variable(next_state, volatile=True), hidden)[0].max(1)[0]

            # Make it un-volatile again:
            target_value.volatile = False

            # Discounting the next state + reward collected in this state:
            discounted_target_value = (GAMMA*target_value) + rewards

        # Final state:
        else:
            # As there is no next state, we only have the rewards:
            discounted_target_value = rewards

        discounted_target_value = discounted_target_value.view(args.batch_size, -1)

        # Calculating Bellman error:
        bellman_loss = F.mse_loss(current_q_values, discounted_target_value)

        # Backprop:
        loss = loss.add(bellman_loss)
        
        # Update current state:
        state = next_state_start

        ### END TRAIN LOOP ###

    print("\n---Validation Statistics---\n")

    print("\n--- Epoch " + str(epoch) + ", Episode " + str(episode + i + 1) + " Statistics ---")
    print("Instance\tAccuracy\tRequests")       
    for key in accuracy_dict.keys():
        predictions = accuracy_dict[key][-1]
        requests = request_dict[key][-1]
        
        accuracy = float(sum(predictions)/len(predictions))
        request_percentage = float(sum(requests)/len(requests))
        
        print("Instance " + str(key) + ":\t" + str(100.0*accuracy)[0:4] + " %" + "\t\t" + str(100.0*request_percentage)[0:4] + " %")
    

    # Even more status update:
    print("\n+------------------STATISTICS----------------------+")
    total_prediction_accuracy = float((100.0 * episode_correct) / max(1, episode_predict-episode_request))
    print("Batch Average Prediction Accuracy = " + str(total_prediction_accuracy)[:5] +  " %")
    total_accuracy = float((100.0 * episode_correct) / episode_predict)
    print("Batch Average Accuracy = " + str(total_accuracy)[:5] +  " %")
    total_loss = loss.data[0]
    print("Batch Average Loss = " + str(total_loss)[:5])
    total_requests = float((100.0 * episode_request) / (args.batch_size*args.episode_size))
    print("Batch Average Requests = " + str(total_requests)[:5] + " %")
    total_reward = float(episode_reward)
    print("Batch Average Reward = " + str(total_reward)[:5])
    print("+--------------------------------------------------+\n")

    ### LOGGING TO TENSORBOARD ###
    data = {
        'validation_total_requests': total_requests,
        'validation_total_accuracy': total_accuracy,
        'validation_total_loss': total_loss,
        'validation_average_reward': total_reward
    }

    for tag, value in data.items():
        writer.scalar_summary(tag, value, epoch)
    ### DONE LOGGING ###

    return total_prediction_accuracy, total_requests, total_accuracy, total_reward, request_dict, accuracy_dict
예제 #18
0
for k in vars.keys():
    for kk in vars[k]:
        if kk == 'Parameters of interest':
            for i in vars[kk][k]:
                vars[kk][k] = []
                if isistance(i, Variable):
                    temp = Variable(i.data, requires_grad = True)
                else:
                    temp  = Variable(i, requires_grad = True)
                var[kk][k].append(temp)
        else:
            print() # Deal with distribution parameters here

c24039= VariableCast(1.0)
c24040= VariableCast(2.0)
x24041 = Normal(c24039, c24040)
x22542 = Variable(torch.Tensor([0.0]),  requires_grad = True)
# x22542.detach()
# x22542 = x24041.sample()   #sample
p24042 = x24041.logpdf( x22542)
c24043 = VariableCast(3.0)
x24044 = Normal(x22542, c24043)
c24045 = VariableCast(7.0)
y22543 = c24045
p24046 = x24044.logpdf( y22543)
p24047 = Variable.add(p24042,p24046)

print(x22542)
print(p24047)
grad_x22542 = torch.autograd.grad([p24047], [x22542] )
print('gradient of x22542 ', x22542)
예제 #19
0
    def write(self,
              lang_h,
              processed,
              max_words,
              temperature,
              loss,
              tgt,
              stop_tokens=STOP_TOKENS,
              resume=False):
        """Generate a sentence word by word and feed the output of the
        previous timestep as input to the next.
        """
        encoded_pad = self.word_dict.w2i(['<pad>'])
        btz_size = lang_h.size()[1]
        outs_btz = torch.LongTensor(size=[max_words, btz_size])
        # scores_loss = Variable(torch.FloatTensor(size=[max_words, btz_total, len(self.word_dict.idx2word)]))
        for btz in range(btz_size):
            tgt_btz = torch.LongTensor(
                [tgt[i] for i in range(btz, tgt.shape[0], btz_size)])
            processed_btz = processed[:, btz, :]
            outs, logprobs, lang_hs = [], [], []
            # remove batch dimension from the language and context hidden states
            if resume:
                inpt = None
            else:
                inpt = Variable(torch.LongTensor(1))
                inpt.data.fill_(self.word_dict.get_idx('Hi'))
                inpt = self.to_device(inpt)
            # generate words until max_words have been generated or <eos>
            for word_idx in range(max_words):
                if inpt is not None:
                    # add the context to the word embedding
                    inpt_emb = torch.cat(
                        [self.word_encoder(inpt), processed_btz], 1)
                    # update RNN state with last word
                    lang_h[:, btz, :] = self.writer(inpt_emb, lang_h[:,
                                                                     btz, :])
                    lang_hs.append(lang_h[:, btz, :])
                # decode words using the inverse of the word embedding matrix

                out = self.decoder(lang_h[:, btz, :])
                scores = Variable(
                    F.linear(out, self.word_encoder.weight).div(temperature))
                # subtract constant to avoid overflows in exponentiation
                scores = Variable(scores.add(-scores.max().item()).squeeze(0))
                # disable special tokens from being generated in a normal turns
                if not resume:
                    mask = Variable(self.special_token_mask)
                    scores = scores.add(mask)
                prob = F.softmax(scores, dim=0)
                logprob = F.log_softmax(scores, dim=0)

                # explicitly defining num_samples for pytorch 0.4.1
                word = prob.multinomial(num_samples=1).detach()
                # logprob = logprob.gather(0, word)

                # logprobs.append(logprob)

                outs.append(word.view(word.size()[0], 1))
                if self.mode == 'train_em':
                    loss += self.crit(scores.view(-1, len(self.word_dict)),
                                      tgt_btz[word_idx].unsqueeze(0))
                    inpt = tgt_btz[word_idx].unsqueeze(0)
                else:
                    inpt = word
                # scores_loss[word_idx, btz] = Variable(scores)

                # check if we generated an <eos> token
                if self.word_dict.get_word(word.data[0]) in stop_tokens:
                    break
            # update the hidden state with the <eos> token
            inpt_emb = torch.cat([self.word_encoder(inpt), processed_btz], 1)
            lang_h[:, btz, :] = self.writer(inpt_emb, lang_h[:, btz, :])
            lang_hs.append(lang_h[:, btz, :])

            # add batch dimension back
            # lang_h = lang_h.unsqueeze(1)
            if len(outs) < max_words:
                outs = [outs[i].item() for i in range(len(outs))]
                outs = outs + encoded_pad * (max_words - len(outs))
                outs = [torch.LongTensor([[i]]) for i in outs]
            outs_btz[:, btz] = torch.cat(outs).squeeze(1)
            lang_hs += [torch.cat(lang_hs)]
        return outs_btz, lang_h, lang_hs, loss
def train(args, lvae):
    best_val_loss = 1000
    # train
    for epoch in range(args.epochs):
        lvae.train()
        print("Training... Epoch = %d" % epoch)
        correct_train = 0
        open('lvae%d/train_fea.txt' % args.lamda, 'w').close()
        open('lvae%d/train_tar.txt' % args.lamda, 'w').close()
        open('lvae%d/train_rec.txt' % args.lamda, 'w').close()
        if epoch in decreasing_lr:
            optimizer.param_groups[0]['lr'] *= 0.1
            print("~~~learning rate:", optimizer.param_groups[0]['lr'])
        for batch_idx, (data, target) in enumerate(train_loader):
            target_en = torch.Tensor(target.shape[0], args.num_classes)
            target_en.zero_()
            target_en.scatter_(1, target.view(-1, 1), 1)  # one-hot encoding
            target_en = target_en.to(device)
            if args.cuda:
                data = data.cuda()
                target = target.cuda()
            data, target = Variable(data), Variable(target)

            loss, mu, output, output_mu, x_re, rec, kl, ce = lvae.loss(
                data, target, target_en, next(beta), args.lamda)
            rec_loss = (x_re - data).pow(2).sum((3, 2, 1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            outlabel = output.data.max(1)[
                1]  # get the index of the max log-probability
            correct_train += outlabel.eq(target.view_as(outlabel)).sum().item()

            cor_fea = mu[(outlabel == target)]
            cor_tar = target[(outlabel == target)]
            cor_fea = torch.Tensor.cpu(cor_fea).detach().numpy()
            cor_tar = torch.Tensor.cpu(cor_tar).detach().numpy()
            rec_loss = torch.Tensor.cpu(rec_loss).detach().numpy()
            with open('lvae%d/train_fea.txt' % args.lamda, 'ab') as f:
                np.savetxt(f, cor_fea, fmt='%f', delimiter=' ', newline='\r')
                f.write(b'\n')
            with open('lvae%d/train_tar.txt' % args.lamda, 'ab') as t:
                np.savetxt(t, cor_tar, fmt='%d', delimiter=' ', newline='\r')
                t.write(b'\n')
            with open('lvae%d/train_rec.txt' % args.lamda, 'ab') as m:
                np.savetxt(m, rec_loss, fmt='%f', delimiter=' ', newline='\r')
                m.write(b'\n')

            if batch_idx % args.log_interval == 0:
                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)] train_batch_loss: {:.6f}={:.6f}+{:.6f}+{:.6f}'
                    .format(
                        epoch, batch_idx * len(data),
                        len(train_loader.dataset), 100. * batch_idx *
                        len(data) / len(train_loader.dataset),
                        loss.data / (len(data)), rec.data / (len(data)),
                        kl.data / (len(data)), ce.data / (len(data))))

        train_acc = float(100 * correct_train) / len(train_loader.dataset)
        print('Train_Acc: {}/{} ({:.2f}%)'.format(correct_train,
                                                  len(train_loader.dataset),
                                                  train_acc))
        # val
        if epoch % args.val_interval == 0 and epoch >= 0:
            lvae.eval()
            correct_val = 0
            total_val_loss = 0
            total_val_rec = 0
            total_val_kl = 0
            total_val_ce = 0
            for data_val, target_val in val_loader:
                target_val_en = torch.Tensor(target_val.shape[0],
                                             args.num_classes)
                target_val_en.zero_()
                target_val_en.scatter_(1, target_val.view(-1, 1),
                                       1)  # one-hot encoding
                target_val_en = target_val_en.to(device)
                if args.cuda:
                    data_val, target_val = data_val.cuda(), target_val.cuda()
                with torch.no_grad():
                    data_val, target_val = Variable(data_val), Variable(
                        target_val)

                loss_val, mu_val, output_val, output_mu_val, val_re, rec_val, kl_val, ce_val = lvae.loss(
                    data_val, target_val, target_val_en, next(beta),
                    args.lamda)
                total_val_loss += loss_val.data.detach().item()
                total_val_rec += rec_val.data.detach().item()
                total_val_kl += kl_val.data.detach().item()
                total_val_ce += ce_val.data.detach().item()

                vallabel = output_val.data.max(1)[
                    1]  # get the index of the max log-probability
                correct_val += vallabel.eq(
                    target_val.view_as(vallabel)).sum().item()

            val_loss = total_val_loss / len(val_loader.dataset)
            val_rec = total_val_rec / len(val_loader.dataset)
            val_kl = total_val_kl / len(val_loader.dataset)
            val_ce = total_val_ce / len(val_loader.dataset)
            print(
                '====> Epoch: {} Val loss: {}/{} ({:.4f}={:.4f}+{:.4f}+{:.4f})'
                .format(epoch, total_val_loss, len(val_loader.dataset),
                        val_loss, val_rec, val_kl, val_ce))
            val_acc = float(100 * correct_val) / len(val_loader.dataset)
            print('Val_Acc: {}/{} ({:.2f}%)'.format(correct_val,
                                                    len(val_loader.dataset),
                                                    val_acc))

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_val_epoch = epoch
                train_fea = np.loadtxt('lvae%d/train_fea.txt' % args.lamda)
                train_tar = np.loadtxt('lvae%d/train_tar.txt' % args.lamda)
                train_rec = np.loadtxt('lvae%d/train_rec.txt' % args.lamda)
                print('!!!Best Val Epoch: {}, Best Val Loss:{:.4f}'.format(
                    best_val_epoch, best_val_loss))
                #torch.save(lvae, 'lvae%d.pt' % args.lamda)
                # test
                open('lvae%d/omn_fea.txt' % args.lamda, 'w').close()
                open('lvae%d/omn_tar.txt' % args.lamda, 'w').close()
                open('lvae%d/omn_pre.txt' % args.lamda, 'w').close()
                open('lvae%d/omn_rec.txt' % args.lamda, 'w').close()

                open('lvae%d/mnist_noise_fea.txt' % args.lamda, 'w').close()
                open('lvae%d/mnist_noise_tar.txt' % args.lamda, 'w').close()
                open('lvae%d/mnist_noise_pre.txt' % args.lamda, 'w').close()
                open('lvae%d/mnist_noise_rec.txt' % args.lamda, 'w').close()

                open('lvae%d/noise_fea.txt' % args.lamda, 'w').close()
                open('lvae%d/noise_tar.txt' % args.lamda, 'w').close()
                open('lvae%d/noise_pre.txt' % args.lamda, 'w').close()
                open('lvae%d/noise_rec.txt' % args.lamda, 'w').close()

                for data_test, target_test in val_loader:
                    target_test_en = torch.Tensor(target_test.shape[0],
                                                  args.num_classes)
                    target_test_en.zero_()
                    target_test_en.scatter_(1, target_test.view(-1, 1),
                                            1)  # one-hot encoding
                    target_test_en = target_test_en.to(device)
                    if args.cuda:
                        data_test, target_test = data_test.cuda(
                        ), target_test.cuda()
                    with torch.no_grad():
                        data_test, target_test = Variable(data_test), Variable(
                            target_test)

                    mu_test, output_test, de_test = lvae.test(
                        data_test, target_test_en)
                    output_test = torch.exp(output_test)
                    prob_test = output_test.max(1)[
                        0]  # get the value of the max probability
                    pre_test = output_test.max(1, keepdim=True)[
                        1]  # get the index of the max log-probability
                    rec_test = (de_test - data_test).pow(2).sum((3, 2, 1))
                    mu_test = torch.Tensor.cpu(mu_test).detach().numpy()
                    target_test = torch.Tensor.cpu(
                        target_test).detach().numpy()
                    pre_test = torch.Tensor.cpu(pre_test).detach().numpy()
                    rec_test = torch.Tensor.cpu(rec_test).detach().numpy()

                    with open('lvae%d/omn_fea.txt' % args.lamda,
                              'ab') as f_test:
                        np.savetxt(f_test,
                                   mu_test,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        f_test.write(b'\n')
                    with open('lvae%d/omn_tar.txt' % args.lamda,
                              'ab') as t_test:
                        np.savetxt(t_test,
                                   target_test,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        t_test.write(b'\n')
                    with open('lvae%d/omn_pre.txt' % args.lamda,
                              'ab') as p_test:
                        np.savetxt(p_test,
                                   pre_test,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        p_test.write(b'\n')
                    with open('lvae%d/omn_rec.txt' % args.lamda,
                              'ab') as l_test:
                        np.savetxt(l_test,
                                   rec_test,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        l_test.write(b'\n')

                    with open('lvae%d/mnist_noise_fea.txt' % args.lamda,
                              'ab') as f_test:
                        np.savetxt(f_test,
                                   mu_test,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        f_test.write(b'\n')
                    with open('lvae%d/mnist_noise_tar.txt' % args.lamda,
                              'ab') as t_test:
                        np.savetxt(t_test,
                                   target_test,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        t_test.write(b'\n')
                    with open('lvae%d/mnist_noise_pre.txt' % args.lamda,
                              'ab') as p_test:
                        np.savetxt(p_test,
                                   pre_test,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        p_test.write(b'\n')
                    with open('lvae%d/mnist_noise_rec.txt' % args.lamda,
                              'ab') as l_test:
                        np.savetxt(l_test,
                                   rec_test,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        l_test.write(b'\n')

                    with open('lvae%d/noise_fea.txt' % args.lamda,
                              'ab') as f_test:
                        np.savetxt(f_test,
                                   mu_test,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        f_test.write(b'\n')
                    with open('lvae%d/noise_tar.txt' % args.lamda,
                              'ab') as t_test:
                        np.savetxt(t_test,
                                   target_test,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        t_test.write(b'\n')
                    with open('lvae%d/noise_pre.txt' % args.lamda,
                              'ab') as p_test:
                        np.savetxt(p_test,
                                   pre_test,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        p_test.write(b'\n')
                    with open('lvae%d/noise_rec.txt' % args.lamda,
                              'ab') as l_test:
                        np.savetxt(l_test,
                                   rec_test,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        l_test.write(b'\n')
# omn_test
                i_omn = 0
                for data_omn, target_omn in omn_loader:
                    i_omn += 1
                    tar_omn = torch.from_numpy(args.num_classes *
                                               np.ones(target_omn.shape[0]))
                    if i_omn <= 158:  #158*64=10112>10000
                        if args.cuda:
                            data_omn = data_omn.cuda()
                        with torch.no_grad():
                            data_omn = Variable(data_omn)
                    else:
                        break

                    mu_omn, output_omn, de_omn = lvae.test(
                        data_omn, target_test_en)
                    output_omn = torch.exp(output_omn)
                    prob_omn = output_omn.max(1)[
                        0]  # get the value of the max probability
                    pre_omn = output_omn.max(1, keepdim=True)[
                        1]  # get the index of the max log-probability
                    rec_omn = (de_omn - data_omn).pow(2).sum((3, 2, 1))
                    mu_omn = torch.Tensor.cpu(mu_omn).detach().numpy()
                    tar_omn = torch.Tensor.cpu(tar_omn).detach().numpy()
                    pre_omn = torch.Tensor.cpu(pre_omn).detach().numpy()
                    rec_omn = torch.Tensor.cpu(rec_omn).detach().numpy()

                    with open('lvae%d/omn_fea.txt' % args.lamda,
                              'ab') as f_test:
                        np.savetxt(f_test,
                                   mu_omn,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        f_test.write(b'\n')
                    with open('lvae%d/omn_tar.txt' % args.lamda,
                              'ab') as t_test:
                        np.savetxt(t_test,
                                   tar_omn,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        t_test.write(b'\n')
                    with open('lvae%d/omn_pre.txt' % args.lamda,
                              'ab') as p_test:
                        np.savetxt(p_test,
                                   pre_omn,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        p_test.write(b'\n')
                    with open('lvae%d/omn_rec.txt' % args.lamda,
                              'ab') as l_test:
                        np.savetxt(l_test,
                                   rec_omn,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        l_test.write(b'\n')
# mnist_noise_test
                for data_test, target_test in val_loader:
                    tar_mnist_noise = torch.from_numpy(
                        args.num_classes * np.ones(target_test.shape[0]))
                    noise = torch.from_numpy(
                        np.random.rand(data_test.shape[0], 1, 28, 28)).float()
                    data_mnist_noise = data_test.add(noise)
                    if args.cuda:
                        data_mnist_noise = data_mnist_noise.cuda()
                    with torch.no_grad():
                        data_mnist_noise = Variable(data_mnist_noise)

                    mu_mnist_noise, output_mnist_noise, de_mnist_noise = lvae.test(
                        data_mnist_noise, target_test_en)
                    output_mnist_noise = torch.exp(output_mnist_noise)
                    prob_mnist_noise = output_mnist_noise.max(1)[
                        0]  # get the value of the max probability
                    pre_mnist_noise = output_mnist_noise.max(1, keepdim=True)[
                        1]  # get the index of the max log-probability
                    rec_mnist_noise = (de_mnist_noise -
                                       data_mnist_noise).pow(2).sum((3, 2, 1))
                    mu_mnist_noise = torch.Tensor.cpu(
                        mu_mnist_noise).detach().numpy()
                    tar_mnist_noise = torch.Tensor.cpu(
                        tar_mnist_noise).detach().numpy()
                    pre_mnist_noise = torch.Tensor.cpu(
                        pre_mnist_noise).detach().numpy()
                    rec_mnist_noise = torch.Tensor.cpu(
                        rec_mnist_noise).detach().numpy()

                    with open('lvae%d/mnist_noise_fea.txt' % args.lamda,
                              'ab') as f_test:
                        np.savetxt(f_test,
                                   mu_mnist_noise,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        f_test.write(b'\n')
                    with open('lvae%d/mnist_noise_tar.txt' % args.lamda,
                              'ab') as t_test:
                        np.savetxt(t_test,
                                   tar_mnist_noise,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        t_test.write(b'\n')
                    with open('lvae%d/mnist_noise_pre.txt' % args.lamda,
                              'ab') as p_test:
                        np.savetxt(p_test,
                                   pre_mnist_noise,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        p_test.write(b'\n')
                    with open('lvae%d/mnist_noise_rec.txt' % args.lamda,
                              'ab') as l_test:
                        np.savetxt(l_test,
                                   rec_mnist_noise,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        l_test.write(b'\n')


# noise_test
                for data_test, target_test in val_loader:
                    tar_noise = torch.from_numpy(args.num_classes *
                                                 np.ones(target_test.shape[0]))
                    data_noise = torch.from_numpy(
                        np.random.rand(data_test.shape[0], 1, 28, 28)).float()
                    if args.cuda:
                        data_noise = data_noise.cuda()
                    with torch.no_grad():
                        data_noise = Variable(data_noise)

                    mu_noise, output_noise, de_noise = lvae.test(
                        data_noise, target_test_en)
                    output_noise = torch.exp(output_noise)
                    prob_noise = output_noise.max(1)[
                        0]  # get the value of the max probability
                    pre_noise = output_noise.max(1, keepdim=True)[
                        1]  # get the index of the max log-probability
                    rec_noise = (de_noise - data_noise).pow(2).sum((3, 2, 1))
                    mu_noise = torch.Tensor.cpu(mu_noise).detach().numpy()
                    tar_noise = torch.Tensor.cpu(tar_noise).detach().numpy()
                    pre_noise = torch.Tensor.cpu(pre_noise).detach().numpy()
                    rec_noise = torch.Tensor.cpu(rec_noise).detach().numpy()

                    with open('lvae%d/noise_fea.txt' % args.lamda,
                              'ab') as f_test:
                        np.savetxt(f_test,
                                   mu_noise,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        f_test.write(b'\n')
                    with open('lvae%d/noise_tar.txt' % args.lamda,
                              'ab') as t_test:
                        np.savetxt(t_test,
                                   tar_noise,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        t_test.write(b'\n')
                    with open('lvae%d/noise_pre.txt' % args.lamda,
                              'ab') as p_test:
                        np.savetxt(p_test,
                                   pre_noise,
                                   fmt='%d',
                                   delimiter=' ',
                                   newline='\r')
                        p_test.write(b'\n')
                    with open('lvae%d/noise_rec.txt' % args.lamda,
                              'ab') as l_test:
                        np.savetxt(l_test,
                                   rec_noise,
                                   fmt='%f',
                                   delimiter=' ',
                                   newline='\r')
                        l_test.write(b'\n')

    open('lvae%d/train_fea.txt' % args.lamda, 'w').close()  # clear
    np.savetxt('lvae%d/train_fea.txt' % args.lamda,
               train_fea,
               delimiter=' ',
               fmt='%f')
    open('lvae%d/train_tar.txt' % args.lamda, 'w').close()
    np.savetxt('lvae%d/train_tar.txt' % args.lamda,
               train_tar,
               delimiter=' ',
               fmt='%d')
    open('lvae%d/train_rec.txt' % args.lamda, 'w').close()
    np.savetxt('lvae%d/train_rec.txt' % args.lamda,
               train_rec,
               delimiter=' ',
               fmt='%f')

    fea_omn = np.loadtxt('lvae%d/omn_fea.txt' % args.lamda)
    tar_omn = np.loadtxt('lvae%d/omn_tar.txt' % args.lamda)
    pre_omn = np.loadtxt('lvae%d/omn_pre.txt' % args.lamda)
    rec_omn = np.loadtxt('lvae%d/omn_rec.txt' % args.lamda)
    fea_omn = fea_omn[:20000, :]
    tar_omn = tar_omn[:20000]
    pre_omn = pre_omn[:20000]
    rec_omn = rec_omn[:20000]
    open('lvae%d/omn_fea.txt' % args.lamda, 'w').close()  # clear
    np.savetxt('lvae%d/omn_fea.txt' % args.lamda,
               fea_omn,
               delimiter=' ',
               fmt='%f')
    open('lvae%d/omn_tar.txt' % args.lamda, 'w').close()
    np.savetxt('lvae%d/omn_tar.txt' % args.lamda,
               tar_omn,
               delimiter=' ',
               fmt='%d')
    open('lvae%d/omn_pre.txt' % args.lamda, 'w').close()
    np.savetxt('lvae%d/omn_pre.txt' % args.lamda,
               pre_omn,
               delimiter=' ',
               fmt='%d')
    open('lvae%d/omn_rec.txt' % args.lamda, 'w').close()
    np.savetxt('lvae%d/omn_rec.txt' % args.lamda,
               rec_omn,
               delimiter=' ',
               fmt='%d')

    return best_val_loss, best_val_epoch
예제 #21
0
파일: sample.py 프로젝트: samgriesemer/avme
latent_size = 100

# load pre-trained models
with open('pretrained_networks/avme.pt', 'rb') as f:
    ame_dec_v = torch.load(f)
    
with open('pretrained_networks/avme_enc.pt', 'rb') as f:
    ame_enc_v = torch.load(f)
    
# lists for output series
canvas_list = []

# initialize canvas
canvas = Variable(torch.zeros(bsz,784))
output = Variable(torch.zeros(bsz,784))

# initialize LSTM hidden state
h_dec_1 = (Variable(torch.zeros(bsz,mid_size)), Variable(torch.zeros(bsz,mid_size)))
h_dec_2 = (Variable(torch.zeros(bsz,ninp)), Variable(torch.zeros(bsz,ninp)))

# generate image series of length 12
for i in range(12):
    sample = Variable(torch.randn(bsz,latent_size))
    output, h_dec_1, h_dec_2 = ame_dec_v(sample, h_dec_1, h_dec_2)
    canvas = canvas.add(output)
    canvas_list.append(canvas.sigmoid().view(bsz,1,28,28).data[:1])

# matplotlib series plot
pic = torch.cat(canvas_list,3)
plt.figure(figsize=(20,10))
imshow(torchvision.utils.make_grid(pic))
예제 #22
0
파일: run.py 프로젝트: samgriesemer/avme
def train(epoch):

    # iterator variables
    count = 1
    total_loss = 0

    # loop through training data
    for i in range(len(data_list)):

        batch = Variable(data_list[i])
        label = Variable(onehot_labels[i])

        canvas = Variable(torch.zeros(bsz, 784))

        output_v = Variable(torch.zeros(bsz, 784))

        # instantiate encoder and decoder network layers
        h_enc_v_1 = (Variable(torch.zeros(bsz, mid_size)),
                     Variable(torch.zeros(bsz, mid_size)))
        h_enc_v_2 = (Variable(torch.zeros(bsz, latent_size)),
                     Variable(torch.zeros(bsz, latent_size)))
        h_enc_v_3 = (Variable(torch.zeros(bsz, latent_size)),
                     Variable(torch.zeros(bsz, latent_size)))

        h_dec_v_1 = (Variable(torch.zeros(bsz, mid_size)),
                     Variable(torch.zeros(bsz, mid_size)))
        h_dec_v_2 = (Variable(torch.zeros(bsz, ninp)),
                     Variable(torch.zeros(bsz, ninp)))

        # zero gradients and reset KLD each data batch iteration
        ame_enc_v.zero_grad()
        ame_dec_v.zero_grad()
        KLD = 0

        # begin loop through K specified generative iterations
        for t in range(sample_iters):

            # push batch through encoder and store hidden states
            mu_v, logvar_v, h_enc_v_1, h_enc_v_2, h_enc_v_3 = ame_enc_v(
                torch.cat(
                    [batch,
                     batch.sub(canvas.sigmoid()), output_v, label], 1),
                h_enc_v_1,
                h_enc_v_2,
                enc=h_enc_v_3)

            # compute params of latent space
            std_v = logvar_v.mul(0.5).exp_()
            eps_v = torch.FloatTensor(std_v.size()).normal_()
            eps_v = Variable(eps_v)
            sample_v = eps_v.mul(std_v).add_(mu_v)

            # compute KL divergence
            KLD_v_elem = mu_v.pow(2).add_(
                logvar_v.exp()).mul_(-1).add_(1).add_(logvar_v)
            KLD += torch.sum(KLD_v_elem).mul_(-0.5)

            # decode latent space sample and map back to sample space
            output_v, h_dec_v_1, h_dec_v_2 = ame_dec_v(
                torch.cat([sample_v, label], 1), h_dec_v_1, h_dec_v_2)

            # add decoder output to recurring canvas element
            canvas = output_v.add(canvas)

        canvas.sigmoid_()

        # compute binary cross entropy loss and add KLD component
        loss = criterion(canvas, batch)
        loss += KLD
        loss.backward()

        # clip encoder gradients
        clipped_lr = lr * clip_gradient(ame_enc_v, clip)
        for p in ame_enc_v.parameters():
            p.data.add_(-clipped_lr, p.grad.data)

        # clip decoder gradients
        clipped_lr = lr * clip_gradient(ame_dec_v, clip)
        for p in ame_dec_v.parameters():
            p.data.add_(-clipped_lr, p.grad.data)

        # store loss data (mainly for viz)
        loss_list.append(loss.data[0])
        total_loss += loss.data[0]

        # print loss stats
        if count % 10 == 0:
            avg_loss = total_loss / count
            print('Epoch:', epoch, 'Iter:', count, 'Avg Loss:', avg_loss,
                  'Cur Loss:', loss.data[0])
        count += 1
예제 #23
0
파일: validate.py 프로젝트: NaiveXu/Master
def validate(model, epoch, optimizer, test_loader, args, writer, accuracy_dict,
             episode, criterion):

    # Initialize training:
    model.eval()

    # Collect all episode images w/labels:
    image_batch, label_batch = test_loader.__iter__().__next__()

    # Episode Statistics:
    episode_loss = 0.0
    episode_optimized = 0.0
    episode_correct = 0.0
    episode_predict = 0.0
    episode_optimized = 0.0
    episode_iter = 0.0

    # Create initial state:
    state = []
    label_dict = []
    for i in range(args.batch_size):
        label_dict.append({})
        state.append([0 for i in range(args.class_vector_size)])

    # Initialize model between each episode:
    hidden = model.reset_hidden(args.batch_size)

    # Accuracy statistics:
    for v in accuracy_dict.values():
        v.append([])

    # Initiate empty loss Variable:
    if (args.cuda):
        loss = Variable(torch.zeros(args.batch_size).type(torch.Tensor)).cuda()
    else:
        loss = Variable(torch.zeros(args.batch_size).type(torch.Tensor))

    predictions = []
    test_labels = []

    for i_e in range(len(label_batch)):

        # Collect timestep images/labels:
        episode_images = image_batch[i_e]
        episode_labels = label_batch[i_e]

        # Tensoring the state:
        state = torch.FloatTensor(state)

        # Need to add image to the state vector:
        flat_images = episode_images.squeeze().view(args.batch_size, -1)

        # Concatenating possible labels/zero vector with image, thus creating the state:
        state = torch.cat((state, flat_images), 1)

        # Generating actions to choose from the model:
        if (args.cuda):
            actions, hidden = model(
                Variable(state).type(torch.FloatTensor).cuda(), hidden)
        else:
            actions, hidden = model(
                Variable(state).type(torch.FloatTensor), hidden)

        predictions.append([pred for pred in F.softmax(actions[0]).data])
        test_labels.append(episode_labels[0])

        if (args.cuda):
            current_loss = criterion(actions, Variable(episode_labels).cuda())
        else:
            current_loss = criterion(actions, Variable(episode_labels))

        loss = loss.add(current_loss)

        actions = actions.data.max(1)[1].squeeze()

        one_hot_labels = []
        for i in range(args.batch_size):
            true_label = episode_labels[i]

            # Creating one hot labels:
            one_hot_labels.append([
                1 if j == true_label else 0
                for j in range(args.class_vector_size)
            ])

            # Logging label occurences:
            if (true_label not in label_dict[i]):
                label_dict[i][true_label] = 1
            else:
                label_dict[i][true_label] += 1

            # Logging accuracy:
            if (actions[i] == true_label):
                episode_correct += 1.0
                episode_predict += 1.0
                if (label_dict[i][true_label] in accuracy_dict):
                    accuracy_dict[label_dict[i][true_label]][-1].append(1)
            else:
                episode_predict += 1.0
                if (label_dict[i][true_label] in accuracy_dict):
                    accuracy_dict[label_dict[i][true_label]][-1].append(0)

        # Update next state:
        state = one_hot_labels

        ### END EPISODE LOOP ###

    # Averaging the loss over the batch (SGD):
    avg_loss = torch.div(loss.sum(), args.batch_size)

    # More status update:
    total_loss = avg_loss.data[0]

    print("\n--- Epoch " + str(epoch) + ", Episode " + str(episode + i + 1) +
          " Statistics ---")
    print("Instance\tAccuracy")
    for key in accuracy_dict.keys():
        prob_list = accuracy_dict[key]

        latest = prob_list[len(prob_list) - 1:]
        probs = 0.0
        prob = 0.0
        for l in latest:
            prob += sum(l)
            probs += len(l)
        prob /= probs
        print("Instance " + str(key) + ":\t" + str(100.0 * prob)[0:4] + " %")

    # Even more status update:
    print("\n+------------------STATISTICS----------------------+")
    total_accuracy = float((100.0 * episode_correct) / episode_predict)
    print("Batch Average Accuracy = " + str(total_accuracy)[:5] + " %")
    total_loss = float(total_loss)
    print("Batch Average Loss = " + str(total_loss)[:5])
    print("+--------------------------------------------------+\n")

    ### LOGGING TO TENSORBOARD ###
    data = {
        'test_total_accuracy': total_accuracy,
        'test_total_loss': total_loss,
    }

    for tag, value in data.items():
        writer.scalar_summary(tag, value, epoch)
    ### DONE LOGGING ###

    return total_accuracy, total_loss, accuracy_dict, predictions, test_labels
예제 #24
0
    def _pad(self,
             sentences,
             pad_id,
             volatile=False,
             raml=False,
             raml_tau=1.,
             vocab_size=None,
             pad_corrupt=False,
             dist_corrupt=False,
             no_corrupt_mask=None):
        """Pad all instances in [data] to the longest length.

    Args:
      sentences: list of [batch_size] lists.

    Returns:
      padded_sentences: Variable of size [batch_size, max_len], the sentences.
      mask: Variable of size [batch_size, max_len]. 1 means to ignore.
      pos_emb_indices: Variable of size [batch_size, max_len]. indices to use
        when computing positional embedding.
      sum_len: total number of words.
      raml: if True, the sentences will be replaced by their samples according
        to the exp payoff distribution, ie. exp(-HammingDist(s, sents)).
      vocab_size: if raml is True, vocab_size has to be specified for negative
        sampling.
      dist_corrupt: if is set to True, sample corrupt words based on word 
        embedding distance.
      no_corrupt_mask: [batch_size, 1] if is set, then 1 means not to corrupt 
        the sentence.
    """

        batch_size = len(sentences)
        lengths = [len(sentence) for sentence in sentences]
        sum_len = sum(lengths)
        max_len = max(lengths)

        padded_sentences = [
            sentence + ([pad_id] * (max_len - len(sentence)))
            for sentence in sentences
        ]
        mask = [([0] * len(sentence)) + ([1] * (max_len - len(sentence)))
                for sentence in sentences]
        pos_emb_indices = [[i + 1 for i in range(len(sentence))] +
                           ([0] * (max_len - len(sentence)))
                           for sentence in sentences]
        if raml:
            raml_mask = [[1] + ([0] * (len(sentence) - 2)) +
                         ([1] * (max_len - len(sentence) + 1))
                         for sentence in sentences]
            raml_mask = torch.ByteTensor(raml_mask)
            if self.hparams.cuda:
                raml_mask = raml_mask.cuda()
        padded_sentences = Variable(torch.LongTensor(padded_sentences))
        mask = torch.ByteTensor(mask)
        pos_emb_indices = Variable(torch.FloatTensor(pos_emb_indices))

        if self.hparams.cuda:
            padded_sentences = padded_sentences.cuda()
            mask = mask.cuda()
            pos_emb_indices = pos_emb_indices.cuda()

        if not raml:
            return padded_sentences, mask, pos_emb_indices, sum_len

        assert vocab_size is not None
        # first, sample the number of words to corrupt for each sentence
        logits = torch.arange(max_len)
        if self.hparams.cuda:
            logits = logits.cuda()
        logits = logits.mul_(-1).unsqueeze(0).expand_as(
            padded_sentences).contiguous().masked_fill_(
                mask, -self.hparams.inf)
        logits = Variable(logits, volatile=True)
        if self.hparams.cuda:
            logits = logits.cuda()
        probs = self.softmax(logits.mul_(raml_tau))
        num_words = torch.distributions.Categorical(probs).sample()

        # sample the indices
        lengths = torch.FloatTensor(lengths)
        if self.hparams.cuda:
            lengths = lengths.cuda()

        # mask out bos, eos
        corrupt_pos = num_words.data.float().div_(lengths - 2).unsqueeze(
            1).expand_as(padded_sentences).contiguous().masked_fill_(
                raml_mask, 0)
        if no_corrupt_mask is not None:
            corrupt_pos.masked_fill_(no_corrupt_mask, 0)
        corrupt_pos = torch.bernoulli(corrupt_pos, out=corrupt_pos).byte()
        total_words = int(corrupt_pos.sum())
        if total_words == 0:
            return padded_sentences, padded_sentences, mask, pos_emb_indices, sum_len
        if dist_corrupt:
            # sample words according to distance
            words_to_corrupt = padded_sentences.data.masked_select(corrupt_pos)
            # (num_words, dim)
            words_emb_to_corrupt = torch.index_select(self.glove,
                                                      dim=0,
                                                      index=words_to_corrupt)
            # (num_words, num_vocab)
            w12 = torch.mm(words_emb_to_corrupt, self.glove.permute(1, 0))
            # (num_words, 1)
            w1 = torch.norm(words_emb_to_corrupt, 2, dim=1, keepdim=True)
            # (num_vocab, 1)
            w2 = torch.norm(self.glove, 2, dim=1, keepdim=True)
            # (num_words, num_vocab)
            distance = w12 / ((torch.mm(w1, w2.permute(1, 0))).clamp(min=1e-8))
            distance = Variable(distance)
            if self.hparams.cuda:
                distance = distance.cuda()
            # mask out the original words
            add = torch.arange(total_words).long() * vocab_size
            if self.hparams.cuda:
                add = add.cuda()
            corrupt_mask_index = (words_to_corrupt + add)
            distance = distance.view(-1)
            distance.data.index_fill_(0, corrupt_mask_index, -self.hparams.inf)
            distance = distance.view(total_words, -1)
            probs = torch.nn.functional.softmax(distance.mul_(
                self.hparams.dist_corrupt_tau),
                                                dim=1)
            corrupt_val = torch.distributions.Categorical(probs).sample().view(
                -1)
            #_, corrupt_val = torch.topk(probs, 1)

            if self.hparams.cuda:
                corrupt_val = corrupt_val.long().cuda()
                corrupt_pos = corrupt_pos.cuda()
            sample_sentences = padded_sentences.masked_scatter(
                Variable(corrupt_pos), corrupt_val)
        else:
            # sample the corrupts, which will be added to padded_sentences
            corrupt_val = torch.LongTensor(total_words)
            if pad_corrupt:
                corrupt_val.fill_(self.hparams.pad_id + vocab_size)
            else:
                corrupt_val = corrupt_val.random_(1, vocab_size)
            corrupts = torch.zeros(batch_size, max_len).long()
            if self.hparams.cuda:
                corrupt_val = corrupt_val.long().cuda()
                corrupts = corrupts.cuda()
                corrupt_pos = corrupt_pos.cuda()
            corrupts = corrupts.masked_scatter_(corrupt_pos, corrupt_val)

            sample_sentences = padded_sentences.add(
                Variable(corrupts)).remainder_(vocab_size).masked_fill_(
                    Variable(mask), pad_id)
        #if dist_corrupt:
        #  print(sample_sentences)
        #  print(corrupt_val)
        #  print(words_to_corrupt)
        #  print(padded_sentences)
        #  print(corrupt_pos)
        #  corrupt_val = corrupt_val.data.view(-1).cpu().numpy()
        #  words_to_corrupt = words_to_corrupt.cpu().numpy()

        #  for c, w in zip(corrupt_val, words_to_corrupt):
        #    print("corrupt", self.source_index_to_word[c])
        #    print("orig", self.source_index_to_word[w])
        #  exit(0)
        return sample_sentences, padded_sentences, mask, pos_emb_indices, sum_len
예제 #25
0
    def interpolate(self, x_grid, x_target, interp_points=range(-2, 2)):
        # Do some boundary checking
        grid_mins = x_grid.min(1)[0]
        grid_maxs = x_grid.max(1)[0]
        x_target_min = x_target.min(0)[0]
        x_target_max = x_target.min(0)[0]
        lt_min_mask = (x_target_min - grid_mins).lt(-1e-7)
        gt_max_mask = (x_target_max - grid_maxs).gt(1e-7)
        if lt_min_mask.data.sum():
            first_out_of_range = lt_min_mask.nonzero().squeeze(1)[0].data
            raise RuntimeError((
                "Received data that was out of bounds for the specified grid. "
                "Grid bounds were ({0:.3f}, {0:.3f}), but min = {0:.3f}, "
                "max = {0:.3f}").format(
                    grid_mins[first_out_of_range].data[0],
                    grid_maxs[first_out_of_range].data[0],
                    x_target_min[first_out_of_range].data[0],
                    x_target_max[first_out_of_range].data[0],
                ))
        if gt_max_mask.data.sum():
            first_out_of_range = gt_max_mask.nonzero().squeeze(1)[0].data
            raise RuntimeError((
                "Received data that was out of bounds for the specified grid. "
                "Grid bounds were ({0:.3f}, {0:.3f}), but min = {0:.3f}, "
                "max = {0:.3f}").format(
                    grid_mins[first_out_of_range].data[0],
                    grid_maxs[first_out_of_range].data[0],
                    x_target_min[first_out_of_range].data[0],
                    x_target_max[first_out_of_range].data[0],
                ))

        # Now do interpolation
        interp_points_flip = Variable(x_grid.data.new(interp_points[::-1]))
        interp_points = Variable(x_grid.data.new(interp_points))

        num_grid_points = x_grid.size(1)
        num_target_points = x_target.size(0)
        num_dim = x_target.size(-1)
        num_coefficients = len(interp_points)

        interp_values = Variable(
            x_target.data.new(num_target_points,
                              num_coefficients**num_dim).fill_(1))
        interp_indices = Variable(
            x_grid.data.new(num_target_points,
                            num_coefficients**num_dim).long().zero_())

        for i in range(num_dim):
            grid_delta = x_grid[i, 1] - x_grid[i, 0]
            lower_grid_pt_idxs = torch.floor(
                (x_target[:, i] - x_grid[i, 0]) / grid_delta).squeeze()
            lower_pt_rel_dists = (x_target[:, i] - x_grid[i, 0]
                                  ) / grid_delta - lower_grid_pt_idxs
            lower_grid_pt_idxs = lower_grid_pt_idxs - interp_points.max()
            lower_grid_pt_idxs.detach_()

            scaled_dist = lower_pt_rel_dists.unsqueeze(
                -1) + interp_points_flip.unsqueeze(-2)
            dim_interp_values = self._cubic_interpolation_kernel(scaled_dist)

            # Find points who's closest lower grid point is the first grid point
            # This corresponds to a boundary condition that we must fix manually.
            left_boundary_pts = torch.nonzero(lower_grid_pt_idxs < 1)
            num_left = len(left_boundary_pts)

            if num_left > 0:
                left_boundary_pts.squeeze_(1)
                x_grid_first = x_grid[i, :num_coefficients].unsqueeze(
                    1).t().expand(num_left, num_coefficients)

                grid_targets = x_target.select(
                    1, i)[left_boundary_pts].unsqueeze(1).expand(
                        num_left, num_coefficients)
                dists = torch.abs(x_grid_first - grid_targets)
                closest_from_first = torch.min(dists, 1)[1]

                for j in range(num_left):
                    dim_interp_values[left_boundary_pts[j], :] = 0
                    dim_interp_values[left_boundary_pts[j],
                                      closest_from_first[j]] = 1
                    lower_grid_pt_idxs[left_boundary_pts[j]] = 0

            right_boundary_pts = torch.nonzero(
                lower_grid_pt_idxs > num_grid_points - num_coefficients)
            num_right = len(right_boundary_pts)

            if num_right > 0:
                right_boundary_pts.squeeze_(1)
                x_grid_last = x_grid[i, -num_coefficients:].unsqueeze(
                    1).t().expand(num_right, num_coefficients)

                grid_targets = x_target.select(
                    1, i)[right_boundary_pts].unsqueeze(1)
                grid_targets = grid_targets.expand(num_right, num_coefficients)
                dists = torch.abs(x_grid_last - grid_targets)
                closest_from_last = torch.min(dists, 1)[1]

                for j in range(num_right):
                    dim_interp_values[right_boundary_pts[j], :] = 0
                    dim_interp_values[right_boundary_pts[j],
                                      closest_from_last[j]] = 1
                    lower_grid_pt_idxs[right_boundary_pts[
                        j]] = num_grid_points - num_coefficients

            offset = (interp_points - interp_points.min()).long().unsqueeze(-2)
            dim_interp_indices = lower_grid_pt_idxs.long().unsqueeze(
                -1) + offset

            n_inner_repeat = num_coefficients**i
            n_outer_repeat = num_coefficients**(num_dim - i - 1)
            index_coeff = num_grid_points**(num_dim - i - 1)
            dim_interp_indices = dim_interp_indices.unsqueeze(-1).repeat(
                1, n_inner_repeat, n_outer_repeat)
            dim_interp_values = dim_interp_values.unsqueeze(-1).repeat(
                1, n_inner_repeat, n_outer_repeat)
            interp_indices = interp_indices.add(
                dim_interp_indices.view(num_target_points,
                                        -1).mul(index_coeff))
            interp_values = interp_values.mul(
                dim_interp_values.view(num_target_points, -1))

        return interp_indices, interp_values
예제 #26
0
    def forward(self, constituency_tree, dependency_tags, input_seq):

        embedded_seq = []

        for elt in input_seq:
            #embedded_seq.append(self.embedding(Variable(torch.LongTensor([elt])).to(device='cuda')).unsqueeze(0))
            embedded_seq.append(
                self.embedding(Variable(torch.LongTensor([elt]))).unsqueeze(0))

        current_level_c = []
        current_level_h = []

        first_level = True

        for level_idx in range(len(constituency_tree)):

            level = constituency_tree[level_idx]

            next_level_c = []
            next_level_h = []

            for node_idx in range(len(level)):

                node = level[node_idx]

                if first_level:
                    x = embedded_seq[node[0]]
                    h_left = Variable(torch.zeros(self.hidden_size))
                    h_right = Variable(torch.zeros(self.hidden_size))
                    c_left = Variable(torch.zeros(self.hidden_size))
                    c_right = Variable(torch.zeros(self.hidden_size))

                elif len(node) == 1:
                    #h_left = Variable(torch.zeros(self.hidden_size).to(device='cuda'))
                    x = embedded_seq[dependency_tags[level_idx - 1][node_idx]]
                    c_left = current_level_c[node[0]]
                    c_right = Variable(torch.zeros(self.hidden_size))
                    h_left = current_level_h[node[0]]
                    h_right = Variable(torch.zeros(self.hidden_size))

                else:
                    x = embedded_seq[dependency_tags[level_idx - 1][node_idx]]
                    c_left = current_level_c[node[0]]
                    c_right = current_level_c[node[1]]
                    h_left = current_level_h[node[0]]
                    h_right = current_level_h[node[1]]

                sum_term_c = Variable(torch.zeros(self.hidden_size))

                i = torch.sigmoid(
                    self.w_i(x).add(self.u_i_l(h_left)).add(
                        self.u_i_r(h_right))).reshape(-1)

                f_l = torch.sigmoid(
                    self.w_f(x).add(self.u_f_ll(h_left)).add(
                        self.u_f_lr(h_right))).reshape(-1)
                f_r = torch.sigmoid(
                    self.w_f(x).add(self.u_f_rl(h_right)).add(
                        self.u_f_rr(h_left))).reshape(-1)

                o = torch.sigmoid(
                    self.w_o(x).add(self.u_o_l(h_left)).add(
                        self.u_o_r(h_right))).reshape(-1)

                u = torch.tanh(
                    self.w_u(x).add(self.u_u_l(h_left)).add(
                        self.u_u_r(h_right))).reshape(-1)

                sum_term_c = sum_term_c.add(torch.mul(f_l, c_left)).add(
                    torch.mul(f_r, c_right))

                c = torch.mul(i, u).add(sum_term_c)
                h = torch.mul(o, torch.tanh(c))

                next_level_c.append(c)
                next_level_h.append(h)

            first_level = False
            current_level_c = next_level_c
            current_level_h = next_level_h

        return current_level_h, current_level_c