def forward(self, sents, sent_lengths):
     '''
         sents is (batch_size by padded_length)
         when we evaluate sentence by sentence, you evaluate it with batch_size = 1, padded_length.
         [[1, 2, 3, 4]] etc. 
     '''
     batch_size = sents.size()[0]
     sent_lengths = list(sent_lengths)
     # We sort and then do pad packed sequence here. 
     descending_lengths = [x for x, _ in sorted(zip(sent_lengths, range(len(sent_lengths))), reverse=True)]
     descending_indices = [x for _, x in sorted(zip(sent_lengths, range(len(sent_lengths))), reverse=True)]
     descending_lengths = torch.tensor(descending_lengths)
     descending_indices = torch.tensor(descending_indices).to(device)
     descending_sents = torch.index_select(sents, torch.tensor(0), descending_indices)
     
     # get embedding
     embed = self.embedding(descending_sents)
     # pack padded sequence
     embed = torch.nn.utils.rnn.pack_padded_sequence(embed, descending_lengths, batch_first=True)
     
     # fprop though RNN
     self.hidden = self.init_hidden(batch_size)
     rnn_out, self.hidden = self.gru(embed, self.hidden)
     pdb.set_trace()
     rnn_out, _ = torch.nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
     # rnn_out is 32 by 72 by 256
     
     # change the order back
     change_it_back = [x for _, x in sorted(zip(descending_indices, range(len(descending_indices))))]
     self.hidden = torch.index_select(self.hidden, 1, torch.LongTensor(change_it_back).to(device))  
     rnn_out = torch.index_select(rnn_out, 0, torch.LongTensor(change_it_back).to(device)) 
     
     return rnn_out, self.hidden
    def get_triplet_loss(image_a_pred, image_b_pred, matches_a, matches_b, non_matches_a, non_matches_b, alpha):
        """
        Computes the loss function

        \sum_{triplets} ||D(I_a, u_a, I_b, u_{b,match})||_2^2 - ||D(I_a, u_a, I_b, u_{b,non-match)||_2^2 + alpha 

        """
        num_matches = matches_a.size()[0]
        num_non_matches = non_matches_a.size()[0]
        multiplier = num_non_matches / num_matches

        ## non_matches_a is already replicated up to be the right size
        ## non_matches_b is also that side
        ## matches_a is just a smaller version of non_matches_a
        ## matches_b is the only thing that needs to be replicated up in size

        matches_b_long =  torch.t(matches_b.repeat(multiplier, 1)).contiguous().view(-1)
                         
        matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a)
        matches_b_descriptors      = torch.index_select(image_b_pred, 1, matches_b_long)
        non_matches_b_descriptors  = torch.index_select(image_b_pred, 1, non_matches_b)

        triplet_losses = (matches_a_descriptors - matches_b_descriptors).pow(2) - (matches_a_descriptors - non_matches_b_descriptors).pow(2) + alpha
        triplet_loss = 1.0 / num_non_matches * torch.clamp(triplet_losses, min=0).sum()

        return triplet_loss
Example #3
0
    def forward(cls, ctx, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
                sparse=False):

        ctx.padding_idx = padding_idx
        ctx.scale_grad_by_freq = scale_grad_by_freq
        ctx._indices = None
        ctx.sparse = sparse

        assert indices.dim() <= 2
        assert not ctx.needs_input_grad[0], "Embedding doesn't " \
            "compute the gradient w.r.t. the indices"

        ctx._backend = type2backend[type(weight)]
        ctx._weight_size = weight.size()

        if not indices.is_contiguous():
            ctx._indices = indices.contiguous()
            indices = ctx._indices
        else:
            ctx.save_for_backward(indices)

        output = weight.new()
        if max_norm is not None:
            cls._renorm(ctx, indices, weight, max_norm, norm_type)

        if indices.dim() == 1:
            output = torch.index_select(weight, 0, indices)
        else:
            output = torch.index_select(weight, 0, indices.view(-1))
            output = output.view(indices.size(0), indices.size(1), weight.size(1))

        return output
    def get_loss(self, image_a_pred, image_b_pred, mask_a, mask_b):
        loss = 0

        # get the nonzero indices
        mask_a_indices_flat = torch.nonzero(mask_a)
        mask_b_indices_flat = torch.nonzero(mask_b)
        if len(mask_a_indices_flat) == 0:
            return Variable(torch.cuda.LongTensor([0]), requires_grad=True)
        if len(mask_b_indices_flat) == 0:
            return Variable(torch.cuda.LongTensor([0]), requires_grad=True)

        # take 5000 random pixel samples of the object, using the mask
        num_samples = 10000

        rand_numbers_a = (torch.rand(num_samples)*len(mask_a_indices_flat)).cuda()
        rand_indices_a = Variable(torch.floor(rand_numbers_a).type(torch.cuda.LongTensor), requires_grad=False)
        randomized_mask_a_indices_flat = torch.index_select(mask_a_indices_flat, 0, rand_indices_a).squeeze(1)

        rand_numbers_b = (torch.rand(num_samples)*len(mask_b_indices_flat)).cuda()
        rand_indices_b = Variable(torch.floor(rand_numbers_b).type(torch.cuda.LongTensor), requires_grad=False)
        randomized_mask_b_indices_flat = torch.index_select(mask_b_indices_flat, 0, rand_indices_b).squeeze(1)

        # index into the image and get descriptors
        M_margin = 0.5 # margin parameter
        random_img_a_object_descriptors = torch.index_select(image_a_pred, 1, randomized_mask_a_indices_flat)
        random_img_b_object_descriptors = torch.index_select(image_b_pred, 1, randomized_mask_b_indices_flat)
        pixel_wise_loss = (random_img_a_object_descriptors - random_img_b_object_descriptors).pow(2).sum(dim=2)
        pixel_wise_loss = torch.add(pixel_wise_loss, -2*M_margin)
        zeros_vec = torch.zeros_like(pixel_wise_loss)
        loss += torch.max(zeros_vec, pixel_wise_loss).sum()

        return loss
    def dual_OT_model(self, Xs_batch, i_t):

        batch_size = i_t.shape[0]
        u_batch = self.u(Xs_batch)
        v_batch = torch.index_select(self.v, dim=0, index=i_t)
        Xt_batch = torch.index_select(self.Xt, dim=0, index=i_t)

        return self.dual_OT_batch_loss(batch_size=batch_size, u_batch=u_batch, v_batch=v_batch, Xs_batch=Xs_batch, Xt_batch=Xt_batch)
    def get_loss_original(self, image_a_pred, image_b_pred, matches_a,
                          matches_b, non_matches_a, non_matches_b,
                          M_margin=0.5, non_match_loss_weight=1.0):

        # this is pegged to it's implemenation at sha 87abdb63bb5b99d9632f5c4360b5f6f1cf54245f
        """
        Computes the loss function
        DCN = Dense Correspondence Network
        num_images = number of images in this batch
        num_matches = number of matches
        num_non_matches = number of non-matches
        W = image width
        H = image height
        D = descriptor dimension
        match_loss = 1/num_matches \sum_{num_matches} ||descriptor_a - descriptor_b||_2^2
        non_match_loss = 1/num_non_matches \sum_{num_non_matches} max(0, M_margin - ||descriptor_a - descriptor_b||_2^2 )
        loss = match_loss + non_match_loss
        :param image_a_pred: Output of DCN network on image A.
        :type image_a_pred: torch.Variable(torch.FloatTensor) shape [1, W * H, D]
        :param image_b_pred: same as image_a_pred
        :type image_b_pred:
        :param matches_a: torch.Variable(torch.LongTensor) has shape [num_matches,],  a (u,v) pair is mapped
        to (u,v) ---> image_width * v + u, this matches the shape of one dimension of image_a_pred
        :type matches_a: torch.Variable(torch.FloatTensor)
        :param matches_b: same as matches_b
        :type matches_b:
        :param non_matches_a: torch.Variable(torch.FloatTensor) has shape [num_non_matches,],  a (u,v) pair is mapped
        to (u,v) ---> image_width * v + u, this matches the shape of image_a_pred
        :type non_matches_a: torch.Variable(torch.FloatTensor)
        :param non_matches_b: same as non_matches_a
        :type non_matches_b:
        :return: loss, match_loss, non_match_loss
        :rtype: torch.Variable(torch.FloatTensor) each of shape torch.Size([1])
        """

        num_matches = matches_a.size()[0]
        num_non_matches = non_matches_a.size()[0]


        matches_a_descriptors = torch.index_select(image_a_pred, 1, matches_a)
        matches_b_descriptors = torch.index_select(image_b_pred, 1, matches_b)

        match_loss = 1.0/num_matches * (matches_a_descriptors - matches_b_descriptors).pow(2).sum()

        # add loss via non_matches
        non_matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a)
        non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b)
        pixel_wise_loss = (non_matches_a_descriptors - non_matches_b_descriptors).pow(2).sum(dim=2)
        pixel_wise_loss = torch.add(torch.neg(pixel_wise_loss), M_margin)
        zeros_vec = torch.zeros_like(pixel_wise_loss)
        non_match_loss = non_match_loss_weight * 1.0/num_non_matches * torch.max(zeros_vec, pixel_wise_loss).sum()

        loss = match_loss + non_match_loss

        return loss, match_loss, non_match_loss
Example #7
0
    def forward(self, input, output, input_lens=None, output_lens=None, lookup=None, **kwargs):
        h0 = self.h0.expand(1, input.size(0), self.hidden_dim).contiguous()
        c0 = self.c0.expand(1, input.size(0), self.hidden_dim).contiguous()
        input_encoded, input_h, input_c = self.encoder(input, h0, c0, lens=input_lens)

        if lookup:
            input_h = th.index_select(input_h, 1, lookup)
            input_c = th.index_select(input_c, 1, lookup)
            
        transfer_h, transfer_c = self.transfer(input_h, input_c, **kwargs)
        log_probs, _, _ = self.decoder(output, transfer_h, transfer_c, lens=output_lens)
        return log_probs
Example #8
0
    def updateOutput(self, input):
        self.renorm(input)
        input = self._makeInputContiguous(input)
        if input.dim() == 1:
            torch.index_select(self.weight, 0, input, out=self.output)
        elif input.dim() == 2:
            torch.index_select(self.weight, 0, input.view(-1), out=self.output)
            self.output = self.output.view(input.size(0), input.size(1), self.weight.size(1))
        else:
            raise RuntimeError("input must be a vector or matrix")

        return self.output
Example #9
0
    def forward(self, base_feat, im_info, gt_boxes, num_boxes):

        batch_size = base_feat.size(0)

        # return feature map after convrelu layer
        rpn_conv1 = F.relu(self.RPN_Conv(base_feat), inplace=True)
        # get rpn classification score
        rpn_cls_score = self.RPN_cls_score(rpn_conv1)

        rpn_cls_score_reshape = self.reshape(rpn_cls_score, 2)
        rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape, dim=1)
        rpn_cls_prob = self.reshape(rpn_cls_prob_reshape, self.nc_score_out)

        # get rpn offsets to the anchor boxes
        rpn_bbox_pred = self.RPN_bbox_pred(rpn_conv1)

        # proposal layer
        cfg_key = 'TRAIN' if self.training else 'TEST'

        rois = self.RPN_proposal((rpn_cls_prob.data, rpn_bbox_pred.data,
                                 im_info, cfg_key))

        self.rpn_loss_cls = 0
        self.rpn_loss_box = 0

        # generating training labels and build the rpn loss
        if self.training:
            assert gt_boxes is not None

            rpn_data = self.RPN_anchor_target((rpn_cls_score.data, gt_boxes, im_info, num_boxes))

            # compute classification loss
            rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
            rpn_label = rpn_data[0].view(batch_size, -1)

            rpn_keep = Variable(rpn_label.view(-1).ne(-1).nonzero().view(-1))
            rpn_cls_score = torch.index_select(rpn_cls_score.view(-1,2), 0, rpn_keep)
            rpn_label = torch.index_select(rpn_label.view(-1), 0, rpn_keep.data)
            rpn_label = Variable(rpn_label.long())
            self.rpn_loss_cls = F.cross_entropy(rpn_cls_score, rpn_label)
            fg_cnt = torch.sum(rpn_label.data.ne(0))

            rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]

            # compute bbox regression loss
            rpn_bbox_inside_weights = Variable(rpn_bbox_inside_weights)
            rpn_bbox_outside_weights = Variable(rpn_bbox_outside_weights)
            rpn_bbox_targets = Variable(rpn_bbox_targets)

            self.rpn_loss_box = _smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights,
                                                            rpn_bbox_outside_weights, sigma=3, dim=[1,2,3])

        return rois, self.rpn_loss_cls, self.rpn_loss_box
    def barycentric_mapping_loss_model(self, neuralNet, Xs_batch, i_t):

        self.u.eval()
        self.v.requires_grad_(False)

        u_batch = self.u(Xs_batch)
        v_batch = torch.index_select(self.v, dim=0, index=i_t)
        Xt_batch = torch.index_select(self.Xt, dim=0, index=i_t)

        fXs_batch = neuralNet(Xs_batch)

        return self.barycentric_model_batch_loss(u_batch, v_batch, Xs_batch, Xt_batch, fXs_batch)
def random_sample_from_masked_image_torch(img_mask, num_samples):
    """

    :param img_mask: Numpy array [H,W] or torch.Tensor with shape [H,W]
    :type img_mask:
    :param num_samples: an integer
    :type num_samples:
    :return: tuple of torch.LongTensor in (u,v) format. Each torch.LongTensor has shape
    [num_samples]
    :rtype:
    """

    image_height, image_width = img_mask.shape

    if isinstance(img_mask, np.ndarray):
        img_mask_torch = torch.from_numpy(img_mask).float()
    else:
        img_mask_torch = img_mask

    # This code would randomly subsample from the mask
    mask = img_mask_torch.view(image_width*image_height,1).squeeze(1)
    mask_indices_flat = torch.nonzero(mask)
    if len(mask_indices_flat) == 0:
        return (None, None)

    rand_numbers = torch.rand(num_samples)*len(mask_indices_flat)
    rand_indices = torch.floor(rand_numbers).long()
    uv_vec_flattened = torch.index_select(mask_indices_flat, 0, rand_indices).squeeze(1)
    uv_vec = utils.flattened_pixel_locations_to_u_v(uv_vec_flattened, image_width)
    return uv_vec
Example #12
0
    def next(self):
        if self.next_i+self.batch_size > len(self.data):
            raise StopIteration()
        else:
            x_idx = self.x_idx[self.next_i:self.next_i+self.batch_size]
            self.next_i += self.batch_size

        labels = {k: torch.index_select(self.labels[k], 0, x_idx) for k in self.labels}
        x = self.select_data(x_idx)

        inputs = {}
        sizes = {} 
        for k,v in labels.items():
            possibilities = [self.label_idxs[k][v[i].item()] for i in range(len(x_idx))]
            sizes[k] = [len(X) for X in possibilities]
            input_idx = [np.random.choice(X, size=self.k_shot[k]) for X in possibilities]
            _inputs = [
                self.select_data(torch.LongTensor([I[j] for I in input_idx]))
                for j in range(self.k_shot[k])]
            if self.mode == "tensor":
                inputs[k] = torch.cat([x.unsqueeze(1) for x in _inputs], dim=1)
            elif self.mode == "list":
                inputs[k] = [[_inputs[j][i] for j in range(self.k_shot[k])]
                        for i in range(len(_inputs[0]))]

        batch = VHEBatch(target=x, inputs=inputs, sizes=sizes)
        for transform in self.transforms:
            batch = transform.apply(batch)
        return batch
Example #13
0
 def forward(self, x):
     x = torch.index_select(x, 1, Variable(self.index))
     x = self.norm(x)
     x = self.relu(x)
     x = self.conv(x)
     x = ShuffleLayer(x, self.groups)
     return x
Example #14
0
        def model():
            p_latent = pyro.param("p1", Variable(torch.Tensor([[0.7], [0.3]])))
            p_obs = pyro.param("p2", Variable(torch.Tensor([[0.9], [0.1]])))

            latents = [Variable(torch.ones(1, 1))]
            observes = []
            for t in range(self.model_steps):

                latents.append(
                    pyro.sample("latent_{}".format(str(t)),
                                Bernoulli(torch.index_select(p_latent, 0, latents[-1].view(-1).long()))))

                observes.append(
                    pyro.observe("observe_{}".format(str(t)),
                                 Bernoulli(torch.index_select(p_obs, 0, latents[-1].view(-1).long())),
                                 self.data[t]))
            return torch.sum(torch.cat(latents))
Example #15
0
def deprocess_img(img):
    # BGR to RGB
    idx = torch.LongTensor([2, 1, 0])
    img = torch.index_select(img, 0, idx)

    # [-1,1] to [0,1]
    img = img.add_(1).div_(2)

    return img
Example #16
0
def nms(boxes, scores, overlap=0.5, top_k=100):
    keep = scores.new(scores.size(0)).zero_().long()
    if boxes.numel() == 0: return keep
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    area = torch.mul(x2 - x1, y2 - y1)
    v, idx = scores.sort(0)  # sort in ascending order
    idx = idx[-top_k:]  # indices of the top-k largest vals
    xx1 = boxes.new()
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    count = 0
    while idx.numel() > 0:
        i = idx[-1]  # index of current largest val
        keep[count] = i
        count += 1
        if idx.size(0) == 1: break
        idx = idx[:-1]  # remove kept element from view
        # load bboxes of next highest vals
        torch.index_select(x1, 0, idx, out=xx1)
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        # store element-wise max with next highest score
        xx1 = torch.clamp(xx1, min=x1[i])
        yy1 = torch.clamp(yy1, min=y1[i])
        xx2 = torch.clamp(xx2, max=x2[i])
        yy2 = torch.clamp(yy2, max=y2[i])
        w.resize_as_(xx2)
        h.resize_as_(yy2)
        w = xx2 - xx1
        h = yy2 - yy1
        # check sizes of xx1 and xx2.. after each iteration
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)
        inter = w*h
        # IoU = i / (area(a) + area(b) - i)
        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)
        union = (rem_areas - inter) + area[i]
        IoU = inter/union  # store result in iou
        # keep only elements with an IoU <= overlap
        idx = idx[IoU.le(overlap)]
    return keep, count
Example #17
0
def log_sum_exp(vecs):
    n = len(vecs.size())
    if n == 1:
        vecs = vecs.view(1, -1)
    _, idx = torch.max(vecs, 1)
    max_score = torch.index_select(vecs, 1, idx.view(-1))
    ret = max_score + torch.log(torch.sum(torch.exp(vecs - max_score.expand_as(vecs))))
    if n == 1:
        return ret.view(-1)
    return ret
Example #18
0
def crop1d(x,cutoff,dim):
    '''Crops tensor x by cutoff elements from the beginning and the end along dimension dim.
    Example:
    x=torch.FloatTensor([1,2,3,4,5,6,7,8]).cuda(1)
    crop1d(x,2,0) '''    
    idx = torch.arange(cutoff, x.shape[dim]-cutoff).long()
    if x.is_cuda:
        dev = x.get_device()
        idx = idx.cuda(dev)
    return torch.index_select(x,dim,idx)
Example #19
0
def flip(x,dim):
    '''Flip tensor along dimension dim.
    Example: 
    A=torch.FloatTensor([[1,2],[3,4]]).cuda()
    A,flip(A,1)'''
    inv_idx = torch.arange(x.shape[dim]-1, -1, -1).long()
    if x.is_cuda:
        dev = x.get_device()
        inv_idx = inv_idx.cuda(dev)
    return torch.index_select(x,dim,inv_idx)
    def non_match_descriptor_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b, M=0.5, invert=False):
        """
        Computes the max(0, M - D(I_a,I_b,u_a,u_b))^2 term

        This is effectively:       "a and b should be AT LEAST M away from each other"
        With invert=True, this is: "a and b should be AT MOST  M away from each other" 

         :param image_a_pred: Output of DCN network on image A.
        :type image_a_pred: torch.Variable(torch.FloatTensor) shape [1, W * H, D]
        :param image_b_pred: same as image_a_pred
        :type image_b_pred:
        :param non_matches_a: torch.Variable(torch.FloatTensor) has shape [num_non_matches,],  a (u,v) pair is mapped
        to (u,v) ---> image_width * v + u, this matches the shape of image_a_pred
        :type non_matches_a: torch.Variable(torch.FloatTensor)
        :param non_matches_b: same as non_matches_a
        :param M: the margin
        :type M: float
        :return: torch.FloatTensor with shape torch.Shape([num_non_matches])
        :rtype:
        """

        non_matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a).squeeze()
        non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b).squeeze()

        # crazily enough, if there is only one element to index_select into
        # above, then the first dimension is collapsed down, and we end up 
        # with shape [D,], where we want [1,D]
        # this unsqueeze fixes that case
        if len(non_matches_a) == 1:
            non_matches_a_descriptors = non_matches_a_descriptors.unsqueeze(0)
            non_matches_b_descriptors = non_matches_b_descriptors.unsqueeze(0)

        norm_degree = 2
        non_match_loss = (non_matches_a_descriptors - non_matches_b_descriptors).norm(norm_degree, 1)
        if not invert:
            non_match_loss = torch.clamp(M - non_match_loss, min=0).pow(2)
        else:
            non_match_loss = torch.clamp(non_match_loss - M, min=0).pow(2)

        hard_negative_idxs = torch.nonzero(non_match_loss)
        num_hard_negatives = len(hard_negative_idxs)

        return non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors
Example #21
0
    def updateOutput(self, input):
        self.output.set_(self.network.forward([input, self.partition]))
        if self.bias is not None:
            self.output.add_(torch.index_select(self.bias, 1, self.partition).expand_as(self.output))
            if self.addBuffer is None:
                self.addBuffer = input.new()
            if self.addBuffer.nelement() != input.size(0):
                self.addBuffer.resize_(input.size(0)).fill_(1)

        return self.output
def prune_matches_if_occluded(foreground_mask_numpy, background_matches_pair):
    """
    Checks if any of the matches have been occluded.

    If yes, prunes them from the list of matches.

    NOTE:
    - background_matches is a tuple
    - the first element of the tuple HAS to be the one that we are actually checking for occlusions
    - the second element of the tuple must also get pruned

    :param foreground_mask_numpy: The mask of the foreground image
    :type foreground_mask_numpy: numpy 2d array of shape (H,W)
    :param background_matches: a tuple of torch Tensors, each of length n, i.e:

        (u_pixel_positions, v_pixel_positions)

        Where each of the elements of the tuple are torch Tensors of length n

        Note: only support torch.LongTensors
    """

    background_matches_a = background_matches_pair[0] 
    background_matches_b = background_matches_pair[1]

    idxs_to_keep  = []
    
    # this is slow but works
    for i in range(len(background_matches_a[0])):
        u = background_matches_a[0][i]
        v = background_matches_a[1][i]

        if foreground_mask_numpy[v,u] == 0:
            idxs_to_keep.append(i)

    if len(idxs_to_keep) == 0:
        return (None, None)

    idxs_to_keep = torch.LongTensor(idxs_to_keep)
    background_matches_a = (torch.index_select(background_matches_a[0], 0, idxs_to_keep), torch.index_select(background_matches_a[1], 0, idxs_to_keep))
    background_matches_b = (torch.index_select(background_matches_b[0], 0, idxs_to_keep), torch.index_select(background_matches_b[1], 0, idxs_to_keep))

    return (background_matches_a, background_matches_b)
Example #23
0
def reverse_sequences_torch(mini_batch, seq_lengths):
    reversed_mini_batch = mini_batch.new_zeros(mini_batch.size())
    for b in range(mini_batch.size(0)):
        T = seq_lengths[b]
        time_slice = np.arange(T - 1, -1, -1)
        time_slice = torch.cuda.LongTensor(time_slice) if 'cuda' in mini_batch.data.type() \
            else torch.LongTensor(time_slice)
        reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice)
        reversed_mini_batch[b, 0:T, :] = reversed_sequence
    return reversed_mini_batch
def indice_pooling(x, indices):
    #out = x.contiguous().transpose(1,2)
    #out = out.contiguous().view(-1,x.size(1))
    #indices = indices.unsqueeze(2).repeat(1,1,x.size(2))
    #print('indices:',indices.size())
    #print('x:',x.size())
    out=torch.cat([torch.index_select(x_, 0, i).unsqueeze(0) for x_, i in zip(x, indices)])
    #print('out:',out)
    #out = x.gather(dim=2, index=indices)
    #out = out.view(x.size(0),x.size(2),-1)
    #out = out.transpose(1,2)
    return out
Example #25
0
    def __call__(self, batch):
        image_batch, theta_batch = batch['image'], batch['theta'] 
        if self.use_cuda:
            image_batch = image_batch.cuda()
            theta_batch = theta_batch.cuda()
            
        b, c, h, w = image_batch.size()
              
        # generate symmetrically padded image for bigger sampling region
        image_batch = self.symmetricImagePad(image_batch,self.padding_factor)
        
        # convert to variables
        image_batch = Variable(image_batch,requires_grad=False)
        theta_batch =  Variable(theta_batch,requires_grad=False)        

        # get cropped image
        cropped_image_batch = self.rescalingTnf(image_batch=image_batch,
                                                theta_batch=None,
                                                padding_factor=self.padding_factor,
                                                crop_factor=self.crop_factor) # Identity is used as no theta given
        # get transformed image
        warped_image_batch = self.geometricTnf(image_batch=image_batch,
                                               theta_batch=theta_batch,
                                               padding_factor=self.padding_factor,
                                               crop_factor=self.crop_factor) # Identity is used as no theta given
        
        if self.supervision=='strong':
            return {'source_image': cropped_image_batch, 'target_image': warped_image_batch, 'theta_GT': theta_batch}
        
        elif self.supervision=='weak':
            pos_batch_idx = torch.LongTensor(range(int(b/2)))
            neg_batch_idx = torch.LongTensor(range(int(b/2),b))
            if self.use_cuda:
                pos_batch_idx = pos_batch_idx.cuda()
                neg_batch_idx = neg_batch_idx.cuda()
            source_image = torch.cat((torch.index_select(cropped_image_batch,0,pos_batch_idx),
                                      torch.index_select(cropped_image_batch,0,pos_batch_idx)),0)
            target_image = torch.cat((torch.index_select(warped_image_batch,0,pos_batch_idx),
                                      torch.index_select(cropped_image_batch,0,neg_batch_idx)),0)
            return {'source_image': source_image, 'target_image': target_image, 'theta_GT': theta_batch}
    def forward(self, sents, sent_lengths):
        '''
            sents is (batch_size by padded_length)
            when we evaluate sentence by sentence, you evaluate it with batch_size = 1, padded_length.
            [[1, 2, 3, 4]] etc. 
        '''
        batch_size = sents.size()[0]
        sent_lengths = list(sent_lengths)
        # We sort and then do pad packed sequence here. 
        descending_lengths = [x for x, _ in sorted(zip(sent_lengths, range(len(sent_lengths))), reverse=True)]
        descending_indices = [x for _, x in sorted(zip(sent_lengths, range(len(sent_lengths))), reverse=True)]
        descending_lengths = torch.tensor(descending_lengths)
        descending_indices = torch.tensor(descending_indices).to(device)
        descending_sents = torch.index_select(sents, torch.tensor(0), descending_indices)
        
        # get embedding
        embed = self.embedding(descending_sents)
        # pack padded sequence
        embed = torch.nn.utils.rnn.pack_padded_sequence(embed, descending_lengths, batch_first=True)
        
        # fprop though RNN
        self.hidden = self.init_hidden(batch_size)
        rnn_out, self.hidden = self.gru(embed, self.hidden)
        rnn_out, _ = torch.nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
        # rnn_out is 32 by 72 by 256
        
        # change the order back
        change_it_back = [x for _, x in sorted(zip(descending_indices, range(len(descending_indices))))]
        self.hidden = torch.index_select(self.hidden, 1, torch.LongTensor(change_it_back).to(device))  
        rnn_out = torch.index_select(rnn_out, 0, torch.LongTensor(change_it_back).to(device)) 
 
        # self.hidden is 4 by 8 by 256
        # let's only use the top-most layer for the encoder output
        # so we want to return 8 by 512
        hidden_top = torch.cat((self.hidden[2], self.hidden[3]), dim=1)
        hidden_bottom = torch.cat((self.hidden[0], self.hidden[1]), dim=1)
        self.hidden = torch.stack((hidden_top, hidden_bottom))
        
        return rnn_out, self.hidden
Example #27
0
def expand_z_where(z_where):
    # Take a batch of three-vectors, and massages them into a batch of
    # 2x3 matrices with elements like so:
    # [s,x,y] -> [[s,0,x],
    #             [0,s,y]]
    n = z_where.size(0)
    out = torch.cat((ng_zeros([1, 1]).type_as(z_where).expand(n, 1), z_where), 1)
    ix = Variable(expansion_indices)
    if z_where.is_cuda:
        ix = ix.cuda()
    out = torch.index_select(out, 1, ix)
    out = out.view(n, 2, 3)
    return out
 def forward(ctx, input, weight):
     assert input.dim() == 2
     assert weight.dim() == 2
     ctx._weight_size = weight.size()
     # repeat each row by number of embeddings
     input = input.view(input.size(0), 1, -1)
     input = input.expand(input.size(0), weight.size(0), -1)
     # compute distance between all embeddings and input
     distance = torch.pow(input - weight.expand_as(input), 2).sum(2)
     # select embeddings with minimal distance
     _, indices = distance.min(1)
     ctx._indices = indices
     output = torch.index_select(weight, 0, indices)
     return output, indices
Example #29
0
 def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg):
     cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size(1), cont_feat.size(2)
     styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size(1), styl_feat.size(2)
     cont_feat_view = cont_feat.view(cont_c, -1).clone()
     styl_feat_view = styl_feat.view(styl_c, -1).clone()
     
     if cont_seg.size == False or styl_seg.size == False:
         target_feature = self.__wct_core(cont_feat_view, styl_feat_view)
     else:
         target_feature = cont_feat.view(cont_c, -1).clone()
 
         t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST))
         t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST))
         
         for l in self.label_set:
             if self.label_indicator[l] == 0:
                 continue
             cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l)
             styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l)
             if cont_mask[0].size <= 0 or styl_mask[0].size <= 0:
                 continue
             
             cont_indi = torch.LongTensor(cont_mask[0])
             styl_indi = torch.LongTensor(styl_mask[0])
             if self.is_cuda:
                 cont_indi = cont_indi.cuda(0)
                 styl_indi = styl_indi.cuda(0)
             
             cFFG = torch.index_select(cont_feat_view, 1, cont_indi)
             sFFG = torch.index_select(styl_feat_view, 1, styl_indi)
             tmp_target_feature = self.__wct_core(cFFG, sFFG)
             target_feature.index_copy_(1, cont_indi, tmp_target_feature)
     
     target_feature = target_feature.view_as(cont_feat)
     ccsF = target_feature.float().unsqueeze(0)
     return ccsF
    def match_loss(image_a_pred, image_b_pred, matches_a, matches_b):
        """
        Computes the match loss given by

        1/num_matches * \sum_{matches} ||D(I_a, u_a, I_b, u_b)||_2^2

        :param image_a_pred: Output of DCN network on image A.
        :type image_a_pred: torch.Variable(torch.FloatTensor) shape [1, W * H, D]
        :param image_b_pred: same as image_a_pred
        :type image_b_pred:
        :param matches_a: torch.Variable(torch.LongTensor) has shape [num_matches,],  a (u,v) pair is mapped
        to (u,v) ---> image_width * v + u, this matches the shape of one dimension of image_a_pred
        :type matches_a: torch.Variable(torch.FloatTensor)
        :param matches_b: same as matches_b

        :return: match_loss, matches_a_descriptors, matches_b_descriptors
        :rtype: torch.Variable(),

        matches_a_descriptors is torch.FloatTensor with shape torch.Shape([num_matches, descriptor_dimension])
        """

        num_matches = matches_a.size()[0]
        matches_a_descriptors = torch.index_select(image_a_pred, 1, matches_a)
        matches_b_descriptors = torch.index_select(image_b_pred, 1, matches_b)

        # crazily enough, if there is only one element to index_select into
        # above, then the first dimension is collapsed down, and we end up 
        # with shape [D,], where we want [1,D]
        # this unsqueeze fixes that case
        if len(matches_a) == 1:
            matches_a_descriptors = matches_a_descriptors.unsqueeze(0)
            matches_b_descriptors = matches_b_descriptors.unsqueeze(0)

        match_loss = 1.0 / num_matches * (matches_a_descriptors - matches_b_descriptors).pow(2).sum()

        return match_loss, matches_a_descriptors, matches_b_descriptors
Example #31
0
    def decode_greed(self,
                     word_seqs,
                     init_tags,
                     lengths,
                     extFeats=None,
                     with_snt_classifier=False,
                     masked_output=None):
        minibatch_size = len(
            lengths
        )  #word_seqs.size(0) if self.encoder.batch_first else word_seqs.size(1)
        max_length = max(
            lengths
        )  #word_seqs.size(1) if self.encoder.batch_first else word_seqs.size(0)
        # encoder
        embeds = self.get_token_embeddings(word_seqs, lengths)
        if type(extFeats) != type(None):
            concat_input = torch.cat((embeds, self.extFeats_linear(extFeats)),
                                     2)
        else:
            concat_input = embeds
        concat_input = self.dropout_layer(concat_input)
        packed_word_embeds = rnn_utils.pack_padded_sequence(concat_input,
                                                            lengths,
                                                            batch_first=True)
        packed_word_lstm_out, (enc_h_t, enc_c_t) = self.encoder(
            packed_word_embeds)  # bsize x seqlen x dim
        word_lstm_out, unpacked_len = rnn_utils.pad_packed_sequence(
            packed_word_lstm_out, batch_first=True)

        # decoder
        if self.bidirectional:
            index_slices = [2 * i + 1 for i in range(self.num_layers)
                            ]  # generated from the reversed path
            index_slices = torch.tensor(index_slices,
                                        dtype=torch.long,
                                        device=self.device)
            h_t = torch.index_select(enc_h_t, 0, index_slices)
            c_t = torch.index_select(enc_c_t, 0, index_slices)
        else:
            h_t = enc_h_t
            c_t = enc_c_t

        top_path = []
        top_path_tag_scores = []
        top_dec_h_t, top_dec_c_t = [0] * minibatch_size, [0] * minibatch_size
        last_tags = init_tags  # bsize x 1
        for i in range(max_length):
            tag_embeds = self.dropout_layer(self.tag_embeddings(last_tags))
            decode_inputs = torch.cat(
                (self.dropout_layer(word_lstm_out[:, i:i + 1]), tag_embeds),
                2)  # bsize x 1 x insize
            tag_lstm_out, (dec_h_t, dec_c_t) = self.decoder(
                decode_inputs,
                (h_t, c_t))  # bsize x 1 x insize => bsize x 1 x hsize

            for j in range(minibatch_size):
                if lengths[j] == i + 1:
                    top_dec_h_t[j] = dec_h_t[:, j:j + 1, :]
                    top_dec_c_t[j] = dec_c_t[:, j:j + 1, :]

            tag_lstm_out_reshape = tag_lstm_out.contiguous().view(
                tag_lstm_out.size(0) * tag_lstm_out.size(1),
                tag_lstm_out.size(2))
            tag_space = self.hidden2tag(
                self.dropout_layer(tag_lstm_out_reshape))
            if masked_output is None:
                tag_scores = F.log_softmax(tag_space, dim=1)  # bsize x outsize
            else:
                tag_scores = masked_function.index_masked_log_softmax(
                    tag_space, masked_output, dim=1)
            top_path_tag_scores.append(torch.unsqueeze(tag_scores.data, 1))

            max_probs, decoder_argmax = torch.max(tag_scores, 1)
            last_tags = decoder_argmax
            if len(last_tags.size()) == 1:
                last_tags = last_tags.unsqueeze(1)
            h_t, c_t = dec_h_t, dec_c_t
            top_path.append(last_tags.data)
        top_path = torch.cat(top_path, 1)
        top_path_tag_scores = torch.cat(top_path_tag_scores, 1)

        top_dec_h_t = torch.cat(top_dec_h_t, 1)
        top_dec_c_t = torch.cat(top_dec_c_t, 1)

        if with_snt_classifier:
            return top_path_tag_scores, top_path, ((enc_h_t, enc_c_t),
                                                   word_lstm_out, lengths)
        else:
            return top_path_tag_scores, top_path
Example #32
0
def init_target_centers(net, target, records, data_dict, batch_size, beat_num,
                        fixed_len, num_workers, lead, thrs):

    dataset = MULTI_ECG_EVAL_DATASET(target,
                                     load_beat_with_rr,
                                     data_dict,
                                     test_records=records,
                                     beat_num=beat_num,
                                     fixed_len=fixed_len,
                                     lead=lead)

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers)

    features = {0: 0, 1: 0, 2: 0, 3: 0}
    counters = {0: 0, 1: 0, 2: 0, 3: 0}

    for idb, data_batch in enumerate(dataloader):

        s_batch, _ = data_batch
        s_batch = s_batch.cuda()

        feat, logits = net(s_batch)

        probs = F.softmax(logits, dim=1).detach().cpu().numpy()
        max_probs = np.max(probs, axis=1)

        indices = []
        for l in range(4):
            max_indices_l = np.argwhere(np.argmax(probs, axis=1) == l)
            if len(max_indices_l) > 0:
                max_indices_l = np.squeeze(max_indices_l, axis=1)
                max_probs_l = max_probs[max_indices_l]
                legal_indices_l = np.argwhere(max_probs_l >= thrs[l])
                if len(legal_indices_l) > 0:
                    legal_indices_l = np.squeeze(legal_indices_l, axis=1)
                    indices_l = max_indices_l[legal_indices_l]
                    indices.append(indices_l)

        indices = np.sort(np.concatenate(indices))
        print("batch index: {}, size of avaliable pesudo label: {}".format(
            idb, len(indices)))

        if len(indices) > 0:
            pesudo_labels = np.argmax(probs, axis=1)[indices]
            feat = torch.index_select(feat,
                                      dim=0,
                                      index=torch.LongTensor(indices).cuda())

            for l in range(4):
                _index = np.argwhere(pesudo_labels == l)
                if len(_index) > 0:
                    counters[l] += len(_index)
                    _index = np.squeeze(_index, axis=1)
                    _feat = torch.index_select(
                        feat, dim=0,
                        index=torch.LongTensor(_index).cuda()).detach().cpu()
                    _feat_sum = torch.sum(_feat, dim=0)
                    features[l] += _feat_sum
        torch.cuda.empty_cache()

    print('Procedure finished! Obtaining centers of target data!')
    print("The numbers of available pesudo labels:")
    for l in range(4):
        if counters[l] > 0:
            print("{}: {}".format(l, counters[l]))
            # features[l] = torch.cat(features[l], dim=0)
            # features[l] = torch.mean(features[l], dim=0)
            features[l] = features[l] / counters[l]
            features[l] = features[l].cuda()
            print(features[l].size())
        else:
            del features[l]
            print('No avaliable centers')

    return features, counters
Example #33
0
    def update(self):
        # keep looping the whole dataset
        for i in range(self.num_batches):
            img, orig_img, im_name, im_dim_list = self.dataloder.getitem()
            if img is None:
                self.Q.put((None, None, None, None, None, None, None))
                return

            with torch.no_grad():
                # Human Detection
                img = img.cuda()
                prediction = self.det_model(img, CUDA=True)
                # NMS process
                dets = dynamic_write_results(prediction,
                                             opt.confidence,
                                             opt.num_classes,
                                             nms=True,
                                             nms_conf=opt.nms_thesh)
                if isinstance(dets, int) or dets.shape[0] == 0:
                    for k in range(len(orig_img)):
                        if self.Q.full():
                            time.sleep(2)
                        self.Q.put((orig_img[k], im_name[k], None, None, None,
                                    None, None))
                    continue
                dets = dets.cpu()
                im_dim_list = torch.index_select(im_dim_list, 0,
                                                 dets[:, 0].long())
                scaling_factor = torch.min(self.det_inp_dim / im_dim_list,
                                           1)[0].view(-1, 1)

                # coordinate transfer
                dets[:, [1, 3]] -= (self.det_inp_dim - scaling_factor *
                                    im_dim_list[:, 0].view(-1, 1)) / 2
                dets[:, [2, 4]] -= (self.det_inp_dim - scaling_factor *
                                    im_dim_list[:, 1].view(-1, 1)) / 2

                dets[:, 1:5] /= scaling_factor
                for j in range(dets.shape[0]):
                    dets[j, [1, 3]] = torch.clamp(dets[j, [1, 3]], 0.0,
                                                  im_dim_list[j, 0])
                    dets[j, [2, 4]] = torch.clamp(dets[j, [2, 4]], 0.0,
                                                  im_dim_list[j, 1])
                boxes = dets[:, 1:5]
                scores = dets[:, 5:6]

            for k in range(len(orig_img)):
                boxes_k = boxes[dets[:, 0] == k]
                if isinstance(boxes_k, int) or boxes_k.shape[0] == 0:
                    if self.Q.full():
                        time.sleep(2)
                    self.Q.put((orig_img[k], im_name[k], None, None, None,
                                None, None))
                    continue
                inps = torch.zeros(boxes_k.size(0), 3, opt.inputResH,
                                   opt.inputResW)
                pt1 = torch.zeros(boxes_k.size(0), 2)
                pt2 = torch.zeros(boxes_k.size(0), 2)
                if self.Q.full():
                    time.sleep(2)
                self.Q.put((orig_img[k], im_name[k], boxes_k,
                            scores[dets[:, 0] == k], inps, pt1, pt2))
Example #34
0
    def build_loss(self,
                   rpn_cls_score_reshape,
                   rpn_bbox_pred,
                   rpn_data,
                   is_region=False):
        # classification loss
        rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3,
                                                      1).contiguous().view(
                                                          -1, 2)
        rpn_label = rpn_data[0]

        # print rpn_label.size(), rpn_cls_score.size()

        rpn_keep = Variable(rpn_label.data.ne(-1).nonzero().squeeze()).cuda()
        rpn_cls_score = torch.index_select(rpn_cls_score, 0, rpn_keep)
        rpn_label = torch.index_select(rpn_label, 0, rpn_keep)

        fg_cnt = torch.sum(rpn_label.data.ne(0))
        bg_cnt = rpn_label.data.numel() - fg_cnt
        # ce_weights = torch.ones(rpn_cls_score.size()[1])
        # ce_weights[0] = float(fg_cnt) / bg_cnt
        # ce_weights = ce_weights.cuda()

        _, predict = torch.max(rpn_cls_score.data, 1)
        error = torch.sum(torch.abs(predict - rpn_label.data))
        #  try:
        if predict.size()[0] < 256:
            print(predict.size())
            print(rpn_label.size())
            print(fg_cnt)

        if is_region:
            self.tp_region = torch.sum(predict[:fg_cnt].eq(
                rpn_label.data[:fg_cnt]))
            self.tf_region = torch.sum(predict[fg_cnt:].eq(
                rpn_label.data[fg_cnt:]))
            self.fg_cnt_region = fg_cnt
            self.bg_cnt_region = bg_cnt
            if DEBUG:
                print('accuracy: %2.2f%%' %
                      ((self.tp + self.tf) / float(fg_cnt + bg_cnt) * 100))
        else:
            self.tp = torch.sum(predict[:fg_cnt].eq(rpn_label.data[:fg_cnt]))
            self.tf = torch.sum(predict[fg_cnt:].eq(rpn_label.data[fg_cnt:]))
            self.fg_cnt = fg_cnt
            self.bg_cnt = bg_cnt
            if DEBUG:
                print('accuracy: %2.2f%%' %
                      ((self.tp + self.tf) / float(fg_cnt + bg_cnt) * 100))

        rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)
        # print rpn_cross_entropy

        # box loss
        rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[
            1:]
        rpn_bbox_targets = torch.mul(rpn_bbox_targets, rpn_bbox_inside_weights)
        rpn_bbox_pred = torch.mul(rpn_bbox_pred, rpn_bbox_inside_weights)

        # print 'Smooth L1 loss: ', F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False)
        # print 'fg_cnt', fg_cnt
        rpn_loss_box = F.smooth_l1_loss(
            rpn_bbox_pred, rpn_bbox_targets,
            size_average=False) / (fg_cnt.float() + 1e-4)
        # print 'rpn_loss_box', rpn_loss_box
        # print rpn_loss_box

        return rpn_cross_entropy, rpn_loss_box
Example #35
0
def compute_loss(model, epoch, batch_sample, plotdir=""):
    loss_function = nn.MSELoss(reduce=False)

    model.zero_grad()
    loss = 0
    norm = 0

    for num_spk in range(batch_sample.max_spk):
        if batch_sample.sub_batch_lens[num_spk] > 0:
            batch = batch_sample.sub_batch_lens[num_spk]
            combo = batch_sample[num_spk]['combo'].cuda()

            model.hidden = model.init_hidden(batch)

            sources = []
            for i in range(num_spk):
                source = batch_sample[num_spk]['source' + str(i + 1)].cuda()
                source, lens = pad_packed_sequence(source, batch_first=True)
                # source: tensor of shape (batch, seq_length, feat_dim)
                sources.append(source)
            source_usage = [[] for _ in range(num_spk)]
            for dnn_pass in range(num_spk):
                mask_out = model(combo)
                # mask_out: tensor of shape (batch, seq_length, feat_dim)

                combos, lens = pad_packed_sequence(combo, batch_first=True)
                # combos: tensor of shape (batch, seq_length, feat_dim*2)
                mixes = torch.index_select(
                    combos, 2,
                    torch.LongTensor(range(model.feat_dim)).cuda())
                # mixes: tensor of shape (batch, seq_length, feat_dim)
                lengths = lens.float().cuda()

                masked = mask_out * mixes
                losses = torch.stack([
                    torch.sum(loss_function(masked, source).view(batch, -1),
                              dim=1) for source in sources
                ])
                # losses: tensor of shape (num_spk, batch)

                for source_ind in range(num_spk):
                    for index in source_usage[source_ind]:
                        losses[source_ind][index] = float("Inf")

                min_losses, indices = torch.min(losses, 0)
                for sample_ind in range(batch):
                    source_usage[indices[sample_ind]].append(sample_ind)

                loss += torch.sum(min_losses) / num_spk
                norm += torch.sum(lengths) * model.feat_dim

                if plotdir:
                    os.system("mkdir -p " + plotdir)
                    if dnn_pass == 0:
                        plot.plot_spec(
                            combos[0].detach().cpu().numpy()[:,
                                                             0:model.feat_dim],
                            plotdir + '/' + str(num_spk) + '-Spk_Mix.png')
                    prefix = plotdir + '/' + str(num_spk) + '-Spk_Pass-' + str(
                        dnn_pass + 1) + '_'
                    plot.plot_spec(combos[0].detach().cpu().numpy(),
                                   prefix + 'Input.png')
                    plot.plot_spec(
                        combos[0].detach().cpu().numpy()[:, model.feat_dim +
                                                         1:model.feat_dim * 2],
                        prefix + 'Attenmask.png')
                    plot.plot_spec(mask_out[0].detach().cpu().numpy(),
                                   prefix + 'Mask_Out.png')
                    plot.plot_spec(masked[0].detach().cpu().numpy(),
                                   prefix + 'Masked_Mix.png')
                    plot.plot_spec(
                        sources[indices[0]][0].detach().cpu().numpy(),
                        prefix + 'Chosen_Source.png')

                spec_zeros = torch.zeros(mask_out.shape).cuda()
                residual_comp = torch.cat((spec_zeros, mask_out), 2)
                combos = F.relu_(combos - residual_comp)
                combo = pack_padded_sequence(combos, lens, batch_first=True)

    return loss / norm, norm
Example #36
0
    def fold(self, class_probabilities, num_classes):
        """
        This function is called everytime we update the beam.
        It will fold finished sentence and then take the topk over all the rest, util there is no
        more finished sentences among the k options.
        """

        # the follwing three tensor is to mask
        stop_tensor = Variable(
            torch.LongTensor(self._beam_width,
                             1).fill_(self._end_index)).cpu()
        one_tensor = Variable(torch.ByteTensor(self._beam_width,
                                               1).fill_(1)).cpu()

        # stop_tensor = stop_tensor.cuda() if use_cuda else stop_tensor
        # one_tensor = one_tensor.cuda() if use_cuda else one_tensor
        # TODO: if want to change to a batched solution, probably need padding
        batch_next_indices = []
        batch_next_log_prob = []
        for i in range(self._batch_size):
            # loop through each batch
            stop = False
            # (batch_size, beam_width, num_classes)
            cur_batch = class_probabilities[i].cpu()

            cur_batch_seq = self.sequences[i]
            cur_batch_log = self.seq_log_prob[i]
            # keep topk unless there is no finished sentences in that batch
            while not stop:

                # align the log_prob in a (-1, 1) vector and then apply topk
                log_prob_cur, indices = torch.topk(cur_batch.view(-1, 1),
                                                   self._beam_width,
                                                   dim=0)
                log_prob_cur = log_prob_cur.view(-1, 1)
                # indices is the position in the long vector
                indices = indices.view(-1, 1)
                nth_class_per_beam = indices % num_classes  # (self._beam_width, 1)
                unfinished_tensor = nth_class_per_beam - stop_tensor != 0
                finished_tensor = nth_class_per_beam - stop_tensor == 0
                # dimension with value 1 means we predict a stop symbol

                if torch.equal(unfinished_tensor, one_tensor):
                    stop = True
                    # all the decoded sequence is not finished!
                else:
                    assert indices.size() == unfinished_tensor.size()
                    # mask the indices tensor
                    # unfinished_index = torch.masked_select(indices, unfinished_tensor)
                    finished_index = torch.masked_select(
                        indices, finished_tensor) / num_classes

                    # nth_beam that is finished in this decoding step
                    finished_index = finished_index.cpu()

                    finished_seq = torch.index_select(cur_batch_seq, 0,
                                                      finished_index)
                    finished_log_prob = torch.index_select(
                        cur_batch_log, 0, finished_index)

                    # add the decoded seq and corresponding log_prob to the list
                    for seq in finished_seq:
                        self.decoded_seq[i].append(seq)
                    for log_prob in finished_log_prob:
                        self.decoded_log_prob[i].append(log_prob)

                    for index in finished_index:
                        # we find the finished decoding in the batch and set the log_prob
                        # to -sys.float_info.max so that it won't affect out next topk
                        nth_beam = int(index / num_classes)
                        cur_batch_next = cur_batch.clone()

                        # cur_batch = (beam_width, num_classes)
                        cur_batch_next[nth_beam,
                                       self._end_index] = -sys.float_info.max
                        cur_batch = cur_batch_next

            assert log_prob_cur.size() == indices.size() and log_prob_cur.size(
            ) == (self._beam_width, 1)
            batch_next_indices.append(indices)
            batch_next_log_prob.append(log_prob_cur)

        batch_next_indices = torch.stack(batch_next_indices, 0)
        batch_next_log_prob = torch.stack(batch_next_log_prob, 0)
        assert len(batch_next_indices) == self._batch_size
        # print("batch_next_indices: " + str(batch_next_indices.size()))
        assert batch_next_indices.size(
            0) == self._batch_size and batch_next_indices.size(
                1) == self._beam_width

        # batch_next_indices = batch_next_indices.cuda() if use_cuda else batch_next_indices
        # batch_next_log_prob = batch_next_log_prob.cuda() if use_cuda else batch_next_log_prob

        return batch_next_indices, batch_next_log_prob
Example #37
0
    def pooling(self, blocks, verts_pos, debug=False):
        # convert vertex positions to x,y coordinates in the image, scaled to fractions of image dimension
        ext_verts_pos = torch.cat(
            (verts_pos,
             torch.FloatTensor(
                 np.ones([verts_pos.shape[0], verts_pos.shape[1], 1])).cuda()),
            dim=-1)
        ext_verts_pos = torch.matmul(ext_verts_pos, self.matrix.permute(1, 0))
        xs = ext_verts_pos[:, :, 1] / ext_verts_pos[:, :, 2] / 256.
        ys = ext_verts_pos[:, :, 0] / ext_verts_pos[:, :, 2] / 256.

        full_features = None
        batch_size = verts_pos.shape[0]

        # check camera project covers the image
        if debug:
            dim = 256
            xs = (torch.clamp(xs * dim, 0,
                              dim - 1).data.cpu().numpy()).astype(np.uint8)
            ys = (torch.clamp(ys * dim, 0,
                              dim - 1).data.cpu().numpy()).astype(np.uint8)
            for ex in range(blocks.shape[0]):
                img = blocks[ex].permute(1, 2, 0).data.cpu().numpy()[:, :, :3]
                for x, y in zip(xs[ex], ys[ex]):
                    img[x, y, 0] = 1
                    img[x, y, 1] = 0
                    img[x, y, 2] = 0

                from PIL import Image
                Image.fromarray(
                    (img * 255).astype(np.uint8)).save('results/temp.png')
                print('saved')
                input()

        for block in blocks:
            # scale projected vertex points to dimension of current feature map
            dim = block.shape[-1]
            cur_xs = torch.clamp(xs * dim, 0, dim - 1)
            cur_ys = torch.clamp(ys * dim, 0, dim - 1)

            # https://en.wikipedia.org/wiki/Bilinear_interpolation
            x1s, y1s, x2s, y2s = torch.floor(cur_xs), torch.floor(
                cur_ys), torch.ceil(cur_xs), torch.ceil(cur_ys)
            A = x2s - cur_xs
            B = cur_xs - x1s
            G = y2s - cur_ys
            H = cur_ys - y1s

            x1s = x1s.type(torch.cuda.LongTensor)
            y1s = y1s.type(torch.cuda.LongTensor)
            x2s = x2s.type(torch.cuda.LongTensor)
            y2s = y2s.type(torch.cuda.LongTensor)

            # flatten batch of feature maps to make vectorization easier
            flat_block = block.permute(1, 0, 2, 3).contiguous().view(
                block.shape[1], -1)
            block_idx = torch.arange(
                0, verts_pos.shape[0]).cuda().unsqueeze(-1).expand(
                    batch_size, verts_pos.shape[1])
            block_idx = block_idx * dim * dim

            selection = (block_idx + (x1s * dim) + y1s).view(-1)
            C = torch.index_select(flat_block, 1, selection)
            C = C.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
            selection = (block_idx + (x1s * dim) + y2s).view(-1)
            D = torch.index_select(flat_block, 1, selection)
            D = D.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
            selection = (block_idx + (x2s * dim) + y1s).view(-1)
            E = torch.index_select(flat_block, 1, selection)
            E = E.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
            selection = (block_idx + (x2s * dim) + y2s).view(-1)
            F = torch.index_select(flat_block, 1, selection)
            F = F.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)

            section1 = A.unsqueeze(1) * C * G.unsqueeze(1)
            section2 = H.unsqueeze(1) * D * A.unsqueeze(1)
            section3 = G.unsqueeze(1) * E * B.unsqueeze(1)
            section4 = B.unsqueeze(1) * F * H.unsqueeze(1)

            features = (section1 + section2 + section3 + section4)
            features = features.permute(0, 2, 1)

            if full_features is None:
                full_features = features
            else:
                full_features = torch.cat((full_features, features), dim=2)

        return full_features
Example #38
0
def train_vae_epoch(epoch, args, rnn, output, data_loader,
                    optimizer_rnn, optimizer_output,
                    scheduler_rnn, scheduler_output):
    rnn.train()
    output.train()
    loss_sum = 0
    for batch_idx, data in enumerate(data_loader):
        rnn.zero_grad()
        output.zero_grad()
        x_unsorted = data['x'].float()
        y_unsorted = data['y'].float()
        y_len_unsorted = data['len']
        y_len_max = max(y_len_unsorted)
        x_unsorted = x_unsorted[:, 0:y_len_max, :]
        y_unsorted = y_unsorted[:, 0:y_len_max, :]
        # initialize lstm hidden state according to batch size
        rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))

        # sort input
        y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
        y_len = y_len.numpy().tolist()
        x = torch.index_select(x_unsorted,0,sort_index)
        y = torch.index_select(y_unsorted,0,sort_index)
        x = Variable(x).to(device)
        y = Variable(y).to(device)

        # if using ground truth to train
        h = rnn(x, pack=True, input_len=y_len)
        y_pred,z_mu,z_lsgms = output(h)
        y_pred = F.sigmoid(y_pred)
        # clean
        y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
        y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
        z_mu = pack_padded_sequence(z_mu, y_len, batch_first=True)
        z_mu = pad_packed_sequence(z_mu, batch_first=True)[0]
        z_lsgms = pack_padded_sequence(z_lsgms, y_len, batch_first=True)
        z_lsgms = pad_packed_sequence(z_lsgms, batch_first=True)[0]
        # use cross entropy loss
        loss_bce = binary_cross_entropy_weight(y_pred, y)
        loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
        loss_kl /= y.size(0)*y.size(1)*sum(y_len) # normalize
        loss = loss_bce + loss_kl
        loss.backward()
        # update deterministic and lstm
        optimizer_output.step()
        optimizer_rnn.step()
        scheduler_output.step()
        scheduler_rnn.step()


        z_mu_mean = torch.mean(z_mu.data)
        z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data)
        z_mu_min = torch.min(z_mu.data)
        z_sgm_min = torch.min(z_lsgms.mul(0.5).exp_().data)
        z_mu_max = torch.max(z_mu.data)
        z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data)


        if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
            print('Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
                epoch, args.epochs,loss_bce.item(), loss_kl.item(), args.graph_type, args.num_layers, args.hidden_size_rnn))
            print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max)

        # logging
        log_value('bce_loss_'+args.fname, loss_bce.item(), epoch*args.batch_ratio+batch_idx)
        log_value('kl_loss_' +args.fname, loss_kl.item(), epoch*args.batch_ratio + batch_idx)
        log_value('z_mu_mean_'+args.fname, z_mu_mean, epoch*args.batch_ratio + batch_idx)
        log_value('z_mu_min_'+args.fname, z_mu_min, epoch*args.batch_ratio + batch_idx)
        log_value('z_mu_max_'+args.fname, z_mu_max, epoch*args.batch_ratio + batch_idx)
        log_value('z_sgm_mean_'+args.fname, z_sgm_mean, epoch*args.batch_ratio + batch_idx)
        log_value('z_sgm_min_'+args.fname, z_sgm_min, epoch*args.batch_ratio + batch_idx)
        log_value('z_sgm_max_'+args.fname, z_sgm_max, epoch*args.batch_ratio + batch_idx)

        loss_sum += loss.item()
    return loss_sum/(batch_idx+1)
Example #39
0
"""
Note: difference between concatenate and stack is output shape
print(concatenated_tensor_0.shape)
print(stacked_tensor.shape)
"""

# Split a Tensor
splited_tensor = torch.split(x_tensor, 1)

# Create tensor with Index select
indices_1 = torch.tensor([0,2])
indices_2 = torch.tensor([0,1])
indices_3 = torch.tensor([0])

tensor_index_1 = torch.index_select(x_tensor, 1, indices_1) # Select element 0 and 2 for each dimension 1.
tensor_index_2 = torch.index_select(x_tensor, 1, indices_2) # Select element 0 and 1 for each dimension 1.
tensor_index_3 = torch.index_select(x_tensor, 0, indices_3) # Select element 0 for dimension 0.

# Create mask and tensor with selected value by that mask
x = torch.randn(3, 4)
mask = x.ge(0.5)
mask_tensor = torch.masked_select(x, mask)

# Squeeze and unsqueeze

"""
Squeeze: Returns a tensor with all the dimensions of input of size 1 removed.
For example, if input is of shape: (A×1×B×C×1×D) then the out tensor will be of shape: (A×B×C×D)
When dim is given, a squeeze operation is done only in the given dimension. 
For example, If input is of shape: (A×1×B) , squeeze(input, 0) leaves the tensor unchanged, 
Example #40
0
 def _variable(self, v):
     return torch.index_select(v, 0, self.batch_indices)
    def forward(self, base_feat, im_info, gt_boxes, num_boxes, file_rpn, file_proposal):
        self.f = file_rpn
        batch_size = base_feat.size(0)

        conv1_tic = time.time()
        # return feature map after convrelu layer
        rpn_conv1 = F.relu(self.RPN_Conv(base_feat), inplace=True)
        conv1_toc = time.time()
        conv1_time = conv1_toc - conv1_tic

        cls_score_tic = time.time()
        # get rpn classification score
        rpn_cls_score = self.RPN_cls_score(rpn_conv1)


        #reshape_tic = time.time()
        rpn_cls_score_reshape = self.reshape(rpn_cls_score, 2)
        rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape, 1)
        rpn_cls_prob = self.reshape(rpn_cls_prob_reshape, self.nc_score_out)
        #reshape_toc = time.time()
        #reshape_time = reshape_toc - reshape_tic
        cls_score_toc = time.time()
        cls_score_time = cls_score_toc - cls_score_tic

        bbox_pred_tic = time.time()
        # get rpn offsets to the anchor boxes
        rpn_bbox_pred = self.RPN_bbox_pred(rpn_conv1)
        bbox_pred_toc = time.time()
        bbox_pred_time = bbox_pred_toc - bbox_pred_tic

        # proposal layer
        cfg_key = 'TRAIN' if self.training else 'TEST'

        proposal_tic = time.time()

        rois, ship_time = self.RPN_proposal((rpn_cls_prob.data, rpn_bbox_pred.data,
                                 im_info, cfg_key), file_proposal)

        proposal_toc = time.time()
        proposal_time = proposal_toc - proposal_tic - ship_time
        total_rpn_time = proposal_toc - conv1_tic - ship_time

        self.rpn_loss_cls = 0
        self.rpn_loss_box = 0

        # generating training labels and build the rpn loss
        if self.training:
            assert gt_boxes is not None

            rpn_data = self.RPN_anchor_target((rpn_cls_score.data, gt_boxes, im_info, num_boxes))

            # compute classification loss
            rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
            rpn_label = rpn_data[0].view(batch_size, -1)

            rpn_keep = Variable(rpn_label.view(-1).ne(-1).nonzero().view(-1))
            rpn_cls_score = torch.index_select(rpn_cls_score.view(-1,2), 0, rpn_keep)
            rpn_label = torch.index_select(rpn_label.view(-1), 0, rpn_keep.data)
            rpn_label = Variable(rpn_label.long())
            self.rpn_loss_cls = F.cross_entropy(rpn_cls_score, rpn_label)
            fg_cnt = torch.sum(rpn_label.data.ne(0))

            rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]

            # compute bbox regression loss
            rpn_bbox_inside_weights = Variable(rpn_bbox_inside_weights)
            rpn_bbox_outside_weights = Variable(rpn_bbox_outside_weights)
            rpn_bbox_targets = Variable(rpn_bbox_targets)

            self.rpn_loss_box = _smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights,
                                                            rpn_bbox_outside_weights, sigma=3, dim=[1,2,3])


        rpn_string = str(total_rpn_time)+' '+str(conv1_time)+' '+str(cls_score_time)+' '+str(bbox_pred_time)+' '+str(proposal_time)+'\n'
        self.f.write(rpn_string)

        #print('RPN TIME DISTRIBUTION: ', ' total: ', total_rpn_time, '\n    conv1: ', conv1_time, ' get_cls_score: ',  cls_score_time, ' get bbox_pred: ', bbox_pred_time,' make_proposal: ', proposal_time)
        return rois, self.rpn_loss_cls, self.rpn_loss_box, proposal_time, ship_time
Example #42
0
    def trainModel(self,
                   targetmodel,
                   optimizer,
                   loss_function,
                   gamma=0.1,
                   minibatch=128):

        ln = self.memory.getMemoryLen()
        if ln < 5000:
            return -1

        self.trainCnt += 1

        indexlist = self.memory.sample()

        if ni.useRnn:
            input = torch.FloatTensor(len(indexlist), ni.actionDuration - 1,
                                      self.dims)
            inputtarget = torch.FloatTensor(len(indexlist),
                                            ni.actionDuration - 1, self.dims)
        else:
            input = torch.FloatTensor(len(indexlist), self.dims)
            inputtarget = torch.FloatTensor(len(indexlist), self.dims)

        target = torch.FloatTensor(len(indexlist), self.actions)

        for j in range(len(indexlist)):

            data = indexlist[j]
            s = data[0]
            a = data[1]
            sn = data[2]
            r = data[3]

            input[j, :] = torch.squeeze(s.data)
            inputtarget[j, :] = torch.squeeze(sn.data)

        input = autograd.Variable(input)
        inputtarget = autograd.Variable(inputtarget)

        if ni.useRnn:
            hidden = torch.zeros(1, len(indexlist), self.dims * 2)
            hidden = autograd.Variable(hidden)
            qs = self.model['qvalue'](input, hidden)
            qsn = targetmodel(inputtarget, hidden)
        else:
            qs = self.model['qvalue'](input)
            qsn = targetmodel(inputtarget)

        qsdata = torch.squeeze(qs.data)
        qsndata = torch.squeeze(qsn.data)

        target = qsdata.clone()

        for j in range(len(indexlist)):

            data = indexlist[j]

            a = data[1]
            r = data[3]
            actionset = data[4]
            target[j, a] = r + gamma * max(
                torch.index_select(qsndata[j, :], 0,
                                   torch.LongTensor(actionset)))
            #target[j,a] = r + gamma*max(qsndata[j,:])

        target = autograd.Variable(target)

        loss = loss_function(qs, target)
        loss.backward()

        for param in self.model['qvalue'].parameters():
            param.grad.data.clamp_(-1, 1)

        optimizer.step()
        optimizer.zero_grad()

        return loss.data[0]
Example #43
0
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
                        labels=(), max_det=300, cls_exclusive=None):
    """Runs Non-Maximum Suppression (NMS) on inference results

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    max_wh = 7680  # (pixels) maximum box width and height
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 0.1 + 0.03 * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    output = [torch.zeros((0, 6), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + 5), device=x.device)
            v[:, :4] = lb[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

        # Batched NMS
        ## torch.index_select(torch.Tensor([8,9,10,11]).to(x.device), 0, x[:, 5].int())
        if cls_exclusive is not None:
            cls = torch.index_select(cls_exclusive.to(x.device), 0, x[:, 5].int())[:, None]
            c = cls * (0 if agnostic else max_wh)
        else:
            c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded

    return output
Example #44
0
    def forward(self, weight, indices, offsets):
        assert not self.needs_input_grad[1], "EmbeddingBag doesn't " \
            "compute the gradient w.r.t. the indices"

        assert not self.needs_input_grad[2], "EmbeddingBag doesn't " \
            "compute the gradient w.r.t. the offsets"

        assert indices.dim() == 1
        if offsets.dim() != 1:
            raise ValueError("offsets has to be a 1D Tensor")

        if offsets[0] != 0:
            raise ValueError("offsets[0] has to be 0, i.e. the first sequence"
                             " in the mini-batch has to start from position 0."
                             "However, got {}".format(offsets[0]))
        if offsets[-1] > indices.size(0):
            raise ValueError(
                "offsets[-1] has to be smaller than indices's length"
                " ({}), but got offsets[-1] of {}".format(
                    indices.size(0), offsets[-1]))

        self._backend = type2backend[type(weight)]
        self._weight_size = weight.size()
        self._offset2bag = offsets.new()

        self.save_for_backward(indices)

        indices = indices.contiguous().view(-1)
        output = weight.new()
        if self.max_norm is not None:
            self._renorm(indices, weight)

        if weight.is_cuda:
            if self.mode == MODE_MEAN:
                self.bag_size = offsets.new().resize_(offsets.size())
            else:
                self.bag_size = None

            self._backend.LookupTableBag_updateOutput(
                self._backend.library_state, indices, offsets, weight, output,
                self._offset2bag, self.mode, self.bag_size)
        else:
            # slow CPU implementation
            index_output = torch.index_select(weight, 0, indices)
            # indices = [1, 2, 30, 100, 12], offsets = [0, 2, 3]
            self._offset2bag.resize_(
                indices.size(0)).zero_()  # offset2bag = [0 0 0 0 0]
            self._offset2bag.index_fill_(0, offsets,
                                         1)  # offset2bag = [1 0 1 0 1]
            self._offset2bag[0] = 0  # offset2bag = [0 0 1 0 1]
            self._offset2bag = self._offset2bag.cumsum(
                0)  # offset2bag = [0 0 1 1 2]
            output.resize_(offsets.size(0), weight.size(1)).zero_()
            output.index_add_(0, self._offset2bag, index_output)
            if self.mode == MODE_MEAN:
                if offsets.size(0) == 1:
                    self.bag_size = indices.size(0)
                else:
                    self.bag_size = weight.new().resize_(offsets.size())
                    self.bag_size[:-1] = offsets[1:] - offsets[:-1]
                    self.bag_size[-1] = indices.size(0) - offsets[-1]
                    self.bag_size = self.bag_size[:, None].expand_as(output)
                output /= self.bag_size

        return output
Example #45
0
def test(device, model,name,epoch, X_val, y_val, \
         vid_lens_batch, mask_val, idxs_val):
    model.eval()

    y = y_val.reshape((-1, 1))
    y = y.repeat(mask_val.shape[-1], axis=-1)


    X=torch.from_numpy(X_val).float().to(device)
    y_val=torch.from_numpy(y_val).to(device)
    vid_lens_batch=torch.from_numpy(vid_lens_batch).to(device)
    mask_val=torch.from_numpy(mask_val).to(device)


    output,ordered_idx = model(X, vid_lens_batch)


    y_val=torch.index_select(y_val,0,ordered_idx)
    mask_val=torch.index_select(mask_val,0,ordered_idx)


    y=torch.from_numpy(y).long().to(device)
    target=torch.index_select(y,0,ordered_idx)


    m=mask_val.float()

    loss = temporal_ce_loss(output, target,m)

    output=output.cpu().detach().numpy()

    seq_len=output.shape[1]
    y_val=y_val[:].contiguous()
    mask_val=mask_val[:,:seq_len].contiguous()

    mask_val=mask_val.cpu().numpy()
    y_val=y_val.cpu().numpy()

    num_classes = output.shape[-1]

    ix = np.zeros((X_val.shape[0],), dtype='int')
    seq_lens = np.sum(mask_val, axis=-1)



    # for each example, we only consider argmax of the seq len
    votes = np.zeros((num_classes,), dtype='int')
    for i, eg in enumerate(output):
        predictions = np.argmax(eg[:seq_lens[i]], axis=-1)
#         print(predictions.shape)
        for cls in range(num_classes):
            count = (predictions == cls).sum(axis=-1)
            votes[cls] = count
        ix[i] = np.argmax(votes)


    c = ix == y_val
#     print(c,ix[:10],y_val[:10])
    classification_rate = np.sum(c == True) / float(len(c))


    print('{} Epoch: {} \tAcc: {:.6f} \tLoss: {:.6f}'.format(name,
                    epoch,classification_rate,loss.item() ))

    preds = ix
    true_labels = y_val

    return classification_rate,loss.item(), preds, true_labels
Example #46
0
def th_gather_nd(x, coords):
    x = x.contiguous()
    inds = coords.mv(th.LongTensor(x.stride()))
    x_gather = th.index_select(th_flatten(x), 0, inds)
    return x_gather
Example #47
0
def execute_thinning_recipe(model, zeros_mask_dict, recipe, loaded_from_file=False):
    """Apply a thinning recipe to a model.

    This will remove filters and channels, as well as handle batch-normalization parameter
    adjustment, and thinning of weight tensors.
    """
    layers = {}
    for name, m in model.named_modules():
        layers[name] = m

    for layer_name, directives in recipe.modules.items():
        for attr, val in directives.items():
            if attr in ['running_mean', 'running_var']:
                running = getattr(layers[layer_name], attr)
                dim_to_trim = val[0]
                indices_to_select = val[1]
                # Check if we're trying to trim a parameter that is already "thin"
                if running.size(dim_to_trim) != len(indices_to_select):
                    msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, len(indices_to_select)))
                    setattr(layers[layer_name], attr,
                            torch.index_select(running, dim=dim_to_trim, index=indices_to_select))
            else:
                msglogger.info("[thinning] {}: setting {} to {}".format(layer_name, attr, val))
                setattr(layers[layer_name], attr, val)

    assert len(recipe.parameters) > 0

    for param_name, param_directives in recipe.parameters.items():
        param = distiller.model_find_param(model, param_name)
        for directive in param_directives:
            dim = directive[0]
            indices = directive[1]
            if len(directive) == 4:  # TODO: this code is hard to follow
                selection_view = param.view(*directive[2])
                # Check if we're trying to trim a parameter that is already "thin"
                if param.data.size(dim) != len(indices):
                    param.data = torch.index_select(selection_view, dim, indices)

                if param.grad is not None:
                    # We also need to change the dimensions of the gradient tensor.
                    grad_selection_view = param.grad.resize_(*directive[2])
                    if grad_selection_view.size(dim) != len(indices):
                        param.grad = torch.index_select(grad_selection_view, dim, indices)

                param.data = param.view(*directive[3])
                if param.grad is not None:
                    param.grad = param.grad.resize_(*directive[3])
            else:
                if param.data.size(dim) != len(indices):
                    param.data = torch.index_select(param.data, dim, indices)
                # We also need to change the dimensions of the gradient tensor.
                # If have not done a backward-pass thus far, then the gradient will
                # not exist, and therefore won't need to be re-dimensioned.
                if param.grad is not None and param.grad.size(dim) != len(indices):
                        param.grad = torch.index_select(param.grad, dim, indices)
                msglogger.info("[thinning] changing param {} shape: {}".format(param_name, len(indices)))

            if not loaded_from_file:
                # If the masks are loaded from a checkpoint file, then we don't need to change
                # their shape, because they are already correctly shaped
                mask = zeros_mask_dict[param_name].mask
                if mask is not None and (mask.size(dim) != len(indices)):
                    zeros_mask_dict[param_name].mask = torch.index_select(mask, dim, indices)
Example #48
0
    def construct_graph(
            self, state: State
    ) -> Tuple[torch_geometric.data.Batch, bool, torch.Tensor]:
        """
        Construct a batched graph object from the state
        """
        partial_trees, tokens_emb, next_token_pos, batch_size = (
            state.partial_trees,
            state.tokens_emb,
            state.n_step,
            state.batch_size,
        )
        next_token_features = state.tokens_emb[:, next_token_pos]
        device = tokens_emb.device

        if next_token_pos == 0:  # the first step
            return (
                self.construct_init_graph(batch_size, device),
                True,
                next_token_features,
            )

        node_token_pos = [
        ]  # positions of tokens in sentences, -1 if not token
        node_label_idx = [
        ]  # labels of internal nodes, -1 if not internal node
        node_label_left = []
        node_label_right = []
        node_batch_idx = []
        edges: List[Tuple[int, int]] = []
        on_rightmost_chain: List[bool] = []

        num_nodes = 0
        for i, tree in enumerate(partial_trees):
            assert isinstance(tree, InternalParseNode)
            x, edge_index, rightmost_chain_i = self.tree2graph_preorder(
                tree, num_nodes, device)
            edges.extend(edge_index)
            on_rightmost_chain.extend(rightmost_chain_i)
            for node in x:
                if "label" in node:  # internal node
                    node_token_pos.append(-1)
                    node_label_idx.append(self.label_idx_map[node["label"]])
                    node_label_left.append(node["left"])
                    node_label_right.append(node["right"])
                else:  # leaf node
                    node_token_pos.append(node["token_pos"])
                    node_label_idx.append(-1)
                    node_label_left.append(-1)
                    node_label_right.append(-1)
            tree_size = len(x)
            node_batch_idx.extend([i] * tree_size)
            num_nodes += tree_size

        node_token_pos = torch.tensor(node_token_pos,
                                      device=device)  # type: ignore
        node_is_token = node_token_pos >= 0  # type: ignore
        node_label_idx = node_token_pos.new_tensor(
            node_label_idx)  # type: ignore
        node_label_left = node_token_pos.new_tensor(
            node_label_left)  # type: ignore
        node_label_right = node_token_pos.new_tensor(
            node_label_right)  # type: ignore
        node_is_label = node_label_idx >= 0  # type: ignore
        node_batch_idx = node_token_pos.new_tensor(
            node_batch_idx)  # type: ignore

        d_model = self.cfg.d_model
        node_emb = tokens_emb.new_zeros((len(node_token_pos), d_model))
        flattened_tokens_emb = tokens_emb.view(-1, d_model)
        node_emb[node_is_token] = torch.index_select(
            flattened_tokens_emb,
            0,
            node_batch_idx[node_is_token] * tokens_emb.size(1) +
            node_token_pos[node_is_token],
        )
        label_position_emb = (
            self.position_table[node_label_left[node_is_label]] +
            self.position_table[node_label_right[node_is_label]])
        node_emb[node_is_label] = torch.cat(
            [
                self.label_embedding(node_label_idx[node_is_label]),
                label_position_emb
            ],
            dim=-1,
        )

        all_edges_index = (torch.tensor(edges, device=device).t()
                           if edges != [] else torch.empty(
                               2, 0, dtype=torch.int64, device=device))
        graph = torch_geometric.data.Batch(
            batch=node_batch_idx,
            x=node_emb,
            edge_index=all_edges_index,
        )

        graph.on_rightmost_chain = torch.tensor(on_rightmost_chain,
                                                dtype=torch.bool,
                                                device=device)

        return graph, False, next_token_features
Example #49
0
    def _generate(
        self,
        sample: Dict[str, Dict[str, Tensor]],
        prefix_tokens: Optional[Tensor] = None,
        constraints: Optional[Tensor] = None,
        bos_token: Optional[int] = None,
    ):
        incremental_states = torch.jit.annotate(
            List[Dict[str, Dict[str, Optional[Tensor]]]],
            [
                torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
                for i in range(self.model.models_size)
            ],
        )
        net_input = sample["net_input"]

        if "src_tokens" in net_input:
            src_tokens = net_input["src_tokens"]
            # length of the source text being the character length except EndOfSentence and pad
            src_lengths = (
                (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
            )
        elif "source" in net_input:
            src_tokens = net_input["source"]
            src_lengths = (
                net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
                if net_input["padding_mask"] is not None
                else torch.tensor(src_tokens.size(-1)).to(src_tokens)
            )
        else:
            raise Exception("expected src_tokens or source in net input")

        # bsz: total number of sentences in beam
        # Note that src_tokens may have more than 2 dimenions (i.e. audio features)
        bsz, src_len = src_tokens.size()[:2]
        beam_size = self.beam_size

        if constraints is not None and not self.search.supports_constraints:
            raise NotImplementedError(
                "Target-side constraints were provided, but search method doesn't support them"
            )

        # Initialize constraints, when active
        self.search.init_constraints(constraints, beam_size)

        max_len: int = -1
        if self.match_source_len:
            max_len = src_lengths.max().item()
        else:
            max_len = min(
                int(self.max_len_a * src_len + self.max_len_b),
                # exclude the EOS marker
                self.model.max_decoder_positions() - 1,
            )
        assert (
            self.min_len <= max_len
        ), "min_len cannot be larger than max_len, please adjust these!"
        # compute the encoder output for each beam
        encoder_outs = self.model.forward_encoder(net_input)

        # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
        new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
        new_order = new_order.to(src_tokens.device).long()
        encoder_outs = self.model.reorder_encoder_out(encoder_outs[0], new_order)
        # ensure encoder_outs is a List.
        assert encoder_outs is not None

        # initialize buffers
        scores = (
            torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
        )  # +1 for eos; pad is never chosen for scoring
        tokens = (
            torch.zeros(bsz * beam_size, max_len + 2)
            .to(src_tokens)
            .long()
            .fill_(self.pad)
        )  # +2 for eos and pad
        tokens[:, 0] = self.eos if bos_token is None else bos_token
        attn: Optional[Tensor] = None

        # A list that indicates candidates that should be ignored.
        # For example, suppose we're sampling and have already finalized 2/5
        # samples. Then cands_to_ignore would mark 2 positions as being ignored,
        # so that we only finalize the remaining 3 samples.
        cands_to_ignore = (
            torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
        )  # forward and backward-compatible False mask

        # list of completed sentences
        finalized = torch.jit.annotate(
            List[List[Dict[str, Tensor]]],
            [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
        )  # contains lists of dictionaries of infomation about the hypothesis being finalized at each step

        finished = [
            False for i in range(bsz)
        ]  # a boolean array indicating if the sentence at the index is finished or not
        num_remaining_sent = bsz  # number of sentences remaining

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        reorder_state: Optional[Tensor] = None
        batch_idxs: Optional[Tensor] = None

        original_batch_idxs: Optional[Tensor] = None
        if "id" in sample and isinstance(sample["id"], Tensor):
            original_batch_idxs = sample["id"]
        else:
            original_batch_idxs = torch.arange(0, bsz).type_as(tokens)

        for step in range(max_len + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            # print(f'step: {step}')
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
                        batch_idxs
                    )
                    reorder_state.view(-1, beam_size).add_(
                        corr.unsqueeze(-1) * beam_size
                    )
                    original_batch_idxs = original_batch_idxs[batch_idxs]
                self.model.reorder_incremental_state(incremental_states, reorder_state)
                encoder_outs = self.model.reorder_encoder_out(
                    encoder_outs, reorder_state
                )

            lprobs, avg_attn_scores = self.model.forward_decoder(
                tokens[:, : step + 1],
                encoder_outs,
                incremental_states,
                self.temperature,
            )

            if self.lm_model is not None:
                lm_out = self.lm_model(tokens[:, : step + 1])
                probs = self.lm_model.get_normalized_probs(
                    lm_out, log_probs=True, sample=None
                )
                probs = probs[:, -1, :] * self.lm_weight
                lprobs += probs

            lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)

            lprobs[:, self.pad] = -math.inf  # never select pad
            lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty

            # handle max length constraint
            if step >= max_len:
                lprobs[:, : self.eos] = -math.inf
                lprobs[:, self.eos + 1 :] = -math.inf

            # handle prefix tokens (possibly with different lengths)
            if (
                prefix_tokens is not None
                and step < prefix_tokens.size(1)
                and step < max_len
            ):
                lprobs, tokens, scores = self._prefix_tokens(
                    step, lprobs, scores, tokens, prefix_tokens, beam_size
                )
            elif step < self.min_len:
                # minimum length constraint (does not apply if using prefix_tokens)
                lprobs[:, self.eos] = -math.inf

            # Record attention scores, only support avg_attn_scores is a Tensor
            if avg_attn_scores is not None:
                if attn is None:
                    attn = torch.empty(
                        bsz * beam_size, avg_attn_scores.size(1), max_len + 2
                    ).to(scores)
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            eos_bbsz_idx = torch.empty(0).to(
                tokens
            )  # indices of hypothesis ending with eos (finished sentences)
            eos_scores = torch.empty(0).to(
                scores
            )  # scores of hypothesis ending with eos (finished sentences)

            if self.should_set_src_lengths:
                self.search.set_src_lengths(src_lengths)

            if self.no_repeat_ngram_size > 0:
                lprobs = self._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step)

            # Shape: (batch, cand_size)
            cand_scores, cand_indices, cand_beams = self.search.step(
                step,
                lprobs.view(bsz, -1, self.vocab_size),
                scores.view(bsz, beam_size, -1)[:, :, :step],
                tokens[:, : step + 1],
                original_batch_idxs,
            )

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos
            # Shape of eos_mask: (batch size, beam size)
            eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
            eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)

            # only consider eos when it's among the top beam_size indices
            # Now we know what beam item(s) to finish
            # Shape: 1d list of absolute-numbered
            eos_bbsz_idx = torch.masked_select(
                cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
            )

            finalized_sents: List[int] = []
            if eos_bbsz_idx.numel() > 0:
                eos_scores = torch.masked_select(
                    cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
                )

                finalized_sents = self.finalize_hypos(
                    step,
                    eos_bbsz_idx,
                    eos_scores,
                    tokens,
                    scores,
                    finalized,
                    finished,
                    beam_size,
                    attn,
                    src_lengths,
                    max_len,
                )
                num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            if self.search.stop_on_max_len and step >= max_len:
                break
            assert step < max_len

            # Remove finalized sentences (ones for which {beam_size}
            # finished hypotheses have been generated) from the batch.
            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = torch.ones(
                    bsz, dtype=torch.bool, device=cand_indices.device
                )
                batch_mask[finalized_sents] = False
                # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
                batch_idxs = torch.arange(
                    bsz, device=cand_indices.device
                ).masked_select(batch_mask)

                # Choose the subset of the hypothesized constraints that will continue
                self.search.prune_sentences(batch_idxs)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]

                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]
                src_lengths = src_lengths[batch_idxs]
                cands_to_ignore = cands_to_ignore[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(
                        new_bsz * beam_size, attn.size(1), -1
                    )
                bsz = new_bsz
            else:
                batch_idxs = None

            # Set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos

            # Rewrite the operator since the element wise or is not supported in torchscript.

            eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
            active_mask = torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[: eos_mask.size(1)],
            )

            # get the top beam_size active hypotheses, which are just
            # the hypos with the smallest values in active_mask.
            # {active_hypos} indicates which {beam_size} hypotheses
            # from the list of {2 * beam_size} candidates were
            # selected. Shapes: (batch size, beam size)
            new_cands_to_ignore, active_hypos = torch.topk(
                active_mask, k=beam_size, dim=1, largest=False
            )

            # update cands_to_ignore to ignore any finalized hypos.
            cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
            # Make sure there is at least one active item for each sentence in the batch.
            assert (~cands_to_ignore).any(dim=1).all()

            # update cands_to_ignore to ignore any finalized hypos

            # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
            # can be selected more than once).
            active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
            active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses

            # Set the tokens for each beam (can select the same row more than once)
            tokens[:, : step + 1] = torch.index_select(
                tokens[:, : step + 1], dim=0, index=active_bbsz_idx
            )
            # Select the next token for each of them
            tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
                cand_indices, dim=1, index=active_hypos
            )
            if step > 0:
                scores[:, :step] = torch.index_select(
                    scores[:, :step], dim=0, index=active_bbsz_idx
                )
            scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
                cand_scores, dim=1, index=active_hypos
            )

            # Update constraints based on which candidates were selected for the next beam
            self.search.update_constraints(active_hypos)

            # copy attention for active hypotheses
            if attn is not None:
                attn[:, :, : step + 2] = torch.index_select(
                    attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
                )

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(len(finalized)):
            scores = torch.tensor(
                [float(elem["score"].item()) for elem in finalized[sent]]
            )
            _, sorted_scores_indices = torch.sort(scores, descending=True)
            finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
            finalized[sent] = torch.jit.annotate(
                List[Dict[str, Tensor]], finalized[sent]
            )
        return finalized
Example #50
0
    def recognize_beam_batch(self,
                             h,
                             hlens,
                             lpz,
                             recog_args,
                             char_list,
                             rnnlm=None,
                             normalize_score=True,
                             strm_idx=0):
        logging.info('input lengths: ' + str(h.size(1)))
        att_idx = min(strm_idx, len(self.att) - 1)
        h = mask_by_length(h, hlens, 0.0)

        # search params
        batch = len(hlens)
        beam = recog_args.beam_size
        penalty = recog_args.penalty
        ctc_weight = recog_args.ctc_weight
        att_weight = 1.0 - ctc_weight

        n_bb = batch * beam
        n_bo = beam * self.odim
        n_bbo = n_bb * self.odim
        pad_b = to_device(
            self,
            torch.LongTensor([i * beam
                              for i in six.moves.range(batch)]).view(-1, 1))
        pad_bo = to_device(
            self,
            torch.LongTensor([i * n_bo
                              for i in six.moves.range(batch)]).view(-1, 1))
        pad_o = to_device(
            self,
            torch.LongTensor([i * self.odim
                              for i in six.moves.range(n_bb)]).view(-1, 1))

        max_hlen = int(max(hlens))
        if recog_args.maxlenratio == 0:
            maxlen = max_hlen
        else:
            maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
        minlen = int(recog_args.minlenratio * max_hlen)
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialization
        c_prev = [
            to_device(self, torch.zeros(n_bb, self.dunits))
            for _ in range(self.dlayers)
        ]
        z_prev = [
            to_device(self, torch.zeros(n_bb, self.dunits))
            for _ in range(self.dlayers)
        ]
        c_list = [
            to_device(self, torch.zeros(n_bb, self.dunits))
            for _ in range(self.dlayers)
        ]
        z_list = [
            to_device(self, torch.zeros(n_bb, self.dunits))
            for _ in range(self.dlayers)
        ]
        vscores = to_device(self, torch.zeros(batch, beam))

        a_prev = None
        rnnlm_prev = None

        self.att[att_idx].reset()  # reset pre-computation of h

        yseq = [[self.sos] for _ in six.moves.range(n_bb)]
        accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)]
        stop_search = [False for _ in six.moves.range(batch)]
        nbest_hyps = [[] for _ in six.moves.range(batch)]
        ended_hyps = [[] for _ in range(batch)]

        exp_hlens = hlens.repeat(beam).view(beam,
                                            batch).transpose(0,
                                                             1).contiguous()
        exp_hlens = exp_hlens.view(-1).tolist()
        exp_h = h.unsqueeze(1).repeat(1, beam, 1, 1).contiguous()
        exp_h = exp_h.view(n_bb, h.size()[1], h.size()[2])

        if lpz is not None:
            device_id = torch.cuda.device_of(next(self.parameters()).data).idx
            ctc_prefix_score = CTCPrefixScoreTH(lpz, 0, self.eos, beam,
                                                exp_hlens, device_id)
            ctc_states_prev = ctc_prefix_score.initial_state()
            ctc_scores_prev = to_device(self, torch.zeros(batch, n_bo))

        for i in six.moves.range(maxlen):
            logging.debug('position ' + str(i))

            vy = to_device(self, torch.LongTensor(self._get_last_yseq(yseq)))
            ey = self.dropout_emb(self.embed(vy))
            att_c, att_w = self.att[att_idx](exp_h, exp_hlens,
                                             self.dropout_dec[0](z_prev[0]),
                                             a_prev)
            ey = torch.cat((ey, att_c), dim=1)

            # attention decoder
            z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev,
                                              c_prev)
            local_scores = att_weight * F.log_softmax(
                self.output(self.dropout_dec[-1](z_list[-1])), dim=1)

            # rnnlm
            if rnnlm:
                rnnlm_state, local_lm_scores = rnnlm.buff_predict(
                    rnnlm_prev, vy, n_bb)
                local_scores = local_scores + recog_args.lm_weight * local_lm_scores
            local_scores = local_scores.view(batch, n_bo)

            # ctc
            if lpz is not None:
                ctc_scores, ctc_states = ctc_prefix_score(
                    yseq, ctc_states_prev, accum_odim_ids)
                ctc_scores = ctc_scores.view(batch, n_bo)
                local_scores = local_scores + ctc_weight * (ctc_scores -
                                                            ctc_scores_prev)
            local_scores = local_scores.view(batch, beam, self.odim)

            if i == 0:
                local_scores[:, 1:, :] = self.logzero
            local_best_scores, local_best_odims = torch.topk(
                local_scores.view(batch, beam, self.odim), beam, 2)
            # local pruning (via xp)
            local_scores = np.full((n_bbo, ), self.logzero)
            _best_odims = local_best_odims.view(n_bb, beam) + pad_o
            _best_odims = _best_odims.view(-1).cpu().numpy()
            _best_score = local_best_scores.view(-1).cpu().detach().numpy()
            local_scores[_best_odims] = _best_score
            local_scores = to_device(
                self,
                torch.from_numpy(local_scores).float()).view(
                    batch, beam, self.odim)

            # (or indexing)
            # local_scores = to_cuda(self, torch.full((batch, beam, self.odim), self.logzero))
            # _best_odims = local_best_odims
            # _best_score = local_best_scores
            # for si in six.moves.range(batch):
            # for bj in six.moves.range(beam):
            # for bk in six.moves.range(beam):
            # local_scores[si, bj, _best_odims[si, bj, bk]] = _best_score[si, bj, bk]

            eos_vscores = local_scores[:, :, self.eos] + vscores
            vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim)
            vscores[:, :, self.eos] = self.logzero
            vscores = (vscores + local_scores).view(batch, n_bo)

            # global pruning
            accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
            accum_odim_ids = torch.fmod(
                accum_best_ids, self.odim).view(-1).data.cpu().tolist()
            accum_padded_odim_ids = (torch.fmod(accum_best_ids, n_bo) +
                                     pad_bo).view(-1).data.cpu().tolist()
            accum_padded_beam_ids = (torch.div(accum_best_ids, self.odim) +
                                     pad_b).view(-1).data.cpu().tolist()

            y_prev = yseq[:][:]
            yseq = self._index_select_list(yseq, accum_padded_beam_ids)
            yseq = self._append_ids(yseq, accum_odim_ids)
            vscores = accum_best_scores
            vidx = to_device(self, torch.LongTensor(accum_padded_beam_ids))

            if isinstance(att_w, torch.Tensor):
                a_prev = torch.index_select(att_w.view(n_bb, *att_w.shape[1:]),
                                            0, vidx)
            elif isinstance(att_w, list):
                # handle the case of multi-head attention
                a_prev = [
                    torch.index_select(att_w_one.view(n_bb, -1), 0, vidx)
                    for att_w_one in att_w
                ]
            else:
                # handle the case of location_recurrent when return is a tuple
                a_prev_ = torch.index_select(att_w[0].view(n_bb, -1), 0, vidx)
                h_prev_ = torch.index_select(att_w[1][0].view(n_bb, -1), 0,
                                             vidx)
                c_prev_ = torch.index_select(att_w[1][1].view(n_bb, -1), 0,
                                             vidx)
                a_prev = (a_prev_, (h_prev_, c_prev_))
            z_prev = [
                torch.index_select(z_list[li].view(n_bb, -1), 0, vidx)
                for li in range(self.dlayers)
            ]
            c_prev = [
                torch.index_select(c_list[li].view(n_bb, -1), 0, vidx)
                for li in range(self.dlayers)
            ]

            if rnnlm:
                rnnlm_prev = self._index_select_lm_state(rnnlm_state, 0, vidx)
            if lpz is not None:
                ctc_vidx = to_device(self,
                                     torch.LongTensor(accum_padded_odim_ids))
                ctc_scores_prev = torch.index_select(ctc_scores.view(-1), 0,
                                                     ctc_vidx)
                ctc_scores_prev = ctc_scores_prev.view(-1, 1).repeat(
                    1, self.odim).view(batch, n_bo)

                ctc_states = torch.transpose(ctc_states, 1, 3).contiguous()
                ctc_states = ctc_states.view(n_bbo, 2, -1)
                ctc_states_prev = torch.index_select(ctc_states, 0,
                                                     ctc_vidx).view(
                                                         n_bb, 2, -1)
                ctc_states_prev = torch.transpose(ctc_states_prev, 1, 2)

            # pick ended hyps
            if i > minlen:
                k = 0
                penalty_i = (i + 1) * penalty
                # thr = accum_best_scores[:, -1]
                for samp_i in six.moves.range(batch):
                    if stop_search[samp_i]:
                        k = k + beam
                        continue
                    for beam_j in six.moves.range(beam):
                        yk = y_prev[k][:]
                        yk.append(self.eos)
                        if len(yk) < hlens[samp_i]:
                            _vscore = eos_vscores[samp_i][beam_j] + penalty_i
                            if normalize_score:
                                _vscore = _vscore / len(yk)
                            _score = _vscore.data.cpu().numpy()
                            ended_hyps[samp_i].append({
                                'yseq': yk,
                                'vscore': _vscore,
                                'score': _score
                            })
                        k = k + 1

            # end detection
            stop_search = [
                stop_search[samp_i] for samp_i in six.moves.range(batch)
            ]
            stop_search_summary = list(set(stop_search))
            if len(stop_search_summary) == 1 and stop_search_summary[0]:
                break

            torch.cuda.empty_cache()

        dummy_hyps = [{
            'yseq': [self.sos, self.eos],
            'score': np.array([-float('inf')])
        }]
        ended_hyps = [
            ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
            for samp_i in six.moves.range(batch)
        ]
        nbest_hyps = [
            sorted(
                ended_hyps[samp_i], key=lambda x: x['score'],
                reverse=True)[:min(len(ended_hyps[samp_i]), recog_args.nbest)]
            for samp_i in six.moves.range(batch)
        ]

        return nbest_hyps
Example #51
0
 def compute_edges(self, keypoints):
     start = torch.index_select(keypoints, 1, self.connections[:, 0])
     end = torch.index_select(keypoints, 1, self.connections[:, 1])
     return start - end
Example #52
0
def subsample(x, mask):
    # sampling, extract only the values specified in the mask and the remaining values are returned as zeros
    x = torch.index_select(x, 0, mask.cuda())
    return x
Example #53
0
    def update(self):
        # keep looping the whole video
        for i in range(self.num_batches):
            img = []
            inp = []
            orig_img = []
            im_name = []
            im_dim_list = []
            for k in range(i * self.batchSize,
                           min((i + 1) * self.batchSize, self.datalen)):
                (grabbed, frame) = self.stream.read()
                # if the `grabbed` boolean is `False`, then we have
                # reached the end of the video file
                if not grabbed:
                    self.stop()
                    return
                # process and add the frame to the queue
                inp_dim = int(opt.inp_dim)
                img_k, orig_img_k, im_dim_list_k = prep_frame(frame, inp_dim)
                inp_k = im_to_torch(orig_img_k)

                img.append(img_k)
                inp.append(inp_k)
                orig_img.append(orig_img_k)
                im_dim_list.append(im_dim_list_k)

            with torch.no_grad():
                ht = inp[0].size(1)
                wd = inp[0].size(2)
                # Human Detection
                img = Variable(torch.cat(img)).cuda()
                im_dim_list = torch.FloatTensor(im_dim_list).repeat(1, 2)
                im_dim_list = im_dim_list.cuda()

                prediction = self.det_model(img, CUDA=True)
                # NMS process
                dets = dynamic_write_results(prediction,
                                             opt.confidence,
                                             opt.num_classes,
                                             nms=True,
                                             nms_conf=opt.nms_thesh)
                if isinstance(dets, int) or dets.shape[0] == 0:
                    for k in range(len(inp)):
                        while self.Q.full():
                            time.sleep(0.2)
                        self.Q.put((inp[k], orig_img[k], None, None))
                    continue

                im_dim_list = torch.index_select(im_dim_list, 0,
                                                 dets[:, 0].long())
                scaling_factor = torch.min(self.det_inp_dim / im_dim_list,
                                           1)[0].view(-1, 1)

                # coordinate transfer
                dets[:, [1, 3]] -= (self.det_inp_dim - scaling_factor *
                                    im_dim_list[:, 0].view(-1, 1)) / 2
                dets[:, [2, 4]] -= (self.det_inp_dim - scaling_factor *
                                    im_dim_list[:, 1].view(-1, 1)) / 2

                dets[:, 1:5] /= scaling_factor
                for j in range(dets.shape[0]):
                    dets[j, [1, 3]] = torch.clamp(dets[j, [1, 3]], 0.0,
                                                  im_dim_list[j, 0])
                    dets[j, [2, 4]] = torch.clamp(dets[j, [2, 4]], 0.0,
                                                  im_dim_list[j, 1])
                boxes = dets[:, 1:5].cpu()
                scores = dets[:, 5:6].cpu()

            for k in range(len(inp)):
                while self.Q.full():
                    time.sleep(0.2)
                self.Q.put((inp[k], orig_img[k], boxes[dets[:, 0] == k],
                            scores[dets[:, 0] == k]))
Example #54
0
    def _decode_target(
        self,
        encoder_input,
        encoder_outs,
        incremental_states,
        diversity_sibling_gamma=0.0,
        beam_size=None,
        maxlen=None,
        prefix_tokens=None,
    ):
        src_tokens_tensor = pytorch_translate_utils.get_source_tokens_tensor(
            encoder_input["src_tokens"])
        beam_size = beam_size if beam_size is not None else self.beam_size
        bsz = src_tokens_tensor.size(0)
        reorder_indices = (torch.arange(bsz).view(-1, 1).repeat(
            1, beam_size).view(-1).long())
        for i, model in enumerate(self.models):
            encoder_outs[i] = model.encoder.reorder_encoder_out(
                encoder_out=encoder_outs[i],
                new_order=reorder_indices.type_as(src_tokens_tensor),
            )
        maxlen = min(maxlen,
                     self.maxlen) if maxlen is not None else self.maxlen
        # initialize buffers
        scores = src_tokens_tensor.new(bsz * beam_size,
                                       maxlen + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src_tokens_tensor.new(bsz * beam_size,
                                       maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos

        # may differ from input length
        if isinstance(encoder_outs[0], (list, tuple)):
            src_encoding_len = encoder_outs[0][0].size(0)
        elif isinstance(encoder_outs[0], dict):
            if isinstance(encoder_outs[0]["encoder_out"], tuple):
                # Fairseq compatibility
                src_encoding_len = encoder_outs[0]["encoder_out"][0].size(1)
            else:
                src_encoding_len = encoder_outs[0]["encoder_out"].size(0)

        attn = scores.new(bsz * beam_size, src_encoding_len, maxlen + 2)
        attn_buf = attn.clone()

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        worst_finalized = [{
            "idx": None,
            "score": -math.inf
        } for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) *
                        beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        # init constraints
        constraints = self._build_constraints(src_tokens_tensor, beam_size)

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                if self.stop_early or step == maxlen or unfinalized_scores is None:
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                best_unfinalized_score = unfinalized_scores[sent].max()
                if self.normalize_scores:
                    best_unfinalized_score /= (maxlen + 1)**self.len_penalty
                if worst_finalized[sent]["score"] >= best_unfinalized_score:
                    return True
            return False

        def finalize_hypos(step,
                           bbsz_idx,
                           eos_scores,
                           unfinalized_scores=None):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
                unfinalized_scores: A vector containing scores for all
                    unfinalized hypotheses
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step +
                                        2]  # skip the first index, which is EOS
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step + 2]

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1)**self.len_penalty

            sents_seen = set()
            for i, (idx, score) in enumerate(
                    zip(bbsz_idx.tolist(), eos_scores.tolist())):
                sent = idx // beam_size
                sents_seen.add(sent)

                def get_hypo():
                    _, alignment = attn_clone[i].max(dim=0)
                    return {
                        "tokens": tokens_clone[i],
                        "score": score,
                        "attention": attn_clone[i],  # src_len x tgt_len
                        "alignment": alignment,
                        "positional_scores": pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                elif not self.stop_early and score > worst_finalized[sent][
                        "score"]:
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]["idx"]
                    if worst_idx is not None:
                        finalized[sent][worst_idx] = get_hypo()

                    # find new worst finalized hypo for this sentence
                    idx, s = min(enumerate(finalized[sent]),
                                 key=lambda r: r[1]["score"])
                    worst_finalized[sent] = {"score": s["score"], "idx": idx}

            # return number of hypotheses finished this step
            num_finished = 0
            for sent in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step,
                                                      unfinalized_scores):
                    finished[sent] = True
                    num_finished += 1
            return num_finished

        reorder_state = None
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                for model in self.models:
                    if isinstance(model.decoder, FairseqIncrementalDecoder):
                        model.decoder.reorder_incremental_state(
                            incremental_states[model], reorder_state)
            # Run decoder for one step
            logprobs, avg_attn, possible_translation_tokens = self._decode(
                tokens[:, :step + 1], encoder_outs, incremental_states)

            logprobs[:, self.pad] = -math.inf  # never select pad
            # apply unk reward
            if possible_translation_tokens is None:
                # No vocab reduction, so unk is represented by self.unk at
                # position self.unk
                unk_index = self.unk
                logprobs[:, unk_index] += self.unk_reward
            else:
                # When we use vocab reduction, the token value self.unk may not
                # be at the position self.unk, but somewhere else in the list
                # of possible_translation_tokens. It's also possible not to
                # show up in possible_translation_tokens at all, meaning we
                # can't generate an unk.
                unk_pos = torch.nonzero(
                    possible_translation_tokens == self.unk)
                if unk_pos.size()[0] != 0:
                    # only add unk_reward if unk index appears in
                    # possible_translation_tokens
                    unk_index = unk_pos[0][0]
                    logprobs[:, unk_index] += self.unk_reward
            # external lexicon reward
            logprobs[:, self.lexicon_indices] += self.lexicon_reward

            logprobs += self.word_reward
            logprobs[:, self.eos] -= self.word_reward
            # Record attention scores
            attn[:, :, step + 1].copy_(avg_attn)

            cand_scores = buffer("cand_scores", type_of=scores)
            cand_indices = buffer("cand_indices")
            cand_beams = buffer("cand_beams")
            eos_bbsz_idx = buffer("eos_bbsz_idx")
            eos_scores = buffer("eos_scores", type_of=scores)
            scores = scores.type_as(logprobs)
            scores_buf = scores_buf.type_as(logprobs)

            if step < maxlen:
                self._apply_constraint_penalty(scores)  # stub call
                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    logprobs_slice = logprobs.view(bsz, -1,
                                                   logprobs.size(-1))[:, 0, :]
                    cand_scores = torch.gather(
                        logprobs_slice,
                        dim=1,
                        index=prefix_tokens[:, step].view(-1, 1)).expand(
                            -1, cand_size)
                    cand_indices = (prefix_tokens[:, step].view(-1, 1).expand(
                        bsz, cand_size))
                    cand_beams.resize_as_(cand_indices).fill_(0)
                else:
                    possible_tokens_size = self.vocab_size
                    if possible_translation_tokens is not None:
                        possible_tokens_size = possible_translation_tokens.size(
                            0)
                    if diversity_sibling_gamma > 0:
                        logprobs = self.diversity_sibling_rank(
                            logprobs.view(bsz, -1, possible_tokens_size),
                            diversity_sibling_gamma,
                        )
                    cand_scores, cand_indices, cand_beams = self.search.step(
                        step,
                        logprobs.view(bsz, -1, possible_tokens_size),
                        scores.view(bsz, beam_size, -1)[:, :, :step],
                    )
                    # vocabulary reduction
                    if possible_translation_tokens is not None:
                        possible_translation_tokens = possible_translation_tokens.view(
                            1, possible_tokens_size).expand(
                                cand_indices.size(0), possible_tokens_size)
                        cand_indices = torch.gather(
                            possible_translation_tokens,
                            dim=1,
                            index=cand_indices,
                            out=cand_indices,
                        )
            else:
                # finalize all active hypotheses once we hit maxlen
                # pick the hypothesis with the highest log prob of EOS right now
                logprobs.add_(scores[:, step - 1].view(-1, 1))
                torch.sort(
                    logprobs[:, self.eos],
                    descending=True,
                    out=(eos_scores, eos_bbsz_idx),
                )
                num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx,
                                                     eos_scores)
                assert num_remaining_sent == 0
                break

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add_(bbsz_offsets)

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)
            if step >= self.minlen:
                # only consider eos when it's among the top beam_size indices
                torch.masked_select(
                    cand_bbsz_idx[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_bbsz_idx,
                )
                if eos_bbsz_idx.numel() > 0:
                    torch.masked_select(
                        cand_scores[:, :beam_size],
                        mask=eos_mask[:, :beam_size],
                        out=eos_scores,
                    )
                    self._apply_eos_constraints(constraints, eos_bbsz_idx,
                                                eos_scores)
                    num_remaining_sent -= finalize_hypos(
                        step, eos_bbsz_idx, eos_scores, cand_scores)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < maxlen

            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer("active_mask")
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer("active_hypos"), buffer("_ignore")
            torch.topk(
                active_mask,
                k=beam_size,
                dim=1,
                largest=False,
                out=(_ignore, active_hypos),
            )
            active_bbsz_idx = buffer("active_bbsz_idx")
            torch.gather(cand_bbsz_idx,
                         dim=1,
                         index=active_hypos,
                         out=active_bbsz_idx)
            active_scores = torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )
            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1],
                dim=0,
                index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices,
                dim=1,
                index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            # update constraints for next step
            constraints = self._reorder_constraints(constraints,
                                                    active_bbsz_idx)
            self._update_constraints(constraints, tokens_buf[:, step + 1],
                                     step)
            if step > 0:
                torch.index_select(
                    scores[:, :step],
                    dim=0,
                    index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            torch.index_select(
                attn[:, :, :step + 2],
                dim=0,
                index=active_bbsz_idx,
                out=attn_buf[:, :, :step + 2],
            )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(bsz):
            finalized[sent] = sorted(finalized[sent],
                                     key=lambda r: r["score"],
                                     reverse=True)
        self._finalize_constrained_results(finalized, scores.device)
        return finalized
Example #55
0
 def forward(self, x):
     """ Only the needed indices are kept. """
     self.keep_tensor = self.keep_tensor.to("cuda" if x.is_cuda else "cpu")
     y = torch.index_select(x, 1, self.keep_tensor)
     return y
Example #56
0
def compute_scores_from_embeds(av,all_embeds,query_nodes,list_test_edges,list_test_non_edges):
  cos = nn.CosineSimilarity(dim=1, eps=1e-6)
  #per qnode
  #all_qnode_auc = [] 
  all_qnode_ap = []
  all_qnode_rr = []
  #all_qnode_ndcg = []
  for qnode in query_nodes : 
    qnode_edges = list(filter(lambda x: x[0]==qnode or x[1]==qnode, list_test_edges))
    qnode_non_edges = list(filter(lambda x: x[0]==qnode or x[1]==qnode, list_test_non_edges))
    if len(qnode_edges)==0 or len(qnode_non_edges)==0: 
      continue
    a,b = zip(*qnode_edges)
    self_tensors = torch.index_select(all_embeds,dim=0,index=cudavar(av,torch.tensor(a)))
    nbr_tensors  = torch.index_select(all_embeds,dim=0,index=cudavar(av,torch.tensor(b)))
    pos_scores   = cos(self_tensors,nbr_tensors)

    a,b = zip(*qnode_non_edges)
    self_tensors = torch.index_select(all_embeds,dim=0,index=cudavar(av,torch.tensor(a)))
    nbr_tensors  = torch.index_select(all_embeds,dim=0,index=cudavar(av,torch.tensor(b)))
    neg_scores   = cos(self_tensors,nbr_tensors)

    if av.has_cuda and av.want_cuda:
      all_scores = torch.cat((pos_scores,neg_scores)).cpu().numpy()
    else:
      all_scores = torch.cat((pos_scores,neg_scores)).numpy()

    all_labels = np.hstack([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
    auc_score  = roc_auc_score(all_labels, all_scores)
    ap_score   = average_precision_score(all_labels, all_scores)
    #ndcg       = ndcg_score([all_labels],[all_scores])

    so = np.argsort(all_scores)[::-1]
    labels_rearranged = all_labels[so]
    rr_score = 1/(labels_rearranged.tolist().index(1)+1)
    
    #all_qnode_auc.append(auc_score)
    all_qnode_ap.append(ap_score)
    all_qnode_rr.append(rr_score)
    #all_qnode_ndcg.append(ndcg)
  #agglo
  pos_scores = []
  neg_scores = []

  a,b = zip(*list_test_edges)
  self_tensors = torch.index_select(all_embeds,dim=0,index=cudavar(av,torch.tensor(a)))
  nbr_tensors  = torch.index_select(all_embeds,dim=0,index=cudavar(av,torch.tensor(b)))
  pos_scores   = cos(self_tensors,nbr_tensors)

  a,b = zip(*list_test_non_edges)
  self_tensors = torch.index_select(all_embeds,dim=0,index=cudavar(av,torch.tensor(a)))
  nbr_tensors  = torch.index_select(all_embeds,dim=0,index=cudavar(av,torch.tensor(b)))
  neg_scores   = cos(self_tensors,nbr_tensors)

  if av.has_cuda and av.want_cuda:
    all_scores = torch.cat((pos_scores,neg_scores)).cpu().numpy()
  else:
    all_scores = torch.cat((pos_scores,neg_scores)).numpy()

  all_labels = np.hstack([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
  auc_score  = roc_auc_score(all_labels, all_scores)
  ap_score   = average_precision_score(all_labels, all_scores)
  #ndcg       = ndcg_score([all_labels],[all_scores])
  
  #so = np.argsort(all_scores)[::-1]
  #labels_rearranged = all_labels[so]
  #rr_score = 1/(labels_rearranged.tolist().index(1)+1)

  return auc_score, ap_score, np.mean(all_qnode_ap), np.mean(all_qnode_rr)
Example #57
0
    def decode_beam_search(self,
                           word_seqs,
                           lengths,
                           beam_size,
                           tag2idx,
                           extFeats=None,
                           with_snt_classifier=False,
                           masked_output=None):
        minibatch_size = len(
            lengths
        )  #word_seqs.size(0) if self.encoder.batch_first else word_seqs.size(1)
        max_length = max(
            lengths
        )  #word_seqs.size(1) if self.encoder.batch_first else word_seqs.size(0)
        # encoder
        embeds = self.get_token_embeddings(word_seqs, lengths)
        if type(extFeats) != type(None):
            concat_input = torch.cat((embeds, self.extFeats_linear(extFeats)),
                                     2)
        else:
            concat_input = embeds
        concat_input = self.dropout_layer(concat_input)
        packed_word_embeds = rnn_utils.pack_padded_sequence(concat_input,
                                                            lengths,
                                                            batch_first=True)
        packed_word_lstm_out, (enc_h_t, enc_c_t) = self.encoder(
            packed_word_embeds)  # bsize x seqlen x dim
        enc_word_lstm_out, unpacked_len = rnn_utils.pad_packed_sequence(
            packed_word_lstm_out, batch_first=True)

        # decoder
        if self.bidirectional:
            index_slices = [2 * i + 1 for i in range(self.num_layers)
                            ]  # generated from the reversed path
            index_slices = torch.tensor(index_slices,
                                        dtype=torch.long,
                                        device=self.device)
            h_t = torch.index_select(enc_h_t, 0, index_slices)
            c_t = torch.index_select(enc_c_t, 0, index_slices)
        else:
            h_t = enc_h_t
            c_t = enc_c_t

        h_t = h_t.repeat(1, beam_size, 1)
        c_t = c_t.repeat(1, beam_size, 1)
        word_lstm_out = enc_word_lstm_out.repeat(beam_size, 1, 1)

        beam = [
            Beam(beam_size, tag2idx, device=self.device)
            for k in range(minibatch_size)
        ]
        batch_idx = list(range(minibatch_size))
        remaining_sents = minibatch_size

        top_dec_h_t, top_dec_c_t = [0] * minibatch_size, [0] * minibatch_size
        for i in range(max_length):
            last_tags = torch.stack([
                b.get_current_state() for b in beam if not b.done
            ]).t().contiguous().view(-1,
                                     1)  # after t() -> beam_size * batch_size
            last_tags = last_tags.to(self.device)
            tag_embeds = self.dropout_layer(self.tag_embeddings(last_tags))
            decode_inputs = torch.cat(
                (self.dropout_layer(word_lstm_out[:, i:i + 1]), tag_embeds),
                2)  # (batch*beam) x 1 x insize
            tag_lstm_out, (dec_h_t, dec_c_t) = self.decoder(
                decode_inputs,
                (h_t,
                 c_t))  # (batch*beam) x 1 x insize => (batch*beam) x 1 x hsize

            tag_lstm_out_reshape = tag_lstm_out.contiguous().view(
                tag_lstm_out.size(0) * tag_lstm_out.size(1),
                tag_lstm_out.size(2))
            tag_space = self.hidden2tag(
                self.dropout_layer(tag_lstm_out_reshape))
            if masked_output is None:
                out = F.log_softmax(tag_space)  # (batch*beam) x outsize
            else:
                out = masked_function.index_masked_log_softmax(tag_space,
                                                               masked_output,
                                                               dim=1)

            word_lk = out.view(beam_size, remaining_sents,
                               -1).transpose(0, 1).contiguous()

            active = []
            for b in range(minibatch_size):
                if beam[b].done:
                    continue
                if lengths[b] == i + 1:
                    beam[b].done = True
                    top_dec_h_t[b] = dec_h_t[:, b:b + beam_size, :]
                    top_dec_c_t[b] = dec_c_t[:, b:b + beam_size, :]
                idx = batch_idx[b]
                beam[b].advance(word_lk.data[idx])
                if not beam[b].done:
                    active.append(b)
                for dec_state in (dec_h_t, dec_c_t):
                    # (layer*direction) x beam*sent x Hdim
                    sent_states = dec_state.view(-1, beam_size,
                                                 remaining_sents,
                                                 dec_state.size(2))[:, :, idx]
                    sent_states.data.copy_(
                        sent_states.data.index_select(
                            1, beam[b].get_current_origin()))
            if not active:
                break

            active_idx = torch.tensor([batch_idx[k] for k in active],
                                      dtype=torch.long,
                                      device=self.device)
            batch_idx = {beam: idx for idx, beam in enumerate(active)}

            def update_active(t, hidden_dim):
                #t_reshape = t.data.view(-1, remaining_sents, hidden_dim)
                t_reshape = t.contiguous().view(-1, remaining_sents,
                                                hidden_dim)
                new_size = list(t.size())
                new_size[-2] = new_size[-2] * len(
                    active_idx) // remaining_sents  # beam*len(active_idx)
                return t_reshape.index_select(1, active_idx).view(*new_size)

            h_t = update_active(dec_h_t, self.hidden_dim)
            c_t = update_active(dec_c_t, self.hidden_dim)
            word_lstm_out = update_active(
                word_lstm_out.transpose(0, 1),
                self.num_directions * self.hidden_dim).transpose(0, 1)

            remaining_sents = len(active)

        allHyp, allScores = [], []
        n_best = 1
        for b in range(minibatch_size):
            scores, ks = beam[b].sort_best()
            allScores += [scores[:n_best]]
            hyps = zip(*[beam[b].get_hyp(k) for k in ks[:n_best]])
            allHyp += [hyps]
            top_dec_h_t[b] = top_dec_h_t[b].data.index_select(1, ks[:n_best])
            top_dec_c_t[b] = top_dec_c_t[b].data.index_select(1, ks[:n_best])
        top_dec_h_t = torch.cat(top_dec_h_t, 1)
        top_dec_c_t = torch.cat(top_dec_c_t, 1)
        allScores = torch.cat(allScores)

        if with_snt_classifier:
            return allScores, allHyp, ((enc_h_t, enc_c_t), enc_word_lstm_out,
                                       lengths)
        else:
            return allScores, allHyp
Example #58
0
    def run(self,
            inp_transitions,
            run_internal_parser=False,
            use_internal_parser=False,
            validate_transitions=True):
        transition_loss = None
        transition_acc = 0.0
        num_transitions = inp_transitions.shape[1]
        batch_size = inp_transitions.shape[0]
        invalid_count = np.zeros(batch_size)
        # Transition Loop
        # ===============
        attended = [[] for i in range(batch_size)]
        for t_step in range(num_transitions):
            transitions = inp_transitions[:, t_step]
            transition_arr = list(transitions)

            # A mask based on SKIP transitions.
            cant_skip = np.array(transitions) != T_SKIP
            must_skip = np.array(transitions) == T_SKIP

            # Memories
            # ========
            # Keep track of key values to determine accuracy and loss.
            self.memory = {}

            # Prepare tracker input.
            if self.debug and any(
                    len(buf) < 1 or len(stack)
                    for buf, stack in zip(self.bufs, self.stacks)):
                # To elaborate on this exception, when cropping examples it is possible
                # that your first 1 or 2 actions is a reduce action. It is unclear if this
                # is a bug in cropping or a bug in how we think about cropping. In the meantime,
                # turn on the truncate batch flag, and set the eval_seq_length
                # very high.
                raise IndexError(
                    "Warning: You are probably trying to encode examples"
                    "with cropped transitions. Although, this is a reasonable"
                    "feature, when predicting/validating transitions, you"
                    "probably will not get the behavior that you expect. Disable"
                    "this exception if you dare.")
            self.memory['top_buf'] = self.wrap_items(
                [buf[-1] if len(buf) > 0 else self.zeros for buf in self.bufs])
            self.memory['top_stack_1'] = self.wrap_items([
                stack[-1] if len(stack) > 0 else self.zeros
                for stack in self.stacks
            ])
            self.memory['top_stack_2'] = self.wrap_items([
                stack[-2] if len(stack) > 1 else self.zeros
                for stack in self.stacks
            ])

            # Run if:
            # A. We have a tracking component and,
            # B. There is at least one transition that will not be skipped.
            if hasattr(self, 'tracker') and sum(cant_skip) > 0:

                # Get hidden output from the tracker. Used to predict
                # transitions.
                tracker_h, tracker_c = self.tracker(
                    self.extract_h(self.memory['top_buf']),
                    self.extract_h(self.memory['top_stack_1']),
                    self.extract_h(self.memory['top_stack_2']))

                if hasattr(self, 'transition_net'):
                    transition_inp = [tracker_h]
                    if self.tracker.lateral_tracking and self.predict_use_cell:
                        transition_inp += [tracker_c]
                    transition_inp = torch.cat(transition_inp, 1)

                    transition_output = self.transition_net(transition_inp)

                if hasattr(self, 'transition_net') and run_internal_parser:

                    # Predict Actions
                    # ===============

                    # TODO: Mask before predicting. This should simplify things and reduce computation.
                    # The downside is that in the Action Phase, need to be smarter about which stacks/bufs
                    # are selected.
                    transition_logdist, transition_preds = self.predict_actions(
                        transition_output)

                    # Distribution of transitions use to calculate transition
                    # loss.
                    self.memory["t_logprobs"] = transition_logdist

                    # Given transitions.
                    self.memory["t_given"] = transitions

                    # Constrain to valid actions
                    # ==========================

                    validated_preds, invalid_mask = self.validate(
                        transition_arr, transition_preds, self.stacks,
                        self.bufs)
                    if validate_transitions:
                        transition_preds = validated_preds

                    # Keep track of which predictions have been valid.
                    self.memory["t_valid_mask"] = np.logical_not(invalid_mask)
                    invalid_count += invalid_mask

                    # If the given action is skip, then must skip.
                    transition_preds[must_skip] = T_SKIP

                    # Actual transition predictions. Used to measure transition
                    # accuracy.
                    self.memory["t_preds"] = transition_preds

                    # Binary mask of examples that have a transition.
                    self.memory["t_mask"] = cant_skip

                    # If this FLAG is set, then use the predicted actions
                    # rather than the given.
                    if use_internal_parser:
                        transition_arr = transition_preds.tolist()

            # Pre-Action Phase
            # ================

            # TODO: See if PyTorch's 'Advanced Indexing for Tensors and
            # Variables' features would simplify this.

            # For SHIFT
            s_stacks, s_tops, s_trackings, s_idxs, r_idxs = [], [], [], [], []

            # For REDUCE
            r_stacks, r_lefts, r_rights, r_trackings = [], [], [], []

            batch = list(
                zip(
                    transition_arr, self.bufs, self.stacks, self.tracker.states
                    if hasattr(self, 'tracker') and self.tracker.h is not None
                    else itertools.repeat(None)))
            reduced_idxs = []

            for batch_idx, (transition, buf, stack,
                            tracking) in enumerate(batch):
                if transition == T_SHIFT:  # shift
                    #attended[batch_idx].append(buf[-1][0].unsqueeze(0))
                    self.t_shift(buf, stack, tracking, s_tops, s_trackings)
                    s_idxs.append(batch_idx)
                    s_stacks.append(stack)
                elif transition == T_REDUCE:  # reduce
                    self.t_reduce(buf, stack, tracking, r_lefts, r_rights,
                                  r_trackings)
                    reduced_idxs.append(batch_idx)
                    attended[batch_idx].append(r_lefts[-1][0].unsqueeze(0))
                    attended[batch_idx].append(r_rights[-1][0].unsqueeze(0))
                    #self.attended[batch_idx].append(buf[-1])
                    r_stacks.append(stack)
                    r_idxs.append(batch_idx)
                elif transition == T_SKIP:  # skip
                    self.t_skip()

            # Action Phase
            # ============
            self.shift_phase(s_tops, s_trackings, s_stacks)
            self.reduce_phase(r_lefts, r_rights, r_trackings, r_stacks)
            self.reduce_phase_hook(r_lefts, r_rights, r_trackings, r_stacks)
            # Memory Phase
            # ============

            # APPEND ALL MEMORIES. MASK LATER.

            self.memories.append(self.memory)

            # Update number of reduces seen so far.
            self.n_reduces += (np.array(transition_arr) == T_REDUCE)

            # Update number of non-skip actions seen so far.
            self.n_steps += (np.array(transition_arr) != T_SKIP)

        # Loss Phase
        # ==========
        if hasattr(self, 'tracker') and hasattr(self, 'transition_net'):
            t_preds = np.concatenate(
                [m['t_preds'] for m in self.memories if 't_preds' in m])
            t_given = np.concatenate(
                [m['t_given'] for m in self.memories if 't_given' in m])
            t_mask = np.concatenate(
                [m['t_mask'] for m in self.memories if 't_mask' in m])
            t_logprobs = torch.cat(
                [m['t_logprobs'] for m in self.memories if 't_logprobs' in m],
                0)

            # We compute accuracy and loss after all transitions have complete,
            # since examples can have different lengths when not using skips.

            # Transition Accuracy.
            n = t_mask.shape[0]
            n_skips = n - t_mask.sum()
            n_total = n - n_skips
            n_correct = (t_preds == t_given).sum() - n_skips
            transition_acc = n_correct / float(n_total)

            # Transition Loss.
            index = to_gpu(
                Variable(torch.from_numpy(np.arange(
                    t_mask.shape[0])[t_mask])).long())
            select_t_given = to_gpu(
                Variable(torch.from_numpy(t_given[t_mask])).long())
            select_t_logprobs = torch.index_select(t_logprobs, 0, index)
            transition_loss = nn.NLLLoss()(select_t_logprobs, select_t_given) * \
                self.transition_weight

            self.n_invalid = (invalid_count > 0).sum()
            self.invalid = self.n_invalid / float(batch_size)

        self.loss_phase_hook()

        if self.debug:
            assert all(len(stack) == 3 for stack in self.stacks), \
                "Stacks should be fully reduced and have 3 elements: " \
                "two zeros and the sentence encoding."
            assert all(len(buf) == 1 for buf in self.bufs), \
                "Stacks should be fully shifted and have 1 zero."

        [
            attended[i].append(self.stacks[i][-1][0].unsqueeze(0))
            for i in range(batch_size)
        ]

        return [stack[-1] for stack in self.stacks
                ], transition_acc, transition_loss, attended
Example #59
0
def ed_train(opt, model, train_data, optimizer, criterion, clip):
    if opt.dataset_name == 'r252':
        ret_INPUT_DIM = r252_src_vocab_size
        ret_OUTPUT_DIM = r252_trg_vocab_size
        ed_input_dim = r252_src_vocab_size
        ed_output_dim = r252_trg_vocab_size
        max_comment_len = r252_max_comment_len
        max_code_len = r252_max_code_len
        src_vocab_size = r252_src_vocab_size
        trg_vocab_size = r252_trg_vocab_size
    if opt.dataset_name == 'hstone':
        ret_INPUT_DIM = hstone_src_vocab_size
        ret_OUTPUT_DIM = hstone_trg_vocab_size
        ed_input_dim = hstone_src_vocab_size
        ed_output_dim = hstone_trg_vocab_size
        max_comment_len = hstone_max_comment_len
        max_code_len = hstone_max_code_len
        src_vocab_size = hstone_src_vocab_size
        trg_vocab_size = hstone_trg_vocab_size
    model.train()
    epoch_loss = 0

    batch_num = 0

    if torch.cuda.is_available():
        model.cuda()
        train_data = train_data.cuda()

    for j in range(
            0, train_data.shape[0],
            batch_size):  #TODO: replace batch_size with train_data.shape[0]
        if j + batch_size < train_data.shape[0]:
            batch_num += 1
            interval = [
                x for x in range(j, min(train_data.shape[0], j + batch_size))
            ]
            interval = torch.LongTensor(interval)
            if torch.cuda.is_available():
                interval = interval.cuda()
            batch = Variable(index_select(train_data, 0, interval))

            x = batch[:, :max_comment_len]
            xprime = batch[:, max_comment_len +
                           max_code_len:max_comment_len * 2 + max_code_len]
            yprime = batch[:, max_comment_len * 2 + max_code_len:]

            trg = batch[:, max_comment_len:max_comment_len + max_code_len]  # y

            optimizer.zero_grad()

            x = torch.transpose(x, 0, 1)
            xprime = torch.transpose(xprime, 0, 1)
            yprime = torch.transpose(yprime, 0, 1)
            trg = torch.transpose(trg, 0, 1)

            output = model(x, xprime, yprime, trg)
            # output shape is code_len, batch, trg_vocab_size

            output = torch.reshape(output,
                                   (batch_size * max_code_len, trg_vocab_size))
            trg = torch.reshape(trg, (batch_size * max_code_len, ))

            loss = criterion(output, trg)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            epoch_loss += loss.item()
            print("Ed batch: {0:3d} | Loss: {1:.3f}".format(
                batch_num, loss.item()))
    return epoch_loss / batch_num, output
Example #60
0
def ed_evaluate(model, valid_data, criterion):
    if opt.dataset_name == 'r252':
        ret_INPUT_DIM = r252_src_vocab_size
        ret_OUTPUT_DIM = r252_trg_vocab_size
        ed_input_dim = r252_src_vocab_size
        ed_output_dim = r252_trg_vocab_size
        max_comment_len = r252_max_comment_len
        max_code_len = r252_max_code_len
        src_vocab_size = r252_src_vocab_size
        trg_vocab_size = r252_trg_vocab_size
    if opt.dataset_name == 'hstone':
        ret_INPUT_DIM = hstone_src_vocab_size
        ret_OUTPUT_DIM = hstone_trg_vocab_size
        ed_input_dim = hstone_src_vocab_size
        ed_output_dim = hstone_trg_vocab_size
        max_comment_len = hstone_max_comment_len
        max_code_len = hstone_max_code_len
        src_vocab_size = hstone_src_vocab_size
        trg_vocab_size = hstone_trg_vocab_size
    model.eval()
    epoch_loss = 0

    if torch.cuda.is_available():
        model.cuda()
        valid_data = valid_data.cuda()

    ref_code = valid_data[:, max_comment_len:max_comment_len + max_code_len]
    candidate_code = torch.zeros_like(ref_code)

    with torch.no_grad():
        batch_num = 0
        for j in range(0, valid_data.shape[0], batch_size):
            if j + batch_size < valid_data.shape[0]:
                batch_num += 1
                interval = [
                    x
                    for x in range(j, min(valid_data.shape[0], j + batch_size))
                ]
                interval = torch.LongTensor(interval)
                if torch.cuda.is_available():
                    interval = interval.cuda()
                batch = Variable(index_select(valid_data, 0, interval))

                x = batch[:, :max_comment_len]
                xprime = batch[:, max_comment_len +
                               max_code_len:max_comment_len * 2 + max_code_len]
                yprime = batch[:, max_comment_len * 2 + max_code_len:]

                trg = batch[:, max_comment_len:max_comment_len +
                            max_code_len]  # y

                x = torch.transpose(x, 0, 1)
                xprime = torch.transpose(xprime, 0, 1)
                yprime = torch.transpose(yprime, 0, 1)
                trg = torch.transpose(trg, 0, 1)

                output = model(x, xprime, yprime, trg)
                # output shape is code_len, batch, trg_vocab_size

                #trg = [trg sent len, batch size]
                #output = [trg sent len, batch size, output dim]

                for cpj in range(batch_size):
                    candidate_code[j + cpj] = torch.argmax(output[cpj], dim=1)

                output = torch.reshape(
                    output, (batch_size * max_code_len, trg_vocab_size))
                trg = torch.reshape(trg, (batch_size * max_code_len, ))

                loss = criterion(output, trg)
                epoch_loss += loss.item()
                print("Ed batch: {0:3d} | Loss: {1:.3f}".format(
                    batch_num, loss.item()))
    return epoch_loss / batch_num, candidate_code