Beispiel #1
0
    def compute_frustum_bounds(self, world_to_grid, camera_to_world):
        corner_points = camera_to_world.new(8, 4, 1).fill_(1)
	    # depth min
        corner_points[0,:3,0] = self.depth_to_skeleton(0, 0, self.depth_min)
        corner_points[1,:3,0] = self.depth_to_skeleton(self.image_dims[0] - 1, 0, self.depth_min)
        corner_points[2,:3,0] = self.depth_to_skeleton(self.image_dims[0] - 1, self.image_dims[1] - 1, self.depth_min)
        corner_points[3,:3,0] = self.depth_to_skeleton(0, self.image_dims[1] - 1, self.depth_min)
        # depth max
        corner_points[4,:3,0] = self.depth_to_skeleton(0, 0, self.depth_max)
        corner_points[5,:3,0] = self.depth_to_skeleton(self.image_dims[0] - 1, 0, self.depth_max)
        corner_points[6,:3,0] = self.depth_to_skeleton(self.image_dims[0] - 1, self.image_dims[1] - 1, self.depth_max)
        corner_points[7,:3,0] = self.depth_to_skeleton(0, self.image_dims[1] - 1, self.depth_max)

        p = torch.bmm(camera_to_world.repeat(8, 1, 1), corner_points)
        pl = torch.round(torch.bmm(world_to_grid.repeat(8, 1, 1), torch.floor(p)))
        pu = torch.round(torch.bmm(world_to_grid.repeat(8, 1, 1), torch.ceil(p)))
        bbox_min0, _ = torch.min(pl[:, :3, 0], 0)
        bbox_min1, _ = torch.min(pu[:, :3, 0], 0)
        bbox_min = np.minimum(bbox_min0, bbox_min1)
        bbox_max0, _ = torch.max(pl[:, :3, 0], 0)
        bbox_max1, _ = torch.max(pu[:, :3, 0], 0) 
        bbox_max = np.maximum(bbox_max0, bbox_max1)
        return bbox_min, bbox_max
Beispiel #2
0
def interpolate_dense_features(pos, dense_features, return_corners=False):
    """
    Args:
        pos
        dense_features
        return_corners:
    
    Returns:
        descriptors
        pos
        ids
        corners
    """
    device = pos.device

    ids = torch.arange(0, pos.size(1), device=device)

    _, h, w = dense_features.size()

    i = pos[0, :]
    j = pos[1, :]

    # Valid corners
    i_top_left = torch.floor(i).long()
    j_top_left = torch.floor(j).long()
    valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)

    i_top_right = torch.floor(i).long()
    j_top_right = torch.ceil(j).long()
    valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)

    i_bottom_left = torch.ceil(i).long()
    j_bottom_left = torch.floor(j).long()
    valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)

    i_bottom_right = torch.ceil(i).long()
    j_bottom_right = torch.ceil(j).long()
    valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)

    valid_corners = torch.min(torch.min(valid_top_left, valid_top_right),
                              torch.min(valid_bottom_left, valid_bottom_right))

    i_top_left = i_top_left[valid_corners]
    j_top_left = j_top_left[valid_corners]

    i_top_right = i_top_right[valid_corners]
    j_top_right = j_top_right[valid_corners]

    i_bottom_left = i_bottom_left[valid_corners]
    j_bottom_left = j_bottom_left[valid_corners]

    i_bottom_right = i_bottom_right[valid_corners]
    j_bottom_right = j_bottom_right[valid_corners]

    ids = ids[valid_corners]
    if ids.size(0) == 0:
        raise Exception  # EmptyTensorError

    # Interpolation
    i = i[ids]
    j = j[ids]
    dist_i_top_left = i - i_top_left.float()
    dist_j_top_left = j - j_top_left.float()
    w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
    w_top_right = (1 - dist_i_top_left) * dist_j_top_left
    w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
    w_bottom_right = dist_i_top_left * dist_j_top_left

    descriptors = (
        w_top_left * dense_features[:, i_top_left, j_top_left] +
        w_top_right * dense_features[:, i_top_right, j_top_right] +
        w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] +
        w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right])

    pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)

    if not return_corners:
        return [descriptors, pos, ids]
    else:
        corners = torch.stack([
            torch.stack([i_top_left, j_top_left], dim=0),
            torch.stack([i_top_right, j_top_right], dim=0),
            torch.stack([i_bottom_left, j_bottom_left], dim=0),
            torch.stack([i_bottom_right, j_bottom_right], dim=0)
        ],
                              dim=0)
        return [descriptors, pos, ids, corners]
Beispiel #3
0
 def ceil(self, x):
     return torch.ceil(x)
Beispiel #4
0
 def test_ceil(x, y):
     c = torch.ceil(torch.add(x, y))
     return c
Beispiel #5
0
    def forward(self, batch, classifier, p=0):
        bs, ts = batch.words.shape
        mask = torch.ne(batch.words, constants.PAD_ID)
        lengths = mask.int().sum(dim=-1)

        # get word embeddings
        explainer_embs = self.word_emb_explainer(batch.words)

        # apply idf
        if self.explainer_idf == 'embs':
            w_idf = self.idf[batch.words].unsqueeze(-1)
            explainer_embs = explainer_embs * w_idf

        # forward and backward hidden states
        lstm_vecs = pack(explainer_embs,
                         lengths,
                         batch_first=True,
                         enforce_sorted=False)
        lstm_vecs, lstm_hidden = self.lstm_explainer(lstm_vecs)
        lstm_vecs, _ = unpack(lstm_vecs, batch_first=True)

        # recover the classifier predictions and its hidden states
        # clf_pred_classes, clf_hidden = self.get_clf_pred_and_hidden(
        #     batch, classifier, p=p
        # )
        clf_pred_classes, clf_hidden = self.get_clf_pred_and_hidden_inside(
            batch, classifier, p=p)

        # concat the lstm hidden states of the hypothesis and create a time dim
        hidden_states = [lstm_hidden[0][0], lstm_hidden[0][1]]
        lstm_hidden = torch.cat(hidden_states, dim=-1).unsqueeze(1)
        self.lstm_hidden = lstm_hidden
        self.lstm_out = lstm_vecs

        # concat clf output and hidden reps with lstm hidden
        # lstm_hidden = torch.cat(
        #   (lstm_hidden, clf_hidden, clf_pred_classes), dim=-1
        # )
        lstm_hidden = torch.cat((lstm_hidden, clf_pred_classes), dim=-1)

        # attention over explainer embs
        # message_emb, attn_weights = self.attn(lstm_hidden,
        #                                       explainer_embs,
        #                                       values=explainer_embs,
        #                                       mask=mask)

        # set score weights as idfs
        if self.explainer_idf == 'scores':
            s_weights = self.idf[batch.words].unsqueeze(-1)
        else:
            s_weights = None

        # attention over lstm vecs
        message_emb, attn_weights = self.attn(lstm_hidden,
                                              lstm_vecs,
                                              values=lstm_vecs,
                                              mask=mask,
                                              s_weights=s_weights)
        # self.lstm_hidden = message_emb

        # (bs, 1, ts) -> (bs, ts)
        attn_weights = attn_weights.squeeze()
        self.attn_weights = attn_weights

        # try to use the straight-through gumbel softmax
        # argmaxes = torch.nn.functional.gumbel_softmax(
        #     torch.log(attn_weights), tau=0.8, hard=True
        # )
        # top_word_ids = batch.words[argmaxes.int().bool()]

        # create message
        # use bag of probas during training
        if self.training:  # training time
            # k = min(self.explainer_attn_top_k, attn_weights.shape[-1])
            # top_probas, top_idxs = torch.topk(attn_weights, k, dim=-1)
            # top_word_ids = batch.words.gather(1, top_idxs)
            # probas = torch.zeros_like(attn_weights)
            # probas.scatter_(1, top_idxs, top_probas)
            # message = self.bag_of_probas(batch.words, probas, normalize=True)
            # message = message_emb.squeeze()
            # message = torch.sum(explainer_embs * attn_weights.unsqueeze(-1), 1)
            message = self.bag_of_probas(batch.words,
                                         attn_weights,
                                         normalize=False)

        # bag of words during test
        else:  # test time
            k = min(self.explainer_attn_top_k, attn_weights.shape[-1])
            top_probas, top_idxs = torch.topk(attn_weights, k, dim=-1)
            top_word_ids = batch.words.gather(1, top_idxs)

            # this is not a part of the computation graph, it is just for saving
            # the valid top word ids in case we need to access them later:
            self.valid_top_word_ids = filter_word_ids_with_non_zero_probability(
                top_word_ids, top_probas, pad_id=constants.PAD_ID)

            message = self.bag_of_probas(top_word_ids,
                                         torch.ceil(top_probas),
                                         normalize=False)
            # top_probas = top_probas / top_probas.sum(1).unsqueeze(-1)
            # message = self.bag_of_probas(top_word_ids, top_probas,
            #                              normalize=True)
            message = message / message.sum(-1).unsqueeze(-1)

        # create a time dimension of size 1
        message = message.unsqueeze(1)

        return message, message_emb
Beispiel #6
0
def proximity_cost(images,
                   states,
                   car_size=(6.4, 14.3),
                   green_channel=1,
                   unnormalize=False,
                   s_mean=None,
                   s_std=None):
    SCALE = 0.25
    safe_factor = 1.5
    bsize, npred, nchannels, crop_h, crop_w = images.size()
    images = images.view(bsize * npred, nchannels, crop_h, crop_w)
    states = states.view(bsize * npred, 4).clone()

    if unnormalize:
        states = states * (1e-8 +
                           s_std.view(1, 4).expand(states.size())).cuda()
        states = states + s_mean.view(1, 4).expand(states.size()).cuda()

    speed = states[:, 2:].norm(2, 1) * SCALE  # pixel/s
    width, length = car_size[:, 0], car_size[:, 1]  # feet
    width = width * SCALE * (0.3048 * 24 / 3.7)  # pixels
    length = length * SCALE * (0.3048 * 24 / 3.7)  # pixels

    safe_distance = torch.abs(speed) * safe_factor + (
        1 * 24 / 3.7) * SCALE  # plus one metre (TODO change)

    # Compute x/y minimum distance to other vehicles (pixel version)
    # Account for 1 metre overlap (low data accuracy)
    alpha = 1 * SCALE * (24 / 3.7)  # 1 m overlap collision
    # Create separable proximity mask

    max_x = torch.ceil((crop_h - torch.clamp(length - alpha, min=0)) / 2)
    max_y = torch.ceil((crop_w - torch.clamp(width - alpha, min=0)) / 2)
    max_x = max_x.view(bsize, 1).expand(bsize, npred).contiguous().view(
        bsize * npred).cuda()
    max_y = max_y.view(bsize, 1).expand(bsize, npred).contiguous().view(
        bsize * npred).cuda()

    min_x = torch.clamp(max_x - safe_distance, min=0)
    min_y = torch.ceil(crop_w / 2 -
                       width)  # assumes other._width / 2 = self._width / 2
    min_y = min_y.view(bsize, 1).expand(bsize, npred).contiguous().view(
        bsize * npred).cuda()

    x_filter = (1 - torch.abs(torch.linspace(-1, 1, crop_h))) * crop_h / 2
    x_filter = x_filter.unsqueeze(0).expand(bsize * npred, crop_h).cuda()
    x_filter = torch.min(x_filter,
                         max_x.view(bsize * npred, 1).expand(x_filter.size()))
    x_filter = torch.max(x_filter, min_x.view(bsize * npred, 1))

    x_filter = (x_filter - min_x.view(bsize * npred, 1)) / (
        max_x - min_x).view(bsize * npred, 1)
    y_filter = (1 - torch.abs(torch.linspace(-1, 1, crop_w))) * crop_w / 2
    y_filter = y_filter.view(1, crop_w).expand(bsize * npred, crop_w).cuda()
    y_filter = torch.min(y_filter, max_y.view(bsize * npred, 1))
    y_filter = torch.max(y_filter, min_y.view(bsize * npred, 1))
    y_filter = (y_filter - min_y.view(bsize * npred, 1)) / (
        max_y.view(bsize * npred, 1) - min_y.view(bsize * npred, 1))
    x_filter = x_filter.cuda()
    y_filter = y_filter.cuda()
    proximity_mask = torch.bmm(x_filter.view(-1, crop_h, 1),
                               y_filter.view(-1, 1, crop_w))
    proximity_mask = proximity_mask.view(bsize, npred, crop_h, crop_w)
    images = images.view(bsize, npred, nchannels, crop_h, crop_w)
    costs = torch.max(
        (proximity_mask * images[:, :, green_channel].float()).view(
            bsize, npred, -1), 2)[0]
    #    costs = torch.sum((proximity_mask * images[:, :, green_channel].float()).view(bsize, npred, -1), 2)
    #    costs = torch.max((proximity_mask * images[:, :, green_channel].float()).view(bsize, npred, -1), 2)[0]
    return costs, proximity_mask
Beispiel #7
0
 def get_seq_len(self, length):
     # Called by forward()
     return torch.ceil(length / self.hop_length).to(dtype=torch.long)
def extract_features(feature_model,
                     image,
                     boxes,
                     feat_map_keys=['map3', 'map4'],
                     exemplar_scales=[0.9, 1.1]):
    N, M = image.shape[0], boxes.shape[2]
    """
    Getting features for the image N * C * H * W
    """
    Image_features = feature_model(image)
    """
    Getting features for the examples (N*M) * C * h * w
    """
    for ix in range(0, N):
        # boxes = boxes.squeeze(0)
        boxes = boxes[ix][0]
        cnter = 0
        Cnter1 = 0
        for keys in feat_map_keys:
            image_features = Image_features[keys][ix].unsqueeze(0)
            if keys == 'map1' or keys == 'map2':
                Scaling = 4.0
            elif keys == 'map3':
                Scaling = 8.0
            elif keys == 'map4':
                Scaling = 16.0
            else:
                Scaling = 32.0
            boxes_scaled = boxes / Scaling
            boxes_scaled[:, 1:3] = torch.floor(boxes_scaled[:, 1:3])
            boxes_scaled[:, 3:5] = torch.ceil(boxes_scaled[:, 3:5])
            boxes_scaled[:, 3:
                         5] = boxes_scaled[:, 3:
                                           5] + 1  # make the end indices exclusive
            feat_h, feat_w = image_features.shape[-2], image_features.shape[-1]
            # make sure exemplars don't go out of bound
            boxes_scaled[:, 1:3] = torch.clamp_min(boxes_scaled[:, 1:3], 0)
            boxes_scaled[:, 3] = torch.clamp_max(boxes_scaled[:, 3], feat_h)
            boxes_scaled[:, 4] = torch.clamp_max(boxes_scaled[:, 4], feat_w)
            box_hs = boxes_scaled[:, 3] - boxes_scaled[:, 1]
            box_ws = boxes_scaled[:, 4] - boxes_scaled[:, 2]
            max_h = math.ceil(max(box_hs))
            max_w = math.ceil(max(box_ws))
            for j in range(0, M):
                y1, x1 = int(boxes_scaled[j, 1]), int(boxes_scaled[j, 2])
                y2, x2 = int(boxes_scaled[j, 3]), int(boxes_scaled[j, 4])
                #print(y1,y2,x1,x2,max_h,max_w)
                if j == 0:
                    examples_features = image_features[:, :, y1:y2, x1:x2]
                    if examples_features.shape[
                            2] != max_h or examples_features.shape[3] != max_w:
                        #examples_features = pad_to_size(examples_features, max_h, max_w)
                        examples_features = F.interpolate(examples_features,
                                                          size=(max_h, max_w),
                                                          mode='bilinear')
                else:
                    feat = image_features[:, :, y1:y2, x1:x2]
                    if feat.shape[2] != max_h or feat.shape[3] != max_w:
                        feat = F.interpolate(feat,
                                             size=(max_h, max_w),
                                             mode='bilinear')
                        #feat = pad_to_size(feat, max_h, max_w)
                    examples_features = torch.cat((examples_features, feat),
                                                  dim=0)
            """
            Convolving example features over image features
            """
            h, w = examples_features.shape[2], examples_features.shape[3]
            features = F.conv2d(
                F.pad(image_features, ((int(w / 2)), int(
                    (w - 1) / 2), int(h / 2), int((h - 1) / 2))),
                examples_features)
            combined = features.permute([1, 0, 2, 3])
            # computing features for scales 0.9 and 1.1
            for scale in exemplar_scales:
                h1 = math.ceil(h * scale)
                w1 = math.ceil(w * scale)
                if h1 < 1:  # use original size if scaled size is too small
                    h1 = h
                if w1 < 1:
                    w1 = w
                examples_features_scaled = F.interpolate(examples_features,
                                                         size=(h1, w1),
                                                         mode='bilinear')
                features_scaled = F.conv2d(
                    F.pad(image_features, ((int(w1 / 2)), int(
                        (w1 - 1) / 2), int(h1 / 2), int((h1 - 1) / 2))),
                    examples_features_scaled)
                features_scaled = features_scaled.permute([1, 0, 2, 3])
                combined = torch.cat((combined, features_scaled), dim=1)
            if cnter == 0:
                Combined = 1.0 * combined
            else:
                if Combined.shape[2] != combined.shape[2] or Combined.shape[
                        3] != combined.shape[3]:
                    combined = F.interpolate(combined,
                                             size=(Combined.shape[2],
                                                   Combined.shape[3]),
                                             mode='bilinear')
                Combined = torch.cat((Combined, combined), dim=1)
            cnter += 1
        if ix == 0:
            All_feat = 1.0 * Combined.unsqueeze(0)
        else:
            All_feat = torch.cat((All_feat, Combined.unsqueeze(0)), dim=0)
    return All_feat
Beispiel #9
0
    def inference(self,
                  tokens,
                  token_lengths,
                  mels_for_prosody,
                  mel_lengths_for_prosody,
                  speakers,
                  mels_for_ge2e,
                  pitches,
                  pitch_lengths,
                  noise_scale=1.0,
                  length_scale=1.0):
        '''
        For inference.
        token: [Batch, Token_t] # Input text
        token_lengths: [Batch]  # Length of input text
        mels_for_prosody: [Batch, Mel_d, Mel_t] # Input of prosody encoder
        mel_lengths_for_prosody: [Batch]    # Length of input mel for prosody
        speakers: [Batch] or None   # Indice of speaker. Only when hp.Speaker_Embedding.Type.upper() == 'LUT'
        mels_for_ge2e: [Batch * Samples, Mel_d, Mel_SE_t]    # Input of speaker embedding
        noise_scale: scalar of float
        length_scale: scalar of float or [Batch]. (I may change this to matrix to control speed letter by letter later)
        '''
        if 'LUT' in self.layer_Dict.keys():
            speakers = self.layer_Dict['LUT'](speakers)
        elif 'GE2E' in self.layer_Dict.keys():
            speakers = self.layer_Dict['GE2E'](mels_for_ge2e)
            speakers = GE2E_Normalize(speakers)
        else:
            speakers = None

        if 'Prosody_Encoder' in self.layer_Dict.keys():
            prosodies = self.layer_Dict['Prosody_Encoder'](
                mels_for_prosody, mel_lengths_for_prosody)
        else:
            prosodies = None

        if hp.Device != '-1': torch.cuda.synchronize()

        token_Masks = self.Mask_Generate(token_lengths)
        mean, log_Std, log_Durations, mask = self.layer_Dict['Encoder'](
            tokens, token_Masks, speakers, prosodies)
        length_scale = length_scale.unsqueeze(-1).unsqueeze(-1)

        if hp.Device != '-1': torch.cuda.synchronize()

        durations = torch.ceil(torch.exp(log_Durations) * mask *
                               length_scale).squeeze(1)
        mel_Lengths = torch.clamp_min(torch.sum(durations, dim=1), 1.0).long()
        mel_Masks = self.Mask_Generate(mel_Lengths)

        attention_Masks = torch.unsqueeze(token_Masks, -1) * torch.unsqueeze(
            mel_Masks, 2)
        attention_Masks = attention_Masks.squeeze(1)

        attentions = self.Path_Generate(
            durations, attention_Masks)  # [Batch, Token_t, Mel_t]

        if hp.Device != '-1': torch.cuda.synchronize()

        mel_Mean = mean @ attentions  # [Batch, Mel_Dim, Token_t] @ [Batch, Token_t, Mel_t] -> [Batch, Mel_dim, Mel_t]
        mel_Log_Std = log_Std @ attentions  # [Batch, Mel_Dim, Token_t] @ [Batch, Token_t, Mel_t] -> [Batch, Mel_dim, Mel_t]
        noises = torch.randn_like(mel_Mean) * noise_scale

        if hp.Device != '-1': torch.cuda.synchronize()

        z = (mel_Mean + torch.exp(mel_Log_Std) * noises) * mel_Masks

        if 'Pitch_Interpolater' in self.layer_Dict.keys():
            pitches = self.layer_Dict['Pitch_Interpolater'](pitches,
                                                            pitch_lengths,
                                                            mel_Lengths)
        else:
            pitches = None

        mels, _, mel_Masks = self.layer_Dict['Decoder'](z,
                                                        mel_Masks,
                                                        speakers,
                                                        prosodies,
                                                        pitches,
                                                        reverse=True)

        if hp.Device != '-1': torch.cuda.synchronize()

        mels.masked_fill_(mel_Masks == 0.0, -hp.Sound.Max_Abs_Mel)

        return mels, mel_Lengths, attentions
Beispiel #10
0
    def forward(self,
                text,
                melspec,
                align,
                text_lengths,
                mel_lengths,
                criterion,
                stage,
                log_viterbi=False,
                cpu_viterbi=False):
        text = text[:, :text_lengths.max().item()]
        melspec = melspec[:, :, :mel_lengths.max().item()]

        if stage == 0:
            # encoder_input = self.Prenet(text)
            # import pdb;pdb.set_trace()
            # hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            # hidden_states, _ = self.FFT_lower(encoder_input, mel_lengths)
            log_probs, hidden_states_spec, _ = self.get_am(
                melspec, mel_lengths, text)
            # mu_sigma = self.get_mu_sigma(hidden_states)
            # mdn_loss, log_prob_matrix = criterion(probs, hidden_states_spec, text_lengths, mel_lengths)
            # mdn_loss, _ = criterion(mu_sigma, melspec, text_lengths, mel_lengths)
            # import pdb;pdb.set_trace()
            mel_lengths = torch.ceil(mel_lengths.float() / 2).long()
            mdn_loss = self.ctc_loss(log_probs, text, mel_lengths,
                                     text_lengths) / log_probs.size(1)
            return mdn_loss

        elif stage == 1:
            align = align[:, :text_lengths.max().item(), :mel_lengths.max().
                          item()]
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            mel_out = self.get_melspec(hidden_states, align, mel_lengths)

            mel_mask = ~get_mask_from_lengths(mel_lengths)
            melspec = melspec.masked_select(mel_mask.unsqueeze(1))
            mel_out = mel_out.masked_select(mel_mask.unsqueeze(1))
            fft_loss = nn.L1Loss()(mel_out, melspec)

            return fft_loss

        elif stage == 2:
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            probs, hidden_states_spec = self.get_am(melspec, mel_lengths, text)

            # mu_sigma = self.get_mu_sigma(hidden_states)
            mdn_loss, log_prob_matrix = criterion(probs, hidden_states_spec,
                                                  text_lengths, mel_lengths)

            before = datetime.now()
            if cpu_viterbi:
                align = self.viterbi_cpu(log_prob_matrix, text_lengths.cpu(),
                                         mel_lengths.cpu())  # B, T
            else:
                align = self.viterbi(log_prob_matrix, text_lengths,
                                     mel_lengths)  # B, T
            after = datetime.now()

            if log_viterbi:
                time_delta = after - before
                print(f'Viterbi took {time_delta.total_seconds()} secs')

            mel_out = self.get_melspec(hidden_states, align, mel_lengths)

            mel_mask = ~get_mask_from_lengths(mel_lengths)
            melspec = melspec.masked_select(mel_mask.unsqueeze(1))
            mel_out = mel_out.masked_select(mel_mask.unsqueeze(1))
            fft_loss = nn.L1Loss()(mel_out, melspec)

            return mdn_loss + fft_loss

        elif stage == 3:
            align = align[:, :text_lengths.max().item(), :mel_lengths.max().
                          item()]
            duration_out = self.get_duration(text,
                                             text_lengths)  # gradient cut
            duration_target = align.sum(-1)

            duration_mask = ~get_mask_from_lengths(text_lengths)
            duration_target = duration_target.masked_select(duration_mask)
            duration_out = duration_out.masked_select(duration_mask)
            duration_loss = nn.MSELoss()(torch.log(duration_out),
                                         torch.log(duration_target))

            return duration_loss
Beispiel #11
0
    def forward(self, episode_idx, sequence, feature_map, oom_val):
        """
        inputs
        episode_idx: [A]
        sequence : [A X Td X 2]
        feature_map: [B X Ce X 100 X 100]
        oom_val: padding value
        outputs
        local_featrue_bt: [A X Td X Ce]
        sequence_mapCS: [A X Td X 2]
        """
        # Detect total agents
        total_agents = sequence.size(0)
        # Detect sequence length
        seq_len = sequence.size(1)

        if feature_map.device != sequence.device:
          feature_map = feature_map.to(sequence.device)

        # Pad the feature_map with oom_val
        pad = (1, 1, 1, 1)
        feature_map_padded = F.pad(feature_map, pad, mode='constant', value=oom_val) # [A X Ce X 102 X 102]

        # Change to map CS
        sequence_mapCS = (sequence + 56.0) / 112.0 * 100.0 + 1.0

        # Merge Agents-Time dimensions
        sequence_mapCS_bt = sequence_mapCS.reshape(-1, 2) # [A*Td, 2]
        x = sequence_mapCS_bt[:, 0:1] # [A*Td, 1]
        y = sequence_mapCS_bt[:, 1:] # [A*Td, 1]

        # Qunatize x and y
        floor_mapCS_bt = torch.floor(sequence_mapCS_bt)
        ceil_mapCS_bt = torch.ceil(sequence_mapCS_bt)

        # Clamp by range [0, 101]
        floor_mapCS_bt = torch.clamp(floor_mapCS_bt, 0, 101)
        ceil_mapCS_bt = torch.clamp(ceil_mapCS_bt, 0, 101)
        x1 = floor_mapCS_bt[:, 0:1]
        y1 = floor_mapCS_bt[:, 1:]
        x2 = ceil_mapCS_bt[:, 0:1]
        y2 = ceil_mapCS_bt[:, 1:]

        # Make integers for indexing
        x1_int = x1.long().squeeze()
        x2_int = x2.long().squeeze()
        y1_int = y1.long().squeeze()
        y2_int = y2.long().squeeze()

        # Generate duplicated batch indexes for prediction length
        # batch_idx_array = [0,0,..,0,1,1,...,1,A-1,A-1,...,A-1]
        # of length (Td * A)
        batch_idx_array = episode_idx.repeat_interleave(seq_len)

        # Get the four quadrants around (x, y)
        q11 = feature_map_padded[batch_idx_array, :, y1_int, x1_int]
        q12 = feature_map_padded[batch_idx_array, :, y1_int, x2_int]
        q21 = feature_map_padded[batch_idx_array, :, y2_int, x1_int]
        q22 = feature_map_padded[batch_idx_array, :, y2_int, x2_int]
        
        # Perform bilinear interpolation
        local_featrue_flat = (q11 * ((x2 - x) * (y2 - y)) +
                              q21 * ((x - x1) * (y2 - y)) +
                              q12 * ((x2 - x) * (y - y1)) +
                              q22 * ((x - x1) * (y - y1))
                              ) # (A*Td) X Ce

        if total_agents == 0:
            local_featrue_bt = local_featrue_flat.reshape((total_agents, seq_len, 6))
        else:
            local_featrue_bt = local_featrue_flat.reshape((total_agents, seq_len, -1))

        return local_featrue_bt, sequence_mapCS
Beispiel #12
0
	def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
		"""
		Args:
			x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those
				used to construct this module.
			box_lists (list[Boxes] | list[RotatedBoxes]):
				A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch.
				The box coordinates are defined on the original image and
				will be scaled by the `scales` argument of :class:`ROIPooler`.

		Returns:
			Tensor:
				A tensor of shape (M, C, output_size, output_size) where M is the total number of
				boxes aggregated over all N batch images and C is the number of channels in `x`.
		"""
		num_level_assignments = len(self.level_poolers)

		assert isinstance(x, list) and isinstance(
			box_lists, list
		), "Arguments to pooler must be lists"
		assert (
			len(x) == num_level_assignments
		), "unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
			num_level_assignments, len(x)
		)

		assert len(box_lists) == x[0].size(
			0
		), "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
			x[0].size(0), len(box_lists)
		)
		if len(box_lists) == 0:
			return torch.zeros(
				(0, x[0].shape[1]) + self.output_size, device=x[0].device, dtype=x[0].dtype
			)

		pooler_fmt_boxes = convert_boxes_to_pooler_format(box_lists)

		if num_level_assignments == 1:
			return self.level_poolers[0](x[0], pooler_fmt_boxes)

		level_assignments = assign_boxes_to_levels(
			box_lists, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level
		)

		num_boxes = pooler_fmt_boxes.size(0)
		num_channels = x[0].shape[1]
		output_size = self.output_size[0]

		dtype, device = x[0].dtype, x[0].device

		if(self.output_size == 14):
			boxes = pooler_fmt_boxes
			scales = []
			for i in range(boxes.shape[0]):
				scales.append(self.level_poolers[level_assignments[i]].spatial_scale)
			scale = torch.tensor(scales,device=device)
			boxes[:,1:3] = torch.floor(boxes[:,1:3]*scale[:,None])
			boxes[:,3:5] = torch.ceil(boxes[:,3:5]*scale[:,None])
			boxes = boxes.to(device=device,dtype=torch.long)

			boxes[boxes[:,1]< 0,1] = 0
			boxes[boxes[:,2]< 0,2] = 0

			#boxes[boxes[:,3] >= feats[0].shape[-1],3] = feats[0].shape[-1]-1
			#boxes[boxes[:,4] >= feats[0].shape[-2],4] = feats[0].shape[-2]-1
			mask = torch.logical_and((boxes[:,3]-boxes[:,1]) > 1,(boxes[:,4]-boxes[:,2]) > 1)
			height = boxes[:,4] - boxes[:,2] + 1
			width = boxes[:,3] - boxes[:,1] + 1
			if boxes.shape > 0 :
				max_h,max_w = torch.max(torch.max(height),0)[0], torch.max(torch.max(width),0)[0]
				max_h,max_w = torch.max(torch.tensor([max_h,3])), torch.max(torch.tensor([max_w,3]))
			else:
				max_h,max_w = torch.tensor(1,device=device),torch.tensor(1,device=device)
			output = torch.zeros(
				(num_boxes, num_channels, max_h, max_w), dtype=dtype, device=device
			)
			
			
			for i in range(boxes.shape[0]):
				ind,x0,y0,x1,y1 = boxes[i]
				print(x1,x[level_assignments[i]][0].shape[-1]-1)
				x1 = torch.min(x1,x[level_assignments[i]][0].shape[-1]-1)
				y1 = torch.min(y1,x[level_assignments[i]][0].shape[-2]-1)
				boxes[i][3] = x1
				boxes[i][4] = y1
				output[i,:,:y1-y0+1,:x1-x0+1] = x[level_assignments[i]][ind][:,y0:y1+1,x0:x1+1]

			boxes[:,0] = torch.arange(boxes.shape[0]) ## i changes this from 0
			boxes[:,3:5] -= boxes[:,1:3]
			boxes[:,1:3] = 0


			return output, boxes
		
		elif(self.output_size == 7):

			output = torch.zeros(
			(num_boxes, num_channels, output_size, output_size), dtype=dtype, device=device)

			for level, pooler in enumerate(self.level_poolers):
				inds = nonzero_tuple(level_assignments == level)[0]
				pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
				output[inds] = pooler(x[level], pooler_fmt_boxes_level)

			return output
def get_local_maps(img, sr, sl):

    c, h, w = img.shape
    cy, cx = (h - 1) / 2.0, (w - 1) / 2.0

    map_maxv = torch.zeros((c, h + 1, w + 1))
    map_maxv[:, :-1, :-1] = torch.max(map_maxv[:, :-1, :-1], img)
    map_maxv[:, +1:, :-1] = torch.max(map_maxv[:, +1:, :-1], img)
    map_maxv[:, :-1, +1:] = torch.max(map_maxv[:, :-1, +1:], img)
    map_maxv[:, +1:, +1:] = torch.max(map_maxv[:, +1:, +1:], img)
    map_minv = torch.zeros((c, h + 1, w + 1))
    map_minv[:, 1:-1, 1:-1] = img[:, :-1, :-1]
    map_minv[:, 1:-1, 1:-1] = torch.min(map_minv[:, 1:-1, 1:-1], img[:, :-1,
                                                                     +1:])
    map_minv[:, 1:-1, 1:-1] = torch.min(map_minv[:, 1:-1, 1:-1], img[:,
                                                                     +1:, :-1])
    map_minv[:, 1:-1, 1:-1] = torch.min(map_minv[:, 1:-1, 1:-1], img[:, +1:,
                                                                     +1:])
    map_maxd = map_maxv - map_minv

    # brute force for correctness checking
    # t_maxvr_map = torch.zeros_like(img)
    # t_maxdr_map = torch.zeros_like(img)
    # t_maxvl_map = torch.zeros_like(img)
    # t_maxdl_map = torch.zeros_like(img)
    # for i in range(h):
    #     for j in range(w):
    #         nir = (i - cy) / sr + cy
    #         njr = (j - cx) / sr + cx
    #         nil = (i - cy) / sl + cy
    #         njl = (j - cx) / sl + cx
    #         nirf, njrf, nilf, njlf = math.floor(nir), math.floor(njr), math.floor(nil), math.floor(njl)
    #         nir, njr, nil, njl = math.ceil(nir), math.ceil(njr), math.ceil(nil), math.ceil(njl)
    #         nir, njr, nil, njl = int(nir), int(njr), int(nil), int(njl)
    #         if 0 <= nirf and nir <= h-1 and 0 <= njrf and njr <= w-1:
    #             t_maxvr_map[:, i, j] = torch.max(t_maxvr_map[:, i, j], map_maxv[:, nir, njr])
    #             t_maxdr_map[:, i, j] = torch.max(t_maxdr_map[:, i, j], map_maxd[:, nir, njr])
    #         if 0 <= nilf and nil <= h-1 and 0 <= njlf and njl <= w-1:
    #             t_maxvl_map[:, i, j] = torch.max(t_maxvl_map[:, i, j], map_maxv[:, nil, njl])
    #             t_maxdl_map[:, i, j] = torch.max(t_maxdl_map[:, i, j], map_maxd[:, nil, njl])
    # t_maxv_map = torch.max(t_maxvr_map, t_maxvl_map)
    # t_maxd_map = torch.max(t_maxdr_map, t_maxdl_map)
    # brute force part ends

    rows = torch.linspace(0.0, h - 1, steps=h)
    cols = torch.linspace(0.0, w - 1, steps=w)
    nyrs = (rows - cy) / sr + cy
    nxrs = (cols - cx) / sr + cx
    nyr_mat = nyrs.unsqueeze(1).repeat(1, w)
    nxr_mat = nxrs.repeat(h, 1)
    nyls = (rows - cy) / sl + cy
    nxls = (cols - cx) / sl + cx
    nyl_mat = nyls.unsqueeze(1).repeat(1, w)
    nxl_mat = nxls.repeat(h, 1)

    nxl_mat, nxr_mat, nyl_mat, nyr_mat = \
        torch.ceil(nxl_mat).type(torch.LongTensor), torch.ceil(nxr_mat).type(torch.LongTensor), \
        torch.ceil(nyl_mat).type(torch.LongTensor), torch.ceil(nyr_mat).type(torch.LongTensor)

    # handling sr
    il = max(math.ceil(cy * (1.0 - sr)), 0)
    ir = min(math.floor(sr * (h - 1) + cy * (1.0 - sr)), h - 1)
    jl = max(math.ceil(cx * (1.0 - sr)), 0)
    jr = min(math.floor(sr * (w - 1) + cy * (1.0 - sr)), w - 1)
    # il = max(math.floor(-sr + cy * (1.0 - sr)) + 1, 0)
    # ir = min(math.floor(sr * h + cy * (1.0 - sr)), h-1)
    # jl = max(math.floor(-sr + cx * (1.0 - sr)) + 1, 0)
    # jr = min(math.floor(sr * w + cx * (1.0 - sr)), w-1)

    maxv_sr_mat = torch.zeros_like(img)
    maxv_sr_mat[:, il:ir + 1, jl:jr + 1] = torch.gather(
        map_maxv.reshape(c, (h + 1) * (w + 1)),
        dim=1,
        index=(nyr_mat[il:ir + 1, jl:jr + 1] * (w + 1) +
               nxr_mat[il:ir + 1, jl:jr + 1]).flatten().repeat(c, 1)).reshape(
                   c, ir - il + 1, jr - jl + 1)

    maxd_sr_mat = torch.zeros_like(img)
    maxd_sr_mat[:, il:ir + 1, jl:jr + 1] = torch.gather(
        map_maxd.reshape(c, (h + 1) * (w + 1)),
        dim=1,
        index=(nyr_mat[il:ir + 1, jl:jr + 1] * (w + 1) +
               nxr_mat[il:ir + 1, jl:jr + 1]).flatten().repeat(c, 1)).reshape(
                   c, ir - il + 1, jr - jl + 1)

    # maxv_sr_mat_old = torch.zeros_like(img)
    # maxv_sr_mat_old[:, il: ir+1, jl: jr+1] = torch.gather(
    #     torch.index_select(map_maxv, dim=1, index=nyr_mat[il: ir + 1, jl: jr + 1].flatten()),
    #     dim=2, index=nxr_mat[il: ir + 1, jl: jr + 1].flatten().repeat(c, 1).unsqueeze(2)).reshape(c, ir-il+1, jr-jl+1)
    # maxd_sr_mat_old = torch.zeros_like(img)
    # maxd_sr_mat_old[:, il: ir+1, jl: jr+1] = torch.gather(
    #     torch.index_select(map_maxd, dim=1, index=nyr_mat[il: ir + 1, jl: jr + 1].flatten()),
    #     dim=2, index=nxr_mat[il: ir + 1, jl: jr + 1].flatten().repeat(c, 1).unsqueeze(2)).reshape(c, ir-il+1, jr-jl+1)
    #
    # diff(maxv_sr_mat, maxv_sr_mat_old)
    # diff(maxd_sr_mat, maxd_sr_mat_old)

    # handling sl
    il = max(math.ceil(cy * (1.0 - sl)), 0)
    ir = min(math.floor(sl * (h - 1) + cy * (1.0 - sl)), h - 1)
    jl = max(math.ceil(cx * (1.0 - sl)), 0)
    jr = min(math.floor(sl * (w - 1) + cy * (1.0 - sl)), w - 1)
    # il = max(math.floor(-sl + cy * (1.0 - sl)) + 1, 0)
    # ir = min(math.floor(sl * h + cy * (1.0 - sl)), h - 1)
    # jl = max(math.floor(-sl + cx * (1.0 - sl)) + 1, 0)
    # jr = min(math.floor(sl * w + cx * (1.0 - sl)), w - 1)

    maxv_sl_mat = torch.zeros_like(img)
    maxv_sl_mat[:, il:ir + 1, jl:jr + 1] = torch.gather(
        map_maxv.reshape(c, (h + 1) * (w + 1)),
        dim=1,
        index=(nyl_mat[il:ir + 1, jl:jr + 1] * (w + 1) +
               nxl_mat[il:ir + 1, jl:jr + 1]).flatten().repeat(c, 1)).reshape(
                   c, ir - il + 1, jr - jl + 1)

    maxd_sl_mat = torch.zeros_like(img)
    maxd_sl_mat[:, il:ir + 1, jl:jr + 1] = torch.gather(
        map_maxd.reshape(c, (h + 1) * (w + 1)),
        dim=1,
        index=(nyl_mat[il:ir + 1, jl:jr + 1] * (w + 1) +
               nxl_mat[il:ir + 1, jl:jr + 1]).flatten().repeat(c, 1)).reshape(
                   c, ir - il + 1, jr - jl + 1)

    # maxv_sl_mat_old = torch.zeros_like(img)
    # maxv_sl_mat_old[:, il: ir+1, jl: jr+1] = torch.gather(
    #     torch.index_select(map_maxv, dim=1, index=nyl_mat[il: ir + 1, jl: jr + 1].flatten()),
    #     dim=2, index=nxl_mat[il: ir + 1, jl: jr + 1].flatten().repeat(c, 1).unsqueeze(2)).reshape(c, ir-il+1, jr-jl+1)
    # maxd_sl_mat_old = torch.zeros_like(img)
    # maxd_sl_mat_old[:, il: ir+1, jl: jr+1] = torch.gather(
    #     torch.index_select(map_maxd, dim=1, index=nyl_mat[il: ir + 1, jl: jr + 1].flatten()),
    #     dim=2, index=nxl_mat[il: ir + 1, jl: jr + 1].flatten().repeat(c, 1).unsqueeze(2)).reshape(c, ir-il+1, jr-jl+1)
    #
    # diff(maxv_sl_mat, maxv_sl_mat_old)
    # diff(maxd_sl_mat, maxd_sl_mat_old)

    ret_maxv = torch.max(maxv_sl_mat, maxv_sr_mat)
    ret_maxd = torch.max(maxd_sl_mat, maxd_sr_mat)

    # print('diff (error checking):', torch.sum(torch.abs(ret_maxv - t_maxv_map)), torch.sum(torch.abs(ret_maxd - t_maxd_map)))

    return ret_maxv, ret_maxd
Beispiel #14
0
def L1_projection(x2, y2, eps1):
    '''
    x2: center of the L1 ball (bs x input_dim)
    y2: current perturbation (x2 + y2 is the point to be projected)
    eps1: radius of the L1 ball

    output: delta s.th. ||y2 + delta||_1 <= eps1
    and 0 <= x2 + y2 + delta <= 1
    '''

    x = x2.clone().float().view(x2.shape[0], -1)
    y = y2.clone().float().view(y2.shape[0], -1)
    sigma = y.clone().sign()
    u = torch.min(1 - x - y, x + y)
    #u = torch.min(u, epsinf - torch.clone(y).abs())
    u = torch.min(torch.zeros_like(y), u)
    l = -torch.clone(y).abs()
    d = u.clone()

    bs, indbs = torch.sort(-torch.cat((u, l), 1), dim=1)
    bs2 = torch.cat((bs[:, 1:], torch.zeros(bs.shape[0], 1).to(bs.device)), 1)

    inu = 2 * (indbs < u.shape[1]).float() - 1
    size1 = inu.cumsum(dim=1)

    s1 = -u.sum(dim=1)

    c = eps1 - y.clone().abs().sum(dim=1)
    c5 = s1 + c < 0
    c2 = c5.nonzero().squeeze(1)

    s = s1.unsqueeze(-1) + torch.cumsum((bs2 - bs) * size1, dim=1)

    if c2.nelement != 0:

        lb = torch.zeros_like(c2).float()
        ub = torch.ones_like(lb) * (bs.shape[1] - 1)

        #print(c2.shape, lb.shape)

        nitermax = torch.ceil(torch.log2(torch.tensor(bs.shape[1]).float()))
        counter2 = torch.zeros_like(lb).long()
        counter = 0

        while counter < nitermax:
            counter4 = torch.floor((lb + ub) / 2.)
            counter2 = counter4.type(torch.LongTensor)

            c8 = s[c2, counter2] + c[c2] < 0
            ind3 = c8.nonzero().squeeze(1)
            ind32 = (~c8).nonzero().squeeze(1)
            #print(ind3.shape)
            if ind3.nelement != 0:
                lb[ind3] = counter4[ind3]
            if ind32.nelement != 0:
                ub[ind32] = counter4[ind32]

            #print(lb, ub)
            counter += 1

        lb2 = lb.long()
        alpha = (-s[c2, lb2] - c[c2]) / size1[c2, lb2 + 1] + bs2[c2, lb2]
        d[c2] = -torch.min(torch.max(-u[c2], alpha.unsqueeze(-1)), -l[c2])

    return (sigma * d).view(x2.shape)
Beispiel #15
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'template_images', search_images', 'template_anno', 'search_anno'
        returns:
            TensorDict - output data block with following fields:
                'template_images', 'search_images', 'template_anno', 'search_anno', 'test_proposals', 'proposal_iou'
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            data['template_images'], data['template_anno'], data[
                'template_masks'] = self.transform['joint'](
                    image=data['template_images'],
                    bbox=data['template_anno'],
                    mask=data['template_masks'])
            data['search_images'], data['search_anno'], data[
                'search_masks'] = self.transform['joint'](
                    image=data['search_images'],
                    bbox=data['search_anno'],
                    mask=data['search_masks'],
                    new_roll=False)

        for s in ['template', 'search']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # 2021.1.9 Check whether data is valid. Avoid too small bounding boxes
            w, h = torch.stack(jittered_anno,
                               dim=0)[:, 2], torch.stack(jittered_anno,
                                                         dim=0)[:, 3]

            crop_sz = torch.ceil(
                torch.sqrt(w * h) * self.search_area_factor[s])
            if (crop_sz < 1).any():
                data['valid'] = False
                # print("Too small box is found. Replace it with new data.")
                return data

            # Crop image region centered at jittered_anno box and get the attention mask
            crops, boxes, att_mask, mask_crops = prutils.jittered_center_crop(
                data[s + '_images'],
                jittered_anno,
                data[s + '_anno'],
                self.search_area_factor[s],
                self.output_sz[s],
                masks=data[s + '_masks'])
            # Apply transforms
            data[s + '_images'], data[s + '_anno'], data[s + '_att'], data[
                s + '_masks'] = self.transform[s](image=crops,
                                                  bbox=boxes,
                                                  att=att_mask,
                                                  mask=mask_crops,
                                                  joint=False)

            # 2021.1.9 Check whether elements in data[s + '_att'] is all 1
            # Note that type of data[s + '_att'] is tuple, type of ele is torch.tensor
            for ele in data[s + '_att']:
                if (ele == 1).all():
                    data['valid'] = False
                    # print("Values of original attention mask are all one. Replace it with new data.")
                    return data
            # 2021.1.10 more strict conditions: require the donwsampled masks not to be all 1
            for ele in data[s + '_att']:
                feat_size = self.output_sz[s] // 16  # 16 is the backbone stride
                # (1,1,128,128) (1,1,256,256) --> (1,1,8,8) (1,1,16,16)
                mask_down = F.interpolate(ele[None, None].float(),
                                          size=feat_size).to(torch.bool)[0]
                if (mask_down == 1).all():
                    data['valid'] = False
                    # print("Values of down-sampled attention mask are all one. "
                    #       "Replace it with new data.")
                    return data

        data['valid'] = True
        # if we use copy-and-paste augmentation
        if data["template_masks"] is None or data["search_masks"] is None:
            data["template_masks"] = torch.zeros(
                (1, self.output_sz["template"], self.output_sz["template"]))
            data["search_masks"] = torch.zeros(
                (1, self.output_sz["search"], self.output_sz["search"]))
        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
def genFlowVector4Visualization(F_fine2coarse):
    F1tmp = F_fine2coarse[0]
    H, W = F1tmp.shape[1:]

    maxvalXMask = torch.ones(1, 1) * (W - 1)
    maxvalXMask = maxvalXMask.repeat(H, W)  #.to(device)
    maxvalYMask = torch.ones(1, 1) * (H - 1)
    maxvalYMask = maxvalYMask.repeat(H, W)  #.to(device)
    minvalMask = torch.zeros(1, 1)
    minvalMask = minvalMask.repeat(H, W)  #.to(device)

    UV = torch.zeros_like(F1tmp)
    grid_x = torch.arange(0, W).view(1, -1).repeat(H, 1).float()  #.to(device)
    grid_y = torch.arange(0, H).view(-1, 1).repeat(1, W).float()  #.to(device)
    #ylist, xlist = grid_y.numpy(), grid_x.numpy()
    ycoord, xcoord = grid_y, grid_x
    for i, Fvec in enumerate(F_fine2coarse):
        xcoord_round = torch.round(xcoord)
        xcoord_round = clipTensor(xcoord_round, maxvalXMask, minvalMask)
        ycoord_round = torch.round(ycoord)
        ycoord_round = clipTensor(ycoord_round, maxvalYMask, minvalMask)
        xcoord_ceil = torch.ceil(xcoord)
        xcoord_ceil = clipTensor(xcoord_ceil, maxvalXMask, minvalMask)
        xcoord_floor = torch.floor(xcoord)
        xcoord_floor = clipTensor(xcoord_floor, maxvalXMask, minvalMask)
        ycoord_ceil = torch.ceil(ycoord)
        ycoord_ceil = clipTensor(ycoord_ceil, maxvalYMask, minvalMask)
        ycoord_floor = torch.floor(ycoord)
        ycoord_floor = clipTensor(ycoord_floor, maxvalYMask, minvalMask)

        xcoord_round = xcoord_round.detach().cpu().numpy()
        ycoord_round = ycoord_round.detach().cpu().numpy()
        xcoord_ceil = xcoord_ceil.detach().cpu().numpy()
        xcoord_floor = xcoord_floor.detach().cpu().numpy()
        ycoord_ceil = ycoord_ceil.detach().cpu().numpy()
        ycoord_floor = ycoord_floor.detach().cpu().numpy()

        xlist_supp_round, ylist_supp_round = Fvec[
            0, ycoord_round, xcoord_round], Fvec[1, ycoord_round, xcoord_round]
        xlist_supp_UL, ylist_supp_UL = Fvec[0, ycoord_floor,
                                            xcoord_floor], Fvec[1,
                                                                ycoord_floor,
                                                                xcoord_floor]
        xlist_supp_UR, ylist_supp_UR = Fvec[0, ycoord_floor,
                                            xcoord_ceil], Fvec[1, ycoord_floor,
                                                               xcoord_ceil]
        xlist_supp_BL, ylist_supp_BL = Fvec[0, ycoord_ceil,
                                            xcoord_floor], Fvec[1, ycoord_ceil,
                                                                xcoord_floor]
        xlist_supp_BR, ylist_supp_BR = Fvec[0, ycoord_ceil,
                                            xcoord_ceil], Fvec[1, ycoord_ceil,
                                                               xcoord_ceil]

        xcoord_ceil = torch.from_numpy(xcoord_ceil)
        xcoord_floor = torch.from_numpy(xcoord_floor)
        ycoord_ceil = torch.from_numpy(ycoord_ceil)
        ycoord_floor = torch.from_numpy(ycoord_floor)

        dominatorTMP = xcoord_ceil - xcoord_floor
        dominatorTMP[dominatorTMP == 0] = 1
        wLeft = xcoord_ceil - xcoord
        wRight = xcoord - xcoord_floor
        wLeft[wLeft == 0] = 0.5
        wRight[wRight == 0] = 0.5

        xlist_supp_u = wLeft / dominatorTMP * xlist_supp_UL + wRight / dominatorTMP * xlist_supp_UR
        xlist_supp_b = wLeft / dominatorTMP * xlist_supp_BL + wRight / dominatorTMP * xlist_supp_BR

        dominatorTMP = ycoord_ceil - ycoord_floor
        dominatorTMP[dominatorTMP == 0] = 1
        wUpper = ycoord_ceil - ycoord
        wBottom = ycoord - ycoord_floor
        wUpper[wUpper == 0] = 0.5
        wBottom[wBottom == 0] = 0.5
        xlist_supp = wUpper / dominatorTMP * xlist_supp_u + wBottom / dominatorTMP * xlist_supp_b

        dominatorTMP = xcoord_ceil - xcoord_floor
        dominatorTMP[dominatorTMP == 0] = 1
        wLeft = xcoord_ceil - xcoord
        wRight = xcoord - xcoord_floor
        wLeft[wLeft == 0] = 0.5
        wRight[wRight == 0] = 0.5

        ylist_supp_u = wLeft / dominatorTMP * ylist_supp_UL + wRight / dominatorTMP * ylist_supp_UR
        ylist_supp_b = wLeft / dominatorTMP * ylist_supp_BL + wRight / dominatorTMP * ylist_supp_BR

        dominatorTMP = ycoord_ceil - ycoord_floor
        dominatorTMP[dominatorTMP == 0] = 1
        wUpper = ycoord_ceil - ycoord
        wBottom = ycoord - ycoord_floor
        wUpper[wUpper == 0] = 0.5
        wBottom[wBottom == 0] = 0.5
        ylist_supp = wUpper / dominatorTMP * ylist_supp_u + wBottom / dominatorTMP * ylist_supp_b

        if i == len(F_fine2coarse) - 1:
            xlist_supp, ylist_supp = xlist_supp_round, ylist_supp_round
            #xlist, ylist = xcoord-grid_x.detach().cpu(), ycoord-grid_y.detach().cpu()
            #xlist, ylist = torch.round(xlist), torch.round(ylist)

        xcoord, ycoord = xlist_supp + xcoord, ylist_supp + ycoord
        xcoord = xcoord.detach().cpu()  #.numpy()
        ycoord = ycoord.detach().cpu()  #.numpy()

        if i == len(F_fine2coarse) - 1:
            xlist, ylist = xcoord - grid_x.detach().cpu(
            ), ycoord - grid_y.detach().cpu()

    UV[0] = xlist.view(1, H, W)
    UV[1] = ylist.view(1, H, W)
    return UV
Beispiel #17
0
    def forward(self, src_tokens, src_lengths):
        # embed tokens and positions
        x = src_tokens
        x = torch.tanh(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = torch.tanh(self.fc2(x))

        # B x T x C -> B x 1 x T x C
        x = x.unsqueeze(1)
        # temporal convolutions
        for conv in self.convolutions:
            x = F.dropout(x, p=self.dropout, training=self.training)
            if conv.kernel_size[0] % 2 == 1:
                # padding is implicit in the conv
                x = conv(x)
            else:
                padding_l = (conv.kernel_size[0] - 1) // 2
                padding_r = conv.kernel_size[0] // 2
                x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
                x = conv(x)
            src_lengths = torch.ceil(src_lengths.float() / 2).long()

        residual = x

        for conv in self.deep_convolutions:
            norm = torch.nn.BatchNorm2d(16).cuda()
            x = norm(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if conv.kernel_size[0] % 2 == 1:
                # padding is implicit in the conv
                x = conv(x)
            else:
                padding_l = (conv.kernel_size[0] - 1) // 2
                padding_r = conv.kernel_size[0] // 2
                x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
                x = conv(x)
            # x = F.relu(x)
            x += residual
            residual = x

        # B x Cout x T x F -> T x B x C
        bsz, out_channels, time, feats = x.size()
        x = x.transpose(1, 2).contiguous().view(bsz, time, -1) \
            .contiguous().transpose(0, 1)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.relu(self.fc3(x))

        x = x + self.embed_positions(x.transpose(0, 1), src_lengths).transpose(
            0, 1)
        x = F.dropout(x, p=self.dropout, training=self.training)

        encoder_padding_mask = self.create_mask(src_lengths)

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.normalize:
            x = self.layer_norm(x)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }
Beispiel #18
0
    def train_Dnet(self, idx, count):
        if idx == 0 or idx == 2:  # Discriminator is only trained in background and child stage. (NOT in parent stage)
            flag = count % 100
            batch_size = self.real_fimgs[0].size(0)
            criterion, criterion_one = self.criterion, self.criterion_one

            netD, optD = self.netsD[idx], self.optimizersD[idx]
            if idx == 0:
                real_imgs = self.real_fimgs[0]

            elif idx == 2:
                real_imgs = self.real_cimgs[0]

            fake_imgs = self.fake_imgs[idx]
            netD.zero_grad()
            real_logits = netD(real_imgs)

            if idx == 2:
                fake_labels = torch.zeros_like(real_logits[1])
                real_labels = torch.ones_like(real_logits[1])
            elif idx == 0:

                fake_labels = torch.zeros_like(real_logits[1])
                ext, output = real_logits
                weights_real = torch.ones_like(output)
                real_labels = torch.ones_like(output)

                for i in range(batch_size):
                    x1 = self.warped_bbox[0][i]
                    x2 = self.warped_bbox[2][i]
                    y1 = self.warped_bbox[1][i]
                    y2 = self.warped_bbox[3][i]

                    a1 = max(
                        torch.tensor(0).float().cuda(),
                        torch.ceil((x1 - self.recp_field) / self.patch_stride))
                    a2 = min(
                        torch.tensor(self.n_out - 1).float().cuda(),
                        torch.floor((self.n_out - 1) -
                                    ((126 - self.recp_field) - x2) /
                                    self.patch_stride)) + 1
                    b1 = max(
                        torch.tensor(0).float().cuda(),
                        torch.ceil((y1 - self.recp_field) / self.patch_stride))
                    b2 = min(
                        torch.tensor(self.n_out - 1).float().cuda(),
                        torch.floor((self.n_out - 1) -
                                    ((126 - self.recp_field) - y2) /
                                    self.patch_stride)) + 1

                    if (x1 != x2 and y1 != y2):
                        weights_real[
                            i, :,
                            a1.type(torch.int):a2.type(torch.int),
                            b1.type(torch.int):b2.type(torch.int)] = 0.0

                norm_fact_real = weights_real.sum()
                norm_fact_fake = weights_real.shape[0] * weights_real.shape[
                    1] * weights_real.shape[2] * weights_real.shape[3]
                real_logits = ext, output

            fake_logits = netD(fake_imgs.detach())

            if idx == 0:  # Background stage

                errD_real_uncond = criterion(
                    real_logits[1], real_labels
                )  # Real/Fake loss for 'real background' (on patch level)
                errD_real_uncond = torch.mul(
                    errD_real_uncond, weights_real
                )  # Masking output units which correspond to receptive fields which lie within the boundin box
                errD_real_uncond = errD_real_uncond.mean()

                errD_real_uncond_classi = criterion(
                    real_logits[0],
                    weights_real)  # Background/foreground classification loss
                errD_real_uncond_classi = errD_real_uncond_classi.mean()

                errD_fake_uncond = criterion(
                    fake_logits[1], fake_labels
                )  # Real/Fake loss for 'fake background' (on patch level)
                errD_fake_uncond = errD_fake_uncond.mean()

                if (
                        norm_fact_real > 0
                ):  # Normalizing the real/fake loss for background after accounting the number of masked members in the output.
                    errD_real = errD_real_uncond * ((norm_fact_fake * 1.0) /
                                                    (norm_fact_real * 1.0))
                else:
                    errD_real = errD_real_uncond

                errD_fake = errD_fake_uncond
                errD = ((errD_real + errD_fake) *
                        cfg.TRAIN.BG_LOSS_WT) + errD_real_uncond_classi

            if idx == 2:

                errD_real = criterion_one(
                    real_logits[1],
                    real_labels)  # Real/Fake loss for the real image
                errD_fake = criterion_one(
                    fake_logits[1],
                    fake_labels)  # Real/Fake loss for the fake image
                errD = errD_real + errD_fake

            if (idx == 0 or idx == 2):
                errD.backward()
                optD.step()

            if (flag == 0):
                summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
                self.summary_writer.add_summary(summary_D, count)
                summary_D_real = summary.scalar('D_loss_real_%d' % idx,
                                                errD_real.data[0])
                self.summary_writer.add_summary(summary_D_real, count)
                summary_D_fake = summary.scalar('D_loss_fake_%d' % idx,
                                                errD_fake.data[0])
                self.summary_writer.add_summary(summary_D_fake, count)

            return errD
Beispiel #19
0
def interpolate_depth(pos, depth):
    device = pos.device

    ids = torch.arange(0, pos.size(1), device=device)

    h, w = depth.size()

    i = pos[0, :]
    j = pos[1, :]

    # Valid corners
    i_top_left = torch.floor(i).long()
    j_top_left = torch.floor(j).long()
    valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)

    i_top_right = torch.floor(i).long()
    j_top_right = torch.ceil(j).long()
    valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)

    i_bottom_left = torch.ceil(i).long()
    j_bottom_left = torch.floor(j).long()
    valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)

    i_bottom_right = torch.ceil(i).long()
    j_bottom_right = torch.ceil(j).long()
    valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)

    valid_corners = torch.min(torch.min(valid_top_left, valid_top_right),
                              torch.min(valid_bottom_left, valid_bottom_right))

    i_top_left = i_top_left[valid_corners]
    j_top_left = j_top_left[valid_corners]

    i_top_right = i_top_right[valid_corners]
    j_top_right = j_top_right[valid_corners]

    i_bottom_left = i_bottom_left[valid_corners]
    j_bottom_left = j_bottom_left[valid_corners]

    i_bottom_right = i_bottom_right[valid_corners]
    j_bottom_right = j_bottom_right[valid_corners]

    ids = ids[valid_corners]
    if ids.size(0) == 0:
        raise EmptyTensorError

    # Valid depth
    valid_depth = torch.min(
        torch.min(depth[i_top_left, j_top_left] > 0,
                  depth[i_top_right, j_top_right] > 0),
        torch.min(depth[i_bottom_left, j_bottom_left] > 0,
                  depth[i_bottom_right, j_bottom_right] > 0))

    i_top_left = i_top_left[valid_depth]
    j_top_left = j_top_left[valid_depth]

    i_top_right = i_top_right[valid_depth]
    j_top_right = j_top_right[valid_depth]

    i_bottom_left = i_bottom_left[valid_depth]
    j_bottom_left = j_bottom_left[valid_depth]

    i_bottom_right = i_bottom_right[valid_depth]
    j_bottom_right = j_bottom_right[valid_depth]

    ids = ids[valid_depth]
    if ids.size(0) == 0:
        raise EmptyTensorError

    # Interpolation
    i = i[ids]
    j = j[ids]
    dist_i_top_left = i - i_top_left.float()
    dist_j_top_left = j - j_top_left.float()
    w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
    w_top_right = (1 - dist_i_top_left) * dist_j_top_left
    w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
    w_bottom_right = dist_i_top_left * dist_j_top_left

    interpolated_depth = (
        w_top_left * depth[i_top_left, j_top_left] +
        w_top_right * depth[i_top_right, j_top_right] +
        w_bottom_left * depth[i_bottom_left, j_bottom_left] +
        w_bottom_right * depth[i_bottom_right, j_bottom_right])

    pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)

    return [interpolated_depth, pos, ids]
Beispiel #20
0
def expm_taylor(A):
    if A.ndimension() < 2 or A.size(-2) != A.size(-1):
        raise ValueError(
            "Expected a square matrix or a batch of square matrices")

    if A.ndimension() == 2:
        # Just one matrix

        # Trivial case
        if A.size() == (1, 1):
            return torch.exp(A)

        if A.element_size() > 4:
            thetas = thetas_dict["double"]
        else:
            thetas = thetas_dict["single"]

        normA = torch.max(torch.sum(torch.abs(A), axis=0)).item()

        # No scale-square needed
        # This could be done marginally faster if iterated in reverse
        for deg, theta in zip(degs, thetas):
            if normA <= theta:
                return taylor_approx(A, deg)

        # Scale square
        s = int(math.ceil(math.log2(normA) - math.log2(thetas[-1])))
        A = A * (2**-s)
        X = taylor_approx(A, degs[-1])
        return torch.matrix_power(X, 2**s)
    else:
        # Batching

        # Trivial case
        if A.size()[-2:] == (1, 1):
            return torch.exp(A)

        if A.element_size() > 4:
            thetas = thetas_dict["double"]
        else:
            thetas = thetas_dict["single"]

        normA = torch.max(torch.sum(torch.abs(A), axis=-2), axis=-1).values

        # Handle trivial case
        if (normA == 0.0).all():
            Id = torch.eye(A.size(-2),
                           A.size(-1),
                           dtype=A.dtype,
                           device=A.device)
            return Id.expand_as(A)

        # Handle small normA
        more = normA > thetas[-1]
        s = normA.new_zeros(normA.size(), dtype=torch.long)
        s[more] = torch.ceil(torch.log2(normA[more]) -
                             math.log2(thetas[-1])).long()

        # A = A * 2**(-s)
        A = torch.pow(0.5,
                      s.float()).unsqueeze_(-1).unsqueeze_(-1).expand_as(A) * A
        X = taylor_approx(A, degs[-1])
        return matrix_power_two_batch(X, s)
Beispiel #21
0
def ceil(x):
    return torch.ceil(x).int()
Beispiel #22
0
def forward(inj_time, Mc, sphi, stheta, dl0, cosi):
    det_num = 3
    wave_length = 8192
    noise_scale = 8.0e-22
    pi = 3.1415926
    G = 6.673e-11
    c = 299792458.0
    Mpc = 3.08567758e22
    Msun = 1.989e30
    fs = 8192
    T = 1
    nsamples = T * fs
    t = np.arange(nsamples) / fs

    Det1_V = np.array([-2.161414928e+06, -3.834695183e+06, 4.600350224e+06])
    Det2_V = np.array([-7.427604192e+04, -5.496283721e+06, 3.224257016e+06])
    Det3_V = np.array([4546374.0, 842990.0, 4378577.0])
    Det1_d = np.array(
        [[-0.392614701790361, -0.077612252813702, -0.247388405118613],
         [-0.077612252813702, 0.319524089053145, 0.227998293910978],
         [-0.247388405118613, 0.227998293910978, 0.073090613199948]])
    Det2_d = np.array(
        [[0.411281743683125, 0.140209630402064, 0.247293475274344],
         [0.140209630402064, -0.109005942619247, -0.181616030843724],
         [0.247293475274344, -0.181616030843724, -0.302275800865383]])
    Det3_d = np.array(
        [[0.243874678248284, -0.099086615422263, -0.232575796255783],
         [-0.099086615422263, -0.447827871578090, 0.187828534783639],
         [-0.232575796255783, 0.187828534783639, 0.203953193329806]])

    dl = dl0 * Mpc

    m1 = torch.sin(sphi) * torch.cos(spsi) - torch.cos(sphi) * torch.cos(
        stheta) * torch.sin(spsi)
    m2 = -torch.cos(sphi) * torch.cos(spsi) - torch.sin(sphi) * torch.cos(
        stheta) * torch.sin(spsi)
    m3 = torch.sin(stheta) * torch.sin(spsi)
    n1 = -torch.sin(sphi) * torch.sin(spsi) - torch.cos(sphi) * torch.cos(
        stheta) * torch.cos(spsi)
    n2 = torch.cos(sphi) * torch.sin(spsi) - torch.sin(sphi) * torch.cos(
        stheta) * torch.cos(spsi)
    n3 = torch.sin(stheta) * torch.cos(spsi)
    mm = torch.cat((m1 * m1, m1 * m2, m1 * m3, m2 * m1, m2 * m2, m2 * m3,
                    m3 * m1, m3 * m2, m3 * m3), 0)
    mn = torch.cat((m1 * n1, m1 * n2, m1 * n3, m2 * n1, m2 * n2, m2 * n3,
                    m3 * n1, m3 * n2, m3 * n3), 0)
    nm = torch.cat((n1 * m1, n1 * m2, n1 * m3, n2 * m1, n2 * m2, n2 * m3,
                    n3 * m1, n3 * m2, n3 * m3), 0)
    nn = torch.cat((n1 * n1, n1 * n2, n1 * n3, n2 * n1, n2 * n2, n2 * n3,
                    n3 * n1, n3 * n2, n3 * n3), 0)
    e_plus = mm - nn
    e_cross = mn + nm
    d1 = torch.from_numpy(Det1_d.reshape(9))
    d2 = torch.from_numpy(Det2_d.reshape(9))
    d3 = torch.from_numpy(Det3_d.reshape(9))
    Fp1 = torch.sum(e_plus * Variable(d1))
    Fx1 = torch.sum(e_cross * Variable(d1))
    Fp2 = torch.sum(e_plus * Variable(d2))
    Fx2 = torch.sum(e_cross * Variable(d2))
    Fp3 = torch.sum(e_plus * Variable(d3))
    Fx3 = torch.sum(e_cross * Variable(d3))

    omega = torch.cat((torch.sin(stheta) * torch.cos(sphi),
                       torch.sin(stheta) * torch.sin(sphi), torch.cos(stheta)),
                      0)

    delay_1 = -torch.sum(Variable(torch.from_numpy(Det1_V)) * omega) / c
    delay_2 = -torch.sum(Variable(torch.from_numpy(Det2_V)) * omega) / c
    delay_3 = -torch.sum(Variable(torch.from_numpy(Det3_V)) * omega) / c
    tc1 = inj_time + delay_1
    tc2 = inj_time + delay_2
    tc3 = inj_time + delay_3
    idinjt1 = torch.ceil(tc1 * fs)
    idinjt2 = torch.ceil(tc2 * fs)
    idinjt3 = torch.ceil(tc3 * fs)

    npbase = np.arange(fs) / fs
    base = Variable(torch.from_numpy(npbase))
    tau1 = tc1.expand(base.size()) - base
    tau2 = tc2.expand(base.size()) - base
    tau3 = tc3.expand(base.size()) - base
    #tau1_relu=0.5*(torch.sign(tau1)+1)
    #tau2_relu=0.5*(torch.sign(tau2)+1)
    #tau3_relu=0.5*(torch.sign(tau3)+1)

    tau1_phi = (torch.pow(relu(tau1), 5 / 8))
    phi_t1 = -2 * torch.pow(
        (5 * G * Mc / (c * c * c)), -5 / 8).expand(tau1_phi.size()) * tau1_phi
    tau1_Ah = torch.pow(relu(5 / (c * tau1)), 1 / 4)
    Ah1 = (1 / dl.expand(tau1_Ah.size())) * torch.pow(
        G * Mc / (c * c), 5 / 4).expand(tau1_Ah.size()) * tau1_Ah
    hp1 = 0.5 * (1 + torch.pow(cosi, 2.0)).expand(
        Ah1.size()) * (Ah1 * torch.cos(phi_t1).expand(Ah1.size()))
    hx1 = Ah1 * torch.cos(phi_t1).expand(Ah1.size()) * cosi.expand(Ah1.size())

    tau2_phi = torch.pow(relu(tau2), 5 / 8)
    phi_t2 = -2 * torch.pow(
        (5 * G * Mc / (c * c * c)), -5 / 8).expand(tau2_phi.size()) * tau2_phi
    tau2_Ah = torch.pow(relu(5 / (c * tau2)), 1 / 4)
    Ah2 = (1 / dl.expand(tau2_Ah.size())) * torch.pow(
        G * Mc / (c * c), 5 / 4).expand(tau2_Ah.size()) * tau2_Ah
    hp2 = 0.5 * (1 + torch.pow(cosi, 2.0)).expand(
        Ah2.size()) * (Ah2 * torch.cos(phi_t2).expand(Ah2.size()))
    hx2 = Ah2 * torch.cos(phi_t2).expand(Ah2.size()) * cosi.expand(Ah2.size())

    tau3_phi = torch.pow(relu(tau3), 5 / 8)
    phi_t3 = -2 * torch.pow(
        (5 * G * Mc / (c * c * c)), -5 / 8).expand(tau3_phi.size()) * tau3_phi
    tau3_Ah = torch.pow(relu(5 / (c * tau3)), 1 / 4)
    Ah3 = (1 / dl.expand(tau3_Ah.size())) * torch.pow(
        G * Mc / (c * c), 5 / 4).expand(tau3_Ah.size()) * tau3_Ah
    hp3 = 0.5 * (1 + torch.pow(cosi, 2.0)).expand(
        Ah3.size()) * (Ah3 * torch.cos(phi_t3).expand(Ah3.size()))
    hx3 = Ah3 * torch.cos(phi_t3).expand(Ah3.size()) * cosi.expand(Ah3.size())

    Wave1 = Fp1.expand(hp1.size()) * hp1 + Fx1.expand(hx1.size()) * hx1
    Wave2 = Fp2.expand(hp2.size()) * hp2 + Fx2.expand(hx2.size()) * hx2
    Wave3 = Fp3.expand(hp3.size()) * hp3 + Fx3.expand(hx3.size()) * hx3
    Wave = torch.cat((Wave1, Wave2, Wave3), 0).view(-1, fs)

    return Wave
Beispiel #23
0
 def get_seq_len(self, seq_len):
     return torch.ceil(seq_len / self.hop_length).to(dtype=torch.long)
Beispiel #24
0
def render_differentiable(mesh, direction, triangleIndexMap, lighting, sensor,
                          lighting_normal, sensor_normal, opt, device):
    angular_transient = torch.DoubleTensor(
        opt.max_distance_bin).fill_(0).to(device)

    triangleIndexMap = torch.from_numpy(triangleIndexMap).long().to(device)
    direction = torch.from_numpy(direction).to(device)
    lighting = torch.from_numpy(lighting).to(device)
    sensor = torch.from_numpy(sensor).to(device)

    d1 = initialize_variable(opt.sample_num, -1, device)
    d2 = initialize_variable(opt.sample_num, np.nan, device)
    intensity = initialize_variable(opt.sample_num, np.nan, device)
    uMap = initialize_variable(opt.sample_num, np.nan, device)
    vMap = initialize_variable(opt.sample_num, np.nan, device)
    cos_theta2 = initialize_variable(opt.sample_num, 0, device)
    intersection_p = initialize_variable((opt.sample_num, 3), np.nan, device)
    normalMap = initialize_variable((opt.sample_num, 3), np.nan, device)
    v2 = initialize_variable((opt.sample_num, 3), np.nan, device)
    tmp_e1 = initialize_variable((opt.sample_num, 3), np.nan, device)
    tmp_e2 = initialize_variable((opt.sample_num, 3), np.nan, device)
    fn = initialize_variable((opt.sample_num, 3), np.nan, device)
    fn_len = initialize_variable(opt.sample_num, np.nan, device)

    distance_bin = torch.LongTensor(
        opt.sample_num).fill_(opt.max_distance_bin + 1).to(device)

    inds = torch.squeeze((triangleIndexMap > -1).data.nonzero())

    for i in inds:
        uMap[i], vMap[i], d1[i] = intersect_ray_mesh_one_direction(
            mesh, torch.squeeze(direction[i, :]),
            torch.squeeze(mesh.f[triangleIndexMap[i], :]), lighting)

    inds = torch.squeeze((d1 > 0).data.nonzero())

    triangleIndexMap_input = torch.index_select(triangleIndexMap, 0, inds)

    data = element_multiply2(
        1 - uMap[inds] - vMap[inds],
        mesh.v[torch.index_select(mesh.f[:, 2], 0, triangleIndexMap_input), :]
    ) + element_multiply2(
        uMap[inds],
        mesh.v[torch.index_select(mesh.f[:, 0], 0, triangleIndexMap_input), :]
    ) + element_multiply2(
        vMap[inds],
        mesh.v[torch.index_select(mesh.f[:, 1], 0, triangleIndexMap_input), :])
    intersection_p.index_copy_(0, inds, data)

    v2[inds, :] = sensor - intersection_p[inds, :]
    d2[inds] = torch.sqrt(torch.sum(v2[inds, :]**2, 1))
    v2[inds, :] = element_divide2(v2[inds, :], d2[inds])

    tmp_e1[inds, :] = mesh.v[torch.index_select(
        mesh.f[:, 1], 0, triangleIndexMap_input), :] - mesh.v[
            torch.index_select(mesh.f[:, 0], 0, triangleIndexMap_input), :]
    tmp_e2[inds, :] = mesh.v[torch.index_select(
        mesh.f[:, 2], 0, triangleIndexMap_input), :] - mesh.v[
            torch.index_select(mesh.f[:, 0], 0, triangleIndexMap_input), :]

    fn[inds, :] = torch.cross(tmp_e1[inds, :], tmp_e2[inds, :], 1)
    fn_len[inds] = torch.sqrt(torch.sum(fn[inds, :]**2, 1))
    normalMap[inds, :] = element_divide2(fn[inds, :], fn_len[inds])
    cos_theta2[inds] = torch.sum(torch.mul(normalMap[inds, :], v2[inds, :]), 1)
    index = torch.squeeze((cos_theta2 < 0).data.nonzero())
    if index.dim() != 0:
        if len(index) != 0:
            cos_theta2 = cos_theta2.index_fill(0, index, 0)

    distance_bin[inds] = torch.ceil(
        (d1[inds] + d2[inds]) / opt.distance_resolution).long() - 1
    inds = torch.squeeze((distance_bin < opt.max_distance_bin).data.nonzero())
    if inds.dim() != 0:
        if len(inds) != 0:
            val = torch.div(cos_theta2.index_select(0, inds),
                            d2.index_select(0, inds)**2)
            intensity.index_copy_(0, inds, val)
            angular_transient.index_add_(0, distance_bin[inds],
                                         intensity[inds])
    angular_transient *= 2 * math.pi
    angular_transient /= opt.sample_num
    #smooth = torch.reshape(torch.from_numpy(np.array([0.2, 0.6, 0.2])), (1,1,3))
    #angular_transient = torch.reshape(angular_transient,(1,1,opt.max_distance_bin))
    #angular_transient = torch.nn.functional.conv1d(angular_transient, smooth, None, 1, 1)
    #angular_transeint = torch.reshape(angular_transient, opt.max_distance_bin)
    return angular_transient
Beispiel #25
0
def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, window_width,
                                lowpass_cutoff, lowpass_filter_width):
    # type: (float, float, int, float, float, int) -> Tuple[Tensor, Tensor]
    r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
    resampling as well as the indices in which they are valid. LinearResample (LR) means
    that the output signal is at linearly spaced intervals (i.e the output signal has a
    frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
    the signal.

    The reason why the same filter is not used for multiple convolutions is because the
    sinc function could sampled at different points in time. For example, suppose
    a signal is sampled at the timestamps (seconds)
    0         16        32
    and we want it to be sampled at the timestamps (seconds)
    0 5 10 15   20 25 30  35
    at the timestamp of 16, the delta timestamps are
    16 11 6 1   4  9  14  19
    at the timestamp of 32, the delta timestamps are
    32 27 22 17 12 8 2    3

    As we can see from deltas, the sinc function is sampled at different points of time
    assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....]
    for 16 vs [...., 2, 3, ....] for 32)

    Example, one case is when the ``orig_freq`` and ``new_freq`` are multiples of each other then
    there needs to be one filter.

    A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function
    has infinite support (non-zero for all values) so instead it is truncated and multiplied by
    a window function which gives it less-than-perfect rolloff [1].

    [1] Chapter 16: Windowed-Sinc Filters, https://www.dspguide.com/ch16/1.htm

    Args:
        orig_freq (float): The original frequency of the signal
        new_freq (float): The desired frequency
        output_samples_in_unit (int): The number of output samples in the smallest repeating unit:
            num_samp_out = new_freq / Gcd(orig_freq, new_freq)
        window_width (float): The width of the window which is nonzero
        lowpass_cutoff (float): The filter cutoff in Hz. The filter cutoff needs to be less
            than samp_rate_in_hz/2 and less than samp_rate_out_hz/2.
        lowpass_filter_width (int): Controls the sharpness of the filter, more == sharper but less
            efficient. We suggest around 4 to 10 for normal use

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: A tuple of ``min_input_index`` (which is the minimum indices
        where the window is valid, size (``output_samples_in_unit``)) and ``weights`` (which is the weights
        which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
    """
    assert lowpass_cutoff < min(orig_freq, new_freq) / 2
    output_t = torch.arange(0., output_samples_in_unit) / new_freq
    min_t = output_t - window_width
    max_t = output_t + window_width

    min_input_index = torch.ceil(min_t * orig_freq)  # size (output_samples_in_unit)
    max_input_index = torch.floor(max_t * orig_freq)  # size (output_samples_in_unit)
    num_indices = max_input_index - min_input_index + 1  # size (output_samples_in_unit)

    max_weight_width = num_indices.max()
    # create a group of weights of size (output_samples_in_unit, max_weight_width)
    j = torch.arange(max_weight_width).unsqueeze(0)
    input_index = min_input_index.unsqueeze(1) + j
    delta_t = (input_index / orig_freq) - output_t.unsqueeze(1)

    weights = torch.zeros_like(delta_t)
    inside_window_indices = delta_t.abs().lt(window_width)
    # raised-cosine (Hanning) window with width `window_width`
    weights[inside_window_indices] = 0.5 * (1 + torch.cos(2 * math.pi * lowpass_cutoff /
                                            lowpass_filter_width * delta_t[inside_window_indices]))

    t_eq_zero_indices = delta_t.eq(0.0)
    t_not_eq_zero_indices = ~t_eq_zero_indices
    # sinc filter function
    weights[t_not_eq_zero_indices] *= torch.sin(
        2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (math.pi * delta_t[t_not_eq_zero_indices])
    # limit of the function at t = 0
    weights[t_eq_zero_indices] *= 2 * lowpass_cutoff

    weights /= orig_freq  # size (output_samples_in_unit, max_weight_width)
    return min_input_index, weights
Beispiel #26
0
def render_intensity_differentiable(mesh, direction, triangleIndexMap,
                                    difference, lighting, sensor,
                                    lighting_normal, sensor_normal, opt,
                                    device):

    triangleIndexMap = torch.from_numpy(triangleIndexMap).long().to(device)
    direction = torch.from_numpy(direction).to(device)
    lighting = torch.from_numpy(lighting).to(device)
    sensor = torch.from_numpy(sensor).to(device)

    d1 = initialize_variable(opt.sample_num, -1, device)
    d2 = initialize_variable(opt.sample_num, np.nan, device)
    intensity = initialize_variable(opt.sample_num, np.nan, device)
    uMap = initialize_variable(opt.sample_num, np.nan, device)
    vMap = initialize_variable(opt.sample_num, np.nan, device)
    cos_theta2 = initialize_variable(opt.sample_num, 0, device)
    intersection_p = initialize_variable((opt.sample_num, 3), np.nan, device)
    normalMap = initialize_variable((opt.sample_num, 3), np.nan, device)
    v2 = initialize_variable((opt.sample_num, 3), np.nan, device)
    tmp_e1 = initialize_variable((opt.sample_num, 3), np.nan, device)
    tmp_e2 = initialize_variable((opt.sample_num, 3), np.nan, device)
    fn = initialize_variable((opt.sample_num, 3), np.nan, device)
    fn_len = initialize_variable(opt.sample_num, np.nan, device)
    w = initialize_variable(opt.sample_num, np.nan, device)

    distance_bin = torch.LongTensor(
        opt.sample_num).fill_(opt.max_distance_bin + 1).to(device)

    inds = torch.squeeze((triangleIndexMap > -1).data.nonzero())

    for i in inds:
        uMap[i], vMap[i], d1[i] = intersect_ray_mesh_one_direction(
            mesh, torch.squeeze(direction[i, :]),
            torch.squeeze(mesh.f[triangleIndexMap[i], :]), lighting)

    inds = torch.squeeze((d1 > 0).data.nonzero())

    triangleIndexMap_input = torch.index_select(triangleIndexMap, 0, inds)

    data = element_multiply2(
        1 - uMap[inds] - vMap[inds],
        mesh.v[torch.index_select(mesh.f[:, 2], 0, triangleIndexMap_input), :]
    ) + element_multiply2(
        uMap[inds],
        mesh.v[torch.index_select(mesh.f[:, 0], 0, triangleIndexMap_input), :]
    ) + element_multiply2(
        vMap[inds],
        mesh.v[torch.index_select(mesh.f[:, 1], 0, triangleIndexMap_input), :])
    intersection_p.index_copy_(0, inds, data)

    v2[inds, :] = sensor - intersection_p[inds, :]
    d2[inds] = torch.sqrt(torch.sum(v2[inds, :]**2, 1))
    v2[inds, :] = element_divide2(v2[inds, :], d2[inds])

    tmp_e1[inds, :] = mesh.v[torch.index_select(
        mesh.f[:, 1], 0, triangleIndexMap_input), :] - mesh.v[
            torch.index_select(mesh.f[:, 0], 0, triangleIndexMap_input), :]
    tmp_e2[inds, :] = mesh.v[torch.index_select(
        mesh.f[:, 2], 0, triangleIndexMap_input), :] - mesh.v[
            torch.index_select(mesh.f[:, 0], 0, triangleIndexMap_input), :]

    fn[inds, :] = torch.cross(tmp_e1[inds, :], tmp_e2[inds, :], 1)
    fn_len[inds] = torch.sqrt(torch.sum(fn[inds, :]**2, 1))
    normalMap[inds, :] = element_divide2(fn[inds, :], fn_len[inds])
    cos_theta2[inds] = torch.sum(torch.mul(normalMap[inds, :], v2[inds, :]), 1)
    index = torch.squeeze((cos_theta2 < 0).data.nonzero())
    if index.dim() != 0:
        if len(index) != 0:
            cos_theta2 = cos_theta2.index_fill(0, index, 0)

    distance_bin[inds] = torch.ceil(
        (d1[inds] + d2[inds]) / opt.distance_resolution).long() - 1
    inds = torch.squeeze((distance_bin < opt.max_distance_bin).data.nonzero())
    if inds.dim() != 0:
        if len(inds) != 0:
            w[inds] = difference.index_select(0, distance_bin[inds])
            val1 = torch.div(cos_theta2.index_select(0, inds),
                             d2.index_select(0, inds)**2)
            val2 = torch.mul(w.index_select(0, inds), val1)
            #val2 = torch.mul(difference.index_select(0,inds), val1)
            intensity.index_copy_(0, inds, val2)
    return torch.sum(intensity) * 2 * math.pi / opt.sample_num
Beispiel #27
0
def resize_tensor(tensor, new_size, do_ceil=False):  # for 3d tensor of shape (C, H, W)
    image_array = np.transpose(rescale_image(tensor.cpu().numpy()), axes=(1, 2, 0))
    image = Image.fromarray(image_array).resize(new_size)  # new_size of form (W, H)
    if do_ceil:
        return torch.ceil(transforms.ToTensor()(image)).to(device)
    return transforms.ToTensor()(image).to(device)  # float values in shape (C, H, W)
Beispiel #28
0
 def get_seq_len(self, seq_len):
     return torch.ceil(seq_len.to(dtype=torch.float) /
                       self.hop_length).to(dtype=torch.int)
Beispiel #29
0
    def forward(self, x, img_ids, mask):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        #Algo
        """
        1. Find patch index in x and y of the key we want for each query using normal dis

        """
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        att = []

        # Load on GPU
        self.cuda_avgs = Parameter(self.avgs[img_ids].cuda(),
                                   requires_grad=True)
        self.cuda_std_devs = Parameter(self.std_devs[img_ids].cuda(),
                                       requires_grad=True)

        for j, img_id in enumerate(img_ids):
            norm_x = torch.normal(mean=torch.zeros(1,
                                                   self.no_of_patches,
                                                   requires_grad=True),
                                  std=torch.ones(1,
                                                 self.no_of_patches,
                                                 requires_grad=True)).cuda()
            norm_y = torch.normal(mean=torch.zeros(1,
                                                   self.no_of_patches,
                                                   requires_grad=True),
                                  std=torch.ones(1,
                                                 self.no_of_patches,
                                                 requires_grad=True)).cuda()
            key_x = (norm_x - self.cuda_avgs[j][0]) / self.cuda_std_devs[j][0]
            key_y = (norm_y - self.cuda_avgs[j][1]) / self.cuda_std_devs[j][1]

            key_x_1 = torch.ceil(key_x)
            key_x_2 = torch.floor(key_x)
            key_y_1 = torch.ceil(key_y)
            key_y_2 = torch.floor(key_y)

            key_index = [0, 0, 0, 0]
            key_index[0] = to_indices(self.grid_dim * key_y_1 + key_x_1)
            key_index[1] = to_indices(self.grid_dim * key_y_1 + key_x_2)
            key_index[2] = to_indices(self.grid_dim * key_y_2 + key_x_1)
            key_index[3] = to_indices(self.grid_dim * key_y_2 + key_x_2)

            # TODO Refactor this : compute once bilinear for both values and keys, muultiply key and val once
            # SAMPLED KEY = E{ (1 - abs(Pn_x - Sample_x)) * (1 - abs(Pn_y - Sample_y)) * Kn
            sample = (key_x, key_y)
            sampled_key = (bilinear((key_x_1 , key_y_1), sample) * k[j][key_index[0]].transpose(dim0=1, dim1=2) + \
                           bilinear((key_x_2 , key_y_1), sample) * k[j][key_index[1]].transpose(dim0=1, dim1=2) + \
                           bilinear((key_x_1 , key_y_2), sample) * k[j][key_index[2]].transpose(dim0=1, dim1=2) + \
                           bilinear((key_x_2 , key_y_2), sample) * k[j][key_index[3]].transpose(dim0=1, dim1=2)).transpose(dim0=1, dim1=2)

            sampled_value = (bilinear((key_x_1 , key_y_1), sample) * v[j][key_index[0]].transpose(dim0=1, dim1=2) + \
                             bilinear((key_x_2 , key_y_1), sample) * v[j][key_index[1]].transpose(dim0=1, dim1=2) + \
                             bilinear((key_x_1 , key_y_2), sample) * v[j][key_index[2]].transpose(dim0=1, dim1=2) + \
                             bilinear((key_x_2 , key_y_2), sample) * v[j][key_index[3]].transpose(dim0=1, dim1=2)).transpose(dim0=1, dim1=2)

            # Lets add ones vector for class embedding
            _, _, k_dim = sampled_key.shape
            class_emb = to_device(torch.ones(1, 1, k_dim),
                                  get_default_device())
            sampled_key = torch.cat((class_emb, sampled_key), dim=1)
            sampled_value = torch.cat((class_emb, sampled_value), dim=1)

            at_sc = torch.matmul(sampled_key.transpose(dim0=0, dim1=1),
                                 q[j].unsqueeze(dim=2))
            full_att = F.softmax(at_sc, dim=1).transpose(
                dim0=0, dim1=1) * sampled_value.squeeze(dim=0)

            att.append(torch.sum(full_att, dim=0))

        return torch.stack(att)
Beispiel #30
0
    def forward(
        self,
        src_tokens,
        src_lengths,
        context_out,
        cls_input: Optional[torch.Tensor] = None,
        return_all_hiddens: bool = False,
    ):
        x = src_tokens.unsqueeze(1)
        for i, conv in enumerate(self.convolutions):
            if conv.kernel_size[0] % 2 == 1:
                # padding is implicit in the conv
                x = conv(x)
            else:
                padding_l = (conv.kernel_size[0] - 1) // 2
                padding_r = conv.kernel_size[0] // 2
                x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
                x = conv(x)
            x = self.bn[i](self.activation_fn(x))
            src_lengths = torch.ceil(src_lengths.float() / 2).long()
            x = F.dropout(x, p=max(self.dropout, .1), training=self.training)

        if hasattr(self, 'attn_2d'):
            residual = x
            x, _ = self.attn_2d[0](query=x, key=x, value=x)
            x = x + residual
            residual = x
            x, _ = self.attn_2d[1](query=x, key=x, value=x)
            x = x + residual

        # B x Cout x T x F -> T x B x C
        bsz, out_channels, time, feats = x.size()
        x = x.transpose(1,
                        2).contiguous().view(bsz, time,
                                             -1).contiguous().transpose(0, 1)
        x = self.activation_fn(self.fc3(x))

        x = x + self.embed_positions(x.transpose(0, 1), src_lengths).transpose(
            0, 1)
        if self.layernorm_embedding is not None:
            x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        encoder_padding_mask = self.create_mask(src_lengths)

        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.layers:
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = torch.empty(1).uniform_()
            if not self.training or (dropout_probability >
                                     self.encoder_layerdrop):
                x = layer(
                    x,
                    encoder_padding_mask,
                    context=context_out['context_out'],
                    context_padding_mask=context_out['context_padding_mask'])
                if return_all_hiddens:
                    assert encoder_states is not None
                    encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)
            if return_all_hiddens:
                encoder_states[-1] = x

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=None,
            encoder_states=encoder_states,  # List[T x B x C]
        )
Beispiel #31
0
    def inference(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: Optional[torch.Tensor] = None,
        feats_lengths: Optional[torch.Tensor] = None,
        sids: Optional[torch.Tensor] = None,
        spembs: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
        dur: Optional[torch.Tensor] = None,
        noise_scale: float = 0.667,
        noise_scale_dur: float = 0.8,
        alpha: float = 1.0,
        max_len: Optional[int] = None,
        use_teacher_forcing: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Run inference.

        Args:
            text (Tensor): Input text index tensor (B, T_text,).
            text_lengths (Tensor): Text length tensor (B,).
            feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
            feats_lengths (Tensor): Feature length tensor (B,).
            sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
            spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
            lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
            dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
                skip the prediction of durations (i.e., teacher forcing).
            noise_scale (float): Noise scale parameter for flow.
            noise_scale_dur (float): Noise scale parameter for duration predictor.
            alpha (float): Alpha parameter to control the speed of generated speech.
            max_len (Optional[int]): Maximum length of acoustic feature sequence.
            use_teacher_forcing (bool): Whether to use teacher forcing.

        Returns:
            Tensor: Generated waveform tensor (B, T_wav).
            Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
            Tensor: Duration tensor (B, T_text).

        """
        # encoder
        x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
        g = None
        if self.spks is not None:
            # (B, global_channels, 1)
            g = self.global_emb(sids.view(-1)).unsqueeze(-1)
        if self.spk_embed_dim is not None:
            # (B, global_channels, 1)
            g_ = self.spemb_proj(F.normalize(
                spembs.unsqueeze(0))).unsqueeze(-1)
            if g is None:
                g = g_
            else:
                g = g + g_
        if self.langs is not None:
            # (B, global_channels, 1)
            g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
            if g is None:
                g = g_
            else:
                g = g + g_

        if use_teacher_forcing:
            # forward posterior encoder
            z, m_q, logs_q, y_mask = self.posterior_encoder(feats,
                                                            feats_lengths,
                                                            g=g)

            # forward flow
            z_p = self.flow(z, y_mask, g=g)  # (B, H, T_feats)

            # monotonic alignment search
            s_p_sq_r = torch.exp(-2 * logs_p)  # (B, H, T_text)
            # (B, 1, T_text)
            neg_x_ent_1 = torch.sum(
                -0.5 * math.log(2 * math.pi) - logs_p,
                [1],
                keepdim=True,
            )
            # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
            neg_x_ent_2 = torch.matmul(
                -0.5 * (z_p**2).transpose(1, 2),
                s_p_sq_r,
            )
            # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
            neg_x_ent_3 = torch.matmul(
                z_p.transpose(1, 2),
                (m_p * s_p_sq_r),
            )
            # (B, 1, T_text)
            neg_x_ent_4 = torch.sum(
                -0.5 * (m_p**2) * s_p_sq_r,
                [1],
                keepdim=True,
            )
            # (B, T_feats, T_text)
            neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
            # (B, 1, T_feats, T_text)
            attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(
                y_mask, -1)
            # monotonic attention weight: (B, 1, T_feats, T_text)
            attn = self.maximum_path(
                neg_x_ent,
                attn_mask.squeeze(1),
            ).unsqueeze(1)
            dur = attn.sum(2)  # (B, 1, T_text)

            # forward decoder with random segments
            wav = self.decoder(z * y_mask, g=g)
        else:
            # duration
            if dur is None:
                logw = self.duration_predictor(
                    x,
                    x_mask,
                    g=g,
                    inverse=True,
                    noise_scale=noise_scale_dur,
                )
                w = torch.exp(logw) * x_mask * alpha
                dur = torch.ceil(w)
            y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
            y_mask = make_non_pad_mask(y_lengths).unsqueeze(1).to(text.device)
            attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(
                y_mask, -1)
            attn = self._generate_path(dur, attn_mask)

            # expand the length to match with the feature sequence
            # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
            m_p = torch.matmul(
                attn.squeeze(1),
                m_p.transpose(1, 2),
            ).transpose(1, 2)
            # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
            logs_p = torch.matmul(
                attn.squeeze(1),
                logs_p.transpose(1, 2),
            ).transpose(1, 2)

            # decoder
            z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
            z = self.flow(z_p, y_mask, g=g, inverse=True)
            wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)

        return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
Beispiel #32
0
    def train(self):

        if self.T - self.target_sync_T > self.args.target:
            self.sync_target_network()
            self.target_sync_T = self.T

        info = {}

        for _ in range(self.args.iters):
            self.dqn.eval()

            batch, indices, is_weights = self.replay.Sample_N(self.args.batch_size, self.args.n_step, self.args.gamma)
            columns = list(zip(*batch))

            states = Variable(torch.from_numpy(np.array(columns[0])).float().transpose_(1, 3))
            actions = Variable(torch.LongTensor(columns[1]))
            terminal_states = Variable(torch.FloatTensor(columns[5]))
            rewards = Variable(torch.FloatTensor(columns[2]))
            # Have to clip rewards for DQN
            rewards = torch.clamp(rewards, -1, 1)
            steps = Variable(torch.FloatTensor(columns[4]))
            new_states = Variable(torch.from_numpy(np.array(columns[3])).float().transpose_(1, 3))

            target_dqn_qvals = self.target_dqn(new_states).cpu()
            # Make a new variable with those values so that these are treated as constants
            target_dqn_qvals_data = Variable(target_dqn_qvals.data)

            q_value_gammas = (Variable(torch.ones(terminal_states.size()[0])) - terminal_states)
            inter = Variable(torch.ones(terminal_states.size()[0]) * self.args.gamma)
            # print(steps)
            q_value_gammas = q_value_gammas * torch.pow(inter, steps)

            values = torch.linspace(self.args.v_min, self.args.v_max, steps=self.args.atoms)
            values = Variable(values)
            values = values.view(1, 1, self.args.atoms)
            values = values.expand(self.args.batch_size, self.args.actions, self.args.atoms)
            # print(values)

            q_value_gammas = q_value_gammas.view(self.args.batch_size, 1, 1)
            q_value_gammas = q_value_gammas.expand(self.args.batch_size, self.args.actions, self.args.atoms)
            # print(q_value_gammas)
            gamma_values = q_value_gammas * values
            # print(gamma_values)
            rewards = rewards.view(self.args.batch_size, 1, 1)
            rewards = rewards.expand(self.args.batch_size, self.args.actions, self.args.atoms)
            # print(rewards)
            operator_q_values = rewards + gamma_values
            # print(operator_q_values)

            clipped_operator_q_values = torch.clamp(operator_q_values, self.args.v_min, self.args.v_max)

            delta_z = (self.args.v_max - self.args.v_min) / (self.args.atoms - 1)
            # Using the notation from the categorical paper
            b_j = (clipped_operator_q_values - self.args.v_min) / delta_z
            # print(b_j)
            lower_bounds = torch.floor(b_j)
            upper_bounds = torch.ceil(b_j)

            # Work out the max action
            atom_values = Variable(torch.linspace(self.args.v_min, self.args.v_max, steps=self.args.atoms))
            atom_values = atom_values.view(1, 1, self.args.atoms)
            atom_values = atom_values.expand(self.args.batch_size, self.args.actions, self.args.atoms)

            # Sum over the atoms dimension
            target_expected_qvalues = torch.sum(target_dqn_qvals_data * atom_values, dim=2)
            # Get the maximum actions index across the batch size
            max_actions = target_expected_qvalues.max(dim=1)[1].view(-1)

            # Project back onto the original support for the max actions
            q_value_distribution_targets = torch.zeros(self.args.batch_size, self.args.atoms)

            # Distributions for the max actions
            # print(target_dqn_qvals_data, max_actions)
            q_value_max_actions_distribs = target_dqn_qvals_data.index_select(dim=1, index=max_actions)[:,0,:]
            # print(q_value_max_actions_distribs)

            # Lower_bounds_actions
            lower_bounds_actions = lower_bounds.index_select(dim=1, index=max_actions)[:,0,:]
            upper_bounds_actions = upper_bounds.index_select(dim=1, index=max_actions)[:,0,:]
            b_j_actions = b_j.index_select(dim=1, index=max_actions)[:,0,:]

            lower_bound_values_to_add = q_value_max_actions_distribs * (upper_bounds_actions - b_j_actions)
            upper_bound_values_to_add = q_value_max_actions_distribs * (b_j_actions - lower_bounds_actions)
            # print(lower_bounds_actions)
            # print(lower_bound_values_to_add)
            # Naive looping
            for b in range(self.args.batch_size):
                for l, pj in zip(lower_bounds_actions.data.type(torch.LongTensor)[b], lower_bound_values_to_add[b].data):
                    q_value_distribution_targets[b][l] += pj
                for u, pj in zip(upper_bounds_actions.data.type(torch.LongTensor)[b], upper_bound_values_to_add[b].data):
                    q_value_distribution_targets[b][u] += pj

            self.dqn.train()
            if self.args.gpu:
                actions = actions.cuda()
                # q_value_targets = q_value_targets.cuda()
                q_value_distribution_targets = q_value_distribution_targets.cuda()
            model_predictions = self.dqn(states).index_select(1, actions.view(-1))[:,0,:]
            q_value_distribution_targets = Variable(q_value_distribution_targets)
            # print(q_value_distribution_targets)
            # print(model_predictions) 

            # Cross entropy loss
            ce_loss = -torch.sum(q_value_distribution_targets * torch.log(model_predictions), dim=1)
            ce_batch_loss = ce_loss.mean()

            info = {}

            self.log("DQN/X_Entropy_Loss", ce_batch_loss.data[0], step=self.T)

            # Update
            self.optimizer.zero_grad()
            ce_batch_loss.backward()

            # Taken from pytorch clip_grad_norm
            # Remove once the pip version it up to date with source
            gradient_norm = clip_grad_norm(self.dqn.parameters(), self.args.clip_value)
            if gradient_norm is not None:
                info["Norm"] = gradient_norm

            self.optimizer.step()

            if "States" in info:
                states_trained = info["States"]
                info["States"] = states_trained + columns[0]
            else:
                info["States"] = columns[0]

        # Pad out the states to be of size batch_size
        if len(info["States"]) < self.args.batch_size:
            old_states = info["States"]
            new_states = old_states[0] * (self.args.batch_size - len(old_states))
            info["States"] = new_states

        return info