Esempio n. 1
0
    def postprocess_sequence(self, X):
        """Embed (variable-length) sequences

        Parameters
        ----------
        X : list
            List of input sequences

        Returns
        -------
        fX : numpy array
            Batch of sequence embeddings.
        """

        lengths = torch.tensor([len(x) for x in X])
        sorted_lengths, sort = torch.sort(lengths, descending=True)
        _, unsort = torch.sort(sort)

        sequences = [torch.tensor(X[i],
                                  dtype=torch.float32,
                                  device=self.device) for i in sort]
        padded = pad_sequence(sequences, batch_first=True, padding_value=0)
        packed = pack_padded_sequence(padded, sorted_lengths,
                                      batch_first=True)

        cpu = torch.device('cpu')
        fX = self.model(packed).detach().to(cpu).numpy()
        return fX[unsort]
Esempio n. 2
0
def test_BiLSTM(x, x_mask):
    # batch_size = 2, seq_length = 6, hidden = 3
    lstm = nn.LSTM(input_size=4, hidden_size=2, num_layers=1, bidirectional=True)

    # sequence_output = np.array([[[0.1, 0.2, 0.1],
    #                              [0.4, 0.5, 0.6],
    #                              [0.1, 0.1, 0.1],
    #                              [0.2, 0.2, 0.2],
    #                              [0.3, 0.1, 0.2],
    #                              [0.4, 0.1, 0.3]], [[0, 0.1, 0.1],
    #                                                 [0.2, 0.3, 0.6],
    #                                                 [0.3, 0.2, 0.1],
    #                                                 [0.1, 0.1, 0.2],
    #                                                 [0.2, 0.2, 0.3],
    #                                                 [0.2, 0.3, 0.4]]])
    # x = torch.from_numpy(sequence_output)
    # x = x.type(torch.float32)
    # x_mask = torch.from_numpy(np.array([[1, 1, 1, 1, 0 ,0], [1, 1, 1, 1, 1, 0]]))
    x_lengths = x_mask.eq(1).sum(1)
    _, idx_sort = torch.sort(x_lengths, dim=0, descending=True) #
    _, idx_unsort = torch.sort(idx_sort, dim=0)
    x = x.index_select(0, idx_sort)
    x_lengths = x_lengths[idx_sort]
    x_packed = nn.utils.rnn.pack_padded_sequence(x, x_lengths, batch_first=True)
    y_packed, _ = lstm(x_packed)
    y_unpacked, _ = nn.utils.rnn.pad_packed_sequence(y_packed, batch_first=True)
    y_unpacked = y_unpacked.index_select(0, idx_unsort)

    if y_unpacked.size(1) != x_mask.size(1):
        padding = torch.zeros(y_unpacked.size(0), x_mask.size(1) - y_unpacked.size(1), y_unpacked.size(2)).type(
            y_unpacked.type())
        y_unpacked = torch.cat((y_unpacked, padding), dim=1)
    #return y_unpacked
    #print(idx_sort, idx_unsort, x_lengths, x_packed, y_packed, y_unpacked)
    return y_unpacked
Esempio n. 3
0
def sort_batch(seqbatch):
    """Sorts torch tensor of integer indices by decreasing order."""
    # 0 is padding_idx
    omask = (seqbatch != 0)
    olens = omask.sum(0)
    slens, sidxs = torch.sort(olens, descending=True)
    oidxs = torch.sort(sidxs)[1]
    return (oidxs, sidxs, slens.data.tolist(), omask.float())
Esempio n. 4
0
    def _PyramidRoI_Feat(self, feat_maps, rois, im_info):
        ''' roi pool on pyramid feature maps'''
        # do roi pooling based on predicted rois
        img_area = im_info[0][0] * im_info[0][1]
        h = rois.data[:, 4] - rois.data[:, 2] + 1
        w = rois.data[:, 3] - rois.data[:, 1] + 1
        roi_level = torch.log(torch.sqrt(h * w) / 224.0) / np.log(2)
        roi_level = torch.floor(roi_level + 4)
        # --------
        # roi_level = torch.log(torch.sqrt(h * w) / 224.0)
        # roi_level = torch.round(roi_level + 4)
        # ------
        roi_level[roi_level < 2] = 2
        roi_level[roi_level > 5] = 5
        # roi_level.fill_(5)
        if cfg.POOLING_MODE == 'crop':
            # pdb.set_trace()
            # pooled_feat_anchor = _crop_pool_layer(base_feat, rois.view(-1, 5))
            # NOTE: need to add pyrmaid
            grid_xy = _affine_grid_gen(rois, feat_maps.size()[2:], self.grid_size)  ##
            grid_yx = torch.stack([grid_xy.data[:,:,:,1], grid_xy.data[:,:,:,0]], 3).contiguous()
            roi_pool_feat = self.RCNN_roi_crop(feat_maps, Variable(grid_yx).detach()) ##
            if cfg.CROP_RESIZE_WITH_MAX_POOL:
                roi_pool_feat = F.max_pool2d(roi_pool_feat, 2, 2)

        elif cfg.POOLING_MODE == 'align':
            roi_pool_feats = []
            box_to_levels = []
            for i, l in enumerate(range(2, 6)):
                if (roi_level == l).sum() == 0:
                    continue
                idx_l = (roi_level == l).nonzero().squeeze()
                box_to_levels.append(idx_l)
                scale = feat_maps[i].size(2) / im_info[0][0]
                feat = self.RCNN_roi_align(feat_maps[i], rois[idx_l], scale)
                roi_pool_feats.append(feat)
            roi_pool_feat = torch.cat(roi_pool_feats, 0)
            box_to_level = torch.cat(box_to_levels, 0)
            idx_sorted, order = torch.sort(box_to_level)
            roi_pool_feat = roi_pool_feat[order]

        elif cfg.POOLING_MODE == 'pool':
            roi_pool_feats = []
            box_to_levels = []
            for i, l in enumerate(range(2, 6)):
                if (roi_level == l).sum() == 0:
                    continue
                idx_l = (roi_level == l).nonzero().squeeze()
                box_to_levels.append(idx_l)
                scale = feat_maps[i].size(2) / im_info[0][0]
                feat = self.RCNN_roi_pool(feat_maps[i], rois[idx_l], scale)
                roi_pool_feats.append(feat)
            roi_pool_feat = torch.cat(roi_pool_feats, 0)
            box_to_level = torch.cat(box_to_levels, 0)
            idx_sorted, order = torch.sort(box_to_level)
            roi_pool_feat = roi_pool_feat[order]
            
        return roi_pool_feat
Esempio n. 5
0
 def forward(self, inputs, lengths, hidden=None):
     lens, indices = torch.sort(inputs.data.new(lengths).long(), 0, True)
     inputs = inputs[indices] if self.batch_first else inputs[:, indices] 
     outputs, (h, c) = self.rnn(pack(inputs, lens.tolist(), 
         batch_first=self.batch_first), hidden)
     outputs = unpack(outputs, batch_first=self.batch_first)[0]
     _, _indices = torch.sort(indices, 0)
     outputs = outputs[_indices] if self.batch_first else outputs[:, _indices]
     h, c = h[:, _indices, :], h[:, _indices, :]
     return outputs, (h, c)
Esempio n. 6
0
    def _forward_padded(self, x, x_mask):
        """Slower (significantly), but more precise,
        encoding that handles padding."""
        # Compute sorted sequence lengths
        lengths = x_mask.data.eq(0).long().sum(1).squeeze()
        _, idx_sort = torch.sort(lengths, dim=0, descending=True)
        _, idx_unsort = torch.sort(idx_sort, dim=0)

        lengths = list(lengths[idx_sort])
        idx_sort = Variable(idx_sort)
        idx_unsort = Variable(idx_unsort)

        # Sort x
        x = x.index_select(0, idx_sort)

        # Transpose batch and sequence dims
        x = x.transpose(0, 1)

        # Pack it up
        rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths)

        # Encode all layers
        outputs = [rnn_input]
        for i in range(self.num_layers):
            rnn_input = outputs[-1]

            # Apply dropout to input
            if self.dropout_rate > 0:
                dropout_input = F.dropout(rnn_input.data,
                                          p=self.dropout_rate,
                                          training=self.training)
                rnn_input = nn.utils.rnn.PackedSequence(dropout_input,
                                                        rnn_input.batch_sizes)
            outputs.append(self.rnns[i](rnn_input)[0])

        # Unpack everything
        for i, o in enumerate(outputs[1:], 1):
            outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0]

        # Concat hidden layers or take final
        if self.concat_layers:
            output = torch.cat(outputs[1:], 2)
        else:
            output = outputs[-1]

        # Transpose and unsort
        output = output.transpose(0, 1)
        output = output.index_select(0, idx_unsort)

        # Dropout on output layer
        if self.dropout_output and self.dropout_rate > 0:
            output = F.dropout(output,
                               p=self.dropout_rate,
                               training=self.training)
        return output
Esempio n. 7
0
    def forward(self, words, seq_lens):
        emb = self.embedder(words)
        seq_lens, idx_sort = torch.sort(seq_lens, descending=True)
        _, idx_unsort = torch.sort(idx_sort, descending=False)

        sorted_input = emb.index_select(1, torch.autograd.Variable(idx_sort))
        packed_input = nn.utils.rnn.pack_padded_sequence(sorted_input, seq_lens.tolist())
        packed_output = self.rnn(packed_input)[0]
        sorted_rnn_output = nn.utils.rnn.pad_packed_sequence(packed_output)[0]
        rnn_output = sorted_rnn_output.index_select(1, torch.autograd.Variable(idx_unsort))

        return rnn_output.max(0)[0]
Esempio n. 8
0
    def value(self):
        # case when number of elements added are 0
        if self.scores.shape[0] == 0:
            return 0.5

        # sorting the arrays
        scores, sortind = torch.sort(torch.from_numpy(self.scores), dim=0, descending=True)
        scores = scores.numpy()
        sortind = sortind.numpy()

        # creating the roc curve
        tpr = np.zeros(shape=(scores.size + 1), dtype=np.float64)
        fpr = np.zeros(shape=(scores.size + 1), dtype=np.float64)

        for i in range(1, scores.size + 1):
            if self.targets[sortind[i - 1]] == 1:
                tpr[i] = tpr[i - 1] + 1
                fpr[i] = fpr[i - 1]
            else:
                tpr[i] = tpr[i - 1]
                fpr[i] = fpr[i - 1] + 1

        tpr /= (self.targets.sum() * 1.0)
        fpr /= ((self.targets - 1.0).sum() * -1.0)

        # calculating area under curve using trapezoidal rule
        n = tpr.shape[0]
        h = fpr[1:n] - fpr[0:n - 1]
        sum_h = np.zeros(fpr.shape)
        sum_h[0:n - 1] = h
        sum_h[1:n] += h
        area = (sum_h * tpr).sum() / 2.0

        return (area, tpr, fpr)
Esempio n. 9
0
def validate():
    softmaxer = torch.nn.Softmax(dim=1)
    model.eval()
    correct = total = 0
    precisionmat = (1/np.arange(1,21))[::-1].cumsum()[::-1]
    precisionmat = torch.cuda.FloatTensor(precisionmat.copy())
    precision = 0
    crossentropy = 0
    hidden = model.initHidden()
    for batch in iter(val_iter):
        sentences = batch.text # n=32,bs
        if torch.cuda.is_available():
            sentences = sentences.cuda()
        out, hidden = model(sentences, hidden)
        for j in range(sentences.size(0)-1):
            outj = out[j] # bs,|V|
            labelsj = sentences[j+1] # bs
            # cross entropy
            crossentropy += F.cross_entropy(outj,labelsj,size_average=False,ignore_index=padidx)
            # precision
            outj, labelsj = softmaxer(outj).data, labelsj.data
            _, outsort = torch.sort(outj,dim=1,descending=True)
            outsort = outsort[:,:20]
            inds = (outsort-labelsj.unsqueeze(1)==0)
            inds = inds.sum(dim=0).type(torch.cuda.FloatTensor)
            precision += inds.dot(precisionmat)
            # plain ol accuracy
            _, predicted = torch.max(outj, 1)
            total += labelsj.ne(padidx).int().sum()
            correct += (predicted==labelsj).sum()
            # DEBUGGING: see the rest in trigram.py
        hidden = repackage_hidden(hidden)
    return correct/total, precision/total, torch.exp(crossentropy/total).data[0]
Esempio n. 10
0
 def eliminate_rows(self, prob_sc, ind, phis):
     """ eliminate rows of phis and prob_matrix scale """
     length = prob_sc.size()[1]
     mask = (prob_sc[:, :, 0] > 0.85).type(dtype)
     rang = (Variable(torch.range(0, length - 1).unsqueeze(0)
             .expand_as(mask)).
             type(dtype))
     ind_sc = torch.sort(rang * (1-mask) + length * mask, 1)[1]
     # permute prob_sc
     m = mask.unsqueeze(2).expand_as(prob_sc)
     mm = m.clone()
     mm[:, :, 1:] = 0
     prob_sc = (torch.gather(prob_sc * (1 - m) + mm, 1,
                ind_sc.unsqueeze(2).expand_as(prob_sc)))
     # compose permutations
     ind = torch.gather(ind, 1, ind_sc)
     active = torch.gather(1-mask, 1, ind_sc)
     # permute phis
     active1 = active.unsqueeze(2).expand_as(phis)
     ind1 = ind.unsqueeze(2).expand_as(phis)
     active2 = active.unsqueeze(1).expand_as(phis)
     ind2 = ind.unsqueeze(1).expand_as(phis)
     phis_out = torch.gather(phis, 1, ind1) * active1
     phis_out = torch.gather(phis_out, 2, ind2) * active2
     return prob_sc, ind, phis_out, active
Esempio n. 11
0
def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
    """
    Removes detections with lower object confidence score than 'conf_thres' and performs
    Non-Maximum Suppression to further filter detections.
    Returns detections with shape:
        (x1, y1, x2, y2, object_conf, class_score, class_pred)
    """

    # From (center x, center y, width, height) to (x1, y1, x2, y2)
    box_corner = prediction.new(prediction.shape)
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4]

    output = [None for _ in range(len(prediction))]
    for image_i, image_pred in enumerate(prediction):
        # Filter out confidence scores below threshold
        conf_mask = (image_pred[:, 4] >= conf_thres).squeeze()
        image_pred = image_pred[conf_mask]
        # If none are remaining => process next image
        if not image_pred.size(0):
            continue
        # Get score and class with highest confidence
        class_conf, class_pred = torch.max(image_pred[:, 5 : 5 + num_classes], 1, keepdim=True)
        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
        detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
        # Iterate through all predicted classes
        unique_labels = detections[:, -1].cpu().unique()
        if prediction.is_cuda:
            unique_labels = unique_labels.cuda()
        for c in unique_labels:
            # Get the detections with the particular class
            detections_class = detections[detections[:, -1] == c]
            # Sort the detections by maximum objectness confidence
            _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True)
            detections_class = detections_class[conf_sort_index]
            # Perform non-maximum suppression
            max_detections = []
            while detections_class.size(0):
                # Get detection with highest confidence and save as max detection
                max_detections.append(detections_class[0].unsqueeze(0))
                # Stop if we're at the last detection
                if len(detections_class) == 1:
                    break
                # Get the IOUs for all boxes with lower confidence
                ious = bbox_iou(max_detections[-1], detections_class[1:])
                # Remove detections with IoU >= NMS threshold
                detections_class = detections_class[1:][ious < nms_thres]

            max_detections = torch.cat(max_detections).data
            # Add max detections to outputs
            output[image_i] = (
                max_detections if output[image_i] is None else torch.cat((output[image_i], max_detections))
            )

    return output
Esempio n. 12
0
 def sample(self):
     # print(torch.sort(torch.randn(200, self.latent_variable_size), dim=0))
     sample = Variable(torch.sort(torch.randn(200, self.latent_variable_size), dim=0)[0])
     sample = self.decode(sample)
     # print(sample)
     # print(sample.size())
     # save_image(sample.data.view(64, 1, 15, 20)*255,
     #         'sample_' + str(epoch) + '.png')
     return sample.data.numpy()
Esempio n. 13
0
def find_neighbors(im, image_set, image_names, num_neighbors=num_neighbors):
    # compute the L2 distances
    dists = compute_L2_dists(im, image_set)
    # sort in the order of increasing distance
    sorted, indices = torch.sort(dists, dim=0, descending=False)
    indices = indices.cpu()
    # pick the nearest neighbors
    nn_names = [image_names[i] for i in indices[:num_neighbors]]
    return nn_names, indices[:num_neighbors]
Esempio n. 14
0
    def from_batch(self, translation_batch):
        batch = translation_batch["batch"]
        assert(len(translation_batch["gold_score"]) ==
               len(translation_batch["predictions"]))
        batch_size = batch.batch_size

        preds, pred_score, attn, gold_score, indices = list(zip(
            *sorted(zip(translation_batch["predictions"],
                        translation_batch["scores"],
                        translation_batch["attention"],
                        translation_batch["gold_score"],
                        batch.indices.data),
                    key=lambda x: x[-1])))

        # Sorting
        inds, perm = torch.sort(batch.indices.data)
        data_type = self.data.data_type
        if data_type == 'text':
            src = batch.src[0].data.index_select(1, perm)
        else:
            src = None

        if self.has_tgt:
            tgt = batch.tgt.data.index_select(1, perm)
        else:
            tgt = None

        translations = []
        for b in range(batch_size):
            if data_type == 'text':
                src_vocab = self.data.src_vocabs[inds[b]] \
                  if self.data.src_vocabs else None
                src_raw = self.data.examples[inds[b]].src
            else:
                src_vocab = None
                src_raw = None
            pred_sents = [self._build_target_tokens(
                src[:, b] if src is not None else None,
                src_vocab, src_raw,
                preds[b][n], attn[b][n])
                          for n in range(self.n_best)]
            gold_sent = None
            if tgt is not None:
                gold_sent = self._build_target_tokens(
                    src[:, b] if src is not None else None,
                    src_vocab, src_raw,
                    tgt[1:, b] if tgt is not None else None, None)

            translation = Translation(src[:, b] if src is not None else None,
                                      src_raw, pred_sents,
                                      attn[b], pred_score[b], gold_sent,
                                      gold_score[b])
            translations.append(translation)

        return translations
Esempio n. 15
0
 def isCompact(tensor):
     # isn't it enough to check if strides == size.cumprod(0)?
     sortedStride, perm = torch.sort(torch.LongTensor(tensor.stride()), 0, True)
     sortedSize = torch.LongTensor(list(tensor.size())).index_select(0, perm)
     nRealDim = int(torch.clamp(sortedStride, 0, 1).sum())
     sortedStride = sortedStride.narrow(0, 0, nRealDim).clone()
     sortedSize = sortedSize.narrow(0, 0, nRealDim).clone()
     t = tensor.new().set_(tensor.storage(), 0,
                           tuple(sortedSize),
                           tuple(sortedStride))
     return t.is_contiguous()
Esempio n. 16
0
 def sort_by_embeddings(self, Phis, Inputs_N, e):
     ind = torch.sort(e, 1)[1].squeeze()
     for i, phis in enumerate(Phis):
         # rearange phis
         phis_out = (torch.gather(Phis[i], 1, ind.unsqueeze(2)
                     .expand_as(phis)))
         Phis[i] = (torch.gather(phis_out, 2, ind.unsqueeze(1)
                    .expand_as(phis)))
         # rearange inputs
         Inputs_N[i] = torch.gather(Inputs_N[i], 1,
                                    ind.unsqueeze(2).expand_as(Inputs_N[i]))
     return Phis, Inputs_N
Esempio n. 17
0
 def reindex_target(self, target, e):
     """ Reindex target by embedding to be coherent. We have to invert
     a permutation and add some padding to do it correctly. """
     ind = torch.sort(e, 1)[1].squeeze()
     first = Variable(torch.zeros(self.batch_size, 1)).type(dtype_l)
     ind = torch.cat((first, ind + 1), 1)
     # target = new_target(ind) -> new_target = target(ind_inv)
     # invert permutation
     ind_inv = torch.sort(ind, 1)[1]
     last = np.zeros((self.batch_size, 1))
     target = np.concatenate((target, last), axis=1)
     for example in range(self.batch_size):
         tar = target[example].astype(int)
         ind_inv_n = ind_inv[example].data.cpu().numpy()
         tar = ind_inv_n[tar]
         tar_aux = tar[np.where(tar > 0)[0]]
         argmin = np.argsort(tar_aux)[0]
         tar_aux = np.array(list(tar_aux[argmin:]) + list(tar_aux[:argmin]))
         tar[:tar_aux.shape[0]] = tar_aux
         target[example] = tar
     return target[:, :-1]
Esempio n. 18
0
    def rank_c2i(self, img_features, cap_features, cap_ix_to_img_ix):
        scores = torch.mm(cap_features, img_features.t()).cpu()
        _, indices = torch.sort(scores, 1, descending=True)

        rank = np.zeros(cap_features.size(0))
        for j in range(cap_features.size(0)):
            for k in range(img_features.size(0)):
                if indices[j, k] == cap_ix_to_img_ix[j]:
                    rank[j] = k + 1
                    break

        return self.get_rank_stats(rank)
Esempio n. 19
0
def infer(data_dir, model_name, sentence=None):
    """
    """

    # Load components
    with open(os.path.join(basedir, data_dir, 'char2index.json'), 'r') as f:
        char2index = json.load(f)
    with open(os.path.join(basedir, data_dir, 'index2class.json'), 'r') as f:
        index2class = json.load(f)

    # Enter the sentence
    print ("Classes:", index2class.values())
    if not sentence:
        sentence = input("Please enter the sentence: ")

    # Normalize the sentece
    sentence = normalize_string(sentence)

    # Convert sentence(s) to indexes
    input_ = convert_sentence(
        sentence=sentence,
        char2index=char2index,
        )

    # Convert to model input
    input_, _ = pad(
        inputs=[input_],
        char2index=char2index,
        max_length=len(input_),
        )

    # Convert to Variable
    X = Variable(torch.FloatTensor(input_), requires_grad=False)

    # Load the model
    model = torch.load(os.path.join(
        basedir, "data", data_dir.split("/")[-1], model_name))

    # Feed through model
    model.eval()
    scores = model(X)
    probabilities = F.softmax(scores)

    # Sorted probabilities
    sorted_, indices = torch.sort(probabilities, descending=True)
    for index in indices[0]:
        print ("%s - %i%%" % (
            index2class[str(index.data[0])],
            100.0*probabilities.data[0][index.data[0]]))
        def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
            #INPUTS:
            #logprobsf: probabilities augmented after diversity
            #beam_size: obvious
            #t        : time instant
            #beam_seq : tensor contanining the beams
            #beam_seq_logprobs: tensor contanining the beam logprobs
            #beam_logprobs_sum: tensor contanining joint logprobs
            #OUPUTS:
            #beam_seq : tensor containing the word indices of the decoded captions
            #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
            #beam_logprobs_sum : joint log-probability of each beam

            ys,ix = torch.sort(logprobsf,1,True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols): # for each column (word, essentially)
                for q in range(rows): # for each beam expansion
                    #compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q,c]
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_logprob})
            candidates = sorted(candidates,  key=lambda x: -x['p'])
            
            new_state = [_.clone() for _ in state]
            #beam_seq_prev, beam_seq_logprobs_prev
            if t >= 1:
            #we''ll need these as reference when we fork beams around
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
            for vix in range(beam_size):
                v = candidates[vix]
                #fork beam index q into index vix
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
                #rearrange recurrent states
                for state_ix in range(len(new_state)):
                #  copy over state in previous beam q to new beam at vix
                    new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
                #append new end terminal at the end of this beam
                beam_seq[t, vix] = v['c'] # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
                beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates
Esempio n. 21
0
def validate():
    softmaxer = torch.nn.Softmax(dim=1)
    model.eval()
    fLSTM.eval()
    fGRU.eval()
    fNNLM.eval()
    correct = total = 0
    precisionmat = (1/np.arange(1,21))[::-1].cumsum()[::-1]
    precisionmat = torch.cuda.FloatTensor(precisionmat.copy())
    precision = 0
    crossentropy = 0
    LSTMhidden = fLSTM.initHidden()
    GRUhidden = fGRU.initHidden()
    for batch in iter(val_iter):
        sentences = batch.text # n=32,bs
        if torch.cuda.is_available():
            sentences = sentences.cuda()
        LSTMout, LSTMhidden = fLSTM(sentences, LSTMhidden)
        GRUout, GRUhidden = fGRU(sentences, GRUhidden)
        word_pad = (32 + n - 1) - sentences.size(0)
        pads = Variable(torch.zeros(word_pad,sentences.size(1))).type(torch.cuda.LongTensor)
        padsentences = torch.cat([pads,sentences],dim=0)
        #print("sentence_dim: {}\npadded_dim: {}".format(sentences.size(),padsentences.size()))
        NNLMout = torch.stack([ fNNLM(torch.cat([ padsentences[:,a:a+1][b:b+n,:] for b in range(32) ],dim=1).t()) for a in range(sentences.size(1)) ],dim=1)
        #eOUT = torch.cat([LSTMout,GRUout,NNLMout],dim=2)
        NNLMout = NNLMout[-sentences.size(0):,:sentences.size(1),:len(TEXT.vocab)]
        tOUT = model(sentences.t(),LSTMout,GRUout,NNLMout)
        out  = tOUT
        for j in range(sentences.size(0)-1):
            outj = out[j] # bs,|V|
            labelsj = sentences[j+1] # bs
            # cross entropy
            crossentropy += F.cross_entropy(outj,labelsj,size_average=False,ignore_index=padidx)
            # precision
            outj, labelsj = softmaxer(outj).data, labelsj.data
            _, outsort = torch.sort(outj,dim=1,descending=True)
            outsort = outsort[:,:20]
            inds = (outsort-labelsj.unsqueeze(1)==0)
            inds = inds.sum(dim=0).type(torch.cuda.FloatTensor)
            precision += inds.dot(precisionmat)
            # plain ol accuracy
            _, predicted = torch.max(outj, 1)
            total += labelsj.ne(padidx).int().sum()
            correct += (predicted==labelsj).sum()
            # DEBUGGING: see the rest in trigram.py
        LSTMhidden = repackage_hidden(LSTMhidden)
        GRUhidden  = repackage_hidden(GRUhidden)
    return correct/total, precision/total, torch.exp(crossentropy/total).data[0]
Esempio n. 22
0
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    return loss
def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float() # foreground for class c
        if only_present and fg.sum() == 0:
            continue
        errors = (Variable(fg) - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)
Esempio n. 24
0
def _threshold_and_support(input, dim=0):
    """Sparsemax building block: compute the threshold

    Args:
        input: any dimension
        dim: dimension along which to apply the sparsemax

    Returns:
        the threshold value
    """

    input_srt, _ = torch.sort(input, descending=True, dim=dim)
    input_cumsum = input_srt.cumsum(dim) - 1
    rhos = _make_ix_like(input, dim)
    support = rhos * input_srt > input_cumsum

    support_size = support.sum(dim=dim).unsqueeze(dim)
    tau = input_cumsum.gather(dim, support_size - 1)
    tau /= support_size.to(input.dtype)
    return tau, support_size
Esempio n. 25
0
    def average_precision(output, target, difficult_examples=True):

        # sort examples
        sorted, indices = torch.sort(output, dim=0, descending=True)

        # Computes prec@i
        pos_count = 0.
        total_count = 0.
        precision_at_i = 0.
        for i in indices:
            label = target[i]
            if difficult_examples and label == 0:
                continue
            if label == 1:
                pos_count += 1
            total_count += 1
            if label == 1:
                precision_at_i += pos_count / total_count
        precision_at_i /= pos_count
        return precision_at_i
Esempio n. 26
0
 def plot_norm_points(self, Inputs_N, e, Perms, scales, fig=1):
     input = Inputs_N[0][0].data.cpu().numpy()
     e = torch.sort(e, 1)[0][0].data.cpu().numpy()
     Perms = [perm[0].data.cpu().numpy() for perm in Perms]
     plt.figure(fig)
     plt.clf()
     ee = e.copy()
     for i, perm in enumerate(Perms):
         plt.subplot(1, len(Perms), i + 1)
         colors = cm.rainbow(np.linspace(0, 1, 2 ** (scales - i)))
         perm = perm[np.where(perm > 0)[0]] - 1
         points = input[perm]
         e_scale = ee[perm]
         for node in range(2 ** (scales - i)):
             ind = np.where(e_scale == node)[0]
             pts = points[ind]
             plt.scatter(pts[:, 0], pts[:, 1], c=colors[node])
         ee //= 2
     path = os.path.join(self.path, 'visualize_example.png')
     plt.savefig(path)
Esempio n. 27
0
def nms(boxes, nms_thresh):
    if len(boxes) == 0:
        return boxes

    det_confs = torch.zeros(len(boxes))
    for i in range(len(boxes)):
        det_confs[i] = 1-boxes[i][4]                

    _,sortIds = torch.sort(det_confs)
    out_boxes = []
    for i in range(len(boxes)):
        box_i = boxes[sortIds[i]]
        if box_i[4] > 0:
            out_boxes.append(box_i)
            for j in range(i+1, len(boxes)):
                box_j = boxes[sortIds[j]]
                if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
                    #print(box_i, box_j, bbox_iou(box_i, box_j, x1y1x2y2=False))
                    box_j[4] = 0
    return out_boxes
Esempio n. 28
0
def validate():
    model.eval()
    correct = total = 0
    precisionmat = (1/np.arange(1,21))[::-1].cumsum()[::-1]
    precisionmat = torch.cuda.FloatTensor(precisionmat.copy()) # hm
    precision = 0
    crossentropy = 0
    for batch in iter(val_iter):
        sentences = batch.text.transpose(1,0).cuda() # bs, n
        if sentences.size(1) < n+1: # make sure sentence length is long enough
            pads = Variable(torch.zeros(sentences.size(0),n+1-sentences.size(1))).type(torch.cuda.LongTensor)
            sentences = torch.cat([pads,sentences],dim=1)
        for j in range(n,sentences.size(1)):
            out = model(sentences[:,j-n:j]) # bs,|V|
            labels = sentences[:,j] # bs
            # cross entropy
            crossentropy += F.cross_entropy(out,labels,size_average=False,ignore_index=padidx)
            # precision
            out, labels = out.data, labels.data
            _, outsort = torch.sort(out,dim=1,descending=True)
            outsort = outsort[:,:20]
            inds = (outsort-labels.unsqueeze(1)==0)
            inds = inds.sum(dim=0).type(torch.cuda.FloatTensor)
            precision += inds.dot(precisionmat)
            # plain ol accuracy
            _, predicted = torch.max(out, 1)
            total += labels.ne(padidx).int().sum()
            correct += (predicted==labels).sum()
            # if total % 500 == 0:
                # DEBUGGING: see the rest in trigram.py
                # print('we are on example', total)
                # for s in range(bs):
                #     print([TEXT.vocab.itos[w] for w in sentences[s,j-n:j].data])
                #     print(TEXT.vocab.itos[labels[s]])
                #     print([TEXT.vocab.itos[w] for w in outsort[s]])
                # print('Test Accuracy', correct/total)
                # print('Precision',precision/total)
                # print('Perplexity',torch.exp(bs*crossentropy/total).data[0])
    return correct/total, precision/total, torch.exp(crossentropy/total).data[0]
Esempio n. 29
0
    def forward(ctx, pred, labels, is_positive, ohem_ratio, group_size):
        n_sample = pred.size()[0]
        assert n_sample == len(labels), "mismatch between sample size and label size"
        losses = torch.zeros(n_sample)
        slopes = torch.zeros(n_sample)
        for i in range(n_sample):
            losses[i] = max(0, 1 - is_positive * pred[i, labels[i] - 1])
            slopes[i] = -is_positive if losses[i] != 0 else 0

        losses = losses.view(-1, group_size).contiguous()
        sorted_losses, indices = torch.sort(losses, dim=1, descending=True)
        keep_num = int(group_size * ohem_ratio)
        loss = torch.zeros(1).cuda()
        for i in range(losses.size(0)):
            loss += sorted_losses[i, :keep_num].sum()
        ctx.loss_ind = indices[:, :keep_num]
        ctx.labels = labels
        ctx.slopes = slopes
        ctx.shape = pred.size()
        ctx.group_size = group_size
        ctx.num_group = losses.size(0)
        return loss
Esempio n. 30
0
    def value(self):
        """Returns the model's average precision for each class

        Return:
            ap (FloatTensor): 1xK tensor, with avg precision for each class k
        """

        if self.scores.numel() == 0:
            return 0
        ap = torch.zeros(self.scores.size(1))
        rg = torch.range(1, self.scores.size(0)).float()
        if self.weights.numel() > 0:
            weight = self.weights.new(self.weights.size())
            weighted_truth = self.weights.new(self.weights.size())

        # compute average precision for each class
        for k in range(self.scores.size(1)):
            # sort scores
            scores = self.scores[:, k]
            targets = self.targets[:, k]
            _, sortind = torch.sort(scores, 0, True)
            truth = targets[sortind]
            if self.weights.numel() > 0:
                weight = self.weights[sortind]
                weighted_truth = truth.float() * weight
                rg = weight.cumsum(0)

            # compute true positive sums
            if self.weights.numel() > 0:
                tp = weighted_truth.cumsum(0)
            else:
                tp = truth.float().cumsum(0)

            # compute precision curve
            precision = tp.div(rg)

            # compute average precision
            ap[k] = precision[truth.byte()].sum() / max(truth.sum(), 1)
        return ap
Esempio n. 31
0
def gen_sentence_tensors(sentence_list, device, data_url):
    """ generate input tensors from sentence list

    Args:
        sentence_list: list of raw sentence
        device: torch device
        data_url: raw data url to locate the vocab url

    Returns:
        sentences, tensor
        sentence_lengths, tensor
        sentence_words, list of tensor
        sentence_word_lengths, list of tensor
        sentence_word_indices, list of tensor

    """
    vocab = ju.load(dirname(data_url) + '/vocab.json')
    char_vocab = ju.load(dirname(data_url) + '/char_vocab.json')

    sentences = list()
    sentence_words = list()
    sentence_word_lengths = list()
    sentence_word_indices = list()

    unk_idx = 1
    for sent in sentence_list:
        # word to word id
        sentence = torch.LongTensor([
            vocab[word] if word in vocab else unk_idx for word in sent
        ]).to(device)

        # char of word to char id
        words = list()
        for word in sent:
            words.append([
                char_vocab[ch] if ch in char_vocab else unk_idx for ch in word
            ])

        # save word lengths
        word_lengths = torch.LongTensor([len(word)
                                         for word in words]).to(device)

        # sorting lengths according to length
        word_lengths, word_indices = torch.sort(word_lengths, descending=True)

        # sorting word according word length
        words = np.array(words)[word_indices.cpu().numpy()]
        word_indices = word_indices.to(device)
        words = [torch.LongTensor(word).to(device) for word in words]

        # padding char tensor of words
        words = pad_sequence(words, batch_first=True).to(device)
        # (max_word_len, sent_len)

        sentences.append(sentence)
        sentence_words.append(words)
        sentence_word_lengths.append(word_lengths)
        sentence_word_indices.append(word_indices)

    # record sentence length and padding sentences
    sentence_lengths = [len(sentence) for sentence in sentences]
    # (batch_size)
    sentences = pad_sequence(sentences, batch_first=True).to(device)
    # (batch_size, max_sent_len)

    return sentences, sentence_lengths, sentence_words, sentence_word_lengths, sentence_word_indices
Esempio n. 32
0
def non_max_suppression(prediction,
                        conf_thres,
                        num_classes,
                        cuda,
                        nms_thres=0.4):
    """ Applies thresholding based on objectness score and non-maximum suppression """
    # Transform (center x, center y, height, width) attributes of the bounding boxes to
    # (top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)
    box_corner = prediction.new(prediction.shape)
    box_corner[:, :, 0] = (prediction[:, :, 0] - prediction[:, :, 2] / 2)
    box_corner[:, :, 1] = (prediction[:, :, 1] - prediction[:, :, 3] / 2)
    box_corner[:, :, 2] = (prediction[:, :, 0] + prediction[:, :, 2] / 2)
    box_corner[:, :, 3] = (prediction[:, :, 1] + prediction[:, :, 3] / 2)
    prediction[:, :, :4] = box_corner[:, :, :
                                      4]  # Apply transformation to prediction

    # Toggle CUDA
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    output = FloatTensor()

    # Loop over images in a batch
    for index in range(prediction.size(0)):
        img_pred = prediction[index]

        # Filter out object confidence scores below threshold
        img_pred = img_pred[img_pred[:, 4] > conf_thres]

        # Skip to next image if no scores are left
        if img_pred.shape[0] == 0:
            continue

        # Get index of class with the highest value and its score
        class_conf, class_score = torch.max(img_pred[:, 5:], 1)  # Offset by 5
        class_conf = class_conf.float().unsqueeze(1)
        class_score = class_score.float().unsqueeze(1)
        img_pred = torch.cat((img_pred[:, :5], class_conf, class_score), 1)

        # Get unique classes detected in an image
        img_classes = img_pred[:, -1].unique()

        # Perform non-maximum suppression class-wise
        for cl in img_classes:
            # Get the detections for one class
            cl_mask = img_pred * (img_pred[:, -1] == cl).float().unsqueeze(1)
            cl_mask_idx = torch.nonzero(cl_mask[:, -2]).squeeze()
            img_pred_class = img_pred[cl_mask_idx].view(-1, 7)

            # Sort detections by objectness score in descending order
            conf_sort_idx = torch.sort(img_pred_class[:, 4],
                                       descending=True)[1]
            img_pred_class = img_pred_class[conf_sort_idx]
            idx = img_pred_class.size(0)  # Number of detections

            for i in range(idx):
                # Get IOUs of all later bounding boxes
                try:
                    ious = bbox_iou(img_pred_class[i].unsqueeze(0),
                                    img_pred_class[i + 1:])

                except ValueError:
                    break

                except IndexError:
                    break

                # Zero out all the detections that have IoU > threshold
                iou_mask = (ious < nms_thres).float().unsqueeze(1)
                img_pred_class[i + 1:] *= iou_mask

                # Remove the non-zero entries
                non_zero_idx = torch.nonzero(img_pred_class[:, 4]).squeeze()
                img_pred_class = img_pred_class[non_zero_idx].view(-1, 7)

            batch_idx = img_pred_class.new(img_pred_class.size(0),
                                           1).fill_(index)

            # Repeat the batch_id for as many detections of the class cl in the image
            seq = torch.cat((batch_idx, img_pred_class), 1)

            output = torch.cat((output, seq))

    return output
Esempio n. 33
0
import torch

# ranks = [54, 67, 59, 46, 2, 1, 100]
ranks = [5, 6, 4, 3, 2, 1, 7]
predicted_ranks = torch.tensor(ranks).float()
predicted_ranks = predicted_ranks.unsqueeze(0)
print(predicted_ranks.size())
new_ranks, rankings = torch.sort(predicted_ranks, dim=-1)

print("Original: ", predicted_ranks)
print("Rankings", rankings)
print("New ranks", new_ranks)
Esempio n. 34
0
def test_net(model=None, image=None, params=None, bg=None, cls=None):
    blob, scale, label = params
    with torch.no_grad():  # pre-processing data for passing net
        im_data = Variable(torch.FloatTensor(1).cuda())
        im_info = Variable(torch.FloatTensor(1).cuda())
        num_boxes = Variable(torch.LongTensor(1).cuda())
        gt_boxes = Variable(torch.FloatTensor(1).cuda())

    im_info_np = np.array([[blob.shape[1], blob.shape[2], scale[0]]], dtype=np.float32)
    im_data_pt = torch.from_numpy(blob)
    im_data_pt = im_data_pt.permute(0, 3, 1, 2)
    im_info_pt = torch.from_numpy(im_info_np)

    with torch.no_grad():  # resize
        im_data.resize_(im_data_pt.size()).copy_(im_data_pt)
        im_info.resize_(im_info_pt.size()).copy_(im_info_pt)
        gt_boxes.resize_(1, 1, 5).zero_()
        num_boxes.resize_(1).zero_()

    rois, cls_prob, bbox_pred, \
    rpn_loss_cls, rpn_loss_box, \
    RCNN_loss_cls, RCNN_loss_bbox, \
    rois_label = model(im_data, im_info, gt_boxes, num_boxes)  # predict

    scores = cls_prob.data
    boxes = rois.data[:, :, 1:5]

    if opt.TEST_BBOX_REG:
        box_deltas = bbox_pred.data
        if opt.TRAIN_BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
            if opt.cuda:
                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(opt.TRAIN_BBOX_NORMALIZE_STDS).cuda() \
                             + torch.FloatTensor(opt.TRAIN_BBOX_NORMALIZE_MEANS).cuda()
            else:
                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(opt.TRAIN_BBOX_NORMALIZE_STDS) \
                             + torch.FloatTensor(opt.TRAIN_BBOX_NORMALIZE_MEANS)

            box_deltas = box_deltas.view(1, -1, 4 * len(label))

        pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
        pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)

    pred_boxes /= scale[0]

    scores = scores.squeeze()
    pred_boxes = pred_boxes.squeeze()

    image = np.copy(image[:, :, ::-1])
    demo = image.copy()
    bubbles = []
    dets_bubbles = []

    for j in range(1, len(label)):
        inds = torch.nonzero(scores[:, j] > opt.THRESH).view(-1)
        if inds.numel() > 0:
            cls_scores = scores[:, j][inds]
            _, order = torch.sort(cls_scores, 0, True)
            cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]

            cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
            cls_dets = cls_dets[order]
            keep = nms(cls_boxes[order, :], cls_scores[order], opt.TEST_NMS)
            cls_dets = cls_dets[keep.view(-1).long()].cpu().numpy()

            #  post-processing : get contours of speech bubble
            demo, image, bubbles, dets_bubbles = bubble_utils.get_cnt_bubble(image, image.copy(), label[j], cls_dets,
                                                                             cls, bg=bg)
    return demo, image, bubbles, dets_bubbles
Esempio n. 35
0
batch_transform = torch.unsqueeze(img_transform, 0)

mnasnet.eval()

inference_times = np.empty(inference_count)

out = mnasnet(batch_transform)  #warm-up

for i in range(inference_count):
    start_time = time.time()
    out = mnasnet(batch_transform)
    end_time = (time.time() - start_time) * 1000
    np.put(inference_times, i, end_time)

total_time = np.sum(inference_times)
average_time = np.mean(inference_times)
standard_deviation = np.std(inference_times)

print("Total time = {0:.2f}ms".format(total_time))
print("Average time = {0:.2f}ms".format(average_time))
print("Standard deviation = {0:.2f}ms".format(standard_deviation))

percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100

_, indices = torch.sort(out, descending=True)
for idx in indices[0][:5]:
    print((labels[idx], percentage[idx].item()))

print("AI_BENCHMARK_MARKER,PyTorch v1.5.1,{0},Float,{1:.2f},{2:.2f},".format(
    model_location.strip(), average_time, standard_deviation))
    def step(self):
        """ Update the pruning masks """
        self.prune_out = []
        self.prune_in = []
        self.prune_both = []
        self.skip = []
        self.basicBlock = collections.defaultdict(int)
        key = self.keyword
        convList = []
        bnList = []

        bn_total = 0
        convList = []
        total_channel = 0
        for name, m in self.model.named_modules():
            if isinstance(m, nn.Conv2d):
                convList.append((name, m))
                self.basicBlock[getBlockID(name)] = self.basicBlock[getBlockID(name)] + 1
                total_channel += m.out_channels
            if isinstance(m, nn.BatchNorm2d):
                bnList.append(m)
                # bn_total += m.weight.data.shape[0]

        print(len(bnList), len(convList))
        assert len(bnList) == len(convList), "unequal bn and conv"

        for index in range(len(convList)):

            name, layer = convList[index]

            outchannel = layer.out_channels
            inchannel = layer.in_channels
            # print("index:{}, name: {}, layer: {}, inchannel :{}, outchannel: {} ".format(index, name, layer, outchannel, inchannel))

            if ifcontain(name, key) or index == 0:
                self.skip.append(index)
                continue

            if index + 1 >= len(convList):
                if inchannel == convList[index - 1][1].out_channels:
                    self.prune_in.append(index)
                    continue
                else:
                    self.skip.append(index)
                    continue
            if outchannel != convList[index + 1][1].in_channels and inchannel == convList[index - 1][1].out_channels:
                self.prune_in.append(index)
                continue

            if outchannel == convList[index + 1][1].in_channels and inchannel != convList[index - 1][1].out_channels:
                self.prune_out.append(index)
                continue

            if outchannel == convList[index + 1][1].in_channels and inchannel == convList[index - 1][1].out_channels:
                if index - 1 in self.prune_in or index - 1 in self.skip:
                    # if ifcontain(convList[index+1][0], key) or getBlockID(name) != getBlockID(convList[index+1][0]):
                    if ifcontain(convList[index + 1][0], key):
                        self.skip.append(index)
                        continue
                    self.prune_out.append(index)
                    continue

                if self.basicBlock[getBlockID(name)] > 1 and getBlockID(name) != getBlockID(convList[index + 1][0]):
                    # if getBlockID(name) != getBlockID(convList[index+1][0]):
                    self.prune_in.append(index)
                    continue

                self.prune_both.append(index)
                continue
            self.skip.append(index)

        # print("prune_skip : {}".format(self.skip))
        # print("prune_out : {}".format(self.prune_out))
        # print("prune_in : {}".format(self.prune_in))
        # print("prune_both : {}".format(self.prune_both))

        for i in range(len(bnList)):
            if i not in self.skip and i not in self.prune_in:
                bn_total += bnList[i].weight.data.shape[0]

        bn = torch.zeros(bn_total)
        index = 0
        for i in range(len(bnList)):
            if i not in self.skip and i not in self.prune_in:
                m = bnList[i]
                size = m.weight.data.shape[0]
                bn[index:(index + size)] = m.weight.data.abs().clone()
                index += size
        # for m in bnList:
        #     size = m.weight.data.shape[0]
        #     bn[index:(index+size)] = m.weight.data.abs().clone()
        #     index += size

        y, i = torch.sort(bn)

        thre_index = int(bn_total * self.pruning_rate)
        thre = y[thre_index]

        self.total = bn_total
        self.cfg = []
        self.cfg_mask = []
        for index in range(len(convList)):
            _, m = convList[index]
            in_channels = m.weight.data.shape[1]
            out_channels = m.weight.data.shape[0]

            if index in self.skip:
                self.cfg.append((in_channels, out_channels))
                continue

            if index in self.prune_in:
                new_inchannel = self.cfg[-1][1]
                self.cfg.append((new_inchannel, out_channels))

                self.cfg_mask.append(self.cfg_mask[-1])
                continue

            if index in self.prune_out or index in self.prune_both:

                m_bn = bnList[index]
                weight_copy = m_bn.weight.data.abs().clone()
                mask = weight_copy.gt(thre).float().cuda()

                if int(torch.sum(mask)) < self.bottleneck:
                    mask = getNewMask(m_bn, self.bottleneck)
                m_bn.weight.data.mul_(mask)
                m_bn.bias.data.mul_(mask)

                self.cfg_mask.append(mask)
                if index in self.prune_both:
                    self.cfg.append((self.cfg[-1][1], int(torch.sum(mask))))
                else:
                    self.cfg.append((in_channels, int(torch.sum(mask))))
def train(model, data_loader, optimizer,
          init_lr=0.002,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None,
          clip_thresh=1.0):
    model.train()
    if use_cuda:
        model = model.cuda()
    linear_dim = model.linear_dim

    criterion = nn.L1Loss()
    #criterion_noavg = nn.L1Loss(reduction='none')

    global global_step, global_epoch
    while global_epoch < nepochs:
        running_loss = 0.
        for step, (x, input_lengths, mel, y) in tqdm(enumerate(data_loader)):
            # Decay learning rate
            current_lr = learning_rate_decay(init_lr, global_step)
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

            optimizer.zero_grad()

            # Sort by length
            sorted_lengths, indices = torch.sort(
                input_lengths.view(-1), dim=0, descending=True)
            sorted_lengths = sorted_lengths.long().numpy()

            x, mel, y = x[indices], mel[indices], y[indices]

            # Feed data
            x, mel, y = Variable(x), Variable(mel), Variable(y)
            if use_cuda:
                x, mel, y = x.cuda(), mel.cuda(), y.cuda()
            mel_outputs, linear_outputs, attn = model.forward_nomasking(
                x, mel, input_lengths=sorted_lengths)

            # Loss
            mel_loss = criterion(mel_outputs, mel)
            n_priority_freq = int(3000 / (fs * 0.5) * linear_dim)
            linear_loss = 0.5 * criterion(linear_outputs, y) \
                + 0.5 * criterion(linear_outputs[:, :, :n_priority_freq],
                                  y[:, :, :n_priority_freq])
            loss = mel_loss + linear_loss

            if global_step > 0 and global_step % checkpoint_interval == 0:
                save_states(
                    global_step, mel_outputs, linear_outputs, attn, y,
                    sorted_lengths, checkpoint_dir)
                save_checkpoint(
                    model, optimizer, global_step, checkpoint_dir, global_epoch)

            # Update
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm(
                model.parameters(), clip_thresh)
            optimizer.step()

            # Logs
            log_value("loss", float(loss.item()), global_step)
            log_value("mel loss", float(mel_loss.item()), global_step)
            log_value("linear loss", float(linear_loss.item()), global_step)
            log_value("gradient norm", grad_norm, global_step)
            log_value("learning rate", current_lr, global_step)

            global_step += 1
            running_loss += loss.item()

        averaged_loss = running_loss / (len(data_loader))
        log_value("loss (per epoch)", averaged_loss, global_epoch)
        #noavg_loss = criterion_noavg(mel_outputs, mel)
        #noavg_loss = noavg_loss.reshape(noavg_loss.shape[0], -1)
        #print("No avg output: ", torch.mean(noavg_loss, dim=-1))
        print("Loss: {}".format(running_loss / (len(data_loader))))


        global_epoch += 1
Esempio n. 38
0
def split_target(args):
    test_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
    ])

    txt_tar = open(args.t_dset_path).readlines()
    if not args.da == 'uda':
        label_map_s = {}
        for i in range(len(args.src_classes)):
            label_map_s[args.src_classes[i]] = i

        new_tar = []
        for i in range(len(txt_tar)):
            rec = txt_tar[i]
            reci = rec.strip().split(' ')
            if int(reci[1]) in args.tar_classes:
                if int(reci[1]) in args.src_classes:
                    line = reci[0] + ' ' + str(label_map_s[int(
                        reci[1])]) + '\n'
                    new_tar.append(line)
                else:
                    line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
                    new_tar.append(line)
        txt_tar = new_tar.copy()

    dset_loaders = {}
    test_set = ImageList(txt_tar, transform=test_transform)
    dset_loaders["target"] = torch.utils.data.DataLoader(
        test_set,
        batch_size=args.batch_size * 3,
        shuffle=False,
        num_workers=args.worker,
        drop_last=False)

    netF = network.ResBase(res_name=args.net).cuda()
    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    if args.model == "source":
        modelpath = args.output_dir + "/source_F.pt"
        netF.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_B.pt"
        netB.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_C.pt"
        netC.load_state_dict(torch.load(modelpath))
    else:
        modelpath = args.output_dir + "/target_F_" + args.savename + ".pt"
        netF.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_B_" + args.savename + ".pt"
        netB.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_C_" + args.savename + ".pt"
        netC.load_state_dict(torch.load(modelpath))

    netF.eval()
    netB.eval()
    netC.eval()

    start_test = True
    with torch.no_grad():
        iter_test = iter(dset_loaders['target'])
        for i in range(len(dset_loaders['target'])):
            data = iter_test.next()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netC(netB(netF(inputs)))
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
    top_pred, predict = torch.max(all_output, 1)
    acc = torch.sum(
        torch.squeeze(predict).float() == all_label).item() / float(
            all_label.size()[0]) * 100
    mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output))

    if args.dset == 'VISDA-C':
        matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
        matrix = matrix[np.unique(all_label).astype(int), :]
        all_acc = matrix.diagonal() / matrix.sum(axis=1) * 100
        acc = all_acc.mean()
        aa = [str(np.round(i, 2)) for i in all_acc]
        acc_list = ' '.join(aa)
        print(acc_list)
        args.out_file.write(acc_list + '\n')
        args.out_file.flush()

    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
        args.name, 0, 0, acc, mean_ent.mean())
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')

    if args.ps == 0:
        est_p = (mean_ent < mean_ent.mean()).sum().item() / mean_ent.size(0)
        log_str = 'Task: {:.2f}'.format(est_p)
        print(log_str + '\n')
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        PS = est_p
    else:
        PS = args.ps

    if args.choice == "ent":
        value = mean_ent
    elif args.choice == "maxp":
        value = -top_pred
    elif args.choice == "marginp":
        pred, _ = torch.sort(all_output, 1)
        value = pred[:, 1] - pred[:, 0]
    else:
        value = torch.rand(len(mean_ent))

    ori_target = txt_tar.copy()
    new_tar = []
    new_src = []

    predict = predict.numpy()

    cls_k = args.class_num
    for c in range(cls_k):
        c_idx = np.where(predict == c)
        c_idx = c_idx[0]
        c_value = value[c_idx]

        _, idx_ = torch.sort(c_value)
        c_num = len(idx_)
        c_num_s = int(c_num * PS)

        for ei in range(0, c_num_s):
            ee = c_idx[idx_[ei]]
            reci = ori_target[ee].strip().split(' ')
            line = reci[0] + ' ' + str(c) + '\n'
            new_src.append(line)
        for ei in range(c_num_s, c_num):
            ee = c_idx[idx_[ei]]
            reci = ori_target[ee].strip().split(' ')
            line = reci[0] + ' ' + str(c) + '\n'
            new_tar.append(line)

    return new_src.copy(), new_tar.copy()
Esempio n. 39
0
    def _generate(self,
                  encoder_input,
                  beam_size=None,
                  maxlen=None,
                  prefix_tokens=None):
        """See generate"""
        src_tokens = encoder_input['src_tokens']
        bsz, srclen = src_tokens.size()
        maxlen = min(maxlen,
                     self.maxlen) if maxlen is not None else self.maxlen

        # the max beam size is the dictionary size - 1, since we never select pad
        beam_size = beam_size if beam_size is not None else self.beam_size
        beam_size = min(beam_size, self.vocab_size - 1)

        encoder_outs = []
        incremental_states = {}
        for model in self.models:
            if not self.retain_dropout:
                model.eval()
            if isinstance(model.decoder, FairseqIncrementalDecoder):
                incremental_states[model] = {}
            else:
                incremental_states[model] = None

            # compute the encoder output for each beam
            encoder_out = model.encoder(**encoder_input)
            new_order = torch.arange(bsz,
                                     dtype=torch.int64).view(-1, 1).repeat(
                                         1, beam_size).view(-1)
            new_order = new_order.to(src_tokens.device)
            encoder_out = model.encoder.reorder_encoder_out(
                encoder_out, new_order)
            encoder_outs.append(encoder_out)

        # initialize buffers
        scores = src_tokens.data.new(bsz * beam_size,
                                     maxlen + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src_tokens.data.new(bsz * beam_size,
                                     maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos
        attn, attn_buf = None, None
        nonpad_idxs = None

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        worst_finalized = [{
            'idx': None,
            'score': -math.inf
        } for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) *
                        beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                if self.stop_early or step == maxlen or unfinalized_scores is None:
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                best_unfinalized_score = unfinalized_scores[sent].max()
                if self.normalize_scores:
                    best_unfinalized_score /= maxlen**self.len_penalty
                if worst_finalized[sent]['score'] >= best_unfinalized_score:
                    return True
            return False

        def finalize_hypos(step,
                           bbsz_idx,
                           eos_scores,
                           unfinalized_scores=None):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.
            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.
            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
                unfinalized_scores: A vector containing scores for all
                    unfinalized hypotheses
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step +
                                        2]  # skip the first index, which is EOS
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(
                0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1)**self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

            sents_seen = set()
            for i, (idx, score) in enumerate(
                    zip(bbsz_idx.tolist(), eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i][nonpad_idxs[sent]]
                        _, alignment = hypo_attn.max(dim=0)
                    else:
                        hypo_attn = None
                        alignment = None

                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': alignment,
                        'positional_scores': pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                elif not self.stop_early and score > worst_finalized[sent][
                        'score']:
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]['idx']
                    if worst_idx is not None:
                        finalized[sent][worst_idx] = get_hypo()

                    # find new worst finalized hypo for this sentence
                    idx, s = min(enumerate(finalized[sent]),
                                 key=lambda r: r[1]['score'])
                    worst_finalized[sent] = {
                        'score': s['score'],
                        'idx': idx,
                    }

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step,
                                                      unfinalized_scores):
                    finished[sent] = True
                    newly_finished.append(unfin_idx)
            return newly_finished

        reorder_state = None
        batch_idxs = None
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(
                        batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(
                        corr.unsqueeze(-1) * beam_size)
                for i, model in enumerate(self.models):
                    if isinstance(model.decoder, FairseqIncrementalDecoder):
                        model.decoder.reorder_incremental_state(
                            incremental_states[model], reorder_state)
                    encoder_outs[i] = model.encoder.reorder_encoder_out(
                        encoder_outs[i], reorder_state)

            lprobs, avg_attn_scores = self._decode(tokens[:, :step + 1],
                                                   encoder_outs,
                                                   incremental_states)

            lprobs[:, self.pad] = -math.inf  # never select pad
            lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty

            # Record attention scores
            if avg_attn_scores is not None:
                if attn is None:
                    attn = scores.new(bsz * beam_size, src_tokens.size(1),
                                      maxlen + 2)
                    attn_buf = attn.clone()
                    nonpad_idxs = src_tokens.ne(self.pad)
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            scores_buf = scores_buf.type_as(lprobs)
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)
            if step < maxlen:
                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:,
                                                                        0, :]
                    cand_scores = torch.gather(
                        probs_slice,
                        dim=1,
                        index=prefix_tokens[:, step].view(-1, 1).data).expand(
                            -1, cand_size)
                    cand_indices = prefix_tokens[:, step].view(-1, 1).expand(
                        bsz, cand_size).data
                    cand_beams = torch.zeros_like(cand_indices)
                else:
                    cand_scores, cand_indices, cand_beams = self.search.step(
                        step,
                        lprobs.view(bsz, -1, self.vocab_size),
                        scores.view(bsz, beam_size, -1)[:, :, :step],
                    )
            else:
                # make probs contain cumulative scores for each hypothesis
                lprobs.add_(scores[:, step - 1].unsqueeze(-1))

                # finalize all active hypotheses once we hit maxlen
                # pick the hypothesis with the highest prob of EOS right now
                torch.sort(
                    lprobs[:, self.eos],
                    descending=True,
                    out=(eos_scores, eos_bbsz_idx),
                )
                num_remaining_sent -= len(
                    finalize_hypos(step, eos_bbsz_idx, eos_scores))
                assert num_remaining_sent == 0
                break

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)

            finalized_sents = set()
            if step >= self.minlen:
                # only consider eos when it's among the top beam_size indices
                torch.masked_select(
                    cand_bbsz_idx[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_bbsz_idx,
                )
                if eos_bbsz_idx.numel() > 0:
                    torch.masked_select(
                        cand_scores[:, :beam_size],
                        mask=eos_mask[:, :beam_size],
                        out=eos_scores,
                    )
                    finalized_sents = finalize_hypos(step, eos_bbsz_idx,
                                                     eos_scores, cand_scores)
                    num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < maxlen

            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = batch_mask.nonzero().squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)

                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(
                        new_bsz * beam_size, attn.size(1), -1)
                    attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer('active_mask')
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
            torch.topk(active_mask,
                       k=beam_size,
                       dim=1,
                       largest=False,
                       out=(_ignore, active_hypos))

            active_bbsz_idx = buffer('active_bbsz_idx')
            torch.gather(
                cand_bbsz_idx,
                dim=1,
                index=active_hypos,
                out=active_bbsz_idx,
            )
            active_scores = torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1],
                dim=0,
                index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices,
                dim=1,
                index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step],
                    dim=0,
                    index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            if attn is not None:
                torch.index_select(
                    attn[:, :, :step + 2],
                    dim=0,
                    index=active_bbsz_idx,
                    out=attn_buf[:, :, :step + 2],
                )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(len(finalized)):
            finalized[sent] = sorted(finalized[sent],
                                     key=lambda r: r['score'],
                                     reverse=True)

        return finalized
Esempio n. 40
0
def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
    ##### parallel rendering of a single image
    ray_batch = ray_sampler.get_all()

    if (ray_batch['ray_d'].shape[0] //
            world_size) * world_size != ray_batch['ray_d'].shape[0]:
        raise Exception(
            'Number of pixels in the image is not divisible by the number of GPUs!\n\t# pixels: {}\n\t# GPUs: {}'
            .format(ray_batch['ray_d'].shape[0], world_size))

    # split into ranks; make sure different processes don't overlap
    rank_split_sizes = [
        ray_batch['ray_d'].shape[0] // world_size,
    ] * world_size
    rank_split_sizes[-1] = ray_batch['ray_d'].shape[0] - sum(
        rank_split_sizes[:-1])
    for key in ray_batch:
        if torch.is_tensor(ray_batch[key]):
            ray_batch[key] = torch.split(ray_batch[key],
                                         rank_split_sizes)[rank].to(rank)

    # split into chunks and render inside each process
    ray_batch_split = OrderedDict()
    for key in ray_batch:
        if torch.is_tensor(ray_batch[key]):
            ray_batch_split[key] = torch.split(ray_batch[key], chunk_size)

    # forward and backward
    ret_merge_chunk = [OrderedDict() for _ in range(models['cascade_level'])]
    for s in range(len(ray_batch_split['ray_d'])):
        ray_o = ray_batch_split['ray_o'][s]
        ray_d = ray_batch_split['ray_d'][s]
        min_depth = ray_batch_split['min_depth'][s]

        dots_sh = list(ray_d.shape[:-1])
        for m in range(models['cascade_level']):
            net = models['net_{}'.format(m)]
            # sample depths
            N_samples = models['cascade_samples'][m]
            if m == 0:
                # foreground depth
                fg_far_depth = intersect_sphere(ray_o, ray_d)  # [...,]
                fg_near_depth = min_depth  # [..., ]
                step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
                fg_depth = torch.stack(
                    [fg_near_depth + i * step for i in range(N_samples)],
                    dim=-1)  # [..., N_samples]

                # background depth
                bg_depth = torch.linspace(0., 1., N_samples).view([
                    1,
                ] * len(dots_sh) + [
                    N_samples,
                ]).expand(dots_sh + [
                    N_samples,
                ]).to(rank)

                # delete unused memory
                del fg_near_depth
                del step
                torch.cuda.empty_cache()
            else:
                # sample pdf and concat with earlier samples
                fg_weights = ret['fg_weights'].clone().detach()
                fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]
                                     )  # [..., N_samples-1]
                fg_weights = fg_weights[..., 1:-1]  # [..., N_samples-2]
                fg_depth_samples = sample_pdf(bins=fg_depth_mid,
                                              weights=fg_weights,
                                              N_samples=N_samples,
                                              det=True)  # [..., N_samples]
                fg_depth, _ = torch.sort(
                    torch.cat((fg_depth, fg_depth_samples), dim=-1))

                # sample pdf and concat with earlier samples
                bg_weights = ret['bg_weights'].clone().detach()
                bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
                bg_weights = bg_weights[..., 1:-1]  # [..., N_samples-2]
                bg_depth_samples = sample_pdf(bins=bg_depth_mid,
                                              weights=bg_weights,
                                              N_samples=N_samples,
                                              det=True)  # [..., N_samples]
                bg_depth, _ = torch.sort(
                    torch.cat((bg_depth, bg_depth_samples), dim=-1))

                # delete unused memory
                del fg_weights
                del fg_depth_mid
                del fg_depth_samples
                del bg_weights
                del bg_depth_mid
                del bg_depth_samples
                torch.cuda.empty_cache()

            with torch.no_grad():
                ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth)

            for key in ret:
                if key not in ['fg_weights', 'bg_weights']:
                    if torch.is_tensor(ret[key]):
                        if key not in ret_merge_chunk[m]:
                            ret_merge_chunk[m][key] = [
                                ret[key].cpu(),
                            ]
                        else:
                            ret_merge_chunk[m][key].append(ret[key].cpu())

                        ret[key] = None

            # clean unused memory
            torch.cuda.empty_cache()

    # merge results from different chunks
    for m in range(len(ret_merge_chunk)):
        for key in ret_merge_chunk[m]:
            ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0)

    # merge results from different processes
    if rank == 0:
        ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))]
        for m in range(len(ret_merge_chunk)):
            for key in ret_merge_chunk[m]:
                # generate tensors to store results from other processes
                sh = list(ret_merge_chunk[m][key].shape[1:])
                ret_merge_rank[m][key] = [
                    torch.zeros(*[
                        size,
                    ] + sh, dtype=torch.float32) for size in rank_split_sizes
                ]
                torch.distributed.gather(ret_merge_chunk[m][key],
                                         ret_merge_rank[m][key])
                ret_merge_rank[m][key] = torch.cat(ret_merge_rank[m][key],
                                                   dim=0).reshape(
                                                       (ray_sampler.H,
                                                        ray_sampler.W,
                                                        -1)).squeeze()
                # print(m, key, ret_merge_rank[m][key].shape)
    else:  # send results to main process
        for m in range(len(ret_merge_chunk)):
            for key in ret_merge_chunk[m]:
                torch.distributed.gather(ret_merge_chunk[m][key])

    # only rank 0 program returns
    if rank == 0:
        return ret_merge_rank
    else:
        return None
Esempio n. 41
0
    def decode(self, x_tree_vecs, prob_decode):
        assert x_tree_vecs.size(0) == 1

        stack = []
        init_hiddens = create_var( torch.zeros(1, self.hidden_size) )
        zero_pad = create_var(torch.zeros(1,1,self.hidden_size))
        contexts = create_var( torch.LongTensor(1).zero_() )

        #Root Prediction
        root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs, 'word')
        _,root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = MolTreeNode(self.vocab.get_smiles(root_wid))
        root.wid = root_wid
        root.idx = 0
        stack.append( (root, self.vocab.get_slots(root.wid)) )

        all_nodes = [root]
        h = {}
        for step in range(MAX_DECODE_LEN):
            node_x,fa_slot = stack[-1]
            cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = create_var(torch.LongTensor([node_x.wid]))
            cur_x = self.embedding(cur_x)

            #Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hiddens = torch.cat([cur_x,cur_h], dim=1)
            stop_hiddens = F.relu( self.U_i(stop_hiddens) )
            stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs, 'stop')

            if prob_decode:
                backtrack = (torch.bernoulli( torch.sigmoid(stop_score) ).item() == 0)
            else:
                backtrack = (stop_score.item() < 0)

            if not backtrack: #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
                pred_score = self.aggregate(new_h, contexts, x_tree_vecs, 'word')

                if prob_decode:
                    sort_wid = torch.multinomial(F.softmax(pred_score, dim=1).squeeze(), 5)
                else:
                    _,sort_wid = torch.sort(pred_score, dim=1, descending=True)
                    sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = MolTreeNode(self.vocab.get_smiles(wid))
                    if have_slots(fa_slot, slots) and can_assemble(node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                if next_wid is None:
                    backtrack = True #No more children can be added
                else:
                    node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
                    node_y.wid = next_wid
                    node_y.idx = len(all_nodes)
                    node_y.neighbors.append(node_x)
                    h[(node_x.idx,node_y.idx)] = new_h[0]
                    stack.append( (node_y,next_slots) )
                    all_nodes.append(node_y)

            if backtrack: #Backtrack, use if instead of else
                if len(stack) == 1:
                    break #At root, terminate

                node_fa,_ = stack[-2]
                cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
                h[(node_x.idx,node_fa.idx)] = new_h[0]
                node_fa.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes
Esempio n. 42
0
    def add_point(self, verb_predict, gt_verbs, labels_predict, gt_labels,
                  roles):
        #encoded predictions should be batch x verbs x values #assumes the are the same order as the references
        #encoded reference should be batch x 1+ references*roles,values (sorted)

        batch_size = verb_predict.size()[0]
        for i in range(batch_size):
            verb_pred = verb_predict[i]
            gt_verb = gt_verbs[i]
            label_pred = self.rearrange_label_pred(labels_predict[i])
            gt_label = gt_labels[i]
            role_set = roles[i]
            #print('check sizes:', verb_pred.size(), gt_verb.size(), label_pred.size(), gt_label.size())
            sorted_idx = torch.sort(verb_pred, 0, True)[1]

            gt_v = gt_verb
            #print('sorted idx:',self.topk, sorted_idx[:self.topk], gt_v)
            #print('groud truth verb id:', gt_v)

            new_card = {
                "verb": 0.0,
                "value": 0.0,
                "value*": 0.0,
                "n_value": 0.0,
                "value-all": 0.0,
                "value-all*": 0.0
            }

            score_card = new_card

            verb_found = (torch.sum(sorted_idx[0:self.topk] == gt_v) == 1)
            if verb_found: score_card["verb"] += 1

            gt_role_count = self.encoder.get_role_count(gt_v)
            gt_role_list = self.encoder.verb2_role_dict[
                self.encoder.verb_list[gt_v]]
            score_card["n_value"] += gt_role_count

            all_found = True
            pred_list = []
            for k in range(0, self.encoder.get_max_role_count()):
                role_id = role_set[k]
                if role_id == len(self.encoder.role_list):
                    continue
                current_role = self.encoder.role_list[role_id]
                if current_role not in gt_role_list:
                    continue

                label_id = torch.max(label_pred[k], 0)[1]
                pred_list.append(label_id.item())
                found = False
                for r in range(0, self.nref):
                    gt_label_id = gt_label[r][k]
                    #print('ground truth label id = ', gt_label_id)
                    if label_id == gt_label_id:
                        found = True
                        break
                if not found: all_found = False
                #both verb and at least one val found
                if found and verb_found: score_card["value"] += 1
                #at least one val found
                if found: score_card["value*"] += 1
            '''if self.topk == 1:
                print('predicted labels :',pred_list)'''
            #both verb and all values found
            score_card["value*"] /= gt_role_count
            score_card["value"] /= gt_role_count
            if all_found and verb_found: score_card["value-all"] += 1
            #all values found
            if all_found: score_card["value-all*"] += 1

            self.score_cards.append(new_card)
Esempio n. 43
0
        scores = scores.squeeze()
        pred_boxes = pred_boxes.squeeze()
        pred_center = pred_center.squeeze()
        det_toc = time.time()
        detect_time = det_toc - det_tic
        misc_tic = time.time()
        if vis:
            im = cv2.imread(imdb.image_path_at(i))
            im2show = np.copy(im)
        for j in xrange(1, imdb.num_classes):
            inds = torch.nonzero(scores[:, j] > thresh).view(-1)
            # if there is det
            if inds.numel() > 0:
                cls_scores = scores[:, j][inds]
                _, order = torch.sort(cls_scores, 0, True)
                if args.class_agnostic:
                    cls_boxes = pred_boxes[inds, :]
                    cls_centers = pred_center[inds, :]
                else:
                    cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]
                    cls_centers = pred_center[inds][:, j * 2:(j + 1) * 2]

                cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
                cls_dets_with_center = torch.cat(
                    (cls_boxes, cls_centers, cls_scores.unsqueeze(1)), 1)
                # cls_dets = torch.cat((cls_boxes, cls_scores), 1)
                cls_dets = cls_dets[order]
                cls_dets_with_center = cls_dets_with_center[order]
                keep = torch.arange(
                    cls_dets.shape[0])  # nms(cls_dets, cfg.TEST.NMS)
Esempio n. 44
0
def build_targets_max(target, anchor_wh, nA, nC, nGh, nGw):
    """
    returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
    """
    nB = len(target)  # number of images in batch

    txy = torch.zeros(nB, nA, nGh, nGw,
                      2).cuda()  # batch size, anchors, grid size
    twh = torch.zeros(nB, nA, nGh, nGw, 2).cuda()
    tconf = torch.LongTensor(nB, nA, nGh, nGw).fill_(0).cuda()
    tcls = torch.ByteTensor(nB, nA, nGh, nGw,
                            nC).fill_(0).cuda()  # nC = number of classes
    tid = torch.LongTensor(nB, nA, nGh, nGw, 1).fill_(-1).cuda()
    for b in range(nB):
        t = target[b]
        t_id = t[:, 1].clone().long().cuda()
        t = t[:, [0, 2, 3, 4, 5]]
        nTb = len(t)  # number of targets
        if nTb == 0:
            continue

        #gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG
        gxy, gwh = t[:, 1:3].clone(), t[:, 3:5].clone()
        gxy[:, 0] = gxy[:, 0] * nGw
        gxy[:, 1] = gxy[:, 1] * nGh
        gwh[:, 0] = gwh[:, 0] * nGw
        gwh[:, 1] = gwh[:, 1] * nGh
        gi = torch.clamp(gxy[:, 0], min=0, max=nGw - 1).long()
        gj = torch.clamp(gxy[:, 1], min=0, max=nGh - 1).long()

        # Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors)
        #gi, gj = torch.clamp(gxy.long(), min=0, max=nG - 1).t()
        #gi, gj = gxy.long().t()

        # iou of targets-anchors (using wh only)
        box1 = gwh
        box2 = anchor_wh.unsqueeze(1)
        inter_area = torch.min(box1, box2).prod(2)
        iou = inter_area / (box1.prod(1) + box2.prod(2) - inter_area + 1e-16)

        # Select best iou_pred and anchor
        iou_best, a = iou.max(0)  # best anchor [0-2] for each target

        # Select best unique target-anchor combinations
        if nTb > 1:
            _, iou_order = torch.sort(-iou_best)  # best to worst

            # Unique anchor selection
            u = torch.stack((gi, gj, a), 0)[:, iou_order]
            # _, first_unique = np.unique(u, axis=1, return_index=True)  # first unique indices
            first_unique = return_torch_unique_index(u, torch.unique(
                u, dim=1))  # torch alternative
            i = iou_order[first_unique]
            # best anchor must share significant commonality (iou) with target
            i = i[iou_best[i] > 0.60]  # TODO: examine arbitrary threshold
            if len(i) == 0:
                continue

            a, gj, gi, t = a[i], gj[i], gi[i], t[i]
            t_id = t_id[i]
            if len(t.shape) == 1:
                t = t.view(1, 5)
        else:
            if iou_best < 0.60:
                continue

        tc, gxy, gwh = t[:, 0].long(), t[:, 1:3].clone(), t[:, 3:5].clone()
        gxy[:, 0] = gxy[:, 0] * nGw
        gxy[:, 1] = gxy[:, 1] * nGh
        gwh[:, 0] = gwh[:, 0] * nGw
        gwh[:, 1] = gwh[:, 1] * nGh

        # XY coordinates
        txy[b, a, gj, gi] = gxy - gxy.floor()

        # Width and height
        twh[b, a, gj, gi] = torch.log(gwh / anchor_wh[a])  # yolo method
        # twh[b, a, gj, gi] = torch.sqrt(gwh / anchor_wh[a]) / 2 # power method

        # One-hot encoding of label
        tcls[b, a, gj, gi, tc] = 1
        tconf[b, a, gj, gi] = 1
        tid[b, a, gj, gi] = t_id.unsqueeze(1)
    tbox = torch.cat([txy, twh], -1)
    return tconf, tbox, tid
def _evaluate_box_proposals(dataset_predictions,
                            coco_api,
                            thresholds=None,
                            area="all",
                            limit=None):
    """
    Evaluate detection proposal recall metrics. This function is a much
    faster alternative to the official COCO API recall evaluation code. However,
    it produces slightly different results.
    """
    # Record max overlap value for each gt box
    # Return vector of overlap values
    areas = {
        "all": 0,
        "small": 1,
        "medium": 2,
        "large": 3,
        "96-128": 4,
        "128-256": 5,
        "256-512": 6,
        "512-inf": 7,
    }
    area_ranges = [
        [0**2, 1e5**2],  # all
        [0**2, 32**2],  # small
        [32**2, 96**2],  # medium
        [96**2, 1e5**2],  # large
        [96**2, 128**2],  # 96-128
        [128**2, 256**2],  # 128-256
        [256**2, 512**2],  # 256-512
        [512**2, 1e5**2],
    ]  # 512-inf
    assert area in areas, "Unknown area range: {}".format(area)
    area_range = area_ranges[areas[area]]
    gt_overlaps = []
    num_pos = 0

    for prediction_dict in dataset_predictions:
        predictions = prediction_dict["proposals"]

        # sort predictions in descending order
        # TODO maybe remove this and make it explicit in the documentation
        inds = predictions.objectness_logits.sort(descending=True)[1]
        predictions = predictions[inds]

        ann_ids = coco_api.getAnnIds(imgIds=prediction_dict["image_id"])
        anno = coco_api.loadAnns(ann_ids)
        gt_boxes = [
            BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
            for obj in anno if obj["iscrowd"] == 0
        ]
        gt_boxes = torch.as_tensor(gt_boxes).reshape(
            -1, 4)  # guard against no boxes
        gt_boxes = Boxes(gt_boxes)
        gt_areas = torch.as_tensor(
            [obj["area"] for obj in anno if obj["iscrowd"] == 0])

        if len(gt_boxes) == 0 or len(predictions) == 0:
            continue

        valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <=
                                                       area_range[1])
        gt_boxes = gt_boxes[valid_gt_inds]

        num_pos += len(gt_boxes)

        if len(gt_boxes) == 0:
            continue

        if limit is not None and len(predictions) > limit:
            predictions = predictions[:limit]

        overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)

        _gt_overlaps = torch.zeros(len(gt_boxes))
        for j in range(min(len(predictions), len(gt_boxes))):
            # find which proposal box maximally covers each gt box
            # and get the iou amount of coverage for each gt box
            max_overlaps, argmax_overlaps = overlaps.max(dim=0)

            # find which gt box is 'best' covered (i.e. 'best' = most iou)
            gt_ovr, gt_ind = max_overlaps.max(dim=0)
            assert gt_ovr >= 0
            # find the proposal box that covers the best covered gt box
            box_ind = argmax_overlaps[gt_ind]
            # record the iou coverage of this gt box
            _gt_overlaps[j] = overlaps[box_ind, gt_ind]
            assert _gt_overlaps[j] == gt_ovr
            # mark the proposal box and the gt box as used
            overlaps[box_ind, :] = -1
            overlaps[:, gt_ind] = -1

        # append recorded iou coverage level
        gt_overlaps.append(_gt_overlaps)
    gt_overlaps = torch.cat(gt_overlaps, dim=0)
    gt_overlaps, _ = torch.sort(gt_overlaps)

    if thresholds is None:
        step = 0.05
        thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
    recalls = torch.zeros_like(thresholds)
    # compute recall for each iou threshold
    for i, t in enumerate(thresholds):
        recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
    # ar = 2 * np.trapz(recalls, thresholds)
    ar = recalls.mean()
    return {
        "ar": ar,
        "recalls": recalls,
        "thresholds": thresholds,
        "gt_overlaps": gt_overlaps,
        "num_pos": num_pos,
    }
Esempio n. 46
0
def pcl_to_prim_loss(
    y_hat,
    X_transformed,
    D,
    use_cuboids=False,
    use_sq=False,
    use_chamfer=False
):
    """
    Arguments:
    ----------
        y_hat: List of Tensors containing the predictions of the network
        X_transformed: Tensor with size BxNxMx3 with the N points from the
                       target object transformed in the M primitive-centric
                       coordinate systems
        D: Tensor of size BxMxSxN that contains the pairwise distances between
           points on the surface of the SQ to the points on the target object
        use_cuboids: when True use cuboids as geometric primitives
        use_sq: when True use superquadrics as geometric primitives
        use_chamfer: when True compute the Chamfer distance
    """
    # Declare some variables
    B = X_transformed.shape[0]  # batch size
    N = X_transformed.shape[1]  # number of points per sample
    M = X_transformed.shape[2]  # number of primitives

    shapes = y_hat[3].view(B, M, 3)
    epsilons = y_hat[4].view(B, M, 2)
    probs = y_hat[0]

    # Get the relative position of points with respect to the SQs using the
    # inside-outside function
    F = shapes.new_tensor(0)
    inside = None

    # XXX
    # if not use_chamfer: # you should still calculate the F's regardless...
    if True:
        if use_cuboids:
            F = points_to_cuboid_distances(X_transformed, shapes)
            inside = F <= 0
        elif use_sq:
            F = inside_outside_function(
                X_transformed,
                shapes,
                epsilons
            )
            inside = F <= 1
        else:
            # If no argument is given (use_sq and use_cuboids) the default
            # geometric primitives are cuboidal superquadrics, namely
            # with \epsilon_1=\epsilon_2=0.25
            F = cuboid_inside_outside_function(
                X_transformed,
                shapes,
                epsilon=0.25
            )
            inside = F <= 1

    D = torch.min(D, 2)[0].permute(0, 2, 1)  # size BxNxM
    assert D.shape == (B, N, M)

    if not use_chamfer:
        D[inside] = 0.0
    distances, idxs = torch.sort(D, dim=-1)

    # Start by computing the cumulative product
    # Sort based on the indices
    probs = torch.cat([
        probs[i].take(idxs[i]).unsqueeze(0) for i in range(len(idxs))
    ])
    neg_cumprod = torch.cumprod(1-probs, dim=-1)
    neg_cumprod = torch.cat(
        [neg_cumprod.new_ones((B, N, 1)), neg_cumprod[:, :, :-1]],
        dim=-1
    )

    # minprob[i, j, k] is the probability that for sample i and point j the
    # k-th primitive has the minimum loss
    minprob = probs.mul(neg_cumprod)

    loss = torch.einsum("ijk,ijk->", [distances, minprob])
    loss = loss / B / N

    # Return some debug statistics
    debug_stats = {}
    debug_stats["F"] = F
    debug_stats["distances"] = distances
    debug_stats["minprob"] = minprob
    debug_stats["neg_cumprod"] = neg_cumprod
    return loss, inside, debug_stats
Esempio n. 47
0
                                             batch_size=args.size,
                                             drop_last=True,
                                             shuffle=False,
                                             num_workers=int(cfg.WORKERS))

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    state_dict = torch.load("./networks/pretrain/text_encoder.pth",
                            map_location=torch.device('cpu'))
    text_encoder.load_state_dict(state_dict)

    for step, data in enumerate(dataloader):
        shape, cap, cap_len, cls_id, key = data

        sorted_cap_lens, sorted_cap_indices = torch.sort(cap_len, 0, True)

        shapes = shape[sorted_cap_indices].squeeze()
        captions = cap[sorted_cap_indices].squeeze()
        cap_len = cap_len[sorted_cap_indices].squeeze()
        key = np.asarray(key)
        key = key[sorted_cap_indices].squeeze()
        class_ids = cls_id[sorted_cap_indices].squeeze().numpy()

        hidden = text_encoder.init_hidden(args.size)
        words_emb, sent_emb = text_encoder(captions, sorted_cap_lens, hidden)

        sent = sent_emb.cpu().detach().numpy()

        # np.savetxt("./plts/train.tsv", sent, delimiter="\t")
        # np.savetxt("./plts/label.tsv", key, delimiter="\t", fmt="%s")
Esempio n. 48
0
    def _generate(self,
                  src_tokens,
                  src_lengths,
                  beam_size=None,
                  maxlen=None,
                  prefix_tokens=None):
        bsz, srclen = src_tokens.size()
        maxlen = min(maxlen,
                     self.maxlen) if maxlen is not None else self.maxlen

        # the max beam size is the dictionary size - 1, since we never select pad
        beam_size = beam_size if beam_size is not None else self.beam_size
        beam_size = min(beam_size, self.vocab_size - 1)

        encoder_outs = []
        incremental_states = {}
        for model in self.models:
            if not self.retain_dropout:
                model.eval()
            if isinstance(model.decoder, FairseqIncrementalDecoder):
                incremental_states[model] = {}
            else:
                incremental_states[model] = None

            # compute the encoder output for each beam
            encoder_out = model.encoder(
                src_tokens.repeat(1, beam_size).view(-1, srclen),
                src_lengths.repeat(beam_size),
            )
            encoder_outs.append(encoder_out)

        # initialize buffers
        scores = src_tokens.data.new(bsz * beam_size,
                                     maxlen + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src_tokens.data.new(bsz * beam_size,
                                     maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos
        attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
        attn_buf = attn.clone()

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        worst_finalized = [{
            'idx': None,
            'score': -math.inf
        } for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) *
                        beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                if self.stop_early or step == maxlen or unfinalized_scores is None:
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                best_unfinalized_score = unfinalized_scores[sent].max()
                if self.normalize_scores:
                    best_unfinalized_score /= maxlen
                if worst_finalized[sent]['score'] >= best_unfinalized_score:
                    return True
            return False

        def finalize_hypos(step,
                           bbsz_idx,
                           eos_scores,
                           unfinalized_scores=None):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
                unfinalized_scores: A vector containing scores for all
                    unfinalized hypotheses
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step +
                                        2]  # skip the first index, which is EOS
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step + 2]

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1)**self.len_penalty

            sents_seen = set()
            for i, (idx, score) in enumerate(
                    zip(
                        bbsz_idx.tolist(),
                        eos_scores.tolist(),
                    ), ):
                sent = idx // beam_size
                sents_seen.add(sent)

                def get_hypo():
                    _, alignment = attn_clone[i].max(dim=0)
                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': attn_clone[i],  # src_len x tgt_len
                        'alignment': alignment,
                        'positional_scores': pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                elif not self.stop_early and score > worst_finalized[sent][
                        'score']:
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]['idx']
                    if worst_idx is not None:
                        finalized[sent][worst_idx] = get_hypo()

                    # find new worst finalized hypo for this sentence
                    idx, s = min(
                        enumerate(finalized[sent]),
                        key=lambda r: r[1]['score'],
                    )
                    worst_finalized[sent] = {
                        'score': s['score'],
                        'idx': idx,
                    }

            # return number of hypotheses finished this step
            num_finished = 0
            for sent in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step,
                                                      unfinalized_scores):
                    finished[sent] = True
                    num_finished += 1
            return num_finished

        reorder_state = None
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                for model in self.models:
                    if isinstance(model.decoder, FairseqIncrementalDecoder):
                        model.decoder.reorder_incremental_state(
                            incremental_states[model], reorder_state)

            decode_output = self._decode(tokens[:, :step + 1], encoder_outs,
                                         incremental_states)
            if len(decode_output) == 3:
                probs, avg_attn_scores, possible_translation_tokens = decode_output
            else:
                probs, avg_attn_scores = decode_output
                possible_translation_tokens = None
            if step == 0:
                # at the first step all hypotheses are equally likely, so use
                # only the first beam
                probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
                scores = scores.type_as(probs)
                scores_buf = scores_buf.type_as(probs)
            else:
                # make probs contain cumulative scores for each hypothesis
                probs.add_(scores[:, step - 1].view(-1, 1))
            probs[:, self.pad] = -math.inf  # never select pad
            probs[:, self.unk] -= self.unk_penalty  # apply unk penalty
            # external lexicon penalty
            probs[:, self.lexicon_indices] -= self.lexicon_penalty

            probs += self.word_reward
            probs[:, self.eos] -= self.word_reward

            # Record attention scores
            attn[:, :, step + 1].copy_(avg_attn_scores)

            cand_scores = buffer('cand_scores', type_of=scores)
            cand_indices = buffer('cand_indices')
            cand_beams = buffer('cand_beams')
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)
            if step < maxlen:
                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
                    cand_scores = torch.gather(
                        probs_slice,
                        dim=1,
                        index=prefix_tokens[:, step].view(-1, 1).data,
                    ).expand(-1, cand_size)
                    cand_indices = prefix_tokens[:, step].view(-1, 1).expand(
                        bsz, cand_size).data
                    cand_beams.resize_as_(cand_indices).fill_(0)
                else:
                    # take the best 2 x beam_size predictions. We'll choose the first
                    # beam_size of these which don't predict eos to continue with.
                    torch.topk(
                        probs.view(bsz, -1),
                        k=min(cand_size,
                              probs.view(bsz, -1).size(1) -
                              1),  # -1 so we never select pad
                        out=(cand_scores, cand_indices),
                    )
                    possible_tokens_size = self.vocab_size
                    if possible_translation_tokens is not None:
                        possible_tokens_size = possible_translation_tokens.size(
                            0)
                    torch.div(cand_indices,
                              possible_tokens_size,
                              out=cand_beams)
                    cand_indices.fmod_(possible_tokens_size)
                    if possible_translation_tokens is not None:
                        possible_translation_tokens = possible_translation_tokens.view(
                            1,
                            possible_tokens_size,
                        ).expand(
                            cand_indices.size(0),
                            possible_tokens_size,
                        ).data
                        cand_indices = torch.gather(
                            possible_translation_tokens,
                            dim=1,
                            index=cand_indices,
                            out=cand_indices,
                        )
            else:
                # finalize all active hypotheses once we hit maxlen
                # pick the hypothesis with the highest prob of EOS right now
                torch.sort(
                    probs[:, self.eos],
                    descending=True,
                    out=(eos_scores, eos_bbsz_idx),
                )
                num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx,
                                                     eos_scores)
                assert num_remaining_sent == 0
                break

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add_(bbsz_offsets)

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)
            if step >= self.minlen:
                # only consider eos when it's among the top beam_size indices
                torch.masked_select(
                    cand_bbsz_idx[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_bbsz_idx,
                )
                if eos_bbsz_idx.numel() > 0:
                    torch.masked_select(
                        cand_scores[:, :beam_size],
                        mask=eos_mask[:, :beam_size],
                        out=eos_scores,
                    )
                    num_remaining_sent -= finalize_hypos(
                        step, eos_bbsz_idx, eos_scores, cand_scores)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < maxlen

            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer('active_mask')
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
            torch.topk(
                active_mask,
                k=beam_size,
                dim=1,
                largest=False,
                out=(_ignore, active_hypos),
            )
            active_bbsz_idx = buffer('active_bbsz_idx')
            torch.gather(
                cand_bbsz_idx,
                dim=1,
                index=active_hypos,
                out=active_bbsz_idx,
            )
            active_scores = torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )
            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1],
                dim=0,
                index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices,
                dim=1,
                index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step],
                    dim=0,
                    index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            torch.index_select(
                attn[:, :, :step + 2],
                dim=0,
                index=active_bbsz_idx,
                out=attn_buf[:, :, :step + 2],
            )

            # swap buffers
            old_tokens = tokens
            tokens = tokens_buf
            tokens_buf = old_tokens
            old_scores = scores
            scores = scores_buf
            scores_buf = old_scores
            old_attn = attn
            attn = attn_buf
            attn_buf = old_attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(bsz):
            finalized[sent] = sorted(
                finalized[sent],
                key=lambda r: r['score'],
                reverse=True,
            )

        return finalized
Esempio n. 49
0
def train_model(model, data, optim, epoch, params, config, device, writer):
    model.train()
    train_loader = data["train_loader"]
    log_vars = defaultdict(float)

    for src, tgt, src_len, tgt_len, original_src, original_tgt, knowledge, knowledge_len in train_loader:
        # put the tensors on cuda devices
        src, tgt = src.to(device), tgt.to(device)
        src_len, tgt_len = src_len.to(device), tgt_len.to(device)
        if config.knowledge:
            knowledge, knowledge_len = knowledge.to(device), knowledge_len.to(
                device)
        # original_src, original_tgt = original_src.to(device), original_tgt.to(device)

        model.zero_grad()

        # reverse sort the lengths for rnn
        lengths, indices = torch.sort(src_len, dim=0, descending=True)
        # select by the indices
        src = torch.index_select(src, dim=0, index=indices)  # [batch, len]
        tgt = torch.index_select(tgt, dim=0, index=indices)  # [batch, len]
        if config.knowledge:
            knowledge = torch.index_select(knowledge, dim=0, index=indices)
            knowledge_len = torch.index_select(knowledge_len,
                                               dim=0,
                                               index=indices)
        dec = tgt[:, :-1]  # [batch, len]
        targets = tgt[:, 1:]  # [batch, len]

        try:
            if config.schesamp:
                if epoch > 8:
                    e = epoch - 8
                    return_dict, outputs = model(src,
                                                 lengths,
                                                 dec,
                                                 targets,
                                                 knowledge,
                                                 knowledge_len,
                                                 teacher_ratio=0.9**e)
                else:
                    return_dict, outputs = model(src, lengths, dec, targets,
                                                 knowledge, knowledge_len)
            else:
                return_dict, outputs = model(src, lengths, dec, targets,
                                             knowledge, knowledge_len)
            # outputs: [len, batch, size]
            pred = outputs.max(2)[1]
            targets = targets.t()
            num_correct = (pred.eq(targets).masked_select(targets.ne(
                utils.PAD)).sum().item())
            num_total = targets.ne(utils.PAD).sum().item()

            return_dict["mle_loss"] = torch.sum(
                return_dict["mle_loss"]) / num_total
            if config.rl:
                return_dict["total_loss"] = (
                    return_dict["mle_loss"] +
                    config.rl_coef * return_dict["rl_loss"])
            else:
                return_dict["total_loss"] = return_dict["mle_loss"]
            return_dict["total_loss"].backward()
            optim.step()

            for key in return_dict:
                log_vars[key] += return_dict[key].item()
            params["report_total_loss"] += return_dict["total_loss"].item()
            params["report_correct"] += num_correct
            params["report_total"] += num_total

        except RuntimeError as e:
            if "out of memory" in str(e):
                print("| WARNING: ran out of memory")
                if hasattr(torch.cuda, "empty_cache"):
                    torch.cuda.empty_cache()
            else:
                raise e

        # utils.progress_bar(params['updates'], config.eval_interval)
        params["updates"] += 1

        if params["updates"] % config.report_interval == 0:
            # print(
            #     "epoch: %3d, loss: %6.3f, time: %6.3f, updates: %8d, accuracy: %2.2f\n"
            #     % (
            #         epoch,
            #         params["report_total_loss"] / config.report_interval,
            #         time.time() - params["report_time"],
            #         params["updates"],
            #         params["report_correct"] * 100.0 / params["report_total"],
            #     )
            # )

            for key in return_dict:
                writer.add_scalar(
                    f"train/{key}",
                    log_vars[key] / config.report_interval,
                    params["updates"],
                )
            # writer.add_scalar("train" + "/lr", optim.lr, params['updates'])
            writer.add_scalar(
                "train" + "/accuracy",
                params["report_correct"] / params["report_total"],
                params["updates"],
            )

            log_vars = defaultdict(float)
            params["report_total_loss"], params["report_time"] = 0, time.time()
            params["report_correct"], params["report_total"] = 0, 0

        if params["updates"] % config.eval_interval == 0:
            print("evaluating after %d updates...\r" % params["updates"])
            score = eval_model(model, data, params, config, device, writer)
            for metric in config.metrics:
                params[metric].append(score[metric])
                if score[metric] >= max(params[metric]):
                    with codecs.open(
                            params["log_path"] + "best_" + metric +
                            "_prediction.txt",
                            "w",
                            "utf-8",
                    ) as f:
                        f.write(
                            codecs.open(params["log_path"] + "candidate.txt",
                                        "r", "utf-8").read())
                    save_model(
                        params["log_path"] + "best_" + metric +
                        "_checkpoint.pt",
                        model,
                        optim,
                        params["updates"],
                        config,
                    )
                writer.add_scalar("valid" + "/" + metric, score[metric],
                                  params["updates"])
            model.train()

        if params["updates"] % config.save_interval == 0:
            if config.save_individual:
                save_model(
                    params["log_path"] + str(params["updates"]) +
                    "checkpoint.pt",
                    model,
                    optim,
                    params["updates"],
                    config,
                )
            save_model(
                params["log_path"] + "checkpoint.pt",
                model,
                optim,
                params["updates"],
                config,
            )

    if config.epoch_decay:
        optim.updateLearningRate(epoch)
Esempio n. 50
0
def write_results(prediction, confidence, num_classes, nms_conf=0.4):
    conf_mask = (prediction[:, :, 4] > confidence).float().unsqueeze(2)
    prediction = prediction * conf_mask

    box_corner = prediction.new(prediction.shape)
    box_corner[:, :, 0] = (prediction[:, :, 0] - prediction[:, :, 2] / 2)
    box_corner[:, :, 1] = (prediction[:, :, 1] - prediction[:, :, 3] / 2)
    box_corner[:, :, 2] = (prediction[:, :, 0] + prediction[:, :, 2] / 2)
    box_corner[:, :, 3] = (prediction[:, :, 1] + prediction[:, :, 3] / 2)
    prediction[:, :, :4] = box_corner[:, :, :4]

    batch_size = prediction.size(0)

    write = False

    for ind in range(batch_size):
        image_pred = prediction[ind]  #image Tensor
        #confidence threshholding
        #NMS

        max_conf, max_conf_score = torch.max(image_pred[:, 5:5 + num_classes],
                                             1)
        max_conf = max_conf.float().unsqueeze(1)
        max_conf_score = max_conf_score.float().unsqueeze(1)
        seq = (image_pred[:, :5], max_conf, max_conf_score)
        image_pred = torch.cat(seq, 1)

        non_zero_ind = (torch.nonzero(image_pred[:, 4]))
        try:
            image_pred_ = image_pred[non_zero_ind.squeeze(), :].view(-1, 7)
        except:
            continue

        if image_pred_.shape[0] == 0:
            continue


#

#Get the various classes detected in the image
        img_classes = unique(image_pred_[:,
                                         -1])  # -1 index holds the class index

        for cls in img_classes:
            #perform NMS

            #get the detections with one particular class
            cls_mask = image_pred_ * (image_pred_[:, -1]
                                      == cls).float().unsqueeze(1)
            class_mask_ind = torch.nonzero(cls_mask[:, -2]).squeeze()
            image_pred_class = image_pred_[class_mask_ind].view(-1, 7)

            #sort the detections such that the entry with the maximum objectness
            #confidence is at the top
            conf_sort_index = torch.sort(image_pred_class[:, 4],
                                         descending=True)[1]
            image_pred_class = image_pred_class[conf_sort_index]
            idx = image_pred_class.size(0)  #Number of detections

            for i in range(idx):
                #Get the IOUs of all boxes that come after the one we are looking at
                #in the loop
                try:
                    ious = bbox_iou(image_pred_class[i].unsqueeze(0),
                                    image_pred_class[i + 1:])
                except ValueError:
                    break

                except IndexError:
                    break

                #Zero out all the detections that have IoU > treshhold
                iou_mask = (ious < nms_conf).float().unsqueeze(1)
                image_pred_class[i + 1:] *= iou_mask

                #Remove the non-zero entries
                non_zero_ind = torch.nonzero(image_pred_class[:, 4]).squeeze()
                image_pred_class = image_pred_class[non_zero_ind].view(-1, 7)

            batch_ind = image_pred_class.new(
                image_pred_class.size(0), 1
            ).fill_(
                ind
            )  #Repeat the batch_id for as many detections of the class cls in the image
            seq = batch_ind, image_pred_class

            if not write:
                output = torch.cat(seq, 1)
                write = True
            else:
                out = torch.cat(seq, 1)
                output = torch.cat((output, out))

    try:
        return output
    except:
        return 0
Esempio n. 51
0
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
    """
    Removes detections with lower object confidence score than 'conf_thres'
    Non-Maximum Suppression to further filter detections.
    Returns detections with shape:
        (x1, y1, x2, y2, object_conf, class_score, class_pred)
    """

    output = [None for _ in range(len(prediction))]
    for image_i, pred in enumerate(prediction):
        # Experiment: Prior class size rejection
        # x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
        # a = w * h  # area
        # ar = w / (h + 1e-16)  # aspect ratio
        # n = len(w)
        # log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar)
        # shape_likelihood = np.zeros((n, 60), dtype=np.float32)
        # x = np.concatenate((log_w.reshape(-1, 1), log_h.reshape(-1, 1)), 1)
        # from scipy.stats import multivariate_normal
        # for c in range(60):
        # shape_likelihood[:, c] =
        #   multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2])

        # Filter out confidence scores below threshold
        class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1)
        v = pred[:, 4] > conf_thres
        v = v.nonzero().squeeze()
        if len(v.shape) == 0:
            v = v.unsqueeze(0)

        pred = pred[v]
        class_prob = class_prob[v]
        class_pred = class_pred[v]

        # If none are remaining => process next image
        nP = pred.shape[0]
        if not nP:
            continue

        # From (center x, center y, width, height) to (x1, y1, x2, y2)
        pred[:, :4] = xywh2xyxy(pred[:, :4])

        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_prob, class_pred)
        detections = torch.cat((pred[:, :5], class_prob.float().unsqueeze(1), class_pred.float().unsqueeze(1)), 1)
        # Iterate through all predicted classes
        unique_labels = detections[:, -1].cpu().unique().to(prediction.device)

        nms_style = 'OR'  # 'OR' (default), 'AND', 'MERGE' (experimental)
        for c in unique_labels:
            # Get the detections with class c
            dc = detections[detections[:, -1] == c]
            # Sort the detections by maximum object confidence
            _, conf_sort_index = torch.sort(dc[:, 4] * dc[:, 5], descending=True)
            dc = dc[conf_sort_index]

            # Non-maximum suppression
            det_max = []
            ind = list(range(len(dc)))
            if nms_style == 'OR':  # default
                while len(ind):
                    j = ind[0]
                    det_max.append(dc[j:j + 1])  # save highest conf detection
                    reject = bbox_iou(dc[j], dc[ind]) > nms_thres
                    [ind.pop(i) for i in reversed(reject.nonzero())]
                # while dc.shape[0]:  # SLOWER METHOD
                #     det_max.append(dc[:1])  # save highest conf detection
                #     if len(dc) == 1:  # Stop if we're at the last detection
                #         break
                #     iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                #     dc = dc[1:][iou < nms_thres]  # remove ious > threshold

                # Image      Total          P          R        mAP
                #  4964       5000      0.629      0.594      0.586

            elif nms_style == 'AND':  # requires overlap, single boxes erased
                while len(dc) > 1:
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    if iou.max() > 0.5:
                        det_max.append(dc[:1])
                    dc = dc[1:][iou < nms_thres]  # remove ious > threshold

            elif nms_style == 'MERGE':  # weighted mixture box
                while len(dc) > 0:
                    iou = bbox_iou(dc[0], dc[0:])  # iou with other boxes
                    i = iou > nms_thres

                    weights = dc[i, 4:5] * dc[i, 5:6]
                    dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
                    det_max.append(dc[:1])
                    dc = dc[iou < nms_thres]

                # Image      Total          P          R        mAP
                #  4964       5000      0.633      0.598      0.589  # normal

            if len(det_max) > 0:
                det_max = torch.cat(det_max)
                # Add max detections to outputs
                output[image_i] = det_max if output[image_i] is None else torch.cat((output[image_i], det_max))

    return output
Esempio n. 52
0
def write_results_half(prediction,
                       confidence,
                       num_classes,
                       nms=True,
                       nms_conf=0.4):
    conf_mask = (prediction[:, :, 4] > confidence).half().unsqueeze(2)
    prediction = prediction * conf_mask

    try:
        ind_nz = torch.nonzero(prediction[:, :, 4]).transpose(0,
                                                              1).contiguous()
    except:
        return 0

    box_a = prediction.new(prediction.shape)
    box_a[:, :, 0] = (prediction[:, :, 0] - prediction[:, :, 2] / 2)
    box_a[:, :, 1] = (prediction[:, :, 1] - prediction[:, :, 3] / 2)
    box_a[:, :, 2] = (prediction[:, :, 0] + prediction[:, :, 2] / 2)
    box_a[:, :, 3] = (prediction[:, :, 1] + prediction[:, :, 3] / 2)
    prediction[:, :, :4] = box_a[:, :, :4]

    batch_size = prediction.size(0)

    output = prediction.new(1, prediction.size(2) + 1)
    write = False

    for ind in range(batch_size):
        #select the image from the batch
        image_pred = prediction[ind]

        #Get the class having maximum score, and the index of that class
        #Get rid of num_classes softmax scores
        #Add the class index and the class score of class having maximum score
        max_conf, max_conf_score = torch.max(image_pred[:, 5:5 + num_classes],
                                             1)
        max_conf = max_conf.half().unsqueeze(1)
        max_conf_score = max_conf_score.half().unsqueeze(1)
        seq = (image_pred[:, :5], max_conf, max_conf_score)
        image_pred = torch.cat(seq, 1)

        #Get rid of the zero entries
        non_zero_ind = (torch.nonzero(image_pred[:, 4]))
        try:
            image_pred_ = image_pred[non_zero_ind.squeeze(), :]
        except:
            continue

        #Get the various classes detected in the image
        img_classes = unique(image_pred_[:, -1].long()).half()

        #WE will do NMS classwise
        for cls in img_classes:
            #get the detections with one particular class
            cls_mask = image_pred_ * (image_pred_[:, -1]
                                      == cls).half().unsqueeze(1)
            class_mask_ind = torch.nonzero(cls_mask[:, -2]).squeeze()

            image_pred_class = image_pred_[class_mask_ind]

            #sort the detections such that the entry with the maximum objectness
            #confidence is at the top
            conf_sort_index = torch.sort(image_pred_class[:, 4],
                                         descending=True)[1]
            image_pred_class = image_pred_class[conf_sort_index]
            idx = image_pred_class.size(0)

            #if nms has to be done
            if nms:
                #For each detection
                for i in range(idx):
                    #Get the IOUs of all boxes that come after the one we are looking at
                    #in the loop
                    try:
                        ious = bbox_iou(image_pred_class[i].unsqueeze(0),
                                        image_pred_class[i + 1:])
                    except ValueError:
                        break

                    except IndexError:
                        break

                    #Zero out all the detections that have IoU > treshhold
                    iou_mask = (ious < nms_conf).half().unsqueeze(1)
                    image_pred_class[i + 1:] *= iou_mask

                    #Remove the non-zero entries
                    non_zero_ind = torch.nonzero(
                        image_pred_class[:, 4]).squeeze()
                    image_pred_class = image_pred_class[non_zero_ind]

            #Concatenate the batch_id of the image to the detection
            #this helps us identify which image does the detection correspond to
            #We use a linear straucture to hold ALL the detections from the batch
            #the batch_dim is flattened
            #batch is identified by extra batch column
            batch_ind = image_pred_class.new(image_pred_class.size(0),
                                             1).fill_(ind)
            seq = batch_ind, image_pred_class

            if not write:
                output = torch.cat(seq, 1)
                write = True
            else:
                out = torch.cat(seq, 1)
                output = torch.cat((output, out))

    return output
Esempio n. 53
0
def sort_pack_padded_sequence(input, lengths):
    sorted_lengths, indices = torch.sort(lengths, descending=True)
    tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
    inv_ix = indices.clone()
    inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
    return tmp, inv_ix
Esempio n. 54
0
    def __call__(self, batch):
        """Collate's training batch from normalized text and mel-spectrogram
        PARAMS
        ------
        batch: [text_normalized, mel_normalized]
        """
        # Right zero-pad all one-hot text sequences to max input length

        # print('batch : ', batch)
        # print('======batch len======', len(batch))


        # 이미지 모아서 확인해보기
        img_list = []
        for m in range(len(batch)):
            img = batch[m][4]
            img_list.append(img)
            # print('image.shape : ========> ', img.shape)

        shape_list = [x.shape for x in img_list]
        # print('shape_list : ', shape_list)
        img_level = shape_list[0][0]
        y_max = max(shape_list, key=lambda x: x[1])[1] # y 최댓값
        x_max = max(shape_list, key=lambda x: x[2])[2] # x 최댓값

        # input_길이 확인하고 최대 input길이 확인하기 위해서[100, 80, 50, 30]
        # ids_sorted_decreasing : 입력으로 들어온 text들의 길이를 정렬할 때 순서 확인용[2,3,0,1]
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x[0]) for x in batch]),
            dim=0, descending=True)
        max_input_len = input_lengths[0]


        text_padded = torch.LongTensor(len(batch), max_input_len)
        text_padded.zero_()
        for i in range(len(ids_sorted_decreasing)):
            text = batch[ids_sorted_decreasing[i]][0]
            text_padded[i, :text.size(0)] = text

        # Right zero-pad mel-spec
        num_mels = batch[0][1].size(0)
        max_target_len = max([x[1].size(1) for x in batch])
        if max_target_len % self.n_frames_per_step != 0:
            max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
            assert max_target_len % self.n_frames_per_step == 0

        # include mel padded and gate padded and speaker ids
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
        mel_padded.zero_()
        gate_padded = torch.FloatTensor(len(batch), max_target_len)
        gate_padded.zero_()
        output_lengths = torch.LongTensor(len(batch))
        speaker_ids = torch.LongTensor(len(batch))
        style_img = torch.FloatTensor(len(batch), img_level, y_max, x_max)
        style_img.zero_()


        # print('data_function, TextMelCollate1 ====> ', style_img)
        # print('data_function, TextMelCollate1 shape ====> ', style_img.shape)
        for i in range(len(ids_sorted_decreasing)):
            mel = batch[ids_sorted_decreasing[i]][1]
            mel_padded[i, :, :mel.size(1)] = mel
            gate_padded[i, mel.size(1)-1:] = 1
            output_lengths[i] = mel.size(1)
            speaker_ids[i] = batch[ids_sorted_decreasing[i]][3]
            img = batch[ids_sorted_decreasing[i]][4]
            style_img[i, :, :img.shape[1], :img.shape[2]] = img


        # count number of items - characters in text
        len_x = [x[2] for x in batch]
        len_x = torch.Tensor(len_x)

        # for i in range(len(ids_sorted_decreasing)):
        #     print('mel_padded.shape : ====> ', mel_padded[i].shape)
        # for j in range(len(ids_sorted_decreasing)):
        #     print('style_img.shape : =======> ', style_img[j].shape)

        # print('data_function, TextMelCollate2 ====> ', style_img)
        # print('data_function, TextMelCollate1 shape ====> ', style_img.shape)
        return text_padded, input_lengths, mel_padded, gate_padded, \
            output_lengths, len_x, speaker_ids, style_img
Esempio n. 55
0
    eval_model = lambda model: test(model=model, cfg=opt.cfg, data=opt.data)
    obtain_num_parameters = lambda model: sum(
        [param.nelement() for param in model.parameters()])

    with torch.no_grad():
        print("\nlet's test the original model first:")
        origin_model_metric = eval_model(model)
    origin_nparameters = obtain_num_parameters(model)

    CBL_idx, Conv_idx, shortcut_idx = parse_module_defs4(model.module_defs)
    print('all shortcut_idx:', [i + 1 for i in shortcut_idx])

    bn_weights = gather_bn_weights(model.module_list, shortcut_idx)

    sorted_bn = torch.sort(bn_weights)[0]

    # highest_thre = torch.zeros(len(shortcut_idx))
    # for i, idx in enumerate(shortcut_idx):
    #     highest_thre[i] = model.module_list[idx][1].weight.data.abs().max().clone()
    # _, sorted_index_thre = torch.sort(highest_thre)

    #这里更改了选层策略,由最大值排序改为均值排序,均值一般表现要稍好,但不是绝对,可以自己切换尝试;前面注释的四行为原策略。
    bn_mean = torch.zeros(len(shortcut_idx))
    for i, idx in enumerate(shortcut_idx):
        bn_mean[i] = model.module_list[idx][1].weight.data.abs().mean().clone()
    _, sorted_index_thre = torch.sort(bn_mean)

    prune_shortcuts = torch.tensor(shortcut_idx)[[
        sorted_index_thre[:opt.shortcuts]
    ]]
Esempio n. 56
0
 def sort_scores(self):
     "Sort the scores."
     return torch.sort(self.scores, 0, True)
Esempio n. 57
0
    def forward(self, v, mask):
        if self.batch_first:
            v = v.transpose(0, 1)

        # layer normalization
        if self.enable_layer_norm:
            seq_len, batch, input_size = v.shape
            v = v.view(-1, input_size)
            v = self.layer_norm(v)
            v = v.view(seq_len, batch, input_size)

        # get sorted v
        lengths = mask.eq(1).long().sum(1)
        lengths_sort, idx_sort = torch.sort(lengths, dim=0, descending=True)
        _, idx_unsort = torch.sort(idx_sort, dim=0)

        v_sort = v.index_select(1, idx_sort)

        # remove zeros lengths
        zero_idx = lengths_sort.nonzero()[-1][0].item() + 1
        zeros_len = lengths_sort.shape[0] - zero_idx

        lengths_sort = lengths_sort[:zero_idx]
        v_sort = v_sort[:, :zero_idx, :]

        # rnn
        v_pack = torch.nn.utils.rnn.pack_padded_sequence(v_sort, lengths_sort)
        v_dropout = self.dropout.forward(v_pack.data)
        v_pack_dropout = torch.nn.utils.rnn.PackedSequence(
            v_dropout, v_pack.batch_sizes)

        o_pack_dropout, o_last = self.hidden.forward(v_pack_dropout)
        o, _ = torch.nn.utils.rnn.pad_packed_sequence(o_pack_dropout)

        # get the last time state
        if isinstance(o_last, tuple):
            o_last = o_last[0]  # if LSTM cell used

        _, batch, hidden_size = o_last.size()
        o_last = o_last.view(self.num_layers, -1, batch, hidden_size)
        o_last = o_last[-1, :].transpose(0, 1).contiguous().view(batch, -1)

        # len_idx = (lengths_sort - 1).view(-1, 1).expand(-1, o.size(2)).unsqueeze(0)
        # o_last = o.gather(0, len_idx)
        # o_last = o_last.squeeze(0)

        # padding for output and output last state
        if zeros_len > 0:
            o_padding_zeros = o.new_zeros(o.shape[0], zeros_len, o.shape[2])
            o = torch.cat([o, o_padding_zeros], dim=1)

            o_last_padding_zeros = o_last.new_zeros(zeros_len, o_last.shape[1])
            o_last = torch.cat([o_last, o_last_padding_zeros], dim=0)

        # unsorted o
        o_unsort = o.index_select(
            1, idx_unsort)  # Note that here first dim is seq_len
        o_last_unsort = o_last.index_select(0, idx_unsort)

        if self.batch_first:
            o_unsort = o_unsort.transpose(0, 1)

        return o_unsort, o_last_unsort
Esempio n. 58
0
def ranking_and_hits(model, dev_rank_batcher, vocab, name):
    log.info('')
    log.info('-' * 50)
    log.info(name)
    log.info('-' * 50)
    log.info('')
    hits_left = []
    hits_right = []
    hits = []
    ranks = []
    ranks_left = []
    ranks_right = []
    mrr_left = []
    mrr_right = []
    rel2ranks = {}
    for i in range(10):
        hits_left.append([])
        hits_right.append([])
        hits.append([])

    for i, str2var in enumerate(dev_rank_batcher):
        e1 = str2var['e1']
        e2 = str2var['e2']
        rel = str2var['rel']
        e2_multi1 = str2var['e2_multi1'].float()
        e2_multi2 = str2var['e2_multi2'].float()
        pred1 = model.forward(e1, rel)
        pred2 = model.forward(e2, rel)
        pred1, pred2 = pred1.data, pred2.data
        e1, e2 = e1.data, e2.data
        e2_multi1, e2_multi2 = e2_multi1.data, e2_multi2.data
        for i in range(Config.batch_size):
            # these filters contain ALL labels
            filter1 = e2_multi1[i].long()
            filter2 = e2_multi2[i].long()

            # save the prediction that is relevant
            target_value1 = pred1[i, e2[i, 0]]
            target_value2 = pred2[i, e1[i, 0]]

            # zero all known cases (this are not interesting)
            # this corresponds to the filtered setting
            pred1[i][filter1] = 0.0
            pred2[i][filter2] = 0.0
            # write base the saved values
            pred1[i][e2[i]] = target_value1
            pred2[i][e1[i]] = target_value2

        # sort and rank
        max_values, argsort1 = torch.sort(pred1, 1, descending=True)
        max_values, argsort2 = torch.sort(pred2, 1, descending=True)

        argsort1 = argsort1.cpu().numpy()
        argsort2 = argsort2.cpu().numpy()
        for i in range(Config.batch_size):
            # find the rank of the target entities
            rank1 = np.where(argsort1[i] == e2[i, 0])[0][0]
            rank2 = np.where(argsort2[i] == e1[i, 0])[0][0]
            # rank+1, since the lowest rank is rank 1 not rank 0
            ranks.append(rank1 + 1)
            ranks_left.append(rank1 + 1)
            ranks.append(rank2 + 1)
            ranks_right.append(rank2 + 1)

            # this could be done more elegantly, but here you go
            for hits_level in range(10):
                if rank1 <= hits_level:
                    hits[hits_level].append(1.0)
                    hits_left[hits_level].append(1.0)
                else:
                    hits[hits_level].append(0.0)
                    hits_left[hits_level].append(0.0)

                if rank2 <= hits_level:
                    hits[hits_level].append(1.0)
                    hits_right[hits_level].append(1.0)
                else:
                    hits[hits_level].append(0.0)
                    hits_right[hits_level].append(0.0)

        dev_rank_batcher.state.loss = [0]

    for i in range(10):
        log.info('Hits left @{0}: {1}'.format(i + 1, np.mean(hits_left[i])))
        log.info('Hits right @{0}: {1}'.format(i + 1, np.mean(hits_right[i])))
        log.info('Hits @{0}: {1}'.format(i + 1, np.mean(hits[i])))
    log.info('Mean rank left: {0}', np.mean(ranks_left))
    log.info('Mean rank right: {0}', np.mean(ranks_right))
    log.info('Mean rank: {0}', np.mean(ranks))
    log.info('Mean reciprocal rank left: {0}',
             np.mean(1. / np.array(ranks_left)))
    log.info('Mean reciprocal rank right: {0}',
             np.mean(1. / np.array(ranks_right)))
    log.info('Mean reciprocal rank: {0}', np.mean(1. / np.array(ranks)))
Esempio n. 59
0
def generate_requests(
    iters: int,
    B: int,
    T: int,
    L: int,
    E: int,
    # inter-batch indices reuse rate
    reuse: float = 0.0,
    # alpha <= 1.0: use uniform distribution
    # alpha > 1.0: use zipf distribution
    alpha: float = 1.0,
    weights_precision: SparseType = SparseType.FP32,
    weighted: bool = False,
) -> List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]]:
    if alpha <= 1.0:
        all_indices = torch.randint(
            low=0,
            high=E,
            size=(iters, T, B, L),
            device=get_device(),
            dtype=torch.int32,
        )
        # each bag is usually sorted
        (all_indices, _) = torch.sort(all_indices)
        all_indices = all_indices.reshape(iters, T, B * L)
    else:
        assert E >= L, "num-embeddings must be greater than equal to bag-size"
        # oversample and then remove duplicates to obtain sampling without
        # replacement
        all_indices = (np.random.zipf(a=alpha, size=(iters, T, B, 3 * L)) -
                       1) % E
        for index_tuple in itertools.product(range(iters), range(T), range(B)):
            # sample without replacement from
            # https://stats.stackexchange.com/questions/20590/how-do-i-sample-without-replacement-using-a-sampling-with-replacement-function
            r = set()
            for x in all_indices[index_tuple]:
                if x not in r:
                    r.add(x)
                    if len(r) == L:
                        break
            assert (len(r)) == L, "too skewed distribution (alpha too big)"
            all_indices[index_tuple][:L] = list(r)
        # shuffle indices so we don't have unintended spatial locality
        all_indices = torch.as_tensor(all_indices[:, :, :, :L])
        rng = default_rng()
        permutation = torch.as_tensor(
            rng.choice(E, size=all_indices.max().item() + 1, replace=False))
        all_indices = permutation.gather(0, all_indices.flatten())
        all_indices = all_indices.to(get_device()).int().reshape(
            iters, T, B * L)
    for it in range(iters - 1):
        for t in range(T):
            reused_indices = torch.randperm(
                B * L, device=get_device())[:int(B * L * reuse)]
            all_indices[it + 1, t,
                        reused_indices] = all_indices[it, t, reused_indices]

    rs = []
    for it in range(iters):
        weights_tensor = (None if not weighted else torch.randn(
            T * B * L, device=get_device()))
        rs.append(
            get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L))
            + (weights_tensor, ))
    return rs
Esempio n. 60
0
def ddp_train_nerf(rank, args):
    ###### set up multi-processing
    setup(rank, args.world_size)
    ###### set up logger
    logger = logging.getLogger(__package__)
    setup_logger()

    ###### decide chunk size according to gpu memory
    logger.info('gpu_mem: {}'.format(
        torch.cuda.get_device_properties(rank).total_memory))
    if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
        logger.info('setting batch size according to 24G gpu')
        args.N_rand = 1024
        args.chunk_size = 8192
    else:
        logger.info('setting batch size according to 12G gpu')
        args.N_rand = 512
        args.chunk_size = 4096

    ###### Create log dir and copy the config file
    if rank == 0:
        os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True)
        f = os.path.join(args.basedir, args.expname, 'args.txt')
        with open(f, 'w') as file:
            for arg in sorted(vars(args)):
                attr = getattr(args, arg)
                file.write('{} = {}\n'.format(arg, attr))
        if args.config is not None:
            f = os.path.join(args.basedir, args.expname, 'config.txt')
            with open(f, 'w') as file:
                file.write(open(args.config, 'r').read())
    torch.distributed.barrier()

    ray_samplers = load_data_split(args.datadir,
                                   args.scene,
                                   split='train',
                                   try_load_min_depth=args.load_min_depth)
    val_ray_samplers = load_data_split(args.datadir,
                                       args.scene,
                                       split='validation',
                                       try_load_min_depth=args.load_min_depth,
                                       skip=args.testskip)

    # write training image names for autoexposure
    if args.optim_autoexpo:
        f = os.path.join(args.basedir, args.expname, 'train_images.json')
        with open(f, 'w') as file:
            img_names = [
                ray_samplers[i].img_path for i in range(len(ray_samplers))
            ]
            json.dump(img_names, file, indent=2)

    ###### create network and wrap in ddp; each process should do this
    start, models = create_nerf(rank, args)

    ##### important!!!
    # make sure different processes sample different rays
    np.random.seed((rank + 1) * 777)
    # make sure different processes have different perturbations in depth samples
    torch.manual_seed((rank + 1) * 777)

    ##### only main process should do the logging
    if rank == 0:
        writer = SummaryWriter(
            os.path.join(args.basedir, 'summaries', args.expname))

    # start training
    what_val_to_log = 0  # helper variable for parallel rendering of a image
    what_train_to_log = 0
    for global_step in range(start + 1, start + 1 + args.N_iters):
        time0 = time.time()
        scalars_to_log = OrderedDict()
        ### Start of core optimization loop
        scalars_to_log['resolution'] = ray_samplers[0].resolution_level
        # randomly sample rays and move to device
        i = np.random.randint(low=0, high=len(ray_samplers))
        ray_batch = ray_samplers[i].random_sample(args.N_rand,
                                                  center_crop=False)
        for key in ray_batch:
            if torch.is_tensor(ray_batch[key]):
                ray_batch[key] = ray_batch[key].to(rank)

        # forward and backward
        dots_sh = list(ray_batch['ray_d'].shape[:-1])  # number of rays
        all_rets = []  # results on different cascade levels
        for m in range(models['cascade_level']):
            optim = models['optim_{}'.format(m)]
            net = models['net_{}'.format(m)]

            # sample depths
            N_samples = models['cascade_samples'][m]
            if m == 0:
                # foreground depth
                fg_far_depth = intersect_sphere(ray_batch['ray_o'],
                                                ray_batch['ray_d'])  # [...,]
                fg_near_depth = ray_batch['min_depth']  # [..., ]
                step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
                fg_depth = torch.stack(
                    [fg_near_depth + i * step for i in range(N_samples)],
                    dim=-1)  # [..., N_samples]
                fg_depth = perturb_samples(
                    fg_depth)  # random perturbation during training

                # background depth
                bg_depth = torch.linspace(0., 1., N_samples).view([
                    1,
                ] * len(dots_sh) + [
                    N_samples,
                ]).expand(dots_sh + [
                    N_samples,
                ]).to(rank)
                bg_depth = perturb_samples(
                    bg_depth)  # random perturbation during training
            else:
                # sample pdf and concat with earlier samples
                fg_weights = ret['fg_weights'].clone().detach()
                fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]
                                     )  # [..., N_samples-1]
                fg_weights = fg_weights[..., 1:-1]  # [..., N_samples-2]
                fg_depth_samples = sample_pdf(bins=fg_depth_mid,
                                              weights=fg_weights,
                                              N_samples=N_samples,
                                              det=False)  # [..., N_samples]
                fg_depth, _ = torch.sort(
                    torch.cat((fg_depth, fg_depth_samples), dim=-1))

                # sample pdf and concat with earlier samples
                bg_weights = ret['bg_weights'].clone().detach()
                bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
                bg_weights = bg_weights[..., 1:-1]  # [..., N_samples-2]
                bg_depth_samples = sample_pdf(bins=bg_depth_mid,
                                              weights=bg_weights,
                                              N_samples=N_samples,
                                              det=False)  # [..., N_samples]
                bg_depth, _ = torch.sort(
                    torch.cat((bg_depth, bg_depth_samples), dim=-1))

            optim.zero_grad()
            ret = net(ray_batch['ray_o'],
                      ray_batch['ray_d'],
                      fg_far_depth,
                      fg_depth,
                      bg_depth,
                      img_name=ray_batch['img_name'])
            all_rets.append(ret)

            rgb_gt = ray_batch['rgb'].to(rank)
            if 'autoexpo' in ret:
                scale, shift = ret['autoexpo']
                scalars_to_log['level_{}/autoexpo_scale'.format(
                    m)] = scale.item()
                scalars_to_log['level_{}/autoexpo_shift'.format(
                    m)] = shift.item()
                # rgb_gt = scale * rgb_gt + shift
                rgb_pred = (ret['rgb'] - shift) / scale
                rgb_loss = img2mse(rgb_pred, rgb_gt)
                loss = rgb_loss + args.lambda_autoexpo * (
                    torch.abs(scale - 1.) + torch.abs(shift))
            else:
                rgb_loss = img2mse(ret['rgb'], rgb_gt)
                loss = rgb_loss
            scalars_to_log['level_{}/loss'.format(m)] = rgb_loss.item()
            scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(
                rgb_loss.item())
            loss.backward()
            optim.step()

            # # clean unused memory
            # torch.cuda.empty_cache()

        ### end of core optimization loop
        dt = time.time() - time0
        scalars_to_log['iter_time'] = dt

        ### only main process should do the logging
        if rank == 0 and (global_step % args.i_print == 0 or global_step < 10):
            logstr = '{} step: {} '.format(args.expname, global_step)
            for k in scalars_to_log:
                logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k])
                writer.add_scalar(k, scalars_to_log[k], global_step)
            logger.info(logstr)

        ### each process should do this; but only main process merges the results
        if global_step % args.i_img == 0 or global_step == start + 1:
            #### critical: make sure each process is working on the same random image
            time0 = time.time()
            idx = what_val_to_log % len(val_ray_samplers)
            log_data = render_single_image(rank, args.world_size, models,
                                           val_ray_samplers[idx],
                                           args.chunk_size)
            what_val_to_log += 1
            dt = time.time() - time0
            if rank == 0:  # only main process should do this
                logger.info(
                    'Logged a random validation view in {} seconds'.format(dt))
                log_view_to_tb(writer,
                               global_step,
                               log_data,
                               gt_img=val_ray_samplers[idx].get_img(),
                               mask=None,
                               prefix='val/')

            time0 = time.time()
            idx = what_train_to_log % len(ray_samplers)
            log_data = render_single_image(rank, args.world_size, models,
                                           ray_samplers[idx], args.chunk_size)
            what_train_to_log += 1
            dt = time.time() - time0
            if rank == 0:  # only main process should do this
                logger.info(
                    'Logged a random training view in {} seconds'.format(dt))
                log_view_to_tb(writer,
                               global_step,
                               log_data,
                               gt_img=ray_samplers[idx].get_img(),
                               mask=None,
                               prefix='train/')

            del log_data
            torch.cuda.empty_cache()

        if rank == 0 and (global_step % args.i_weights == 0
                          and global_step > 0):
            # saving checkpoints and logging
            fpath = os.path.join(args.basedir, args.expname,
                                 'model_{:06d}.pth'.format(global_step))
            to_save = OrderedDict()
            for m in range(models['cascade_level']):
                name = 'net_{}'.format(m)
                to_save[name] = models[name].state_dict()

                name = 'optim_{}'.format(m)
                to_save[name] = models[name].state_dict()
            torch.save(to_save, fpath)

    # clean up for multi-processing
    cleanup()