Example #1
0
    def test_forward(self):
        batch = 16
        len1, len2 = 21, 24
        seq_len1 = torch.randint(low=len1 - 10, high=len1 + 1, size=(batch,)).long()
        seq_len2 = torch.randint(low=len2 - 10, high=len2 + 1, size=(batch,)).long()

        mask1 = []
        for w in seq_len1:
            mask1.append([1] * w.item() + [0] * (len1 - w.item()))
        mask1 = torch.FloatTensor(mask1)
        mask2 = []
        for w in seq_len2:
            mask2.append([1] * w.item() + [0] * (len2 - w.item()))
        mask2 = torch.FloatTensor(mask2)

        d = 200  # hidden dimension
        l = 20  # number of perspective
        test1 = torch.randn(batch, len1, d)
        test2 = torch.randn(batch, len2, d)
        test1 = test1 * mask1.view(-1, len1, 1).expand(-1, len1, d)
        test2 = test2 * mask2.view(-1, len2, 1).expand(-1, len2, d)

        test1_fw, test1_bw = torch.split(test1, d // 2, dim=-1)
        test2_fw, test2_bw = torch.split(test2, d // 2, dim=-1)

        ml_fw = BiMpmMatching.from_params(Params({"is_forward": True, "num_perspectives": l}))
        ml_bw = BiMpmMatching.from_params(Params({"is_forward": False, "num_perspectives": l}))

        vecs_p_fw, vecs_h_fw = ml_fw(test1_fw, mask1, test2_fw, mask2)
        vecs_p_bw, vecs_h_bw = ml_bw(test1_bw, mask1, test2_bw, mask2)
        vecs_p, vecs_h = torch.cat(vecs_p_fw + vecs_p_bw, dim=2), torch.cat(vecs_h_fw + vecs_h_bw, dim=2)

        assert vecs_p.size() == torch.Size([batch, len1, 10 + 10 * l])
        assert vecs_h.size() == torch.Size([batch, len2, 10 + 10 * l])
        assert ml_fw.get_output_dim() == ml_bw.get_output_dim() == vecs_p.size(2) // 2 == vecs_h.size(2) // 2
Example #2
0
    def __init__(self, X_in, y_in):
        # Check if we have Torch.LongTensor inputs (assume Numpy array otherwise)
        if not isinstance(X_in, torch.LongTensor):
            X_in = torch.from_numpy(X_in.astype('int64')).long()
        if not isinstance(y_in, torch.LongTensor):
            y_in = torch.from_numpy(y_in.astype('int64')).long()

        self.X_in = torch.split(X_in, 1, dim=0)
        self.y_in = torch.split(y_in, 1, dim=0)
Example #3
0
 def forward(self, input):
     projected = self.layer(input)
     non_lin, gate = torch.split(projected, self.input_dim, -1)
     non_lin = self.activation(non_lin)
     gate = self.gate(gate)
     combined = gate * input + (1 - gate) * non_lin
     return combined
Example #4
0
    def forward(self, featuremap, boxes, box_ind):
        """
        RoIAlign based on crop_and_resize.
        See more details on https://github.com/ppwwyyxx/tensorpack/blob/6d5ba6a970710eaaa14b89d24aace179eb8ee1af/examples/FasterRCNN/model.py#L301
        :param featuremap: NxCxHxW
        :param boxes: Mx4 float box with (x1, y1, x2, y2) **without normalization**
        :param box_ind: M
        :return: MxCxoHxoW
        """
        x1, y1, x2, y2 = torch.split(boxes, 1, dim=1)
        image_height, image_width = featuremap.size()[2:4]

        if self.transform_fpcoor:
            spacing_w = (x2 - x1) / float(self.crop_width)
            spacing_h = (y2 - y1) / float(self.crop_height)

            nx0 = (x1 + spacing_w / 2 - 0.5) / float(image_width - 1)
            ny0 = (y1 + spacing_h / 2 - 0.5) / float(image_height - 1)
            nw = spacing_w * float(self.crop_width - 1) / float(image_width - 1)
            nh = spacing_h * float(self.crop_height - 1) / float(image_height - 1)

            boxes = torch.cat((ny0, nx0, ny0 + nh, nx0 + nw), 1)
        else:
            x1 = x1 / float(image_width - 1)
            x2 = x2 / float(image_width - 1)
            y1 = y1 / float(image_height - 1)
            y2 = y2 / float(image_height - 1)
            boxes = torch.cat((y1, x1, y2, x2), 1)

        boxes = boxes.detach().contiguous()
        box_ind = box_ind.detach()
        return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind)
Example #5
0
 def calc_pdparam(self, x, evaluate=True):
     '''
     Calculate pdparams for multi-action by chunking the network logits output
     '''
     x = torch.cat(torch.split(x, self.state_dims, dim=1)).unsqueeze_(dim=1)
     pdparam = SARSA.calc_pdparam(self, x, evaluate=evaluate)
     return pdparam
Example #6
0
 def calc_pdparam(self, x, evaluate=True):
     '''
     Calculate pdparams for multi-action by chunking the network logits output
     '''
     pdparam = super(MultitaskDQN, self).calc_pdparam(x, evaluate=evaluate)
     pdparam = torch.cat(torch.split(pdparam, self.action_dims, dim=1))
     return pdparam
Example #7
0
    def forward(self, q, k, v, attn_mask=None):

        d_k, d_v = self.d_k, self.d_v
        n_head = self.n_head

        residual = q
        #print('q,k,v:',q.size(),k.size(),v.size())
        mb_size, len_q, q_hidden_size = q.size()
        mb_size, len_k, k_hidden_size = k.size()
        mb_size, len_v, v_hidden_size = v.size()

        # treat as a (n_head) size batch
        q_s = q.repeat(n_head, 1, 1).view(n_head, -1, q_hidden_size) # n_head x (mb_size*len_q) x d_model
        k_s = k.repeat(n_head, 1, 1).view(n_head, -1, k_hidden_size) # n_head x (mb_size*len_k) x d_model
        v_s = v.repeat(n_head, 1, 1).view(n_head, -1, v_hidden_size) # n_head x (mb_size*len_v) x d_model
        #print('q_s,k_s,v_s:',q_s.size(),k_s.size(),v_s.size())
        #print('w_qs',self.w_qs.size())
        # treat the result as a (n_head * mb_size) size batch
        q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k)   # (n_head*mb_size) x len_q x d_k
        k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k)   # (n_head*mb_size) x len_k x d_k
        v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v)   # (n_head*mb_size) x len_v x d_v

        # perform attention, result size = (n_head * mb_size) x len_q x d_v
        #print('attn_mask:',attn_mask.size())
        #print(attn_mask)
        outputs, attns = self.attention.forward(q_s, k_s, v_s, attn_mask=attn_mask.repeat(n_head,1,1))

        # back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v)
        outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1) 

        # project back to residual size
        outputs = self.proj.forward(outputs)
        outputs = self.dropout(outputs)

        return self.layer_norm(outputs + residual), attns
    def forward(self, input_, hx):
        """
        Args:
            input_: A (batch, input_size) tensor containing input
                features.
            hx: A tuple (h_0, c_0), which contains the initial hidden
                and cell state, where the size of both states is
                (batch, hidden_size).
            time: The current timestep value, which is used to
                get appropriate running statistics.

        Returns:
            h_1, c_1: Tensors containing the next hidden and cell state.
        """

        h_0, c_0 = hx
        batch_size = h_0.size(0)
        bias_batch = (self.bias.unsqueeze(0)
                      .expand(batch_size, *self.bias.size()))
        wh = torch.mm(h_0, self.weight_hh)
        wh = torch.mm(h_0, self.weight_hh)
        wi = torch.mm(input_, self.weight_ih)
        bn_wh = self.bn_hh(wh)
        bn_wi = self.bn_ih(wi)
        f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch,
                                 split_size=self.hidden_size, dim=1)
        c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
        h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1))
        return h_1, c_1
Example #9
0
def occlusion_sensitivity(
    model, images, ids, mean=None, patch=35, stride=1, n_batches=128
):
    """
    "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
    https://arxiv.org/pdf/1610.02391.pdf
    Look at Figure A5 on page 17
    
    Originally proposed in:
    "Visualizing and Understanding Convolutional Networks"
    https://arxiv.org/abs/1311.2901
    """

    torch.set_grad_enabled(False)
    model.eval()
    mean = mean if mean else 0
    patch_H, patch_W = patch if isinstance(patch, Sequence) else (patch, patch)
    pad_H, pad_W = patch_H // 2, patch_W // 2

    # Padded image
    images = F.pad(images, (pad_W, pad_W, pad_H, pad_H), value=mean)
    B, _, H, W = images.shape
    new_H = (H - patch_H) // stride + 1
    new_W = (W - patch_W) // stride + 1

    # Prepare sampling grids
    anchors = []
    grid_h = 0
    while grid_h <= H - patch_H:
        grid_w = 0
        while grid_w <= W - patch_W:
            grid_w += stride
            anchors.append((grid_h, grid_w))
        grid_h += stride

    # Baseline score without occlusion
    baseline = model(images).detach().gather(1, ids)

    # Compute per-pixel logits
    scoremaps = []
    for i in tqdm(range(0, len(anchors), n_batches), leave=False):
        batch_images = []
        batch_ids = []
        for grid_h, grid_w in anchors[i : i + n_batches]:
            images_ = images.clone()
            images_[..., grid_h : grid_h + patch_H, grid_w : grid_w + patch_W] = mean
            batch_images.append(images_)
            batch_ids.append(ids)
        batch_images = torch.cat(batch_images, dim=0)
        batch_ids = torch.cat(batch_ids, dim=0)
        scores = model(batch_images).detach().gather(1, batch_ids)
        scoremaps += list(torch.split(scores, B))

    diffmaps = torch.cat(scoremaps, dim=1) - baseline
    diffmaps = diffmaps.view(B, new_H, new_W)

    return diffmaps
    def forward(self, tensors):
        # tensors must all be the same shape, let's say (batch_size, timesteps, dim)
        assert self.num_tensors == len(tensors)

        normed_weights = torch.nn.functional.softmax(torch.cat([p for p in self.scalar_parameters]), dim=0)
        normed_weights = torch.split(normed_weights, split_size_or_sections=1)
        pieces = []
        for weight, tensor in zip(normed_weights, tensors):
            pieces.append(weight * tensor)
        return self.gamma * sum(pieces)
Example #11
0
def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
    """
    Calculates the intersection area of two lists of bounding boxes.
    :author 申瑞珉 (Ruimin Shen)
    :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
    :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
    :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
    :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
    :return: The matrix (size [N1, N2]) of the intersection area.
    """
    ymin1, xmin1 = torch.split(yx_min1, 1, -1)
    ymax1, xmax1 = torch.split(yx_max1, 1, -1)
    ymin2, xmin2 = torch.split(yx_min2, 1, -1)
    ymax2, xmax2 = torch.split(yx_max2, 1, -1)
    max_ymin = torch.max(ymin1.repeat(1, ymin2.size(0)), torch.transpose(ymin2, 0, 1).repeat(ymin1.size(0), 1)) # PyTorch's bug
    min_ymax = torch.min(ymax1.repeat(1, ymax2.size(0)), torch.transpose(ymax2, 0, 1).repeat(ymax1.size(0), 1)) # PyTorch's bug
    height = torch.clamp(min_ymax - max_ymin, min=0)
    max_xmin = torch.max(xmin1.repeat(1, xmin2.size(0)), torch.transpose(xmin2, 0, 1).repeat(xmin1.size(0), 1)) # PyTorch's bug
    min_xmax = torch.min(xmax1.repeat(1, xmax2.size(0)), torch.transpose(xmax2, 0, 1).repeat(xmax1.size(0), 1)) # PyTorch's bug
    width = torch.clamp(min_xmax - max_xmin, min=0)
    return height * width
Example #12
0
def filter_shard_state(state, shard_size=None):
    for k, v in state.items():
        if shard_size is None:
            yield k, v

        if v is not None:
            v_split = []
            if isinstance(v, torch.Tensor):
                for v_chunk in torch.split(v, shard_size):
                    v_chunk = v_chunk.data.clone()
                    v_chunk.requires_grad = v.requires_grad
                    v_split.append(v_chunk)
            yield k, (v, v_split)
Example #13
0
 def dis_update(self, images_a, images_b, hyperparameters):
   self.dis.zero_grad()
   x_aa, x_ba, x_ab, x_bb, shared = self.gen(images_a, images_b)
   data_a = torch.cat((images_a, x_ba), 0)
   data_b = torch.cat((images_b, x_ab), 0)
   res_a, res_b = self.dis(data_a,data_b)
   # res_true_a, res_true_b = self.dis(images_a,images_b)
   # res_fake_a, res_fake_b = self.dis(x_ba, x_ab)
   for it, (this_a, this_b) in enumerate(itertools.izip(res_a, res_b)):
     out_a = nn.functional.sigmoid(this_a)
     out_b = nn.functional.sigmoid(this_b)
     out_true_a, out_fake_a = torch.split(out_a, out_a.size(0) // 2, 0)
     out_true_b, out_fake_b = torch.split(out_b, out_b.size(0) // 2, 0)
     out_true_n = out_true_a.size(0)
     out_fake_n = out_fake_a.size(0)
     all1 = Variable(torch.ones((out_true_n)).cuda(self.gpu))
     all0 = Variable(torch.zeros((out_fake_n)).cuda(self.gpu))
     ad_true_loss_a = nn.functional.binary_cross_entropy(out_true_a, all1)
     ad_true_loss_b = nn.functional.binary_cross_entropy(out_true_b, all1)
     ad_fake_loss_a = nn.functional.binary_cross_entropy(out_fake_a, all0)
     ad_fake_loss_b = nn.functional.binary_cross_entropy(out_fake_b, all0)
     if it==0:
       ad_loss_a = ad_true_loss_a + ad_fake_loss_a
       ad_loss_b = ad_true_loss_b + ad_fake_loss_b
     else:
       ad_loss_a += ad_true_loss_a + ad_fake_loss_a
       ad_loss_b += ad_true_loss_b + ad_fake_loss_b
     true_a_acc = _compute_true_acc(out_true_a)
     true_b_acc = _compute_true_acc(out_true_b)
     fake_a_acc = _compute_fake_acc(out_fake_a)
     fake_b_acc = _compute_fake_acc(out_fake_b)
     exec( 'self.dis_true_acc_%d = 0.5 * (true_a_acc + true_b_acc)' %it)
     exec( 'self.dis_fake_acc_%d = 0.5 * (fake_a_acc + fake_b_acc)' %it)
   loss = hyperparameters['gan_w'] * ( ad_loss_a + ad_loss_b )
   loss.backward()
   self.dis_opt.step()
   self.dis_loss = loss.data.cpu().numpy()[0]
   return
Example #14
0
def shards(state, shard_size, eval_only=False):
    """
    Args:
        state: A dictionary which corresponds to the output of
               *LossCompute._make_shard_state(). The values for
               those keys are Tensor-like or None.
        shard_size: The maximum size of the shards yielded by the model.
        eval_only: If True, only yield the state, nothing else.
              Otherwise, yield shards.

    Yields:
        Each yielded shard is a dict.

    Side effect:
        After the last shard, this function does back-propagation.
    """
    if eval_only:
        yield filter_shard_state(state)
    else:
        # non_none: the subdict of the state dictionary where the values
        # are not None.
        non_none = dict(filter_shard_state(state, shard_size))

        # Now, the iteration:
        # state is a dictionary of sequences of tensor-like but we
        # want a sequence of dictionaries of tensors.
        # First, unzip the dictionary into a sequence of keys and a
        # sequence of tensor-like sequences.
        keys, values = zip(*((k, [v_chunk for v_chunk in v_split])
                             for k, (_, v_split) in non_none.items()))

        # Now, yield a dictionary for each shard. The keys are always
        # the same. values is a sequence of length #keys where each
        # element is a sequence of length #shards. We want to iterate
        # over the shards, not over the keys: therefore, the values need
        # to be re-zipped by shard and then each shard can be paired
        # with the keys.
        for shard_tensors in zip(*values):
            yield dict(zip(keys, shard_tensors))

        # Assumed backprop'd
        variables = []
        for k, (v, v_split) in non_none.items():
            if isinstance(v, torch.Tensor) and state[k].requires_grad:
                variables.extend(zip(torch.split(state[k], shard_size),
                                     [v_chunk.grad for v_chunk in v_split]))
        inputs, grads = zip(*variables)
        torch.autograd.backward(inputs, grads)
 def forward(self, input, hidden_state):
     hidden,c=hidden_state#hidden and c are images with several channels
     #print 'hidden ',hidden.size()
     #print 'input ',input.size()
     combined = torch.cat((input, hidden), 1)#oncatenate in the channels
     #print 'combined',combined.size()
     A=self.conv(combined)
     (ai,af,ao,ag)=torch.split(A,self.num_features,dim=1)#it should return 4 tensors
     i=torch.sigmoid(ai)
     f=torch.sigmoid(af)
     o=torch.sigmoid(ao)
     g=torch.tanh(ag)
     
     next_c=f*c+i*g
     next_h=o*torch.tanh(next_c)
     return next_h, next_c
    def forward(self, input_, c_input, hx):
        """
        Args:
            batch = 1
            input_: A (batch, input_size) tensor containing input
                features.
            c_input: A  list with size c_num,each element is the input ct from skip word (batch, hidden_size).
            hx: A tuple (h_0, c_0), which contains the initial hidden
                and cell state, where the size of both states is
                (batch, hidden_size).
        Returns:
            h_1, c_1: Tensors containing the next hidden and cell state.
        """

        h_0, c_0 = hx
        batch_size = h_0.size(0)
        #assert(batch_size == 1)
        bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
        wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
        wi = torch.mm(input_, self.weight_ih)
        i, o, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1)
        i = torch.sigmoid(i)
        g = torch.tanh(g)
        o = torch.sigmoid(o)
        c_num = len(c_input)
        if c_num == 0:
            f = 1 - i
            c_1 = f*c_0 + i*g
            h_1 = o * torch.tanh(c_1)
        else:
            c_input_var = torch.cat(c_input, 0)
            alpha_bias_batch = (self.alpha_bias.unsqueeze(0).expand(batch_size, *self.alpha_bias.size()))
            c_input_var = c_input_var.squeeze(1) ## (c_num, hidden_dim)
            alpha_wi = torch.addmm(self.alpha_bias, input_, self.alpha_weight_ih).expand(c_num, self.hidden_size)
            alpha_wh = torch.mm(c_input_var, self.alpha_weight_hh)
            alpha = torch.sigmoid(alpha_wi + alpha_wh)
            ## alpha  = i concat alpha
            alpha = torch.exp(torch.cat([i, alpha],0))
            alpha_sum = alpha.sum(0)
            ## alpha = softmax for each hidden element
            alpha = torch.div(alpha, alpha_sum)
            merge_i_c = torch.cat([g, c_input_var],0)
            c_1 = merge_i_c * alpha
            c_1 = c_1.sum(0).unsqueeze(0)
            h_1 = o * torch.tanh(c_1)
        return h_1, c_1
Example #17
0
    def forward(self, input_tensor, cur_state):

        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next
Example #18
0
    def node_forward(self, inputs, child_c, child_h):
      
        child_h_sum = torch.mean(child_h, dim=0, keepdim=True)
      
        
        iou = self.ioux(inputs) + self.iouh(child_h_sum)
        i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
        i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u)

        f = F.sigmoid(
                self.fh(child_h) +
                self.fx(inputs).repeat(len(child_h), 1)
            )
        fc = torch.mul(f, child_c)

        c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
        h = torch.mul(o, F.tanh(c))
        return c, h
Example #19
0
    def forward(self, tensors: List[torch.Tensor],  # pylint: disable=arguments-differ
                mask: torch.Tensor = None) -> torch.Tensor:
        """
        Compute a weighted average of the ``tensors``.  The input tensors an be any shape
        with at least two dimensions, but must all be the same shape.

        When ``do_layer_norm=True``, the ``mask`` is required input.  If the ``tensors`` are
        dimensioned  ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned
        ``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape
        ``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``.

        When ``do_layer_norm=False`` the ``mask`` is ignored.
        """
        if len(tensors) != self.mixture_size:
            raise ConfigurationError("{} tensors were passed, but the module was initialized to "
                                     "mix {} tensors.".format(len(tensors), self.mixture_size))

        def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked):
            tensor_masked = tensor * broadcast_mask
            mean = torch.sum(tensor_masked) / num_elements_not_masked
            variance = torch.sum(((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked
            return (tensor - mean) / torch.sqrt(variance + 1E-12)

        normed_weights = torch.nn.functional.softmax(torch.cat([parameter for parameter
                                                                in self.scalar_parameters]), dim=0)
        normed_weights = torch.split(normed_weights, split_size=1)

        if not self.do_layer_norm:
            pieces = []
            for weight, tensor in zip(normed_weights, tensors):
                pieces.append(weight * tensor)
            return self.gamma * sum(pieces)

        else:
            mask_float = mask.float()
            broadcast_mask = mask_float.unsqueeze(-1)
            input_dim = tensors[0].size(-1)
            num_elements_not_masked = torch.sum(mask_float) * input_dim

            pieces = []
            for weight, tensor in zip(normed_weights, tensors):
                pieces.append(weight * _do_layer_norm(tensor,
                                                      broadcast_mask, num_elements_not_masked))
            return self.gamma * sum(pieces)
Example #20
0
    def forward(self, x, y, y_mask):
        """Input shapes:
            x = batch * len1 * h
            y = batch * len2 * h
            y_mask = batch * len2
        Output shapes:
            matched_seq = batch * len1 * h
        """
        # Project vectors
        x_s = x.repeat(self.n_head,1,1).view(self.n_head,-1,x.size(2)) # n_head * (batch x len1) * input_size
        #print('y', y.size())
        y_s = y.repeat(self.n_head, 1, 1).view(self.n_head, -1, y.size(2)) # n_head * (batch x len2) * input_size
        #print('y_s', y_s.size())
        x_s = torch.bmm(x_s,self.w) # n_head * (batch x len1) * hidden_size
        y_s = torch.bmm(y_s, self.w)  # n_head * (batch x len2) * hidden_size
        #print('y_s',y_s.size())
        x_s = x_s.view(-1,x.size(1),self.hidden_size) # (n_head x batch) * len1 * hhidden_size
        y_s = y_s.view(-1, y.size(1), self.hidden_size) # (n_head x batch) * len2 * hidden_size
        #print('y_s', y_s.size())
        #x_proj = self.linear(x.view(-1, x.size(2))).view(x.size())
        x_s_proj = F.relu(x_s)
        #y_proj = self.linear(y.view(-1, y.size(2))).view(y.size())
        y_s_proj = F.relu(y_s)

        # Compute scores
        scores = x_s_proj.bmm(y_s_proj.transpose(2, 1)) # (n_head x batch) * len1 * len2

        # Mask padding
        y_mask = y_mask.unsqueeze(1).repeat(self.n_head,x.size(1),1)
        #print('y_mask:',y_mask.size())
        #print('scores:', scores.size())

        scores.data.masked_fill_(y_mask.data, -float('inf'))

        # Normalize with softmax
        alpha_flat = F.softmax(scores.view(-1, y.size(1)),dim=1)
        alpha = alpha_flat.view(-1, x.size(1), y.size(1))

        # Take weighted average
        matched_seq = alpha.bmm(y_s) # (n_head x batch) * len1 * hidden_size
        matched_seq = torch.cat(torch.split(matched_seq,x.size(0),dim=0),dim=-1) # batch x len1 x (n_head * hidden_size)

        return matched_seq
Example #21
0
    def test_elmo_bilm_can_handle_higher_dimensional_input_with_cache(self):
        sentences = [["This", "is", "a", "sentence"],
                     ["Here", "'s", "one"],
                     ["Another", "one"]]
        vocab, tensor = self.get_vocab_and_both_elmo_indexed_ids(sentences)
        words_to_cache = list(vocab.get_token_to_index_vocabulary("tokens").keys())
        elmo_bilm = Elmo(self.options_file, self.weight_file, 1, vocab_to_cache=words_to_cache)
        elmo_bilm.eval()

        individual_dim = elmo_bilm(tensor["character_ids"], tensor["tokens"])
        elmo_bilm = Elmo(self.options_file, self.weight_file, 1, vocab_to_cache=words_to_cache)
        elmo_bilm.eval()

        expanded_word_ids = torch.stack([tensor["tokens"] for _ in range(4)], dim=1)
        expanded_char_ids = torch.stack([tensor["character_ids"] for _ in range(4)], dim=1)
        expanded_result = elmo_bilm(expanded_char_ids, expanded_word_ids)
        split_result = [x.squeeze(1) for x in torch.split(expanded_result["elmo_representations"][0], 1, dim=1)]
        for expanded in split_result:
            numpy.testing.assert_array_almost_equal(expanded.data.cpu().numpy(),
                                                    individual_dim["elmo_representations"][0].data.cpu().numpy())
    def forward(self, input_, hx):
        """
        Args:
            input_: A (batch, input_size) tensor containing input
                features.
            hx: A tuple (h_0, c_0), which contains the initial hidden
                and cell state, where the size of both states is
                (batch, hidden_size).
        Returns:
            h_1, c_1: Tensors containing the next hidden and cell state.
        """

        h_0, c_0 = hx
        batch_size = h_0.size(0)
        bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
        wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
        wi = torch.mm(input_, self.weight_ih)
        f, i, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1)
        c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
        return c_1
Example #23
0
    def forward(self, x, char_x, lens):
        B, T = x.shape
        # 获取掩码
        mask = x.gt(0)
        # 获取词嵌入向量
        x = self.embed(x)

        # 获取字嵌入向量
        char_x = self.char_lstm(char_x[mask])
        char_x = pad_sequence(torch.split(char_x, lens.tolist()), True)

        # 获取词表示与字表示的拼接
        x = torch.cat((x, char_x), dim=-1)
        x = self.drop(x)

        x = pack_padded_sequence(x, lens, True)
        x, _ = self.word_lstm(x)
        x, _ = pad_packed_sequence(x, True)
        x = self.drop(x)

        return self.out(x)
Example #24
0
 def forward(self, k, q):
     if len(q.shape) == 2:  # q_len missing
         q = torch.unsqueeze(q, dim=1)
     if len(k.shape) == 2:  # k_len missing
         k = torch.unsqueeze(k, dim=1)
     mb_size = k.shape[0]  # ?
     k_len = k.shape[1]
     q_len = q.shape[1]
     # k: (?, k_len, embed_dim,)
     # q: (?, q_len, embed_dim,)
     # kx: (n_head, ?*k_len, embed_dim) -> (n_head*?, k_len, hidden_dim)
     # qx: (n_head, ?*q_len, embed_dim) -> (n_head*?, q_len, hidden_dim)
     # score: (n_head*?, q_len, k_len,)
     # output: (?, q_len, out_dim,)
     kx = k.repeat(self.n_head, 1, 1).view(self.n_head, -1, self.embed_dim)  # (n_head, ?*k_len, embed_dim)
     qx = q.repeat(self.n_head, 1, 1).view(self.n_head, -1, self.embed_dim)  # (n_head, ?*q_len, embed_dim)
     kx = torch.bmm(kx, self.w_kx).view(-1, k_len, self.hidden_dim)  # (n_head*?, k_len, hidden_dim)
     qx = torch.bmm(qx, self.w_qx).view(-1, q_len, self.hidden_dim)  # (n_head*?, q_len, hidden_dim)
     if self.score_function == 'scaled_dot_product':
         kt = kx.permute(0, 2, 1)
         qkt = torch.bmm(qx, kt)
         score = torch.div(qkt, math.sqrt(self.hidden_dim))
     elif self.score_function == 'mlp':
         kxx = torch.unsqueeze(kx, dim=1).expand(-1, q_len, -1, -1)
         qxx = torch.unsqueeze(qx, dim=2).expand(-1, -1, k_len, -1)
         kq = torch.cat((kxx, qxx), dim=-1)  # (n_head*?, q_len, k_len, hidden_dim*2)
         score = F.tanh(torch.matmul(kq, self.weight))
     elif self.score_function == 'bi_linear':
         qw = torch.matmul(qx, self.weight)
         kt = kx.permute(0, 2, 1)
         score = torch.bmm(qw, kt)
     else:
         raise RuntimeError('invalid score_function')
     score = F.softmax(score, dim=-1)
     output = torch.bmm(score, kx)  # (n_head*?, q_len, hidden_dim)
     output = torch.cat(torch.split(output, mb_size, dim=0), dim=-1)  # (?, q_len, n_head*hidden_dim)
     output = self.proj(output)  # (?, q_len, out_dim)
     output = self.dropout(output)
     return output
 def forward(self, x):
     x_shape = x.size()  # (b, c, h, w)
     offset = self.offset_filter(x)  # (b, 2*c, h, w)
     offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1)  # (b, c, h, w)
     offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))  # (b*c, h, w)
     offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))  # (b*c, h, w)
     if not self.input_shape or self.input_shape != x_shape:
         self.input_shape = x_shape
         grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2]))  # (h, w)
         grid_w = torch.Tensor(grid_w)
         grid_h = torch.Tensor(grid_h)
         if self.cuda:
             grid_w = grid_w.cuda()
             grid_h = grid_h.cuda()
         self.grid_w = nn.Parameter(grid_w)
         self.grid_h = nn.Parameter(grid_h)
     offset_w = offset_w + self.grid_w  # (b*c, h, w)
     offset_h = offset_h + self.grid_h  # (b*c, h, w)
     x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1)  # (b*c, 1, h, w)
     x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3))  # (b*c, h, w)
     x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3]))  # (b, c, h, w)
     x = self.regular_filter(x)
     return x
Example #26
0
    def forward(
        self,
        queries,
        keys,
        time_mask,
        attn_mask,
        time_matrix_K,
        time_matrix_V,
        abs_pos_K,
        abs_pos_V,
    ):
        """Forward function.

        Args:
            queries ([type]): [description]
            keys ([type]): [description]
            time_mask ([type]): [description]
            attn_mask ([type]): [description]
            time_matrix_K ([type]): [description]
            time_matrix_V ([type]): [description]
            abs_pos_K ([type]): [description]
            abs_pos_V ([type]): [description]

        Returns:
            [type]: [description]
        """
        Q, K, V = self.Q_w(queries), self.K_w(keys), self.V_w(keys)

        # head dim * batch dim for parallelization (h*N, T, C/h)
        Q_ = torch.cat(torch.split(Q, self.head_size, dim=2), dim=0)
        K_ = torch.cat(torch.split(K, self.head_size, dim=2), dim=0)
        V_ = torch.cat(torch.split(V, self.head_size, dim=2), dim=0)

        time_matrix_K_ = torch.cat(
            torch.split(time_matrix_K, self.head_size, dim=3), dim=0
        )
        time_matrix_V_ = torch.cat(
            torch.split(time_matrix_V, self.head_size, dim=3), dim=0
        )
        abs_pos_K_ = torch.cat(torch.split(abs_pos_K, self.head_size, dim=2), dim=0)
        abs_pos_V_ = torch.cat(torch.split(abs_pos_V, self.head_size, dim=2), dim=0)

        # batched channel wise matmul to gen attention weights
        attn_weights = Q_.matmul(torch.transpose(K_, 1, 2))
        attn_weights += Q_.matmul(torch.transpose(abs_pos_K_, 1, 2))
        attn_weights += time_matrix_K_.matmul(Q_.unsqueeze(-1)).squeeze(-1)

        # seq length adaptive scaling
        attn_weights = attn_weights / (K_.shape[-1] ** 0.5)

        # key masking, -2^32 lead to leaking, inf lead to nan
        # 0 * inf = nan, then reduce_sum([nan,...]) = nan

        # time_mask = time_mask.unsqueeze(-1).expand(attn_weights.shape[0], -1, attn_weights.shape[-1])
        time_mask = time_mask.unsqueeze(-1).repeat(self.head_num, 1, 1)
        time_mask = time_mask.expand(-1, -1, attn_weights.shape[-1])
        attn_mask = attn_mask.unsqueeze(0).expand(attn_weights.shape[0], -1, -1)
        paddings = torch.ones(attn_weights.shape) * (
            -(2 ** 32) + 1
        )  # -1e23 # float('-inf')
        paddings = paddings.to("cuda")
        attn_weights = torch.where(
            time_mask, paddings, attn_weights
        )  # True:pick padding
        attn_weights = torch.where(
            attn_mask, paddings, attn_weights
        )  # enforcing causality

        attn_weights = self.softmax(
            attn_weights
        )  # code as below invalids pytorch backward rules
        # attn_weights = torch.where(time_mask, paddings, attn_weights) # weird query mask in tf impl
        # https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/4
        # attn_weights[attn_weights != attn_weights] = 0 # rm nan for -inf into softmax case
        attn_weights = self.dropout(attn_weights)

        outputs = attn_weights.matmul(V_)
        outputs += attn_weights.matmul(abs_pos_V_)
        outputs += (
            attn_weights.unsqueeze(2)
            .matmul(time_matrix_V_)
            .reshape(outputs.shape)
            .squeeze(2)
        )

        # (num_head * N, T, C / num_head) -> (N, T, C)
        outputs = torch.cat(
            torch.split(outputs, Q.shape[0], dim=0), dim=2
        )  # div batch_size

        return outputs
Example #27
0
 def forward(self, x):
     split = torch.split(x, self.half, 1)
     out1 = self.IN(split[0].contiguous())
     out2 = self.BN(split[1].contiguous())
     out = torch.cat((out1, out2), 1)
     return out
    def _upward_downward(self, layer, direction, inputs, tree, idx):
        # check to see whether this node has been computed on this
        # layer in this direction, if so short circuit the rest of
        # this function and return that result
        if idx in self.hidden_state[layer][direction]:
            h_t = self.hidden_state[layer][direction][idx]
            c_t = self.cell_state[layer][direction][idx]

            return h_t, c_t

        x_t = self._construct_x_t(layer, inputs, idx, tree)

        oidx, (h_prev, c_prev) = self._construct_previous(layer, direction,
                                                          inputs, tree, idx)

        if self.bias:
            Wih, Whh, bih, bhh = self._get_parameters(layer, direction)

            # print(Wih.size())
            # print(Whh.size())
            # print(bih.size())
            # print(bhh.size())

            # print(x_t.size())
            # print(h_prev.size())

            fcio_t_raw = torch.matmul(Whh, h_prev) +\
                torch.matmul(Wih, x_t[:, None]) +\
                bhh[:, None] + bih[:, None]

        else:
            Wih, Whh = self._get_parameters(layer, direction)

            fcio_t_raw = torch.matmul(Whh, h_prev) +\
                torch.matmul(Wih, x_t[:, None])

        f_t_raw, c_hat_t_raw, i_t_raw, o_t_raw = torch.split(fcio_t_raw,
                                                             self.hidden_size,
                                                             dim=0)

        f_t = F.sigmoid(f_t_raw)

        gated_children = torch.mul(f_t, c_prev)
        gated_children = torch.sum(gated_children, 1, keepdim=False)

        c_hat_t_raw = torch.sum(c_hat_t_raw, 1, keepdim=False)
        i_t_raw = torch.sum(i_t_raw, 1, keepdim=False)
        o_t_raw = torch.sum(o_t_raw, 1, keepdim=False)

        c_hat_t = self.__class__.nonlinearity(c_hat_t_raw)
        i_t = F.sigmoid(i_t_raw)
        o_t = F.sigmoid(o_t_raw)

        c_t = gated_children + torch.mul(i_t, c_hat_t)
        h_t = torch.mul(o_t, self.__class__.nonlinearity(c_t))

        if self.dropout:
            dropout = Dropout(p=self.dropout)
            h_t = dropout(h_t)
            c_t = dropout(c_t)

        self.hidden_state[layer][direction][idx] = h_t
        self.cell_state[layer][direction][idx] = c_t

        if direction == 'up' and self.bidirectional:
            self._upward_downward(layer, 'down', inputs, tree, idx)

        return h_t, c_t
Example #29
0
  def train(x, y):
    G.optim.zero_grad()
    D.optim.zero_grad()
    # How many chunks to split x and y into?
    x = torch.split(x, config['batch_size'])
    y = torch.split(y, config['batch_size'])
    counter = 0
    
    # Optionally toggle D and G's "require_grad"
    if config['toggle_grads']:
      utils.toggle_grad(D, True)
      utils.toggle_grad(G, False)

    for step_index in range(config['num_D_steps']):
      # If accumulating gradients, loop multiple times before an optimizer step
      D.optim.zero_grad()
      for accumulation_index in range(config['num_D_accumulations']):
        z_.sample_()
        y_.sample_()
        D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], 
                            x[counter], y[counter], train_G=False, 
                            split_D=config['split_D'])
         
        # Compute components of D's loss, average them, and divide by 
        # the number of gradient accumulations
        D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
        D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations'])
        D_loss.backward()
        counter += 1
        
      # Optionally apply ortho reg in D
      if config['D_ortho'] > 0.0:
        # Debug print to indicate we're using ortho reg in D.
        print('using modified ortho reg in D')
        utils.ortho(D, config['D_ortho'])
      
      D.optim.step()
    
    # Optionally toggle "requires_grad"
    if config['toggle_grads']:
      utils.toggle_grad(D, False)
      utils.toggle_grad(G, True)
      
    # Zero G's gradients by default before training G, for safety
    G.optim.zero_grad()
    
    # If accumulating gradients, loop multiple times
    for accumulation_index in range(config['num_G_accumulations']):    
      z_.sample_()
      y_.sample_()
      D_fake = GD(z_, y_, train_G=True, split_D=config['split_D'])
      G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations'])
      G_loss.backward()
    
    # Optionally apply modified ortho reg in G
    if config['G_ortho'] > 0.0:
      print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G
      # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
      utils.ortho(G, config['G_ortho'],
                  blacklist=[param for param in G.shared.parameters()])
    G.optim.step()
    
    # If we have an ema, update it, regardless of if we test with it or not
    if config['ema']:
      ema.update(state_dict['itr'])
    
    out = {'G_loss': float(G_loss.item()), 
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item())}
    # Return G's loss and the components of D's loss.
    return out
Example #30
0
    def train_ctr_phase(self):

        self.max_epoch = -1
        self.max_val_score = 0.0
        for epoch in range(self.args.epochs):

            if epoch == 0 or (epoch % self.args.match_interrupt == 0
                              and self.args.match_flag):
                data_match_tensor, label_match_tensor = self.get_match_function(
                    epoch)

            penalty_same_ctr = 0
            penalty_diff_ctr = 0
            penalty_same_hinge = 0
            penalty_diff_hinge = 0
            train_acc = 0.0
            train_size = 0

            perm = torch.randperm(data_match_tensor.size(0))
            data_match_tensor_split = torch.split(data_match_tensor[perm],
                                                  self.args.batch_size,
                                                  dim=0)
            label_match_tensor_split = torch.split(label_match_tensor[perm],
                                                   self.args.batch_size,
                                                   dim=0)
            print('Split Matched Data: ', len(data_match_tensor_split),
                  data_match_tensor_split[0].shape,
                  len(label_match_tensor_split))

            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e, d_e,
                            idx_e) in enumerate(self.train_dataset):
                #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e = torch.tensor(0.0).to(self.cuda)

                x_e = x_e.to(self.cuda)
                y_e = torch.argmax(y_e, dim=1).to(self.cuda)
                d_e = torch.argmax(d_e, dim=1).numpy()

                same_ctr_loss = torch.tensor(0.0).to(self.cuda)
                diff_ctr_loss = torch.tensor(0.0).to(self.cuda)
                same_hinge_loss = torch.tensor(0.0).to(self.cuda)
                diff_hinge_loss = torch.tensor(0.0).to(self.cuda)

                if epoch > self.args.penalty_s:
                    # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
                    total_batch_size = len(data_match_tensor_split)
                    if batch_idx >= total_batch_size:
                        break
                    curr_batch_size = data_match_tensor_split[batch_idx].shape[
                        0]

                    #             data_match= data_match_tensor[idx].to(cuda)
                    data_match = data_match_tensor_split[batch_idx].to(
                        self.cuda)
                    data_match = data_match.view(
                        data_match.shape[0] * data_match.shape[1],
                        data_match.shape[2], data_match.shape[3],
                        data_match.shape[4])
                    feat_match = self.phi(data_match)

                    #             label_match= label_match_tensor[idx].to(self.cuda)
                    label_match = label_match_tensor_split[batch_idx].to(
                        self.cuda)
                    label_match = label_match.view(label_match.shape[0] *
                                                   label_match.shape[1])

                    # Creating tensor of shape ( domain size, total domains, feat size )
                    if len(feat_match.shape) == 4:
                        feat_match = feat_match.view(
                            curr_batch_size, len(self.train_domains),
                            feat_match.shape[1] * feat_match.shape[2] *
                            feat_match.shape[3])
                    else:
                        feat_match = feat_match.view(curr_batch_size,
                                                     len(self.train_domains),
                                                     feat_match.shape[1])

                    label_match = label_match.view(curr_batch_size,
                                                   len(self.train_domains))

                    #             print(feat_match.shape)
                    data_match = data_match.view(curr_batch_size,
                                                 len(self.train_domains),
                                                 data_match.shape[1],
                                                 data_match.shape[2],
                                                 data_match.shape[3])

                    # Contrastive Loss
                    same_neg_counter = 1
                    diff_neg_counter = 1
                    for y_c in range(self.args.out_classes):

                        pos_indices = label_match[:, 0] == y_c
                        neg_indices = label_match[:, 0] != y_c
                        pos_feat_match = feat_match[pos_indices]
                        neg_feat_match = feat_match[neg_indices]

                        #                         if pos_feat_match.shape[0] > neg_feat_match.shape[0]:
                        #                             print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])

                        # If no instances of label y_c in the current batch then continue
                        if pos_feat_match.shape[
                                0] == 0 or neg_feat_match.shape[0] == 0:
                            continue

                        # Iterating over anchors from different domains
                        for d_i in range(pos_feat_match.shape[1]):
                            if torch.sum(torch.isnan(neg_feat_match)):
                                print('Non Reshaped X2 is Nan')
                                sys.exit()

                            diff_neg_feat_match = neg_feat_match.view(
                                neg_feat_match.shape[0] *
                                neg_feat_match.shape[1],
                                neg_feat_match.shape[2])

                            if torch.sum(torch.isnan(diff_neg_feat_match)):
                                print('Reshaped X2 is Nan')
                                sys.exit()

                            neg_dist = embedding_dist(
                                pos_feat_match[:, d_i, :],
                                diff_neg_feat_match[:, :],
                                self.args.pos_metric,
                                self.args.tau,
                                xent=True)
                            if torch.sum(torch.isnan(neg_dist)):
                                print('Neg Dist Nan')
                                sys.exit()

                            # Iterating pos dist for current anchor
                            for d_j in range(pos_feat_match.shape[1]):
                                if d_i != d_j:
                                    pos_dist = 1.0 - embedding_dist(
                                        pos_feat_match[:, d_i, :],
                                        pos_feat_match[:, d_j, :],
                                        self.args.pos_metric)
                                    pos_dist = pos_dist / self.args.tau
                                    if torch.sum(torch.isnan(neg_dist)):
                                        print('Pos Dist Nan')
                                        sys.exit()

                                    if torch.sum(
                                            torch.isnan(
                                                torch.log(
                                                    torch.exp(pos_dist) +
                                                    neg_dist))):
                                        print('Xent Nan')
                                        sys.exit()

    #                                 print( 'Pos Dist', pos_dist )
    #                                 print( 'Log Dist ', torch.log( torch.exp(pos_dist) + neg_dist ))
                                    diff_hinge_loss += -1 * torch.sum(
                                        pos_dist - torch.log(
                                            torch.exp(pos_dist) + neg_dist))
                                    diff_ctr_loss += torch.sum(neg_dist)
                                    diff_neg_counter += pos_dist.shape[0]

                    same_ctr_loss = same_ctr_loss / same_neg_counter
                    diff_ctr_loss = diff_ctr_loss / diff_neg_counter
                    same_hinge_loss = same_hinge_loss / same_neg_counter
                    diff_hinge_loss = diff_hinge_loss / diff_neg_counter

                    penalty_same_ctr += float(same_ctr_loss)
                    penalty_diff_ctr += float(diff_ctr_loss)
                    penalty_same_hinge += float(same_hinge_loss)
                    penalty_diff_hinge += float(diff_hinge_loss)

                    loss_e += ((epoch - self.args.penalty_s) /
                               (self.args.epochs -
                                self.args.penalty_s)) * diff_hinge_loss

                loss_e.backward(retain_graph=False)
                self.opt.step()

                del same_ctr_loss
                del diff_ctr_loss
                del same_hinge_loss
                del diff_hinge_loss
                torch.cuda.empty_cache()

            print('Train Loss Ctr : ', penalty_same_ctr, penalty_diff_ctr,
                  penalty_same_hinge, penalty_diff_hinge)
            print('Done Training for epoch: ', epoch)

            if (epoch + 1) % 10 == 0:

                from evaluation.match_eval import MatchEval
                test_method = MatchEval(self.args, self.train_dataset,
                                        self.val_dataset, self.test_dataset,
                                        self.base_res_dir, self.run, self.cuda)
                #Compute test metrics: Mean Rank
                test_method.phi = self.phi
                test_method.get_metric_eval()

                # Save the model's weights post training
                if test_method.metric_score[
                        'TopK Perfect Match Score'] > self.max_val_score:
                    self.max_val_score = test_method.metric_score[
                        'TopK Perfect Match Score']
                    self.max_epoch = epoch
                    self.save_model_ctr_phase(epoch)

                print('Current Best Epoch: ', self.max_epoch,
                      ' with TopK Overlap: ', self.max_val_score)
Example #31
0
 def forward(self, x):
     size1 = x.size(1) / 2
     x = torch.split(x, size1, 1)
     x = torch.max(x[0], x[1])
     return x
    def expected_calls(self, data, model, mode, callback):
        if not hasattr(model, callback.attr_name):
            return []

        if callback.split_channels == (2, 1) or callback.split_channels == ([
                2, 1
        ], None):
            pytest.skip("incompatible test")

        data1, data2 = data
        data1 = prepare_image(data1)
        data2 = prepare_image(data2)
        B, _, _, _ = data1.shape
        img1 = [data1]
        img2 = [data2]
        name1 = [
            f"{mode}/{callback.name}",
        ]
        name2 = [
            f"{mode}/{callback.name}",
        ]
        img = [img1, img2]
        name = [name1, name2]

        # channel splitting

        for pos in range(2):
            if callback.split_channels[pos]:
                img[pos] = []
                name[pos] = []
                img_new, name_new = [], []
                splits = torch.split(data[pos],
                                     callback.split_channels[pos],
                                     dim=-3)
                for i, s in enumerate(splits):
                    if isinstance(callback.name, str):
                        n = f"{mode}/{callback.name}_{i}"
                    else:
                        n = f"{mode}/{callback.name[i]}"
                    name_new.append(n)
                    img_new.append(s)
                img[pos] = img_new
                name[pos] = name_new

        if len(img[0]) != len(img[1]):
            if len(img[0]) == 1:
                img[0] = img[0] * len(img[1])
            elif len(img[1]) == 1:
                img[1] = img[1] * len(img[0])
            else:
                raise RuntimeError()

        for pos in range(2):
            if callback.max_resolution:
                resize_mode = callback.resize_mode
                target = callback.max_resolution
                H_max, W_max = target
                needs_resize = [
                    i.shape[-2] > H_max or i.shape[-1] > W_max
                    for i in img[pos]
                ]
                img[pos] = [
                    F.interpolate(i, target, mode=resize_mode) if resize else i
                    for i, resize in zip(img[pos], needs_resize)
                ]

        for pos in range(2):
            if (colormap := callback.colormap[pos]):
                if isinstance(colormap, str):
                    colormap = [colormap] * len(img[pos])
                img[pos] = [
                    apply_colormap(i, cmap)[..., :3, :, :]
                    if cmap is not None else i
                    for cmap, i in zip(colormap, img[pos])
                ]
class TestVisualizeCallback:

    callback_cls = VisualizeCallback

    # training, validation, or testing mode
    @pytest.fixture(params=["train", "val", "test"])
    def mode(self, request):
        return request.param

    @pytest.fixture
    def data_shape(self):
        return 2, 3, 32, 32

    @pytest.fixture
    def data(self, data_shape):
        data = create_image(*data_shape)
        return data

    @pytest.fixture(autouse=True)
    def model(self, request, mocker, callback, data, mode, trainer):

        if hasattr(request, "param"):
            step = request.param.pop("step", 10)
            epoch = request.param.pop("epoch", 1)
        else:
            step = 10
            epoch = 1

        model = mocker.MagicMock(name="module")
        model.current_epoch = epoch
        model.global_step = step
        model.global_step = step
        if callback.attr_name is not None:
            setattr(model, callback.attr_name, data)

        if mode == "train":
            attr = "on_train_batch_end"
        elif mode == "val":
            attr = "on_validation_batch_end"
        elif mode == "test":
            attr = "on_test_batch_end"
        else:
            raise ValueError(f"{mode}")
        callback.trigger = lambda: getattr(callback, attr)(trainer, model)

        return model

    @pytest.fixture
    def callback(self, request):
        cls = self.callback_cls
        init_signature = inspect.signature(cls)
        defaults = {
            k: v.default
            for k, v in init_signature.parameters.items()
            if v.default is not inspect.Parameter.empty
        }

        if hasattr(request, "param"):
            name = request.param.get("name", "image")
            defaults.update(request.param)
        else:
            name = "image"
            defaults["name"] = name

        callback = cls(**defaults)
        return callback

    @pytest.fixture
    def logger_func(self, model):
        return model.logger.experiment.add_images

    @pytest.fixture
    def expected_calls(self, data, model, mode, callback):
        if not hasattr(model, callback.attr_name):
            return []
        data = prepare_image(data)

        B, _, H, W = data.shape
        img = [data]
        name = [
            f"{mode}/{callback.name}",
        ]

        # channel splitting
        if callback.split_channels:
            img, name = [], []
            splits = torch.split(data, callback.split_channels, dim=-3)
            for i, s in enumerate(splits):
                if isinstance(callback.name, str):
                    n = f"{mode}/{callback.name}_{i}"
                else:
                    n = f"{mode}/{callback.name[i]}"
                name.append(n)
                img.append(s)

        if callback.max_resolution:
            resize_mode = callback.resize_mode
            target = callback.max_resolution
            H_max, W_max = target
            scale_factor = []
            for i in img:
                H, W = i.shape[-2:]
                height_ratio, width_ratio = H / H_max, W / W_max
                s = 1 / max(height_ratio, width_ratio)
                scale_factor.append(s)

            img = [
                F.interpolate(i, scale_factor=s, mode=resize_mode)
                if s < 1 else i for i, s in zip(img, scale_factor)
            ]

        if (colormap := callback.colormap):
            if isinstance(colormap, str):
                colormap = [colormap] * len(img)
            img = [
                apply_colormap(i, cmap)[..., :3, :, :]
                if cmap is not None else i for cmap, i in zip(colormap, img)
            ]

        if callback.split_batches:
            new_img, new_name = [], []
            for i, n in zip(img, name):
                split_i = torch.split(i, 1, dim=0)
                split_n = [f"{n}/{b}" for b in range(B)]
                new_img += split_i
                new_name += split_n
            name, img = new_name, new_img

        if callback.as_uint8:
            img = [
                to_8bit(i, same_on_batch=not callback.per_img_norm)
                for i in img
            ]

        step = [
            model.current_epoch
            if callback.epoch_counter else model.global_step
        ] * len(name)
        expected = [(n, i, s) for n, i, s in zip(name, img, step)]
        return expected
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        # encoder_padding_mask = encoder_out['encoder_padding_mask']
        encoder_out = encoder_out['encoder_out']

        bsz, seqlen, _ = prev_output_tokens.size()

        # get outputs from encoder
        encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
        # srclen = encoder_outs.size(0)

        x = prev_output_tokens
        #x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C 10,32,16
        x = x.transpose(0, 1)

        # for_logging = ((x*self.all_stds)+self.all_means).cpu().detach().numpy()
        # #from fairseq import pdb; pdb.set_trace();
        # wandb.log(
        #             {'mean_input_velocities': for_logging[:,:,1::self.num_var_per_segment].mean(),
        #             'mean_input_velocities_1': for_logging[:,:,1].mean(),
        #             'mean_input_velocities_4': for_logging[:,:,-3].mean(),
        #             'mean_input_densities': for_logging[:,:,0::self.num_var_per_segment].mean(),
        #             'mean_input_density_1': for_logging[:,:,0].mean()#,
        #             # 'mean_onramp_flows': for_logging[:,:,::self.num_var_per_segment].mean(),
        #             # 'mean_offramp_flows': for_logging[:,:,::self.num_var_per_segment].mean()
        #             }
        #         )

        if self.encoder_hidden_proj != None:
            prev_hiddens = self.encoder_hidden_proj(encoder_hiddens[0, :, :])
            prev_cells = self.encoder_cell_proj(encoder_cells[0, :, :])
        else:
            prev_hiddens = encoder_hiddens[0, :, :]
            prev_cells = encoder_cells[0, :, :]

        # input_feed = torch.sigmoid(self.encoder_hidden_to_input_feed_proj(prev_hiddens))
        if self.encoder_hidden_to_input_feed_proj != None:
            if self.input_feed_activation != None:
                input_feed = self.input_feed_activation(
                    self.encoder_hidden_to_input_feed_proj(
                        encoder_hiddens[0, :, :]))
            else:
                if self.extra_hidden_layer:
                    extra_hidden = torch.relu(
                        self.encoder_hidden_to_decoder_input_feed_hidden_layer(
                            encoder_hiddens[0, :, :]))
                else:
                    extra_hidden = encoder_hiddens[0, :, :]
                input_feed = self.encoder_hidden_to_input_feed_proj(
                    extra_hidden)
                # input_feed = self.encoder_hidden_to_input_feed_proj(encoder_hiddens[0,:,:])
        else:
            if self.input_feed_activation != None:
                input_feed = self.input_feed_activation(
                    encoder_hiddens[0, :, :])
            else:
                input_feed = encoder_hiddens[0, :, :]

        self.first_input_feed = input_feed

        outs = []
        common_params_list = []
        segment_params_list = []

        flow_res_list = []

        for j in range(seqlen):

            #input_to_rnn = torch.cat((x[j, :,:], input_feed), dim=1)
            # hidden, cell = self.rnn(input_to_rnn, (prev_hiddens, prev_cells))

            input_x = ((x[j, :, :] * self.input_stds) + self.input_means
                       )  #+ torch.Tensor([0.5]).float()
            # input_x = F.dropout(input_x, p=self.dropout_in, training=self.training)

            # ['Seg00_q', 'Seg00_speed','Seg04_q', 'Seg04_speed','Seg04_r', 'Seg02_s']
            # T x (B x C)
            q0_i, v0_i, q4_i, v4_i, r4_i, s2_i = torch.unbind(input_x, dim=1)

            unnormed_input_feed = (input_feed * self.all_stds) + self.all_means
            rho1, v1, r1, s1, rho2, v2, r2, s2, rho3, v3, r3, s3, rho4, v4, r4, s4 = torch.unbind(
                unnormed_input_feed, dim=1)

            # q4 = rho4*v4*3.0
            # q4 = q4_i if q4_i>0 else q4

            rho4 = (q4_i / (v4_i * 3.0)) * (
                (q4_i / (v4_i * 3.0)) > 0).float() + rho4 * (
                    (q4_i / (v4_i * 3.0)) <=
                    0).float()  #if (q4_i/(v4_i*3.0))>0 else rho4
            v4 = v4_i * ((v4_i > 0).float()) + v4 * (
                (v4_i <= 0).float())  #if v4_i>0 else v4
            r4 = r4_i * ((r4_i > 0).float()) + r4 * (
                (r4_i <= 0).float())  #if r4_i>0 else r4
            s2 = s2_i * ((s2_i > 0).float()) + s2 * (
                (s2_i <= 0).float())  #if s2_i>0 else s2

            real_size_input = torch.stack([
                rho1, v1, r1, s1, rho2, v2, r2, s2, rho3, v3, r3, s3, rho4, v4,
                r4, s4
            ],
                                          dim=1)
            # real_size_input = (blended_input * self.all_stds) + self.all_means

            #input_x = F.dropout(input_x, p=self.dropout_in, training=self.training)
            #input_feed = input_feed #+ torch.Tensor([0.5]).float()
            # input_mask = (input_x*self.max_vals) > 0.0

            # input_mask = ((input_x*self.all_stds)+self.all_means) > 0.0

            # input_mask = input_mask*0.0
            # blended_input = (input_x*input_mask.float()) + ( (1-input_mask.float())*(input_feed))

            # input_feed_mask = ((input_feed*self.all_stds)+self.all_means) > 0.0
            # blended_input = input_feed * input_feed_mask.float()

            # hidden, cell = self.rnn(blended_input, (prev_hiddens, prev_cells))
            hidden, cell = self.rnn(x[j, :, :], (prev_hiddens, prev_cells))

            prev_hiddens = hidden  #for next loop
            prev_cells = cell

            ntf_params = self.ntf_projection(hidden)

            #NTF
            common_params, segment_params = torch.split(
                ntf_params,
                [self.num_common_params, self.total_segment_specific_params],
                dim=1)
            if self.common_param_activation != None:
                common_params = self.common_param_activation(common_params)
            common_params = (self.common_param_multipliers *
                             common_params) + self.common_param_additions
            v0, q0, rhoNp1, vf, a_var, rhocr = torch.unbind(common_params,
                                                            dim=1)  #, g_var
            g_var = torch.Tensor([[1.0]])

            v0 = v0_i * ((v0_i > 0).float()) + v0 * (
                (v0_i <= 0).float())  #if v0_i>0 else v0
            q0 = q0_i * ((q0_i > 0).float()) + q0 * (
                (q0_i <= 0).float())  # if q0_i>0 else q0

            # vf = vf.detach() #* 0.0 +120.0
            # a_var = a_var.detach() #* 0.0 + 1.4
            # rhocr = rhocr.detach() #* 0.0 + 30.
            # g_var = g_var.detach() #*0.0 + 1.0

            if self.segment_param_activation != None:
                segment_params = self.segment_param_activation(segment_params)
            else:
                segment_params = segment_params
            segment_params = segment_params.view(
                (-1, self.num_segment_specific_params, self.num_segments))
            segment_params = segment_params * self.segment_param_multipliers + self.segment_param_additions
            future_r, future_s = torch.unbind(segment_params, dim=1)

            # real_size_input = blended_input*self.max_vals
            # real_size_input = real_size_input * input_feed_mask.float()
            model_steps = []

            for _ in range(self.num_ntf_steps):
                one_ntf_output, flow_res = self.ntf_module(
                    x=real_size_input, v0=v0, q0=q0, rhoNp1=rhoNp1, vf=vf, a_var=a_var, rhocr=rhocr,\
                    g_var=g_var, future_r=future_r, future_s=future_s)
                real_size_input = one_ntf_output
                flow_res_list.append(flow_res)

                model_steps.append(one_ntf_output)

            # mean_ntf_output = torch.stack(model_steps, dim=0).mean(dim=0)
            mean_ntf_output = real_size_input

            # scaled_output = mean_ntf_output/(self.max_vals+1e-6)
            normed_output = (mean_ntf_output -
                             self.all_means) / (self.all_stds)

            common_params_list.append(common_params)
            segment_params_list.append(segment_params)
            # outs.append(scaled_output)
            outs.append(normed_output)

            input_feed = normed_output  #- torch.Tensor([0.5]).float()

        # collect outputs across time steps
        # dim=1 to go from T x B x C -> B x T x C
        self.returned_out = torch.stack(outs, dim=1)
        self.all_common_params = torch.stack(common_params_list, dim=1)
        self.all_segment_params = torch.stack(segment_params_list, dim=1)

        v0_a, q0_a, rhoNp1_a, vf_a, a_var_a, rhocr_a = torch.unbind(
            self.all_common_params, dim=2)
        q0_a = (q0_a - 3000.) / 2000.
        v0_a = (v0_a - 90.) / 20.
        rho1_a, v1_a, r1_a, s1_a, rho2_a, v2_a, r2_a, s2_a, rho3_a, v3_a, r3_a, s3_a, rho4_a, v4_a, r4_a, s4_a = torch.unbind(
            self.returned_out, dim=2)

        # q4 = rho4 * v4 * 3.0
        q4_a = ((rho4_a * 15) + 15) * (
            (v4_a * 20) + 90) * 3.0  #3 lanes lambda 4
        q4_a = (q4_a - 3000.) / 2000.
        # v4 = v4
        # r4 =
        # s2 =

        new_out = torch.stack([q0_a, v0_a, q4_a, v4_a, r4_a, s2_a], dim=2)

        self.mean_flow_res = torch.stack(flow_res_list,
                                         dim=2).sum(axis=1).abs().mean(axis=1)

        # return returned_out, self.all_common_params, self.all_segment_params
        return new_out, self.all_common_params, self.all_segment_params
Example #35
0
 def g(self, t, y):  # Diagonal diffusion.
     y = torch.split(y, split_size_or_sections=1, dim=1)
     out = [g_net_i(y_i) for (g_net_i, y_i) in zip(self.g_nets, y)]
     return torch.cat(out, dim=1)
 def forward(self, input):
     embedding = self.main(input).view(-1, 64 * 8 * 2 * 4)
     embedding = self.flatten(embedding)
     ha, hg = torch.split(embedding, [opt.ha_dim, opt.hg_dim], dim=1)
     return ha, hg, embedding
    def forward(self, x):
        out = self.conv1(x)
        outs = torch.split(out, self.core_list, dim=1)
        #print(outs[0].size(),outs[1].size(),outs[2].size())
        #split the input features
        #print(outs[0].size(),outs[1].size())
        #outs = [self.module0_1(outs[0]),self.module0_2(outs[1])]
        outs = [
            self.trans0_1(self.module0_1(outs[0].contiguous())),
            self.trans0_2(self.module0_2(outs[1].contiguous())),
            self.trans0_3(self.module0_3(outs[2].contiguous()))
        ]
        #print(outs[0].size(),outs[1].size(),outs[2].size())
        out_temp = []

        for ind, i in enumerate(self.graph[0]):
            for index, k in enumerate(i):
                if index == 0:
                    out_temp.append(outs[k])
                    #print(out_temp[ind].size())
                    #print(k)
                else:
                    out_temp[ind] = torch.cat((out_temp[ind], outs[k]), dim=1)

        outs = out_temp

        #outs = [self.module1_1(outs[0]),self.module1_2(outs[1]),self.module1_3(outs[2])]

        outs = [
            self.trans1_1(self.module1_1(outs[0].contiguous())),
            self.trans1_2(self.module1_2(outs[1].contiguous())),
            self.trans1_3(self.module1_3(outs[2].contiguous()))
        ]
        #print(outs[0].size(),outs[1].size(),outs[2].size())
        out_temp = []

        for ind, i in enumerate(self.graph[1]):
            for index, k in enumerate(i):
                if index == 0:
                    out_temp.append(outs[k])
                    #print(out_temp[ind].size())
                    #print(k)
                else:
                    out_temp[ind] = torch.cat((out_temp[ind], outs[k]), dim=1)
        outs = out_temp

        outs = [
            self.trans2_1(self.module2_1(outs[0].contiguous())),
            self.trans2_2(self.module2_2(outs[1].contiguous())),
            self.trans2_3(self.module2_3(outs[2].contiguous()))
        ]
        #print(outs[0].size(),outs[1].size(),outs[2].size())
        out_temp = []

        for ind, i in enumerate(self.graph[2]):
            for index, k in enumerate(i):
                if index == 0:
                    out_temp.append(outs[k])
                    #print(out_temp[ind].size())
                    #print(k)
                else:
                    out_temp[ind] = torch.cat((out_temp[ind], outs[k]), dim=1)
        outs = out_temp

        outs = [
            self.trans3_1(self.module3_1(outs[0].contiguous())),
            self.trans3_2(self.module3_2(outs[1].contiguous())),
            self.trans3_3(self.module3_3(outs[2].contiguous()))
        ]
        #print(outs[0].size(),outs[1].size(),outs[2].size())

        out = torch.cat((outs[0], outs[1], outs[2]), dim=1)
        #print(out.size())
        out = F.avg_pool2d(F.relu(self.bn(out.contiguous())), 2)
        #print(out.size())
        out = out.view(out.size(0), -1)
        #print(out.size())
        #out = self.linear(out.contiguous())
        out = self.Linear_01(out.contiguous())

        return out
Example #38
0
 def mean_axis(self, xs, axis):
     y = list(map(lambda x: torch.mean(x, 0), torch.split(xs, axis)))
     return torch.stack(y)
Example #39
0
    def __call__(
        self, batch: Dict[str, Union[torch.Tensor, np.ndarray]]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            batch: Input speech data and corresponding lengths
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        if isinstance(batch["speech"], np.ndarray):
            batch["speech"] = torch.tensor(batch["speech"])
        if isinstance(batch["speech_lengths"], np.ndarray):
            batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, encoder_out_lens = self.asr_model.encode(**batch)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2
        )

        # It maybe useful to tune blank_bias.
        # The valid range of blank_bias is [-inf, 0]
        logp_encoder_output[:, :, 0] += self.blank_bias

        batch_size = encoder_out_lens.size(0)
        sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(torch.int32)
        start_frame = torch.zeros([batch_size], dtype=torch.int32).unsqueeze(0).t()
        num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32)
        supervision_segments = torch.cat([sequence_idx, start_frame, num_frames], dim=1)

        supervision_segments = supervision_segments.to(torch.int32)

        # An introduction to DenseFsaVec:
        # https://k2-fsa.github.io/k2/core_concepts/index.html#dense-fsa-vector
        # It could be viewed as a fsa-type lopg_encoder_output,
        # whose weight on the arcs are initialized with logp_encoder_output.
        # The goal of converting tensor-type to fsa-type is using
        # fsa related functions in k2. e.g. k2.intersect_dense_pruned below
        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments)

        # The term "intersect" is similar to "compose" in k2.
        # The differences is are:
        # for "compose" functions, the composition involves
        # mathcing output label of a.fsa and input label of b.fsa
        # while for "intersect" functions, the composition involves
        # matching input label of a.fsa and input label of b.fsa
        # Actually, in compose functions, b.fsa is inverted and then
        # a.fsa and inv_b.fsa are intersected together.
        # For difference between compose and interset:
        # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/fsa_algo.py#L308
        # For definition of k2.intersect_dense_pruned:
        # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/autograd.py#L648
        lattices = k2.intersect_dense_pruned(
            self.decode_graph,
            dense_fsa_vec,
            self.search_beam_size,
            self.output_beam_size,
            self.min_active_states,
            self.max_active_states,
        )

        # lattices.scores is the sum of decode_graph.scores(a.k.a. lm weight) and
        # dense_fsa_vec.scores(a.k.a. am weight) on related arcs.
        # For ctc decoding graph, lattices.scores only store am weight
        # since the decoder_graph only define the ctc topology and
        # has no lm weight on its arcs.
        # While for 3-gram decoding, whose graph is converted from language models,
        # lattice.scores contains both am weights and lm weights
        #
        # It maybe useful to tune lattice.scores
        # The valid range of lattice_weight is [0, inf)
        # The lattice_weight will affect the search of k2.random_paths
        lattices.scores *= self.lattice_weight

        results = []
        if self.use_nbest_rescoring:
            (
                am_scores,
                lm_scores,
                token_ids,
                new2old,
                path_to_seq_map,
                seq_to_path_splits,
            ) = nbest_am_lm_scores(
                lattices, self.num_paths, self.device, self.nbest_batch_size
            )

            ys_pad_lens = torch.tensor([len(hyp) for hyp in token_ids]).to(self.device)
            max_token_length = max(ys_pad_lens)
            ys_pad_list = []
            for hyp in token_ids:
                ys_pad_list.append(
                    torch.cat(
                        [
                            torch.tensor(hyp, dtype=torch.long),
                            torch.tensor(
                                [self.asr_model.ignore_id]
                                * (max_token_length.item() - len(hyp)),
                                dtype=torch.long,
                            ),
                        ]
                    )
                )

            ys_pad = (
                torch.stack(ys_pad_list).to(torch.long).to(self.device)
            )  # [batch, max_token_length]

            encoder_out = enc.index_select(0, path_to_seq_map.to(torch.long)).to(
                self.device
            )  # [batch, T, dim]
            encoder_out_lens = encoder_out_lens.index_select(
                0, path_to_seq_map.to(torch.long)
            ).to(
                self.device
            )  # [batch]

            decoder_scores = -self.asr_model.batchify_nll(
                encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, self.nll_batch_size
            )

            # padded_value for nnlm is 0
            ys_pad[ys_pad == self.asr_model.ignore_id] = 0
            nnlm_nll, x_lengths = self.lm.batchify_nll(
                ys_pad, ys_pad_lens, self.nll_batch_size
            )
            nnlm_scores = -nnlm_nll.sum(dim=1)

            batch_tot_scores = (
                self.am_weight * am_scores
                + self.decoder_weight * decoder_scores
                + self.nnlm_weight * nnlm_scores
            )
            split_size = indices_to_split_size(
                seq_to_path_splits.tolist(), total_elements=batch_tot_scores.size(0)
            )
            batch_tot_scores = torch.split(
                batch_tot_scores,
                split_size,
            )

            hyps = []
            scores = []
            processed_seqs = 0
            for tot_scores in batch_tot_scores:
                if tot_scores.nelement() == 0:
                    # the last element by torch.tensor_split may be empty
                    # e.g.
                    # torch.tensor_split(torch.tensor([1,2,3,4]), torch.tensor([2,4]))
                    # (tensor([1, 2]), tensor([3, 4]), tensor([], dtype=torch.int64))
                    break
                best_seq_idx = processed_seqs + torch.argmax(tot_scores)

                assert best_seq_idx < len(token_ids)
                best_token_seqs = token_ids[best_seq_idx]
                processed_seqs += tot_scores.nelement()
                hyps.append(best_token_seqs)
                scores.append(tot_scores.max().item())

            assert len(hyps) == len(split_size)
        else:
            best_paths = k2.shortest_path(lattices, use_double_scores=True)
            scores = best_paths.get_tot_scores(
                use_double_scores=True, log_semiring=False
            ).tolist()
            hyps = get_texts(best_paths)

        assert len(scores) == len(hyps)

        for token_int, score in zip(hyps, scores):
            # For decoding methods nbest_rescoring and ctc_decoding
            # hyps stores token_index, which is lattice.labels.

            # convert token_id to text with self.tokenizer
            token = self.converter.ids2tokens(token_int)
            assert self.tokenizer is not None
            text = self.tokenizer.tokens2text(token)
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results
 def forward(self, x):
     x = self.filter(x)
     out = torch.split(x, self.out_channels, 1)  # split channels in CNN or split D in FC
     return torch.max(out[0], out[1])
Example #41
0
    def train_erm_phase(self):

        for run_erm in range(self.args.n_runs_matchdg_erm):

            self.max_epoch = -1
            self.max_val_acc = 0.0
            for epoch in range(self.args.epochs):

                if epoch == 0:
                    data_match_tensor, label_match_tensor = self.init_erm_phase(
                    )
                elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
                    data_match_tensor, label_match_tensor = self.get_match_function(
                        epoch)

                penalty_erm = 0
                penalty_erm_extra = 0
                penalty_ws = 0
                train_acc = 0.0
                train_size = 0

                perm = torch.randperm(data_match_tensor.size(0))
                data_match_tensor_split = torch.split(data_match_tensor[perm],
                                                      self.args.batch_size,
                                                      dim=0)
                label_match_tensor_split = torch.split(
                    label_match_tensor[perm], self.args.batch_size, dim=0)
                print('Split Matched Data: ', len(data_match_tensor_split),
                      data_match_tensor_split[0].shape,
                      len(label_match_tensor_split))

                #Batch iteration over single epoch
                for batch_idx, (x_e, y_e, d_e,
                                idx_e) in enumerate(self.train_dataset):
                    #         print('Batch Idx: ', batch_idx)

                    self.opt.zero_grad()
                    loss_e = torch.tensor(0.0).to(self.cuda)

                    x_e = x_e.to(self.cuda)
                    y_e = torch.argmax(y_e, dim=1).to(self.cuda)
                    d_e = torch.argmax(d_e, dim=1).numpy()

                    #Forward Pass
                    out = self.phi(x_e)
                    erm_loss_extra = F.cross_entropy(out,
                                                     y_e.long()).to(self.cuda)
                    penalty_erm_extra += float(erm_loss_extra)

                    wasserstein_loss = torch.tensor(0.0).to(self.cuda)
                    erm_loss = torch.tensor(0.0).to(self.cuda)
                    if epoch > self.args.penalty_s:
                        # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
                        total_batch_size = len(data_match_tensor_split)
                        if batch_idx >= total_batch_size:
                            break
                        curr_batch_size = data_match_tensor_split[
                            batch_idx].shape[0]

                        #             data_match= data_match_tensor[idx].to(self.cuda)
                        data_match = data_match_tensor_split[batch_idx].to(
                            self.cuda)
                        data_match = data_match.view(
                            data_match.shape[0] * data_match.shape[1],
                            data_match.shape[2], data_match.shape[3],
                            data_match.shape[4])
                        feat_match = self.phi(data_match)

                        #             label_match= label_match_tensor[idx].to(self.cuda)
                        label_match = label_match_tensor_split[batch_idx].to(
                            self.cuda)
                        label_match = label_match.view(label_match.shape[0] *
                                                       label_match.shape[1])

                        erm_loss += F.cross_entropy(feat_match,
                                                    label_match.long()).to(
                                                        self.cuda)
                        penalty_erm += float(erm_loss)

                        train_acc += torch.sum(
                            torch.argmax(feat_match, dim=1) ==
                            label_match).item()
                        train_size += label_match.shape[0]

                        # Creating tensor of shape ( domain size, total domains, feat size )
                        if len(feat_match.shape) == 4:
                            feat_match = feat_match.view(
                                curr_batch_size, len(self.train_domains),
                                feat_match.shape[1] * feat_match.shape[2] *
                                feat_match.shape[3])
                        else:
                            feat_match = feat_match.view(
                                curr_batch_size, len(self.train_domains),
                                feat_match.shape[1])

                        label_match = label_match.view(curr_batch_size,
                                                       len(self.train_domains))

                        #             print(feat_match.shape)
                        data_match = data_match.view(curr_batch_size,
                                                     len(self.train_domains),
                                                     data_match.shape[1],
                                                     data_match.shape[2],
                                                     data_match.shape[3])

                        #Positive Match Loss
                        pos_match_counter = 0
                        for d_i in range(feat_match.shape[1]):
                            #                 if d_i != base_domain_idx:
                            #                     continue
                            for d_j in range(feat_match.shape[1]):
                                if d_j > d_i:
                                    if self.args.pos_metric == 'l2':
                                        wasserstein_loss += torch.sum(
                                            torch.sum(
                                                (feat_match[:, d_i, :] -
                                                 feat_match[:, d_j, :])**2,
                                                dim=1))
                                    elif self.args.pos_metric == 'l1':
                                        wasserstein_loss += torch.sum(
                                            torch.sum(torch.abs(
                                                feat_match[:, d_i, :] -
                                                feat_match[:, d_j, :]),
                                                      dim=1))
                                    elif self.args.pos_metric == 'cos':
                                        wasserstein_loss += torch.sum(
                                            cosine_similarity(
                                                feat_match[:, d_i, :],
                                                feat_match[:, d_j, :]))

                                    pos_match_counter += feat_match.shape[0]

                        wasserstein_loss = wasserstein_loss / pos_match_counter
                        penalty_ws += float(wasserstein_loss)

                        loss_e += (self.args.penalty_ws *
                                   (epoch - self.args.penalty_s) /
                                   (self.args.epochs -
                                    self.args.penalty_s)) * wasserstein_loss
                        loss_e += erm_loss
                        loss_e += erm_loss_extra

                    loss_e.backward(retain_graph=False)
                    self.opt.step()

                    del erm_loss_extra
                    del erm_loss
                    del wasserstein_loss
                    del loss_e
                    torch.cuda.empty_cache()

                print('Train Loss Basic : ', penalty_erm_extra, penalty_erm,
                      penalty_ws)
                print('Train Acc Env : ', 100 * train_acc / train_size)
                print('Done Training for epoch: ', epoch)

                #Train Dataset Accuracy
                self.train_acc.append(100 * train_acc / train_size)

                #Val Dataset Accuracy
                self.val_acc.append(self.get_test_accuracy('val'))

                #Test Dataset Accuracy
                self.final_acc.append(self.get_test_accuracy('test'))

                #Save the model if current best epoch as per validation loss
                if self.val_acc[-1] > self.max_val_acc:
                    self.max_val_acc = self.val_acc[-1]
                    self.max_epoch = epoch
                    self.save_model_erm_phase(run_erm)

                print('Current Best Epoch: ', self.max_epoch,
                      ' with Test Accuracy: ', self.final_acc[self.max_epoch])
class TestBlendVisualizeCallback(TestVisualizeCallback):

    callback_cls = BlendVisualizeCallback

    @pytest.fixture(params=[
        pytest.param(True, id="float"),
        pytest.param(False, id="long")
    ])
    def data(self, data_shape, request):
        B, C, H, W = data_shape
        img = create_image(B, C, H, W)
        if request.param:
            img = img.float()
        else:
            img = img.long()
        return img.clone(), img.clone()

    @pytest.fixture
    def expected_calls(self, data, model, mode, callback):
        if not hasattr(model, callback.attr_name):
            return []

        if callback.split_channels == (2, 1) or callback.split_channels == ([
                2, 1
        ], None):
            pytest.skip("incompatible test")

        data1, data2 = data
        data1 = prepare_image(data1)
        data2 = prepare_image(data2)
        B, _, _, _ = data1.shape
        img1 = [data1]
        img2 = [data2]
        name1 = [
            f"{mode}/{callback.name}",
        ]
        name2 = [
            f"{mode}/{callback.name}",
        ]
        img = [img1, img2]
        name = [name1, name2]

        # channel splitting

        for pos in range(2):
            if callback.split_channels[pos]:
                img[pos] = []
                name[pos] = []
                img_new, name_new = [], []
                splits = torch.split(data[pos],
                                     callback.split_channels[pos],
                                     dim=-3)
                for i, s in enumerate(splits):
                    if isinstance(callback.name, str):
                        n = f"{mode}/{callback.name}_{i}"
                    else:
                        n = f"{mode}/{callback.name[i]}"
                    name_new.append(n)
                    img_new.append(s)
                img[pos] = img_new
                name[pos] = name_new

        if len(img[0]) != len(img[1]):
            if len(img[0]) == 1:
                img[0] = img[0] * len(img[1])
            elif len(img[1]) == 1:
                img[1] = img[1] * len(img[0])
            else:
                raise RuntimeError()

        for pos in range(2):
            if callback.max_resolution:
                resize_mode = callback.resize_mode
                target = callback.max_resolution
                H_max, W_max = target
                needs_resize = [
                    i.shape[-2] > H_max or i.shape[-1] > W_max
                    for i in img[pos]
                ]
                img[pos] = [
                    F.interpolate(i, target, mode=resize_mode) if resize else i
                    for i, resize in zip(img[pos], needs_resize)
                ]

        for pos in range(2):
            if (colormap := callback.colormap[pos]):
                if isinstance(colormap, str):
                    colormap = [colormap] * len(img[pos])
                img[pos] = [
                    apply_colormap(i, cmap)[..., :3, :, :]
                    if cmap is not None else i
                    for cmap, i in zip(colormap, img[pos])
                ]

        name = name[0]
        final_img = []
        for pos, (d, s) in enumerate(zip(img[0], img[1])):
            _, C1, _, _ = d.shape
            _, C2, _, _ = s.shape

            if C1 != C2:
                if C1 == 1:
                    d = d.repeat(1, C2, 1, 1)
                elif C2 == 1:
                    s = s.repeat(1, C1, 1, 1)
                else:
                    raise ValueError(
                        f"could not match shapes {d.shape}, {s.shape}")

            s = clamp_normalize(s, inplace=True)
            d = clamp_normalize(d, inplace=True)
            final_img.append(
                alpha_blend(d, s, callback.alpha[1], callback.alpha[0])[0])
        img = final_img

        if callback.as_uint8:
            img = [
                to_8bit(i, same_on_batch=not callback.per_img_norm)
                for i in img
            ]

        if callback.split_batches:
            new_img, new_name = [], []
            for i, n in zip(img, name):
                split_i = torch.split(i, 1, dim=0)
                split_n = [f"{n}/{b}" for b in range(B)]
                new_img += split_i
                new_name += split_n
            name, img = new_name, new_img

        step = [
            model.current_epoch
            if callback.epoch_counter else model.global_step
        ] * len(name)
        expected = [(n, i, s) for n, i, s in zip(name, img, step)]
        return expected
        image_path = f'input/{i}.png'
        image = Image.open(image_path).convert("RGB")
        image = T.Resize(in_sz)(image)
        image = image_to_tensor(image).to(device=device)
        images.append(image)
    images = torch.stack(images)
    images = images.unsqueeze(0)
    print(f'i: {images.shape}')
    print(f'c: {cam_poses.shape}')

    net.encode(
        images, cam_poses, focal,
    )
    print("Rendering", args.num_views * H * W, "rays")
    all_rgb_fine = []
    for rays in tqdm.tqdm(torch.split(render_rays.view(-1, 8), 80000, dim=0)):
        rgb, _depth = render_par(rays[None])
        all_rgb_fine.append(rgb[0])
    _depth = None
    rgb_fine = torch.cat(all_rgb_fine)
    frames = (rgb_fine.view(args.num_views, H, W, 3).cpu().numpy() * 255).astype(
        np.uint8
    )

        
    #im_name = os.path.basename(os.path.splitext(image_path)[0])
    im_name = output_name

    frames_dir_name = os.path.join(args.output, im_name + "_frames")
    os.makedirs(frames_dir_name, exist_ok=True)
class TestKeypointVisualizeCallback(TestVisualizeCallback):

    callback_cls = KeypointVisualizeCallback

    @pytest.fixture(params=[
        pytest.param(True, id="float"),
        pytest.param(False, id="long")
    ])
    def data(self, data_shape, request):
        B, C, H, W = data_shape
        N = 3
        img = create_image(B, C, H, W)
        bbox = self.create_bbox(B, N)
        bbox = bbox.float() if request.param else bbox.long()
        cls = self.create_classes(B, N)
        score = self.create_classes(B, N)
        target = {"coords": bbox, "class": cls, "score": score}
        return img, target

    def create_bbox(self, B, N):
        torch.random.manual_seed(42)
        bbox = torch.empty(B, N, 4).fill_(-1).float()
        return bbox

    def create_classes(self, B, N):
        torch.random.manual_seed(42)
        return torch.empty(B, N, 1).fill_(-1).float()

    def create_scores(self, B, N):
        torch.random.manual_seed(42)
        return torch.empty(B, N, 1).fill_(-1)

    @pytest.fixture
    def expected_calls(self, data, model, mode, callback):
        if not hasattr(model, callback.attr_name):
            return []

        data, target = data
        data = prepare_image(data)
        B, _, _, _ = data.shape
        img = [data]
        name = [
            f"{mode}/{callback.name}",
        ]

        # channel splitting
        if callback.split_channels:
            img, name = [], []
            splits = torch.split(data, callback.split_channels, dim=-3)
            for i, s in enumerate(splits):
                n = f"{mode}/{callback.name[i]}"
                name.append(n)
                img.append(s)

        if callback.max_resolution:
            resize_mode = callback.resize_mode
            target = callback.max_resolution
            H_max, W_max = target
            needs_resize = [
                i.shape[-2] > H_max or i.shape[-1] > W_max for i in img
            ]
            img = [
                F.interpolate(i, target, mode=resize_mode) if resize else i
                for i, resize in zip(img, needs_resize)
            ]

        if (colormap := callback.colormap):
            if isinstance(colormap, str):
                colormap = [colormap] * len(img)
            img = [
                apply_colormap(i, cmap)[..., :3, :, :]
                if cmap is not None else i for cmap, i in zip(colormap, img)
            ]

        img = [i.repeat(1, 3, 1, 1) if i.shape[-3] == 1 else i for i in img]

        if callback.as_uint8:
            img = [
                to_8bit(i, same_on_batch=not callback.per_img_norm)
                for i in img
            ]

        if callback.split_batches:
            new_img, new_name = [], []
            for i, n in zip(img, name):
                split_i = torch.split(i, 1, dim=0)
                split_n = [f"{n}/{b}" for b in range(B)]
                new_img += split_i
                new_name += split_n
            name, img = new_name, new_img

        step = [
            model.current_epoch
            if callback.epoch_counter else model.global_step
        ] * len(name)
        expected = [(n, i, s) for n, i, s in zip(name, img, step)]
        return expected
Example #45
0
def decode_structure(model, root_code):
    """
    Decode a root code into a tree structure of boxes
    """
    decode = model.sampleDecoder(root_code)
    syms = [torch.ones(8).mul(10).cuda()]
    stack = [decode]
    boxes = []
    while len(stack) > 0:
        f = stack.pop()
        label_prob = model.nodeClassifier(f)
        _, label = torch.max(label_prob, 1)
        label = label.data
        if label[0] == 1:  # ADJ
            left, right = model.adjDecoder(f)
            stack.append(left)
            stack.append(right)
            s = syms.pop()
            syms.append(s)
            syms.append(s)
        if label[0] == 2:  # SYM
            left, s = model.symDecoder(f)
            s = s.squeeze(0)
            stack.append(left)
            syms.pop()
            syms.append(s.data)
        if label[0] == 0:  # BOX
            reBox = model.boxDecoder(f)
            reBoxes = [reBox]
            s = syms.pop()
            l1 = abs(s[0] + 1)
            l2 = abs(s[0])
            l3 = abs(s[0] - 1)

            if l1 < 0.15:
                sList = torch.split(s, 1, 0)
                bList = torch.split(reBox.data.squeeze(0), 1, 0)
                f1 = torch.cat([sList[1], sList[2], sList[3]])
                f1 = f1 / torch.norm(f1)
                f2 = torch.cat([sList[4], sList[5], sList[6]])
                folds = round(1 / s[7])
                for i in range(folds - 1):
                    rotvector = torch.cat(
                        [f1, sList[7].mul(2 * 3.1415).mul(i + 1)])
                    rotm = vrrotvec2mat(rotvector)
                    center = torch.cat([bList[0], bList[1], bList[2]])
                    dir0 = torch.cat([bList[3], bList[4], bList[5]])
                    dir1 = torch.cat([bList[6], bList[7], bList[8]])
                    dir2 = torch.cat([bList[9], bList[10], bList[11]])
                    newcenter = rotm.matmul(center.add(-f2)).add(f2)
                    newdir1 = rotm.matmul(dir1)
                    newdir2 = rotm.matmul(dir2)
                    newbox = torch.cat([newcenter, dir0, newdir1, newdir2])
                    reBoxes.append(Variable(newbox.unsqueeze(0)))

            if l2 < 0.15:
                sList = torch.split(s, 1, 0)
                bList = torch.split(reBox.data.squeeze(0), 1, 0)
                trans = torch.cat([sList[1], sList[2], sList[3]])
                trans_end = torch.cat([sList[4], sList[5], sList[6]])
                center = torch.cat([bList[0], bList[1], bList[2]])
                trans_length = math.sqrt(torch.sum(trans**2))
                trans_total = math.sqrt(torch.sum(trans_end.add(-center)**2))
                folds = round(trans_total / trans_length)
                for i in range(folds):
                    center = torch.cat([bList[0], bList[1], bList[2]])
                    dir0 = torch.cat([bList[3], bList[4], bList[5]])
                    dir1 = torch.cat([bList[6], bList[7], bList[8]])
                    dir2 = torch.cat([bList[9], bList[10], bList[11]])
                    newcenter = center.add(trans.mul(i + 1))
                    newbox = torch.cat([newcenter, dir0, dir1, dir2])
                    reBoxes.append(Variable(newbox.unsqueeze(0)))

            if l3 < 0.15:
                sList = torch.split(s, 1, 0)
                bList = torch.split(reBox.data.squeeze(0), 1, 0)
                ref_normal = torch.cat([sList[1], sList[2], sList[3]])
                ref_normal = ref_normal / torch.norm(ref_normal)
                ref_point = torch.cat([sList[4], sList[5], sList[6]])
                center = torch.cat([bList[0], bList[1], bList[2]])
                dir0 = torch.cat([bList[3], bList[4], bList[5]])
                dir1 = torch.cat([bList[6], bList[7], bList[8]])
                dir2 = torch.cat([bList[9], bList[10], bList[11]])
                if ref_normal.matmul(ref_point.add(-center)) < 0:
                    ref_normal = -ref_normal
                newcenter = ref_normal.mul(
                    2 * abs(torch.sum(
                        ref_point.add(-center).mul(ref_normal)))).add(center)
                if ref_normal.matmul(dir1) < 0:
                    ref_normal = -ref_normal
                dir1 = dir1.add(ref_normal.mul(-2 * ref_normal.matmul(dir1)))
                if ref_normal.matmul(dir2) < 0:
                    ref_normal = -ref_normal
                dir2 = dir2.add(ref_normal.mul(-2 * ref_normal.matmul(dir2)))
                newbox = torch.cat([newcenter, dir0, dir1, dir2])
                reBoxes.append(Variable(newbox.unsqueeze(0)))

            boxes.extend(reBoxes)

    return boxes
    def forward(
        self,  # pylint: disable=arguments-differ
        inputs: torch.Tensor,
        mask: torch.LongTensor = None,
    ) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, hidden_dim = inputs.size()
        if mask is None:
            mask = Variable(inputs.data.new(batch_size, timesteps).fill_(1.0))

        # Treat the queries, keys and values each as a ``num_heads`` size batch.
        # shape (num_heads, batch_size * timesteps, hidden_dim)
        inputs_per_head = inputs.repeat(num_heads, 1,
                                        1).view(num_heads,
                                                batch_size * timesteps,
                                                hidden_dim)
        # Do the projections for all the heads at once.
        # Then reshape the result as though it had a
        # (num_heads * batch_size) sized batch.
        queries_per_head = torch.bmm(inputs_per_head, self._query_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        queries_per_head = queries_per_head.view(num_heads * batch_size,
                                                 timesteps,
                                                 self._attention_dim)

        keys_per_head = torch.bmm(inputs_per_head, self._key_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        keys_per_head = keys_per_head.view(num_heads * batch_size, timesteps,
                                           self._attention_dim)

        values_per_head = torch.bmm(inputs_per_head, self._value_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        values_per_head = values_per_head.view(num_heads * batch_size,
                                               timesteps, self._values_dim)

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = (
            torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) /
            self._scale)

        # Masking should go here
        causality_mask = subsequent_mask(timesteps).cuda()
        masked_scaled_similarities = scaled_similarities.masked_fill(
            causality_mask == 0, -1e9)

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = masked_softmax(masked_scaled_similarities,
                                   mask.repeat(num_heads, 1))
        attention = self._attention_dropout(attention)
        # This is doing the following batch-wise matrix multiplication:
        # (num_heads * batch_size, timesteps, timesteps) *
        # (num_heads * batch_size, timesteps, values_dim)
        # which is equivalent to a weighted sum of the values with respect to
        # the attention distributions for each element in the num_heads * batch_size
        # dimension.
        # shape (num_heads * batch_size, timesteps, values_dim)
        outputs = torch.bmm(attention, values_per_head)

        # Reshape back to original shape (batch_size, timesteps, num_heads * values_dim)
        # Note that we _cannot_ use a reshape here, because this tensor was created
        # with num_heads being the first dimension, so reshaping naively would not
        # throw an error, but give an incorrect result.
        outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=-1)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
Example #47
0
    def losses(self):
        """
        Return the losses from a set of RPN predictions and their associated ground-truth.

        Returns:
            dict[loss name -> loss value]: A dict mapping from loss name to loss value.
                Loss names are: `loss_rpn_cls` for objectness classification and
                `loss_rpn_loc` for proposal localization.
        """

        def resample(label):
            """
            Randomly sample a subset of positive and negative examples by overwritting
            the label vector to the ignore value (-1) for all elements that are not
            included in the sample.
            """
            pos_idx, neg_idx = subsample_labels(
                label, self.batch_size_per_image, self.positive_fraction, 0
            )
            # Fill with the ignore label (-1), then set positive and negative labels
            label.fill_(-1)
            label.scatter_(0, pos_idx, 1)
            label.scatter_(0, neg_idx, 0)
            return label

        gt_objectness_logits, gt_anchor_deltas = self._get_ground_truth()
        """
        gt_objectness_logits: list of N tensors. Tensor i is a vector whose length is the
            total number of anchors in image i (i.e., len(anchors[i]))
        gt_anchor_deltas: list of N tensors. Tensor i has shape (len(anchors[i]), B),
            where B is the box dimension
        """
        # Collect all objectness labels and delta targets over feature maps and images
        # The final ordering is L, N, H, W, A from slowest to fastest axis.
        num_anchors_per_map = [np.prod(x.shape[1:]) for x in self.pred_objectness_logits]
        num_anchors_per_image = sum(num_anchors_per_map)

        # Stack to: (N, num_anchors_per_image)
        gt_objectness_logits = torch.stack(
            [resample(label) for label in gt_objectness_logits], dim=0
        )

        # Log the number of positive/negative anchors per-image that's used in training
        num_pos_anchors = (gt_objectness_logits == 1).sum().item()
        num_neg_anchors = (gt_objectness_logits == 0).sum().item()
        storage = get_event_storage()
        storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / self.num_images)
        storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / self.num_images)

        assert gt_objectness_logits.shape[1] == num_anchors_per_image
        # Split to tuple of L tensors, each with shape (N, num_anchors_per_map)
        gt_objectness_logits = torch.split(gt_objectness_logits, num_anchors_per_map, dim=1)
        # Concat from all feature maps
        gt_objectness_logits = cat([x.flatten() for x in gt_objectness_logits], dim=0)

        # Stack to: (N, num_anchors_per_image, B)
        gt_anchor_deltas = torch.stack(gt_anchor_deltas, dim=0)
        assert gt_anchor_deltas.shape[1] == num_anchors_per_image
        B = gt_anchor_deltas.shape[2]  # box dimension (4 or 5)

        # Split to tuple of L tensors, each with shape (N, num_anchors_per_image)
        gt_anchor_deltas = torch.split(gt_anchor_deltas, num_anchors_per_map, dim=1)
        # Concat from all feature maps
        gt_anchor_deltas = cat([x.reshape(-1, B) for x in gt_anchor_deltas], dim=0)

        # Collect all objectness logits and delta predictions over feature maps
        # and images to arrive at the same shape as the labels and targets
        # The final ordering is L, N, H, W, A from slowest to fastest axis.
        pred_objectness_logits = cat(
            [
                # Reshape: (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N*Hi*Wi*A, )
                x.permute(0, 2, 3, 1).flatten()
                for x in self.pred_objectness_logits
            ],
            dim=0,
        )
        pred_anchor_deltas = cat(
            [
                # Reshape: (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B)
                #          -> (N*Hi*Wi*A, B)
                x.view(x.shape[0], -1, B, x.shape[-2], x.shape[-1])
                .permute(0, 3, 4, 1, 2)
                .reshape(-1, B)
                for x in self.pred_anchor_deltas
            ],
            dim=0,
        )

        objectness_loss, localization_loss = rpn_losses(
            gt_objectness_logits,
            gt_anchor_deltas,
            pred_objectness_logits,
            pred_anchor_deltas,
            self.smooth_l1_beta,
        )
        normalizer = 1.0 / (self.batch_size_per_image * self.num_images)
        loss_cls = objectness_loss * normalizer  # cls: classification loss
        loss_loc = localization_loss * normalizer  # loc: localization loss
        losses = {"loss_rpn_cls": loss_cls, "loss_rpn_loc": loss_loc}

        return losses
Example #48
0
def split(input, sizes_or_sections, dim):
    return th.split(input, sizes_or_sections, dim)
Example #49
0
    def train(x, y, iteration, epoch, batch_size, target_map = None, r_mixup = 0.0):
        G.optim.zero_grad()
        D.optim.zero_grad()

        if config["unet_mixup"]:
            real_target = torch.tensor([1.0]).cuda()
            fake_target = torch.tensor([0.0]).cuda()

        if config["unet_mixup"] and not config["full_batch_mixup"]:
            use_mixup_in_this_round = True
        elif config["unet_mixup"] and config["full_batch_mixup"]:
            use_mixup_in_this_round = torch.rand(1).detach().item()<r_mixup
        else:
            use_mixup_in_this_round = False

        out = {}

        skip_normal_real_fake_loss = (use_mixup_in_this_round and config["full_batch_mixup"] )

        n_d_accu = config['num_D_accumulations']

        split_size = int(x.size(0)/n_d_accu)

        x = torch.split(x, split_size)
        y = torch.split(y, split_size)

        d_real_target = torch.tensor([1.0]).cuda()
        d_fake_target = torch.tensor([0.0]).cuda()

        discriminator_loss = functools.partial(BCEloss, d_real_target=d_real_target, d_fake_target=d_fake_target)
        mix_fake_target = torch.tensor([1.0]).cuda()
        fake_loss = functools.partial(BCEfakeloss, target = mix_fake_target)

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            counter = 0
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()

            for accumulation_index in range(n_d_accu):

                z_.sample_()
                y_.sample_()

                if use_mixup_in_this_round:

                    if (not config["full_batch_mixup"]) or (config["full_batch_mixup"] and (config["consistency_loss_and_augmentation"] or config["consistency_loss"]) ):

                        D_fake, D_real , D_mixed, G_z, mixed,  D_middle_fake, D_middle_real, D_middle_mixed, target_map   = GD(z_[:batch_size], y_[:batch_size],
                                                            x[counter], y[counter], train_G=False,
                                                            split_D=config['split_D'], mixup = True, target_map = target_map) # mixup can be true because weight is set to 0 when no mixup is used
                    else:
                        D_mixed, G_z, mixed, D_middle_mixed, target_map   = GD(z_[:batch_size], y_[:batch_size],
                                                            x[counter], y[counter], train_G=False, return_G_z = True,
                                                            split_D=config['split_D'], mixup = True, mixup_only = True, target_map = target_map)

                    if config["slow_mixup"] and not config["full_batch_mixup"]:
                        mixup_coeff = min(1.0, epoch/config["warmup_epochs"] )#use without full batch mixup
                    else:
                        mixup_coeff = 1.0

                    if config["display_mixed_batch"]:
                        # This can help for debugging
                        plt.figure()
                        m = torchvision.utils.make_grid(mixed,nrow=5,padding=2,normalize = True)
                        m = m.permute(1,2,0)
                        m = m.cpu().numpy()
                        plt.imshow(m)
                        plt.figure()
                        plt.figure()
                        m = torchvision.utils.make_grid(G_z,nrow=5,padding=2,normalize = True)
                        m = m.permute(1,2,0)
                        m = m.cpu().numpy()
                        plt.imshow(m)
                        plt.figure()
                        plt.figure()
                        m = torchvision.utils.make_grid(x[counter],nrow=5,padding=2,normalize = True)
                        m = m.permute(1,2,0)
                        m = m.cpu().numpy()
                        plt.imshow(m)
                        plt.figure()
                        m = torchvision.utils.make_grid(target_map,nrow=5,padding=2)
                        m = m.permute(1,2,0)
                        m = m.cpu().numpy()
                        plt.imshow(m)
                        plt.title("mix")
                        plt.show()
                        plt.figure()

                else:
                    D_fake, D_real , G_z, D_middle_fake, D_middle_real   = GD(z_[:batch_size], y_[:batch_size],
                                                        x[counter], y[counter], train_G=False,
                                                        split_D=config['split_D'])



                if not skip_normal_real_fake_loss:
                    D_loss_real_2d, D_loss_fake_2d = discriminator_loss(D_fake.view(-1), D_real.view(-1))
                    D_loss_real_2d_item = D_loss_real_2d.detach().item()
                    D_loss_fake_2d_item = D_loss_fake_2d.detach().item()

                if use_mixup_in_this_round  and (config["consistency_loss"] or config["consistency_loss_and_augmentation"]):
                    mix =  D_real*target_map + D_fake*(1-target_map)

                if use_mixup_in_this_round:

                    D_mixed_flattened = D_mixed.view(-1)
                    target_map_flattend = target_map.view(-1)

                    mix_list = []
                    for i in range(D_mixed.size(0)):
                        # MIXUP LOSS 2D
                        mix2d_i= F.binary_cross_entropy_with_logits(D_mixed[i].view(-1),target_map[i].view(-1) )
                        mix_list.append(mix2d_i)

                    D_loss_mixed_2d = torch.stack(mix_list)
                    #-> D_loss_mixed_2d.mean() is taken later

                    D_loss_mixed_2d_item = D_loss_mixed_2d.mean().detach().item()
                    #D_loss_mixed_2d = D_loss_mixed_2d.view(D_mixed.size()).mean([2,3])

                if not skip_normal_real_fake_loss:
                    D_loss_real_middle, D_loss_fake_middle = discriminator_loss(D_middle_fake, D_middle_real)

                    D_loss_real_middle_item = D_loss_real_middle.detach().item()
                    D_loss_fake_middle_item = D_loss_fake_middle.detach().item()

                if use_mixup_in_this_round and not config["consistency_loss"]:
                    # consistency loss is only concerned with segmenter output

                    #target for mixed encoder loss is fake
                    mix_bce = F.binary_cross_entropy_with_logits(D_middle_mixed, fake_target.expand_as(D_middle_mixed), reduction="none")

                    mixed_middle_loss = mixup_coeff*mix_bce
                    mixed_middle_loss_item = mixed_middle_loss.mean().detach().item()

                if skip_normal_real_fake_loss:
                    D_loss_real = torch.tensor([0.0]).cuda()
                    D_loss_fake = torch.tensor([0.0]).cuda()
                else:
                    D_loss_real = D_loss_real_2d + D_loss_real_middle
                    D_loss_fake = D_loss_fake_2d + D_loss_fake_middle

                D_loss_real_item = D_loss_real.detach().item()
                D_loss_fake_item = D_loss_fake.detach().item()

                D_loss = 0.5*D_loss_real + 0.5*D_loss_fake

                if use_mixup_in_this_round:
                    if config["consistency_loss"] or config["consistency_loss_and_augmentation"]:
                        consistency_loss = mixup_coeff*1.0*F.mse_loss(D_mixed, mix )
                        consistency_loss_item = consistency_loss.float().detach().item()

                    if not config["consistency_loss"]:
                        # GAN loss from cutmix augmentation (=/= consitency loss)
                        mix_loss = D_loss_mixed_2d + mixed_middle_loss
                        mix_loss = mix_loss.mean()
                    else:
                        mix_loss = 0.0

                    if config["consistency_loss"]:
                        mix_loss = consistency_loss
                    elif config["consistency_loss_and_augmentation"]:
                        mix_loss = mix_loss + consistency_loss

                    D_loss = D_loss + mix_loss

                D_loss = D_loss / float(config['num_D_accumulations'])

                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            if iteration%2 == 0 and EG:
                print('D extrapolation')
                D.optim.extrapolation()
            else:
                print('D step')
                D.optim.step()
            del D_loss

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        ######################################
        # G-step
        ######################################
        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()
        counter = 0

        z_.sample_()
        y_.sample_()

        z__ = torch.split(z_, split_size) #batch_size)
        y__ = torch.split(y_, split_size) #batch_size)

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config['num_G_accumulations']):
            G_fake, G_fake_middle = GD(z__[counter], y__[counter], train_G=True, split_D=config['split_D'], reference_x = x[counter] )

            G_loss_fake_2d = fake_loss(G_fake)
            G_loss_fake_middle = fake_loss(G_fake_middle)
            G_loss = 0.5*G_loss_fake_middle + 0.5*G_loss_fake_2d
            G_loss = G_loss / float(config['num_G_accumulations'])

            G_loss_fake_middle_item = G_loss_fake_middle.detach().item()
            G_loss_fake_2d_item = G_loss_fake_2d.detach().item()
            G_loss_item = G_loss.detach().item()

            G_loss.backward()
            counter += 1

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G, config['G_ortho'],
                                    blacklist=[param for param in G.shared.parameters()])

        print(iteration)
        if iteration%2 == 0 and EG:
            print('G extrapolation')
            G.optim.extrapolation()
        else:
            print('G step')
            G.optim.step()
        del G_loss

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])


        # save intermediate losses
        if use_mixup_in_this_round and (config["consistency_loss"] or config["consistency_loss_and_augmentation"]) and config["num_D_steps"]>0:
            out["consistency"] = float(consistency_loss_item)

        out['G_loss'] = float(G_loss_item)
        if  not (config["full_batch_mixup"] and use_mixup_in_this_round) and config["num_D_steps"]>0:
            out['D_loss_real'] = float(D_loss_real_item)
            out['D_loss_fake'] = float(D_loss_fake_item)

        if use_mixup_in_this_round and not config["consistency_loss"] and config["num_D_steps"]>0:
            out["mixed_middle_loss"] = float(mixed_middle_loss_item)
            out["D_loss_mixed_2d"] = float(D_loss_mixed_2d_item)

        if  not (config["full_batch_mixup"] and use_mixup_in_this_round):
            if config["num_D_steps"]>0:
                out["D_loss_real_middle"] = float(D_loss_real_middle_item)
                out["D_loss_fake_middle"] = float(D_loss_fake_middle_item)
                out["D_loss_real_2d"] = float(D_loss_real_2d_item)
                out["D_loss_fake_2d"] = float(D_loss_fake_2d_item)
            out["G_loss_fake_middle"] = float(G_loss_fake_middle_item)
            out["G_loss_fake_2d"] = float(G_loss_fake_2d_item)

        return out
Example #50
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            yesno: torch.IntTensor = None,
            question_tf: torch.FloatTensor = None,
            passage_tf: torch.FloatTensor = None,
            q_em_cased: torch.IntTensor = None,
            p_em_cased: torch.IntTensor = None,
            q_em_uncased: torch.IntTensor = None,
            p_em_uncased: torch.IntTensor = None,
            q_in_lemma: torch.IntTensor = None,
            p_in_lemma: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        x1_c_emb = self._dropout(self._char_field_embedder(passage))
        x2_c_emb = self._dropout(self._char_field_embedder(question))

        # embedded_question = torch.cat([self._dropout(self._text_field_embedder(question)),
        #                                self._features_embedder(q_em_cased),
        #                                self._features_embedder(q_em_uncased),
        #                                self._features_embedder(q_in_lemma),
        #                                question_tf.unsqueeze(2)], dim=2)
        # embedded_passage = torch.cat([self._dropout(self._text_field_embedder(passage)),
        #                               self._features_embedder(p_em_cased),
        #                               self._features_embedder(p_em_uncased),
        #                               self._features_embedder(p_in_lemma),
        #                               passage_tf.unsqueeze(2)], dim=2)
        token_emb_q = self._dropout(self._text_field_embedder(question))
        token_emb_c = self._dropout(self._text_field_embedder(passage))
        token_emb_question, q_ner_and_pos = torch.split(token_emb_q, [300, 40],
                                                        dim=2)
        token_emb_passage, p_ner_and_pos = torch.split(token_emb_c, [300, 40],
                                                       dim=2)
        question_word_features = torch.cat([
            q_ner_and_pos,
            self._features_embedder(q_em_cased),
            self._features_embedder(q_em_uncased),
            self._features_embedder(q_in_lemma),
            question_tf.unsqueeze(2)
        ],
                                           dim=2)
        passage_word_features = torch.cat([
            p_ner_and_pos,
            self._features_embedder(p_em_cased),
            self._features_embedder(p_em_uncased),
            self._features_embedder(p_in_lemma),
            passage_tf.unsqueeze(2)
        ],
                                          dim=2)

        # embedded_question = self._highway_layer(embedded_q)
        # embedded_passage = self._highway_layer(embedded_q)

        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        char_features_c = self._char_rnn(
            x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1),
                              x1_c_emb.size(2), x1_c_emb.size(3))),
            passage_lstm_mask.unsqueeze(2).repeat(
                1, 1, x1_c_emb.size(2)).reshape(
                    (x1_c_emb.size(0) * x1_c_emb.size(1),
                     x1_c_emb.size(2)))).reshape(
                         (x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2),
                          -1))[:, :, -1, :]
        char_features_q = self._char_rnn(
            x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1),
                              x2_c_emb.size(2), x2_c_emb.size(3))),
            question_lstm_mask.unsqueeze(2).repeat(
                1, 1, x2_c_emb.size(2)).reshape(
                    (x2_c_emb.size(0) * x2_c_emb.size(1),
                     x2_c_emb.size(2)))).reshape(
                         (x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2),
                          -1))[:, :, -1, :]

        # token_emb_q, char_emb_q, question_word_features = torch.split(embedded_question, [300, 300, 56], dim=2)
        # token_emb_c, char_emb_c, passage_word_features = torch.split(embedded_passage, [300, 300, 56], dim=2)

        # char_features_q = self._char_rnn(char_emb_q, question_lstm_mask)
        # char_features_c = self._char_rnn(char_emb_c, passage_lstm_mask)

        emb_question = torch.cat(
            [token_emb_question, char_features_q, question_word_features],
            dim=2)
        emb_passage = torch.cat(
            [token_emb_passage, char_features_c, passage_word_features], dim=2)

        encoded_question = self._dropout(
            self._phrase_layer(emb_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(emb_passage, passage_lstm_mask))

        batch_size = encoded_question.size(0)
        passage_length = encoded_passage.size(1)

        encoding_dim = encoded_question.size(-1)

        # c_check = self._stacked_brnn(encoded_passage, passage_lstm_mask)
        # q = self._stacked_brnn(encoded_question, question_lstm_mask)
        c_check = encoded_passage
        q = encoded_question
        for i in range(self.hops):
            q_tilde = self.interactive_aligners[i].forward(
                c_check, q, question_mask)
            c_bar = self.interactive_SFUs[i].forward(
                c_check,
                torch.cat([q_tilde, c_check * q_tilde, c_check - q_tilde], 2))
            c_tilde = self.self_aligners[i].forward(c_bar, passage_mask)
            c_hat = self.self_SFUs[i].forward(
                c_bar, torch.cat([c_tilde, c_bar * c_tilde, c_bar - c_tilde],
                                 2))
            c_check = self.aggregate_rnns[i].forward(c_hat, passage_mask)

        # Predict
        start_scores, end_scores, yesno_scores = self.mem_ans_ptr.forward(
            c_check, q, passage_mask, question_mask)

        best_span, yesno_predict, loc = self.get_best_span(
            start_scores, end_scores, yesno_scores)

        output_dict = {
            "span_start_logits": start_scores,
            "span_end_logits": end_scores,
            "best_span": best_span
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(start_scores, span_start.squeeze(-1))
            self._span_start_accuracy(start_scores, span_start.squeeze(-1))
            loss += nll_loss(end_scores, span_end.squeeze(-1))
            self._span_end_accuracy(end_scores, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))

            gold_span_end_loc = []
            span_end = span_end.view(batch_size).squeeze().data.cpu().numpy()
            for i in range(batch_size):
                gold_span_end_loc.append(
                    max(span_end[i] + i * passage_length, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(_yesno, yesno.view(-1), ignore_index=-1)

            pred_span_end_loc = []
            for i in range(batch_size):
                pred_span_end_loc.append(max(loc[i], 0))
            predicted_end = span_start.new(pred_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(0,
                                                           predicted_end).view(
                                                               -1, 3)
            self._span_yesno_accuracy(_yesno, yesno.squeeze(-1))

            output_dict['loss'] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['yesno'] = yesno_predict
        return output_dict
Example #51
0
    def forward(self, x):

        batch_size = x.size(0)  # batch, window, n_val

        for i in range(self.target):
            embed = self.em(torch.tensor(i).cuda())
            embed = torch.unsqueeze(torch.unsqueeze(embed, 0), 1)  # 1, 1, size
            embed = embed.repeat(batch_size, window, 1)
            fea = x[:, :, 5 * i:5 * (i + 1)]
            # fea1=fea.unsqueeze(-1)
            # fea2=fea.unsqueeze(-2)
            # fea_cross=torch.matmul(fea1,fea2).reshape(batch_size,window,-1)
            # fea_map=[]
            # for j in range(self.em_size):
            #     fea_map.append(self.fea_cross_weight[j](fea_cross))
            # crossed_fea=torch.cat(fea_map,dim=2)
            # fc=crossed_fea
            fc = self.w(fea)
            fc += embed
            # fc = self.fc(torch.cat([embed, fea], dim=2))      # batch_size, window, size

            if i == 0:
                fc_output = fc
            else:
                fc_output = torch.cat([fc_output, fc], dim=2)
        index_fea = x[:, :, 30:]
        index_fea = self.index_fc(index_fea)

        features = torch.cat([fc_output, index_fea], dim=2)
        features_origin1 = features
        # lstnet_input = fc_output
        features = features.reshape(batch_size, window, -1, self.em_size)
        split_tensor1 = torch.stack(
            torch.split(features, self.em_size * [1], 3), 0)
        split_tensor2 = split_tensor1.permute(0, 1, 2, 4, 3)
        dot_result_m = torch.matmul(split_tensor1, split_tensor2)
        dot_result_m = dot_result_m.view(self.em_size, batch_size, window,
                                         7 * 7)
        crossed_feas = self.W1(dot_result_m)
        crossed_feas = crossed_feas.permute(1, 0, 2, 3)
        crossed_feas = crossed_feas.permute(0, 2, 1, 3)
        crossed_feas = crossed_feas.permute(0, 1, 3, 2)
        # crossed_feas=F.relu(crossed_feas)
        features_origin2 = crossed_feas
        # features
        lstnet_input = crossed_feas
        lstnet_input = lstnet_input.reshape(batch_size, window, -1)
        lstnet_input += features_origin1
        # features2 = features.reshape(batch_size,window,self.em_size,-1)

        # CNN
        # lstnet_input=features
        c = lstnet_input.view(-1, 1, self.P, self.m)  # batch, 1, window, n_val
        c = F.relu(self.conv1(c))
        c = self.dropout(c)
        c = torch.squeeze(c, 3)  # batch, hidCNN, window-kernel_size+1

        # RNN
        r = c.permute(2, 0, 1).contiguous()
        _, r = self.GRU1(r)
        r = self.dropout(torch.squeeze(r, 0))  # batch, hidRNN

        # skip-rnn
        if (self.skip > 0):
            s = c[:, :, int(-self.pt * self.skip):].contiguous()
            s = s.view(batch_size, self.hidC, self.pt, self.skip)
            s = s.permute(2, 0, 3, 1).contiguous()
            s = s.view(self.pt, batch_size * self.skip, self.hidC)
            _, s = self.GRUskip(s)
            s = s.view(batch_size, self.skip * self.hidS)
            s = self.dropout(s)
            r = torch.cat((r, s), 1)  # batch, skip*hidSkip + hidRNN

        res = self.linear1(r)  # batch, n_val

        # highway
        if (self.hw > 0):
            z = lstnet_input[:, -self.hw:, :]  # batch, hw, n_val
            z = z.permute(0, 2, 1).contiguous().view(-1, self.hw)
            z = self.highway(z)
            z = z.view(-1, self.m)
            res = res + z  # batch, n_val

        res = self.output(res)

        return res
Example #52
0
    def split(self, split_size, dim=0):
        r"""Splits this tensor into tensor chunks of :attr:`split_size` size.

        See :func:`torch.split`.
        """
        return torch.split(self, split_size, dim)
Example #53
0
    def _train_epoch(self, epoch):
        """
        Train the model for the given epoch
        """
        # change to the way we are loading data the correct form .. @ask sylvia
        train_data_with_time = None
        train_data, train_times = data.get_time_columns(train_data_with_time)
        self.train_rnn_inp = data.get_rnn_input(
            self.train_tokens, self.train_counts, train_times, self.hyperparameters['num_times'], len(self.vocab),
            len(self.train_tokens))

        self.model.train()
        acc_loss = 0
        acc_nll = 0
        acc_kl_theta_loss = 0
        acc_kl_eta_loss = 0
        acc_kl_alpha_loss = 0
        cnt = 0
        indices = torch.randperm(train_data.shape[0])
        indices = torch.split(indices, self.hyperparameters['batch_size'])
        optimizer = self.set_optimizer()
        for idx, ind in enumerate(indices):
            optimizer.zero_grad()
            self.zero_grad()
            data_batch, times_batch = data.get_batch(train_data, ind, self.device,
                                                     train_times)  # we can use pytorch data loader here
            ### I comment the following row just because I need to make the code compile :/
            # times_batch = get_indices(train_times, times_batch)
            sums = data_batch.sum(1).unsqueeze(1)
            times_batch = torch.from_numpy(times_batch)
            if self.hyperparameters['bow_norm']:
                normalized_data_batch = data_batch / sums
            else:
                normalized_data_batch = data_batch
            loss, nll, kl_alpha, kl_eta, kl_theta = self.model.forward(
                data_batch, normalized_data_batch, times_batch, self.train_rnn_inp, train_data.shape[0])
            loss.backward()
            if self.hyperparameters['clip'] > 0:
                torch.nn.utils.clip_grad_norm_(self.parameters(), self.hyperparameters['clip'])
            optimizer.step()
            acc_loss += torch.sum(loss).item()
            acc_nll += torch.sum(nll).item()
            acc_kl_theta_loss += torch.sum(kl_theta).item()
            acc_kl_eta_loss += torch.sum(kl_eta).item()
            acc_kl_alpha_loss += torch.sum(kl_alpha).item()
            cnt += 1
            if idx % self.hyperparameters['log_interval'] == 0 and idx > 0:
                cur_loss = round(acc_loss / cnt, 2)
                cur_nll = round(acc_nll / cnt, 2)
                cur_kl_theta = round(acc_kl_theta_loss / cnt, 2)
                cur_kl_eta = round(acc_kl_eta_loss / cnt, 2)
                cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2)
                lr = optimizer.param_groups[0]['lr']
                print(
                    'Epoch: {} .. batch: {}/{} .. LR: {} .. KL_theta: {} .. KL_eta: {} .. KL_alpha: {} .. Rec_loss: {} .. NELBO: {}'.format(
                        epoch, idx, len(indices), lr, cur_kl_theta, cur_kl_eta, cur_kl_alpha, cur_nll, cur_loss))
        cur_loss = round(acc_loss / cnt, 2)
        cur_nll = round(acc_nll / cnt, 2)
        cur_kl_theta = round(acc_kl_theta_loss / cnt, 2)
        cur_kl_eta = round(acc_kl_eta_loss / cnt, 2)
        cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2)
        lr = optimizer.param_groups[0]['lr']
        print('*' * 100)
        print(
            'Epoch----->{} .. LR: {} .. KL_theta: {} .. KL_eta: {} .. KL_alpha: {} .. Rec_loss: {} .. NELBO: {}'.format(
                epoch, lr, cur_kl_theta, cur_kl_eta, cur_kl_alpha, cur_nll, cur_loss))
        print('*' * 100)
Example #54
0
def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    ws = AverageMeter()
    end = time.time()

    # bar = Bar('Training', max=args.train_iteration)
    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)

    model.train()
    for batch_idx in range(args.train_iteration):
        try:
            inputs_x, targets_x = labeled_train_iter.next()
        except:
            labeled_train_iter = iter(labeled_trainloader)
            inputs_x, targets_x = labeled_train_iter.next()

        try:
            (inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            (inputs_u, inputs_u2), _ = unlabeled_train_iter.next()

        # measure data loading time
        data_time.update(time.time() - end)

        batch_size = inputs_x.size(0)

        # Transform label to one-hot
        targets_x = torch.zeros(batch_size, args.num_classes).scatter_(1, targets_x.view(-1,1).long(), 1)

        if use_cuda:
            inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
            inputs_u = inputs_u.cuda()
            inputs_u2 = inputs_u2.cuda()


        with torch.no_grad():
            # compute guessed labels of unlabel samples
            outputs_u = model(inputs_u)
            outputs_u2 = model(inputs_u2)
            p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
            pt = p**(1/args.T)
            targets_u = pt / pt.sum(dim=1, keepdim=True)
            targets_u = targets_u.detach()

        # mixup
        all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
        all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)

        l = np.random.beta(args.alpha, args.alpha)

        l = max(l, 1-l)

        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 
        mixed_input = list(torch.split(mixed_input, batch_size))
        mixed_input = interleave(mixed_input, batch_size)

        logits = [model(mixed_input[0])]
        for input in mixed_input[1:]:
            logits.append(model(input))

        # put interleaved samples back
        logits = interleave(logits, batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/args.train_iteration)

        loss = Lx + w * Lu

        # record loss
        losses.update(loss.item(), inputs_x.size(0))
        losses_x.update(Lx.item(), inputs_x.size(0))
        losses_u.update(Lu.item(), inputs_x.size(0))
        ws.update(w, inputs_x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema_optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        # bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format(
        #             batch=batch_idx + 1,
        #             size=args.train_iteration,
        #             data=data_time.avg,
        #             bt=batch_time.avg,
        #             total=bar.elapsed_td,
        #             eta=bar.eta_td,
        #             loss=losses.avg,
        #             loss_x=losses_x.avg,
        #             loss_u=losses_u.avg,
        #             w=ws.avg,
        #             )
        # bar.next()
    # bar.finish()

    return (losses.avg, losses_x.avg, losses_u.avg,)
Example #55
0
    def split(self, split_size, dim=0):
        """Splits this tensor into a list of tensors.

        See :func:`torch.split`.
        """
        return torch.split(self, split_size, dim)
 def forward(self, x):
     x = self.filter(x)
     out = torch.split(x, self.out_channels, 1)
     return torch.max(out[0], out[1])
    def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]:
        """Compute the matrix fwt for the given input signal.

        Matrix fwt are used to avoid padding.

        Args:
            input_signal (torch.Tensor): Batched input data [batch_size, time],
                should be of even length. 1d inputs are interpreted as [time].

        Returns:
            List[torch.Tensor]: A list with the coefficients for each scale.

        Raises:
            ValueError: If the decomposition level is not a positive integer
                or if the input signal has not the expected shape.
        """
        if input_signal.dim() == 1:
            # assume time series
            input_signal = input_signal.unsqueeze(0)
        elif input_signal.dim() != 2:
            raise ValueError(
                f"Invalid input tensor shape {input_signal.size()}. "
                "The input signal is expected to be of the form "
                "[batch_size, length].")

        if input_signal.shape[-1] % 2 != 0:
            # odd length input
            # print('input length odd, padding a zero on the right')
            input_signal = torch.nn.functional.pad(input_signal, [0, 1])

        _, length = input_signal.shape

        re_build = False
        if self.input_length != length:
            self.input_length = length
            re_build = True

        if self.level is None:
            self.level = int(np.log2(length))
            re_build = True
        elif self.level <= 0:
            raise ValueError("level must be a positive integer.")

        if not self.fwt_matrix_list or re_build:
            self._construct_analysis_matrices(device=input_signal.device,
                                              dtype=input_signal.dtype)

        lo = input_signal.T
        split_list = []
        for scale, fwt_matrix in enumerate(self.fwt_matrix_list):
            if self.pad_list[scale]:
                # fix odd coefficients lengths for the conv matrix to work.
                lo = torch.nn.functional.pad(lo.T.unsqueeze(1),
                                             [0, 1]).squeeze(1).T
            coefficients = torch.sparse.mm(fwt_matrix, lo)
            lo, hi = torch.split(coefficients,
                                 coefficients.shape[0] // 2,
                                 dim=0)
            split_list.append(hi)
        split_list.append(lo)
        return split_list[::-1]
Example #58
0
 def forward(self, x):
     split = torch.split(x, self.half, 1)
     out1 = self.IN(split[0].contiguous())
     out2 = self.BN(split[1].contiguous())
     out = torch.cat((out1, out2), 1)
     return out
    times = []
    epochs = 500
    for epoch in range(epochs):
        print("Epoch", epoch)
        epoch_start = time.time()
        model.train()
        train_losses = []
        train_accuracies = []
        train_accuracies2 = []
        _start = time.time()
        for batch_idx, sequence_batch in enumerate(train_loader):
            sequence_batch = Variable(sequence_batch, requires_grad=False)
            if use_cuda:
                sequence_batch = sequence_batch.cuda()
            sequence = sequence_batch.squeeze(dim=0)
            subsequences = torch.split(sequence, split_size=100)
            for seq in subsequences:
                batch_size = 1
                seq_len = seq.size()[0]
                seq = seq.view(seq_len, -1).contiguous()
                seq = seq.unsqueeze(dim=1)
                targets = seq[1:]

                optimizer.zero_grad()
                predictions = model(seq)
                losses = [F.binary_cross_entropy(input=pred, target=targets[step]) for step, pred in enumerate(predictions[:-1])]
                loss = sum(losses)
                loss.backward()
                torch.nn.utils.clip_grad_norm(model.parameters(), .25)
                optimizer.step()
                train_losses.append(np.mean([l.data.cpu().numpy() for l in losses]))
Example #60
0
    def cell(self, x, h_prev, x_mask, h_mask):
        """
        forwards inside the cell of our model
        :param x: input
        :param h_prev: hidden of previous step
        :param x_mask: mask for input dropout
        :param h_mask: mask for hidden dropout
        :return: the hidden output of the current step
        """

        s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)

        # states contains the nodes computed as described in the paper,
        # that means, the sum of output of the operations of the incoming
        # edges. If genotype defined, there is only one incoming edge
        # as a constraint described in the paper.
        states = [s0]

        # IMPORTANT: genotype is None when doing arch search
        # "i" is the index of the next intermediate node,
        # "name" is the label of the activation function,
        # "pred" is the index of the previous node, so the edge will be pred -->name--->i
        for i, (name, pred) in enumerate(self.genotype.recurrent):

            # taking the previous state using its index
            s_prev = states[pred]

            if self.use_edge_matrices:
                # sum of first (i-1) natural numbers plus the previous node index
                edge_weights_id = sum(num for num in range(1, i)) + pred

            else:
                edge_weights_id = i

            # applying dropout masks if training.
            # computing the matrix mul between the previous output
            # and the weights of the current node "i" (FC layer)
            if self.training:
                ch = (s_prev * h_mask).mm(self._Ws[edge_weights_id])
            else:
                ch = s_prev.mm(self._Ws[edge_weights_id])
            c, h = torch.split(ch, self.nhid, dim=-1)
            c = c.sigmoid()

            # getting the chosen activation function
            fn = self._get_activation(name)

            # activation function on hidden
            h = fn(h)

            s = s_prev + c * (h - s_prev)

            states += [s]

        # computing the output as the mean of the output of
        # the INTERMEDIATE nodes, where their index are
        # defined by the "concat" list in the genotype
        output = torch.mean(
            torch.stack([states[i] for i in self.genotype.concat], -1), -1)

        if self.handle_hidden_mode == 'ACTIVATION':
            # avoid the explosion of the hidden state by forcing the value in [-1, 1]
            output = F.tanh(output)
        return output