Beispiel #1
0
def scale_tensor(tensor, scale):
    """
    Safely scale a tensor without increasing its ``.shape``.
    This avoids NANs by assuming ``inf * 0 = 0 * inf = 0``.
    """
    if isinstance(tensor, numbers.Number):
        if isinstance(scale, numbers.Number):
            return tensor * scale
        elif tensor == 0:
            return torch.zeros_like(scale)
        elif tensor == 1:
            return scale
        else:
            return scale
    if isinstance(scale, numbers.Number):
        if scale == 0:
            return torch.zeros_like(tensor)
        elif scale == 1:
            return tensor
        else:
            return tensor * scale
    result = tensor * scale
    result[(scale == 0).expand_as(result)] = 0  # avoid NANs
    if result.shape != tensor.shape:
        raise ValueError("Broadcasting error: scale is incompatible with tensor: "
                         "{} vs {}".format(scale.shape, tensor.shape))
    return result
Beispiel #2
0
    def __init__(self, block, layers, c_out=1000):
        self.inplanes = 64
        super(XResNet, self).__init__()
        self.conv1 = conv2d(3, 32, 2)
        self.conv2 = conv2d(32, 32, 1)
        self.conv3 = conv2d(32, 64, 1)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512 * block.expansion, c_out)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        for m in self.modules():
            if isinstance(m, BasicBlock): m.bn2.weight = nn.Parameter(torch.zeros_like(m.bn2.weight))
            if isinstance(m, Bottleneck): m.bn3.weight = nn.Parameter(torch.zeros_like(m.bn3.weight))
            if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)
Beispiel #3
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                if state['step'] > 1:
                    prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
                    prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
                    # Hypergradient for Adam:
                    h = torch.dot(grad.view(-1), torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])).view(-1)) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                    # Hypergradient descent of the learning rate:
                    tmp = group['hypergrad_lr'] * h
                    group['lr'] += tmp.double().cpu()

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)

        return loss
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['b1'], group['b2']

                state['step'] += 1

                # Add grad clipping
                if group['max_grad_norm'] > 0:
                    clip_grad_norm_(p, group['max_grad_norm'])

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                denom = exp_avg_sq.sqrt().add_(group['e'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                schedule_fct = SCHEDULES[group['schedule']]
                lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
                step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)

                # Add weight decay at the end (fixed version)
                if (len(p.size()) > 1 or group['vector_l2']) and group['l2'] > 0:
                    p.data.add_(-lr_scheduled * group['l2'], p.data)

        return loss
def manual_forget_mult(x, f, h=None, batch_first=True, backward=False):
    if batch_first: x,f = x.transpose(0,1),f.transpose(0,1)
    out = torch.zeros_like(x)
    prev = h if h is not None else torch.zeros_like(out[0])
    idx_range = range(x.shape[0]-1,-1,-1) if backward else range(x.shape[0])
    for i in idx_range:
        out[i] = f[i] * x[i] + (1-f[i]) * prev
        prev = out[i]
    if batch_first: out = out.transpose(0,1)
    return out
Beispiel #6
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('RMSprop does not support sparse gradients')
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['square_avg'] = torch.zeros_like(p.data)
                    if group['momentum'] > 0:
                        state['momentum_buffer'] = torch.zeros_like(p.data)
                    if group['centered']:
                        state['grad_avg'] = torch.zeros_like(p.data)

                square_avg = state['square_avg']
                alpha = group['alpha']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)

                if group['centered']:
                    grad_avg = state['grad_avg']
                    grad_avg.mul_(alpha).add_(1 - alpha, grad)
                    avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps'])
                else:
                    avg = square_avg.sqrt().add_(group['eps'])

                if group['momentum'] > 0:
                    buf = state['momentum_buffer']
                    buf.mul_(group['momentum']).addcdiv_(grad, avg)
                    p.data.add_(-group['lr'], buf)
                else:
                    p.data.addcdiv_(-group['lr'], grad, avg)

        return loss
Beispiel #7
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        assert len(self.param_groups) == 1

        loss = None
        if closure is not None:
            loss = closure()

        group = self.param_groups[0]
        weight_decay = group['weight_decay']
        momentum = group['momentum']
        dampening = group['dampening']
        nesterov = group['nesterov']

        grad = self._gather_flat_grad_with_weight_decay(weight_decay)

        # NOTE: SGDHD has only global state, but we register it as state for
        # the first param, because this helps with casting in load_state_dict
        state = self.state[self._params[0]]
        # State initialization
        if len(state) == 0:
            state['grad_prev'] = torch.zeros_like(grad)

        grad_prev = state['grad_prev']
        # Hypergradient for SGD
        h = torch.dot(grad, grad_prev)
        # Hypergradient descent of the learning rate:
        group['lr'] += group['hypergrad_lr'] * h

        if momentum != 0:
            if 'momentum_buffer' not in state:
                buf = state['momentum_buffer'] = torch.zeros_like(grad)
                buf.mul_(momentum).add_(grad)
            else:
                buf = state['momentum_buffer']
                buf.mul_(momentum).add_(1 - dampening, grad)
            if nesterov:
                grad.add_(momentum, buf)
            else:
                grad = buf

        state['grad_prev'] = grad

        self._add_grad(-group['lr'], grad)

        return loss
Beispiel #8
0
def get_analytical_jacobian(input, output):
    diff_input_list = list(iter_tensors(input, True))
    jacobian = make_jacobian(input, output.numel())
    jacobian_reentrant = make_jacobian(input, output.numel())
    grad_output = torch.zeros_like(output)
    flat_grad_output = grad_output.view(-1)
    reentrant = True
    correct_grad_sizes = True

    for i in range(flat_grad_output.numel()):
        flat_grad_output.zero_()
        flat_grad_output[i] = 1
        for jacobian_c in (jacobian, jacobian_reentrant):
            grads_input = torch.autograd.grad(output, diff_input_list, grad_output,
                                              retain_graph=True, allow_unused=True)
            for jacobian_x, d_x, x in zip(jacobian_c, grads_input, diff_input_list):
                if d_x is not None and d_x.size() != x.size():
                    correct_grad_sizes = False
                elif jacobian_x.numel() != 0:
                    if d_x is None:
                        jacobian_x[:, i].zero_()
                    else:
                        d_x_dense = d_x.to_dense() if d_x.is_sparse else d_x
                        assert jacobian_x[:, i].numel() == d_x_dense.numel()
                        jacobian_x[:, i] = d_x_dense.contiguous().view(-1)

    for jacobian_x, jacobian_reentrant_x in zip(jacobian, jacobian_reentrant):
        if jacobian_x.numel() != 0 and (jacobian_x - jacobian_reentrant_x).abs().max() != 0:
            reentrant = False

    return jacobian, reentrant, correct_grad_sizes
  def testDutyCycleUpdate(self):
    """
    Start with equal duty cycle, boost factor=0, k=4, batch size=2
    """
    x = self.x2

    expected = torch.zeros_like(x)
    expected[0, 0, 1, 0] = 1.1
    expected[0, 0, 1, 1] = 1.2
    expected[0, 1, 0, 1] = 1.2
    expected[0, 2, 1, 0] = 1.3
    expected[1, 0, 0, 0] = 1.4
    expected[1, 1, 0, 0] = 1.5
    expected[1, 1, 0, 1] = 1.6
    expected[1, 2, 1, 1] = 1.7

    dutyCycle = torch.zeros((1, 3, 1, 1))
    dutyCycle[:] = 1.0 / 3.0
    updateDutyCycleCNN(expected, dutyCycle, 2, 2)
    newDuty = torch.tensor([1.5000, 1.5000, 1.0000]) / 4.0
    diff = (dutyCycle.reshape(-1) - newDuty).abs().sum()
    self.assertLessEqual(diff, 0.001)

    dutyCycle[:] = 1.0 / 3.0
    updateDutyCycleCNN(expected, dutyCycle, 4, 4)
    newDuty = torch.tensor([0.3541667, 0.3541667, 0.2916667])
    diff = (dutyCycle.reshape(-1) - newDuty).abs().sum()
    self.assertLessEqual(diff, 0.001)
  def testFour(self):
    """
    Equal duty cycle, boost factor=0, k=3, batch size=2
    """
    x = self.x2

    ctx = TestContext()

    result = KWinnersCNN.forward(ctx, x, self.dutyCycle, k=3, boostStrength=0.0)

    expected = torch.zeros_like(x)
    expected[0, 0, 1, 1] = 1.2
    expected[0, 1, 0, 1] = 1.2
    expected[0, 2, 1, 0] = 1.3
    expected[1, 1, 0, 0] = 1.5
    expected[1, 1, 0, 1] = 1.6
    expected[1, 2, 1, 1] = 1.7

    self.assertEqual(result.shape, expected.shape)

    numCorrect = (result == expected).sum()
    self.assertEqual(numCorrect, result.reshape(-1).size()[0])

    indices = ctx.saved_tensors[0]
    expectedIndices = torch.tensor([[3, 10, 5], [4, 5, 11]])
    numCorrect = (indices == expectedIndices).sum()
    self.assertEqual(numCorrect, 6)

    # Test that gradient values are in the right places, that their sum is
    # equal, and that they have exactly the right number of nonzeros
    out_grad, _, _, _ = KWinnersCNN.backward(ctx, self.gradient2)
    out_grad = out_grad.reshape(2, -1)
    in_grad = self.gradient2.reshape(2, -1)
    self.assertEqual((out_grad == in_grad).sum(), 6)
    self.assertEqual(len(out_grad.nonzero()), 6)
Beispiel #11
0
 def forward(self, x):
     x = torch.tanh(self.fc1(x))
     x = torch.tanh(self.fc2(x))
     mu = self.fc3(x)
     logstd = torch.zeros_like(mu)
     std = torch.exp(logstd)
     return mu, std
  def testOne(self):
    """
    Equal duty cycle, boost factor 0, k=4, batch size 1
    """
    x = self.x

    ctx = TestContext()

    result = KWinnersCNN.forward(ctx, x, self.dutyCycle, k=4, boostStrength=0.0)

    expected = torch.zeros_like(x)
    expected[0, 0, 1, 0] = 1.1
    expected[0, 0, 1, 1] = 1.2
    expected[0, 1, 0, 1] = 1.2
    expected[0, 2, 1, 0] = 1.3

    self.assertEqual(result.shape, expected.shape)

    numCorrect = (result == expected).sum()
    self.assertEqual(numCorrect, result.reshape(-1).size()[0])

    indices = ctx.saved_tensors[0].reshape(-1)
    expectedIndices = torch.tensor([2, 3, 10, 5])
    numCorrect = (indices == expectedIndices).sum()
    self.assertEqual(numCorrect, 4)

    # Test that gradient values are in the right places, that their sum is
    # equal, and that they have exactly the right number of nonzeros
    grad_x, _, _, _ = KWinnersCNN.backward(ctx, self.gradient)
    grad_x = grad_x.reshape(-1)
    self.assertEqual(
      (grad_x[indices] == self.gradient.reshape(-1)[indices]).sum(), 4)
    self.assertAlmostEqual(
      grad_x.sum(), self.gradient.reshape(-1)[indices].sum(), places=4)
    self.assertEqual(len(grad_x.nonzero()), 4)
Beispiel #13
0
 def predict(self, x, attn_type = "hard"):
     #predict with greedy decoding
     emb = self.embedding(x)
     h = Variable(torch.zeros(1, x.size(0), self.hidden_dim))
     c = Variable(torch.zeros(1, x.size(0), self.hidden_dim))
     enc_h, _ = self.encoder(emb, (h, c))
     y = [Variable(torch.zeros(x.size(0)).long())]
     self.attn = []        
     for t in range(x.size(1)):
         emb_t = self.embedding(y[-1])
         dec_h, (h, c) = self.decoder(emb_t.unsqueeze(1), (h, c))
         scores = torch.bmm(enc_h, dec_h.transpose(1,2)).squeeze(2)
         attn_dist = F.softmax(scores, dim = 1)
         self.attn.append(attn_dist.data)
         if attn_type == "hard":
             _, argmax = attn_dist.max(1)
             one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, argmax.data.unsqueeze(1), 1))
             context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1)                    
         else:                
             context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1)
         pred = self.vocab_layer(torch.cat([dec_h.squeeze(1), context], 1))
         _, next_token = pred.max(1)
         y.append(next_token)
     self.attn = torch.stack(self.attn, 0).transpose(0, 1)
     return torch.stack(y, 0).transpose(0, 1)
Beispiel #14
0
def bisect_demo():
    """ Bisect the LB/UB on specified columns.
        The key is to use scatter_() to convert indices into one-hot encodings.
    """
    t1t2 = torch.stack((torch.randn(5, 4), torch.randn(5, 4)), dim=-1)
    lb, _ = torch.min(t1t2, dim=-1)
    ub, _ = torch.max(t1t2, dim=-1)
    print('LB:', lb)
    print('UB:', ub)

    # random idxs for testing
    idxs = torch.randn_like(lb)
    _, idxs = idxs.max(dim=-1)  # <Batch>
    print('Split idxs:', idxs)

    idxs = idxs.unsqueeze(dim=-1)  # Batch x 1
    idxs = torch.zeros_like(lb).byte().scatter_(-1, idxs, 1)  # convert into one-hot encoding
    print('Reorg idxs:', idxs)

    mid = (lb + ub) / 2.0
    lefts_lb = lb
    lefts_ub = torch.where(idxs, mid, ub)  # use the one-hot encoding to call torch.where()
    rights_lb = torch.where(idxs, mid, lb)  # definitely faster than element-wise reassignment
    rights_ub = ub

    print('LEFT LB:', lefts_lb)
    print('LEFT UB:', lefts_ub)
    print('RIGHT LB:', rights_lb)
    print('RIGHT UB:', rights_ub)

    newlb = torch.cat((lefts_lb, rights_lb), dim=0)
    newub = torch.cat((lefts_ub, rights_ub), dim=0)
    return newlb, newub
Beispiel #15
0
    def forward(self, X: Tensor) -> Tensor:
        r"""Evaluate Expected Improvement on the candidate set X.

        Args:
            X: A `b1 x ... bk x 1 x d`-dim batched tensor of `d`-dim design points.
                Expected Improvement is computed for each point individually,
                i.e., what is considered are the marginal posteriors, not the
                joint.

        Returns:
            A `b1 x ... bk`-dim tensor of Expected Improvement values at the
            given design points `X`.
        """
        self.best_f = self.best_f.to(X)
        posterior = self.model.posterior(X)
        self._validate_single_output_posterior(posterior)
        mean = posterior.mean
        # deal with batch evaluation and broadcasting
        view_shape = mean.shape[:-2] if mean.dim() >= X.dim() else X.shape[:-2]
        mean = mean.view(view_shape)
        sigma = posterior.variance.clamp_min(1e-9).sqrt().view(view_shape)
        u = (mean - self.best_f.expand_as(mean)) / sigma
        if not self.maximize:
            u = -u
        normal = Normal(torch.zeros_like(u), torch.ones_like(u))
        ucdf = normal.cdf(u)
        updf = torch.exp(normal.log_prob(u))
        ei = sigma * (updf + u * ucdf)
        return ei
    def get_loss(self, image_a_pred, image_b_pred, mask_a, mask_b):
        loss = 0

        # get the nonzero indices
        mask_a_indices_flat = torch.nonzero(mask_a)
        mask_b_indices_flat = torch.nonzero(mask_b)
        if len(mask_a_indices_flat) == 0:
            return Variable(torch.cuda.LongTensor([0]), requires_grad=True)
        if len(mask_b_indices_flat) == 0:
            return Variable(torch.cuda.LongTensor([0]), requires_grad=True)

        # take 5000 random pixel samples of the object, using the mask
        num_samples = 10000

        rand_numbers_a = (torch.rand(num_samples)*len(mask_a_indices_flat)).cuda()
        rand_indices_a = Variable(torch.floor(rand_numbers_a).type(torch.cuda.LongTensor), requires_grad=False)
        randomized_mask_a_indices_flat = torch.index_select(mask_a_indices_flat, 0, rand_indices_a).squeeze(1)

        rand_numbers_b = (torch.rand(num_samples)*len(mask_b_indices_flat)).cuda()
        rand_indices_b = Variable(torch.floor(rand_numbers_b).type(torch.cuda.LongTensor), requires_grad=False)
        randomized_mask_b_indices_flat = torch.index_select(mask_b_indices_flat, 0, rand_indices_b).squeeze(1)

        # index into the image and get descriptors
        M_margin = 0.5 # margin parameter
        random_img_a_object_descriptors = torch.index_select(image_a_pred, 1, randomized_mask_a_indices_flat)
        random_img_b_object_descriptors = torch.index_select(image_b_pred, 1, randomized_mask_b_indices_flat)
        pixel_wise_loss = (random_img_a_object_descriptors - random_img_b_object_descriptors).pow(2).sum(dim=2)
        pixel_wise_loss = torch.add(pixel_wise_loss, -2*M_margin)
        zeros_vec = torch.zeros_like(pixel_wise_loss)
        loss += torch.max(zeros_vec, pixel_wise_loss).sum()

        return loss
Beispiel #17
0
 def predict(self, x_de, x_en):
     bs = x_de.size(0)
     emb_de = self.embedding_de(x_de) # bs,n_de,word_dim
     emb_en = self.embedding_en(x_en) # bs,n_en,word_dim
     h = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda())
     c = Variable(torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda())
     enc_h, _ = self.encoder(emb_de, (h, c))
     dec_h, _ = self.decoder(emb_en, (h, c))
     # all the same. enc_h is bs,n_de,hiddensz*n_directions. h and c are both n_layers*n_directions,bs,hiddensz
     if self.directions == 2:
         enc_h = self.dim_reduce(enc_h) # bs,n_de,hiddensz
     scores = torch.bmm(enc_h, dec_h.transpose(1,2))
     # (bs,n_de,hiddensz) * (bs,hiddensz,n_en) = (bs,n_de,n_en)
     y = [Variable(torch.cuda.LongTensor([sos_token]*bs))] # bs
     self.attn = []
     for t in range(x_en.size(1)-1): # iterate over english words, with teacher forcing
         attn_dist = F.softmax(scores[:,:,t],dim=1) # bs,n_de
         self.attn.append(attn_dist.data)
         if self.attn_type == "hard":
             _, argmax = attn_dist.max(1) # bs. for each batch, select most likely german word to pay attention to
             one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, argmax.data.unsqueeze(1), 1).cuda())
             context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1)
         else:
             context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1)
         # the difference btwn hard and soft is just whether we use a one_hot or a distribution
         # context is bs,hiddensz
         pred = self.vocab_layer(torch.cat([dec_h[:,t,:], context], 1)) # bs,len(EN.vocab)
         _, next_token = pred.max(1) # bs
         y.append(next_token)
     self.attn = torch.stack(self.attn, 0).transpose(0, 1) # bs,n_en,n_de (for visualization!)
     y = torch.stack(y,0).transpose(0,1) # bs,n_en
     return y,self.attn
Beispiel #18
0
 def predict2(self, x_de, beamsz, gen_len):
     emb_de = self.embedding_de(x_de) # "batch size",n_de,word_dim, but "batch size" is 1 in this case!
     h0 = Variable(torch.zeros(self.n_layers*self.directions, 1, self.hidden_dim).cuda())
     c0 = Variable(torch.zeros(self.n_layers*self.directions, 1, self.hidden_dim).cuda())
     enc_h, _ = self.encoder(emb_de, (h0, c0))
     # since enc batch size=1, enc_h is 1,n_de,hiddensz*n_directions
     if self.directions == 2:
         enc_h = self.dim_reduce(enc_h) # 1,n_de,hiddensz
     masterheap = CandList(self.n_layers,self.hidden_dim,enc_h.size(1),beamsz)
     # in the following loop, beamsz is length 1 for first iteration, length true beamsz (100) afterward
     for i in range(gen_len):
         prev = masterheap.get_prev() # beamsz
         emb_t = self.embedding_en(prev) # embed the last thing we generated. beamsz,word_dim
         enc_h_expand = enc_h.expand(prev.size(0),-1,-1) # beamsz,n_de,hiddensz
         
         h, c = masterheap.get_hiddens() # (n_layers,beamsz,hiddensz),(n_layers,beamsz,hiddensz)
         dec_h, (h, c) = self.decoder(emb_t.unsqueeze(1), (h, c)) # dec_h is beamsz,1,hiddensz (batch_first=True)
         scores = torch.bmm(enc_h_expand, dec_h.transpose(1,2)).squeeze(2)
         # (beamsz,n_de,hiddensz) * (beamsz,hiddensz,1) = (beamsz,n_de,1). squeeze to beamsz,n_de
         attn_dist = F.softmax(scores,dim=1)
         if self.attn_type == "hard":
             _, argmax = attn_dist.max(1) # beamsz for each batch, select most likely german word to pay attention to
             one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, argmax.data.unsqueeze(1), 1).cuda())
             context = torch.bmm(one_hot.unsqueeze(1), enc_h_expand).squeeze(1)
         else:
             context = torch.bmm(attn_dist.unsqueeze(1), enc_h_expand).squeeze(1)
         # the difference btwn hard and soft is just whether we use a one_hot or a distribution
         # context is beamsz,hiddensz*n_directions
         pred = self.vocab_layer(torch.cat([dec_h.squeeze(1), context], 1)) # beamsz,len(EN.vocab)
         # TODO: set the columns corresponding to <pad>,<unk>,</s>,etc to 0
         masterheap.update_beam(pred)
         masterheap.update_hiddens(h,c)
         masterheap.update_attentions(attn_dist)
         masterheap.firstloop = False
     return masterheap.probs,masterheap.wordlist,masterheap.attentions
Beispiel #19
0
    def forward(self, y_pred, y_true, eps=1e-6):
        return NotImplementedError

        torch.nn.modules.loss._assert_no_grad(y_true)

        assert y_pred.shape[1] == 2

        same_left = torch.stack([y_true[:, 0], y_pred[:, 0]], dim=1)
        same_left, _ = torch.max(same_left, dim=1)

        same_right = torch.stack([y_true[:, 1], y_pred[:, 1]], dim=1)
        same_right, _ = torch.min(same_right, dim=1)

        same_len = same_right - same_left + 1   # (batch_size,)
        same_len = torch.stack([same_len, torch.zeros_like(same_len)], dim=1)
        same_len, _ = torch.max(same_len, dim=1)

        same_len = same_len.type(torch.float)

        pred_len = (y_pred[:, 1] - y_pred[:, 0] + 1).type(torch.float)
        true_len = (y_true[:, 1] - y_true[:, 0] + 1).type(torch.float)

        pre = same_len / (pred_len + eps)
        rec = same_len / (true_len + eps)

        f1 = 2 * pre * rec / (pre + rec + eps)

        return -torch.mean(f1)
    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
        sequence_output = all_encoder_layers[-1]
        pooled_output = self.pooler(sequence_output)
        return all_encoder_layers, pooled_output
    def sample_conditional_a(self, resid_image, var_so_far, pixel_1d):

        is_on = (pixel_1d < (self.n_discrete_latent - 1)).float()

        # pass through galaxy encoder
        pixel_2d = self.one_galaxy_vae.pixel_1d_to_2d(pixel_1d)
        z_mean, z_var = self.one_galaxy_vae.enc(resid_image, pixel_2d)

        # sample z
        q_z = Normal(z_mean, z_var.sqrt())
        z_sample = q_z.rsample()

        # kl term for continuous latent vars
        log_q_z = q_z.log_prob(z_sample).sum(1)
        p_z = Normal(torch.zeros_like(z_sample), torch.ones_like(z_sample))
        log_p_z = p_z.log_prob(z_sample).sum(1)
        kl_z = is_on * (log_q_z - log_p_z)

        # run through decoder
        recon_mean, recon_var = self.one_galaxy_vae.dec(is_on, pixel_2d, z_sample)

        # NOTE: we will have to the recon means once we do more detections
        # recon_means = recon_mean + image_so_far
        # recon_vars = recon_var + var_so_far

        return recon_mean, recon_var, is_on, kl_z
Beispiel #22
0
def get_analytical_jacobian(input, output):
    input = contiguous(input)
    jacobian = make_jacobian(input, output.numel())
    jacobian_reentrant = make_jacobian(input, output.numel())
    grad_output = torch.zeros_like(output)
    flat_grad_output = grad_output.view(-1)
    reentrant = True
    correct_grad_sizes = True

    for i in range(flat_grad_output.numel()):
        flat_grad_output.zero_()
        flat_grad_output[i] = 1
        for jacobian_c in (jacobian, jacobian_reentrant):
            zero_gradients(input)
            output.backward(grad_output, create_graph=True)
            for jacobian_x, (d_x, x) in zip(jacobian_c, iter_variables(input)):
                if jacobian_x.numel() != 0:
                    if d_x is None:
                        jacobian_x[:, i].zero_()
                    else:
                        jacobian_x[:, i] = d_x.to_dense() if d_x.is_sparse else d_x
                if d_x is not None and d_x.size() != x.size():
                    correct_grad_sizes = False

    for jacobian_x, jacobian_reentrant_x in zip(jacobian, jacobian_reentrant):
        if jacobian_x.numel() != 0 and (jacobian_x - jacobian_reentrant_x).abs().max() != 0:
            reentrant = False

    return jacobian, reentrant, correct_grad_sizes
Beispiel #23
0
    def test_sparse_variable_methods(self):
        # TODO: delete when tensor/variable are merged
        from torch.autograd import Variable
        i = self.IndexTensor([[0, 1, 1], [2, 0, 2]])
        v = self.ValueTensor([3, 4, 5])
        sparse_mat = self.SparseTensor(i, v, torch.Size([2, 3]))
        sparse_var = Variable(sparse_mat)

        to_test_one_arg = {
            'zeros_like': lambda x: torch.zeros_like(x),
            'transpose': lambda x: x.transpose(0, 1),
            'transpose_': lambda x: x.transpose(0, 1),
            't': lambda x: x.t(),
            't_': lambda x: x.t_(),
            'div': lambda x: x.div(2),
            'div_': lambda x: x.div_(2),
            'pow': lambda x: x.pow(2),
            '_nnz': lambda x: x._nnz(),
            'is_coalesced': lambda x: x.is_coalesced(),
            'coalesce': lambda x: x.coalesce(),
            'to_dense': lambda x: x.to_dense(),
            '_dimI': lambda x: x._dimI(),
            '_dimV': lambda x: x._dimV(),
        }

        for test_name, test_fn in to_test_one_arg.items():
            var1 = sparse_var.clone()
            tensor1 = sparse_mat.clone()

            out_var = test_fn(var1)
            out_tensor = test_fn(tensor1)

            if isinstance(out_tensor, int) or isinstance(out_tensor, bool):
                self.assertEqual(out_var, out_tensor)
                continue

            # Assume output is variable / tensor
            self.assertEqual(test_fn(var1).data, test_fn(tensor1),
                             test_name)

        i = self.IndexTensor([[0, 0, 1], [1, 2, 1]])
        v = self.ValueTensor([3, 3, 4])
        sparse_mat2 = self.SparseTensor(i, v, torch.Size([2, 3]))
        sparse_var2 = Variable(sparse_mat2)

        to_test_two_arg = {
            'sub': lambda x, y: x.sub(y),
            'sub_': lambda x, y: x.sub_(y),
            'mul': lambda x, y: x.mul(y),
            'mul_': lambda x, y: x.mul_(y),
        }

        for test_name, test_fn in to_test_two_arg.items():
            var1 = sparse_var.clone()
            var2 = sparse_var2.clone()
            tensor1 = sparse_mat.clone()
            tensor2 = sparse_mat2.clone()
            self.assertEqual(test_fn(var1, var2).data,
                             test_fn(tensor1, tensor2), test_name)
def softplus_double_backwards(ctx, ggI):
    t = ctx.saved_tensors
    input, gO, output = t[0], t[1], t[2]
    beta, threshold = ctx.additional_args[0], ctx.additional_args[1]

    input_beta = input * beta
    above_threshold = torch.zeros_like(ggI).masked_fill_(input_beta > threshold, 1)
    below_threshold = torch.zeros_like(ggI).masked_fill_(input_beta <= threshold, 1)

    exp_output_beta = (output * beta).exp()
    first_deriv = (exp_output_beta - 1) / exp_output_beta
    first_deriv_below_threshold = first_deriv * below_threshold

    gI = ggI * gO * first_deriv_below_threshold * beta / exp_output_beta
    ggO = ggI * (above_threshold + first_deriv_below_threshold)

    return gI, ggO, None, None, None, None
 def backward(ctx, grad_output):
     input,         = ctx.saved_tensors
     grad_input     = torch.stack((grad_output, torch.zeros_like(grad_output)), dim=len(grad_output.shape))
     phase_input    = angle(input)
     phase_input    = torch.stack((torch.cos(phase_input), torch.sin(phase_input)), dim=len(grad_output.shape))
     grad_input     = multiply_complex(phase_input, grad_input)
     
     return 0.5*grad_input
Beispiel #26
0
 def forward(self, input, target):
     one_hot = torch.zeros_like(input)
     one_hot.scatter_(1, target.view(-1, 1), 1.0)
     # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
     output = self.s * (input - one_hot * self.m)
     
     loss = self.ce(output, target)
     return loss
Beispiel #27
0
 def forward(self, x_de, x_en, update_baseline=True):
     bs = x_de.size(0)
     # x_de is bs,n_de. x_en is bs,n_en
     emb_de = self.embedding_de(x_de) # bs,n_de,word_dim
     emb_en = self.embedding_en(x_en) # bs,n_en,word_dim
     h0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()
     c0_enc = torch.zeros(self.n_layers*self.directions, bs, self.hidden_dim).cuda()
     h0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda()
     c0_dec = torch.zeros(self.n_layers, bs, self.hidden_dim).cuda()
     # hidden vars have dimension n_layers*n_directions,bs,hiddensz
     enc_h, _ = self.encoder(emb_de, (Variable(h0_enc), Variable(c0_enc)))
     # enc_h is bs,n_de,hiddensz*n_directions. ordering is different from last week because batch_first=True
     dec_h, _ = self.decoder(emb_en, (Variable(h0_dec), Variable(c0_dec)))
     # dec_h is bs,n_en,hidden_size*n_directions
     # we've gotten our encoder/decoder hidden states so we are ready to do attention
     # first let's get all our scores, which we can do easily since we are using dot-prod attention
     if self.directions == 2:
         scores = torch.bmm(self.dim_reduce(enc_h), dec_h.transpose(1,2))
         # TODO: any easier ways to reduce dimension?
     else:
         scores = torch.bmm(enc_h, dec_h.transpose(1,2))
     # (bs,n_de,hiddensz*n_directions) * (bs,hiddensz*n_directions,n_en) = (bs,n_de,n_en)
     reinforce_loss = 0 # we only use this variable for hard attention
     loss = 0
     avg_reward = 0
     # we just iterate to dec_h.size(1)-1, since there's </s> at the end of each sentence
     for t in range(dec_h.size(1)-1): # iterate over english words, with teacher forcing
         attn_dist = F.softmax(scores[:, :, t],dim=1) # bs,n_de. these are the alphas (attention scores for each german word)
         if self.attn_type == "hard":
             cat = torch.distributions.Categorical(attn_dist) 
             attn_samples = cat.sample() # bs. each element is a sample from categorical distribution
             one_hot = Variable(torch.zeros_like(attn_dist.data).scatter_(-1, attn_samples.data.unsqueeze(1), 1).cuda()) # bs,n_de
             # made a bunch of one-hot vectors
             context = torch.bmm(one_hot.unsqueeze(1), enc_h).squeeze(1)
             # now we use the one-hot vectors to select correct hidden vectors from enc_h
             # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions). squeeze to bs,hiddensz*n_directions
         else:
             context = torch.bmm(attn_dist.unsqueeze(1), enc_h).squeeze(1) # same dimensions
             # (bs,1,n_de) * (bs,n_de,hiddensz*n_directions) = (bs,1,hiddensz*n_directions)
         # context is bs,hidden_size*n_directions
         # the rnn output and the context together make the decoder "hidden state", which is bs,2*hidden_size*n_directions
         pred = self.vocab_layer(torch.cat([dec_h[:,t,:], context], 1)) # bs,len(EN.vocab)
         y = x_en[:, t+1] # bs. these are our labels
         no_pad = (y != pad_token) # exclude english padding tokens
         reward = torch.gather(pred, 1, y.unsqueeze(1)) # bs,1
         # reward[i,1] = pred[i,y[i]]. this gets log prob of correct word for each batch. similar to -crossentropy
         reward = reward.squeeze(1)[no_pad] # less than bs
         if self.attn_type == "hard":
             reinforce_loss -= (cat.log_prob(attn_samples[no_pad]) * (reward-self.baseline).detach()).sum() 
             # reinforce rule (just read the formula), with special baseline
         loss -= reward.sum() # minimizing loss is maximizing reward
     no_pad_total = (x_en[:,1:] != pad_token).data.sum() # TODO: i think this is right, right?
     loss /= no_pad_total
     reinforce_loss /= no_pad_total
     avg_reward = -loss.data[0]
     if update_baseline: # update baseline as a moving average
         self.baseline = Variable(0.95*self.baseline.data + 0.05*avg_reward)
     return loss, reinforce_loss,avg_reward
Beispiel #28
0
    def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0):
        defaults = dict(lr=lr, lr_decay=lr_decay, weight_decay=weight_decay)
        super(Adagrad, self).__init__(params, defaults)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['sum'] = torch.zeros_like(p.data)
Beispiel #29
0
 def noncontiguize(self, obj):
     if isinstance(obj, list):
         return [self.noncontiguize(o) for o in obj]
     tensor = obj
     ndim = tensor.dim()
     noncontig = torch.stack([torch.zeros_like(tensor), tensor], ndim).select(ndim, 1).detach()
     assert noncontig.numel() == 1 or not noncontig.is_contiguous()
     noncontig.requires_grad = tensor.requires_grad
     return noncontig
Beispiel #30
0
def get_paths(tree, actions, batch_size, num_actions):
    # gets the parts of the tree corresponding to actions taken
    action_indices = cudify(torch.zeros_like(actions[:,0]).long())
    output = []
    for i, x in enumerate(tree):
        action_indices = action_indices * num_actions + actions[:, i]
        batch_indices = cudify(torch.arange(0, batch_size).long() * x.size(0) / batch_size) + action_indices
        output.append(x[batch_indices])
    return output
Beispiel #31
0
    def compute_ptps(self):

        t_range = torch.arange(-(self.n_times // 2),
                               self.n_times // 2 + 1).cuda()

        ptps_raw = torch.zeros(self.spike_index.shape[0]).float().cuda()
        if self.denoiser is not None:
            ptps_denoised = torch.zeros(
                self.spike_index.shape[0]).float().cuda()
        else:
            ptps_denoised = None

        # batch offsets
        offsets = torch.from_numpy(self.reader_residual.idx_list[:, 0] -
                                   self.reader_residual.buffer).cuda().long()

        with tqdm(total=self.reader_residual.n_batches) as pbar:

            for batch_id in range(self.reader_residual.n_batches):

                # load residual data
                dat = self.reader_residual.read_data_batch(batch_id,
                                                           add_buffer=True)
                dat = torch.from_numpy(dat).cuda()

                # relevant idx
                idx_in = torch.nonzero(
                    (self.spike_index[:, 0] >
                     self.reader_residual.idx_list[batch_id][0])
                    & (self.spike_index[:, 0] <
                       self.reader_residual.idx_list[batch_id][1]))[:, 0]

                spike_index_batch = self.spike_index[idx_in]
                spike_index_batch[:, 0] -= offsets[batch_id]

                # get residual snippets
                t_index = spike_index_batch[:, 0][:, None] + t_range
                c_index = spike_index_batch[:, 1].long()

                dat = torch.cat((dat, torch.zeros((dat.shape[0], 1)).cuda()),
                                1)
                residuals = dat[t_index, c_index[:, None]]

                # TODO: align residuals
                #shifts_batch = self.shifts[idx_in]
                #residuals = shift_chans(residuals, -shifts_batch)

                # make clean wfs
                wfs = residuals + self.scales[
                    idx_in][:, None] * self.templates[self.labels[idx_in]]

                ptps_raw[idx_in] = (torch.max(wfs, 1)[0] -
                                    torch.min(wfs, 1)[0])

                if self.denoiser is not None:
                    n_sample_run = 1000

                    idx_list = np.hstack(
                        (np.arange(0, wfs.shape[0],
                                   n_sample_run), wfs.shape[0]))
                    denoised_wfs = torch.zeros_like(wfs).cuda()
                    for j in range(len(idx_list) - 1):
                        denoised_wfs[
                            idx_list[j]:idx_list[j + 1]] = self.denoiser(
                                wfs[idx_list[j]:idx_list[j + 1]])[0].data
                    ptps_denoised[idx_in] = (torch.max(denoised_wfs, 1)[0] -
                                             torch.min(denoised_wfs, 1)[0])

                pbar.update()

        ptps_raw_cpu = ptps_raw.cpu().numpy()

        del dat, idx_in, spike_index_batch, t_index, c_index, residuals, wfs, ptps_raw

        torch.cuda.empty_cache()

        if self.denoiser is not None:
            ptps_denoised_cpu = ptps_denoised.cpu().numpy()
            del denoised_wfs, ptps_denoised
        else:
            ptps_denoised_cpu = np.copy(ptps_raw_cpu)

        return ptps_raw_cpu, ptps_denoised_cpu
Beispiel #32
0
def test(**kwargs):
    #refresh parameter
    opt._parse(kwargs)

    #create save folder, open log file
    save_test_root = opt.save_test_root
    if not os.path.exists(save_test_root):
        os.makedirs(save_test_root)
    log_file = open(save_test_root+"/result.txt",mode='w')

    #set gpu environment, load weights
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    bernoulli_weights = np.loadtxt(opt.weights)
    bernoulli_weights = t.from_numpy(bernoulli_weights).float().to(opt.device)

    #initialize model, load model file
    model = MCNet(bernoulli_weights,opt.cr,opt.blk_size,opt.ref_size).eval()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    model.to(opt.device) 
    model.eval()

    #get test videos
    videos = [os.path.join(opt.test_data_root,video) for video in os.listdir(opt.test_data_root)]
    video_num = len(videos)
    print("total test video number:",video_num)

    end = time.time()
    psnr_av = 0
    ssim_av = 0
    time_av = 0
    for item in videos:
        if (item.split(".")[-1]!='avi'):
            continue
        print("now is processing:",item)
        log_file.write("%s"%item)
        log_file.write("\n")
        
        uv = utils.Video(opt.height,opt.width)
        test_data = uv.video2array(item,opt.frame_num)
        test_data_t = t.from_numpy(test_data).float().to(opt.device)
        result_data_t = t.zeros_like(test_data_t).cuda()

        psnr_total = 0
        ssim_total = 0
        frame_cnt = 0
        
        #do test on every video
        for i in range(test_data_t.size(0)):
            for j in range(test_data_t.size(1)):
                
                frames = test_data_t[i,j,:,:,:]
                frames_num = frames.size(0)
        
                result_frame = t.ones(1,frames[0].size(0),frames[0].size(1)).float().to(opt.device)
                result_frames = t.zeros(frames_num,frames[0].size(0),frames[0].size(1)).to(opt.device)
                frames_t = frames

                x_b = uv.frame_unfold(frames_t,opt.blk_size,int(opt.blk_size/2)).to(opt.device)
                blk_num_h = x_b.size(1)
                blk_num_w = x_b.size(2)
        
                for ii in range(frames_num):
                    x_ref_b = uv.frame_unfold(result_frame,opt.ref_size,int(opt.ref_size/2))
                    result_b = t.zeros_like(x_b[0].unsqueeze_(0))

                    input_ = (x_b[ii,:,:,:,:]/255.0).float().to(opt.device)
                    input_target = input_.view(1*blk_num_h*blk_num_w,opt.blk_size,opt.blk_size)
                    input_m = input_.view(1*blk_num_h*blk_num_w,opt.blk_size*opt.blk_size,1)

                    ref = x_ref_b.repeat(1,2,1,1,1)
                    ref[:,[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17],:,:,:] = ref[:,[0,9,1,10,2,11,3,12,4,13,5,14,6,15,7,16,8,17],:,:,:]
                    ref = ref.repeat(1,1,2,1,1)
                    ref[:,:,[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17],:,:] = ref[:,:,[0,9,1,10,2,11,3,12,4,13,5,14,6,15,7,16,8,17],:,:]
                    ref_cat = ref[:,:,-1,:,:].view(1,ref.size(1),1,opt.ref_size,opt.ref_size)
                    ref = t.cat((ref,ref_cat),2)
                    ref_cat = ref[:,-1,:,:,:].view(1,1,ref.size(2),opt.ref_size,opt.ref_size)
                    ref = t.cat((ref,ref_cat),1)
                    ref = ref.view(1*blk_num_h*blk_num_w,opt.ref_size,opt.ref_size)

                    b_s = input_target.size(0)
                    weight = bernoulli_weights.unsqueeze(0).repeat(b_s,1,1).to(opt.device)
                    input = t.bmm(weight,input_m).squeeze_(2)
                    if(opt.noise_snr>0):
                        input = add_noise(input,opt.noise_snr,10)

                    output,_ = model(input,ref,input)
                    result_b = output.view(1,blk_num_h,blk_num_w,opt.blk_size,opt.blk_size)

                    frame_cnt = frame_cnt + 1
                    result_frame = uv.frame_fold(result_b,opt.blk_size,int(opt.blk_size/2))
                 
                    result_frames[ii] = result_frame
                    psnr = compare_psnr((frames_t[ii].unsqueeze(0)).cpu().numpy(),(result_frame*255).cpu().numpy(),data_range=255)
                    ssim = compare_ssim(frames_t[ii].cpu().numpy(),(result_frame*255).squeeze(0).cpu().numpy())
                    psnr_total = psnr_total + psnr
                    ssim_total = ssim_total + ssim
                
                result_data_t[i,j,:,:,:] = result_frames

        uv.array2video(result_data_t,opt.save_test_root)

        #get log information
        video_time = time.time() - end
        info = str(psnr_total/frame_cnt)+"\n"
        log_file.write("%s"%info)
        info = str(ssim_total/frame_cnt)+"\n"
        log_file.write("%s"%info)
        info = str(video_time/frame_cnt)+"\n"
        log_file.write("%s"%info)
        end = time.time()

        print("PSNR is:",psnr_total/frame_cnt,"SSIM is:",ssim_total/frame_cnt,"Time per frame is:",video_time/frame_cnt)
        psnr_av = psnr_av + psnr_total/frame_cnt
        ssim_av = ssim_av + ssim_total/frame_cnt
        time_av = time_av + video_time/frame_cnt
    log_file.close()

    print("Average PSNR is:",psnr_av/video_num,"Average SSIM is:",ssim_av/video_num,"Average Time per frame is:",time_av/video_num)
Beispiel #33
0
 def __call__(self, tensor):
     assert tensor.size(0) == 3
     alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
     quatity = torch.mm(self.eig_val * alpha, self.eig_vec)
     tensor = tensor + quatity.view(3, 1, 1)
     return tensor
Beispiel #34
0
  cand_tensor_2d.scatter_(1, head_indices, 1)
  cand_tensor_3d = torch.zeros_like(logits)
  cand_tensor_3d.scatter_(1, dep_indices_, cand_tensor_2d.unsqueeze(1))

  #print ("max_tensor_2d:\n", max_tensor_2d)
  #print ("dep_indices_:\n", dep_indices_)
  #print ("logits_:\n", logits_)
  #print ("head_indices:\n", head_indices)
  #print ("cand_tensor_2d:\n", cand_tensor_2d)
  print ("cand_tensor_3d:\n", cand_tensor_3d)


batch = 2
length = 3

logp = torch.Tensor(batch, length, length).random_() % 10
print ('logp:\n',logp)

#a = torch.Tensor([.9,.2,.5])
#flag = torch.ge(a, 0.5).sum() == 2
#if flag:
#	print ("flag")

#exit()

#_get_recomp_logp(logp)
max_tensor = torch.zeros_like(logp)
max_tensor[0,2,2] = 1
max_tensor[1,2,1] = 1
print ("max_tensor:\n", max_tensor)
get_topk_v2(2, max_tensor, logp)
Beispiel #35
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            square_avgs = []
            grad_avgs = []
            momentum_buffer_list = []

            for p in group['params']:
                if p.grad is None:
                    continue
                params_with_grad.append(p)

                if p.grad.is_sparse:
                    raise RuntimeError(
                        'RMSprop does not support sparse gradients')
                grads.append(p.grad)

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['square_avg'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)
                    if group['momentum'] > 0:
                        state['momentum_buffer'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)
                    if group['centered']:
                        state['grad_avg'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)

                square_avgs.append(state['square_avg'])

                if group['momentum'] > 0:
                    momentum_buffer_list.append(state['momentum_buffer'])
                if group['centered']:
                    grad_avgs.append(state['grad_avg'])

                state['step'] += 1

            F.rmsprop(params_with_grad,
                      grads,
                      square_avgs,
                      grad_avgs,
                      momentum_buffer_list,
                      lr=group['lr'],
                      alpha=group['alpha'],
                      eps=group['eps'],
                      weight_decay=group['weight_decay'],
                      momentum=group['momentum'],
                      centered=group['centered'])

        return loss
Beispiel #36
0
def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, window_width,
                                lowpass_cutoff, lowpass_filter_width):
    r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
    resampling as well as the indices in which they are valid. LinearResample (LR) means
    that the output signal is at linearly spaced intervals (i.e the output signal has a
    frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
    the signal.

    The reason why the same filter is not used for multiple convolutions is because the
    sinc function could sampled at different points in time. For example, suppose
    a signal is sampled at the timestamps (seconds)
    0         16        32
    and we want it to be sampled at the timestamps (seconds)
    0 5 10 15   20 25 30  35
    at the timestamp of 16, the delta timestamps are
    16 11 6 1   4  9  14  19
    at the timestamp of 32, the delta timestamps are
    32 27 22 17 12 8 2    3

    As we can see from deltas, the sinc function is sampled at different points of time
    assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....]
    for 16 vs [...., 2, 3, ....] for 32)

    Example, one case is when the ``orig_freq`` and ``new_freq`` are multiples of each other then
    there needs to be one filter.

    A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function
    has infinite support (non-zero for all values) so instead it is truncated and multiplied by
    a window function which gives it less-than-perfect rolloff [1].

    [1] Chapter 16: Windowed-Sinc Filters, https://www.dspguide.com/ch16/1.htm

    Args:
        orig_freq (float): The original frequency of the signal
        new_freq (float): The desired frequency
        output_samples_in_unit (int): The number of output samples in the smallest repeating unit:
            num_samp_out = new_freq / Gcd(orig_freq, new_freq)
        window_width (float): The width of the window which is nonzero
        lowpass_cutoff (float): The filter cutoff in Hz. The filter cutoff needs to be less
            than samp_rate_in_hz/2 and less than samp_rate_out_hz/2.
        lowpass_filter_width (int): Controls the sharpness of the filter, more == sharper but less
            efficient. We suggest around 4 to 10 for normal use

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: A tuple of ``min_input_index`` (which is the minimum indices
        where the window is valid, size (``output_samples_in_unit``)) and ``weights`` (which is the weights
        which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
    """
    assert lowpass_cutoff < min(orig_freq, new_freq) / 2
    output_t = torch.arange(0, output_samples_in_unit, dtype=torch.get_default_dtype()) / new_freq
    min_t = output_t - window_width
    max_t = output_t + window_width

    min_input_index = torch.ceil(min_t * orig_freq)  # size (output_samples_in_unit)
    max_input_index = torch.floor(max_t * orig_freq)  # size (output_samples_in_unit)
    num_indices = max_input_index - min_input_index + 1  # size (output_samples_in_unit)

    max_weight_width = num_indices.max()
    # create a group of weights of size (output_samples_in_unit, max_weight_width)
    j = torch.arange(max_weight_width).unsqueeze(0)
    input_index = min_input_index.unsqueeze(1) + j
    delta_t = (input_index / orig_freq) - output_t.unsqueeze(1)

    weights = torch.zeros_like(delta_t)
    inside_window_indices = delta_t.abs().lt(window_width)
    # raised-cosine (Hanning) window with width `window_width`
    weights[inside_window_indices] = 0.5 * (1 + torch.cos(2 * math.pi * lowpass_cutoff /
                                            lowpass_filter_width * delta_t[inside_window_indices]))

    t_eq_zero_indices = delta_t.eq(0.0)
    t_not_eq_zero_indices = ~t_eq_zero_indices
    # sinc filter function
    weights[t_not_eq_zero_indices] *= torch.sin(
        2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (math.pi * delta_t[t_not_eq_zero_indices])
    # limit of the function at t = 0
    weights[t_eq_zero_indices] *= 2 * lowpass_cutoff

    weights /= orig_freq  # size (output_samples_in_unit, max_weight_width)
    return min_input_index, weights
Beispiel #37
0
def show_full_render(neural_radiance_field,
                     camera,
                     target_image,
                     target_silhouette,
                     loss_history_color,
                     loss_history_sil,
                     renderer_grid,
                     num_forward=1):
    """
    This is a helper function for visualizing the
    intermediate results of the learning.

    Since the `NeuralRadianceField` suffers from
    a large memory footprint, which does not allow to
    render the full image grid in a single forward pass,
    we utilize the `NeuralRadianceField.batched_forward`
    function in combination with disabling the gradient caching.
    This chunks the set of emitted rays to batches and
    evaluates the implicit function on one-batch at a time
    to prevent GPU memory overflow.
    """

    rendered_image_list, rendered_silhouette_list = [], []
    # Prevent gradient caching.
    with torch.no_grad():
        for _ in range(num_forward):
            # Render using the grid renderer and the
            # batched_forward function of neural_radiance_field.
            rendered_image_silhouette, _ = renderer_grid(
                cameras=camera,
                volumetric_function=partial(batched_forward,
                                            net=neural_radiance_field))
            # Split the rendering result to a silhouette render
            # and the image render.
            rendered_image_, rendered_silhouette_ = (
                rendered_image_silhouette[0].split([3, 1], dim=-1))
            rendered_image_list.append(rendered_image_)
            rendered_silhouette_list.append(rendered_silhouette_)

        rendered_images = torch.stack(rendered_image_list)
        rendered_image = rendered_images.mean(0)

        rendered_silhouettes = torch.stack(rendered_silhouette_list)
        rendered_silhouette = rendered_silhouettes.mean(0)

        if num_forward > 1:
            rendered_image_std = rendered_images.var(0).sum(-1).sqrt()
            rendered_silhouette_std = rendered_silhouettes.std(0)
        else:
            rendered_image_std = torch.zeros_like(rendered_image[..., 0])
            rendered_silhouette_std = torch.zeros_like(rendered_silhouette)

    print(f"Max image std: {rendered_image_std.max().item():.4f}; "
          f"max image: {rendered_image.max().item():.4f}; "
          f"max silhouette std: {rendered_silhouette_std.max().item():.4f}; "
          f"max silhouette: {rendered_silhouette.max().item():.4f}")
    # Generate plots.
    fig, ax = plt.subplots(2, 4, figsize=(20, 10))
    ax = ax.ravel()
    clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()
    ax[0].plot(list(range(len(loss_history_color))),
               loss_history_color,
               linewidth=1)
    ax[1].imshow(clamp_and_detach(rendered_image))
    ax[2].imshow(clamp_and_detach(rendered_silhouette[..., 0]))
    ax[3].imshow(clamp_and_detach(rendered_image_std),
                 cmap="hot",
                 vmax=0.75**0.5)
    ax[4].plot(list(range(len(loss_history_sil))),
               loss_history_sil,
               linewidth=1)
    ax[5].imshow(clamp_and_detach(target_image))
    ax[6].imshow(clamp_and_detach(target_silhouette))
    ax[7].imshow(clamp_and_detach(rendered_silhouette_std),
                 cmap="hot",
                 vmax=0.5)
    for ax_, title_ in zip(
            ax, ("loss color", "rendered image", "rendered silhouette",
                 "image uncertainty", "loss silhouette", "target image",
                 "target silhouette", "silhouette uncertainty")):
        if not title_.startswith('loss'):
            ax_.grid("off")
            ax_.axis("off")
        ax_.set_title(title_)
    fig.canvas.draw()
    fig.show()
    return fig
Beispiel #38
0
 def postprocess(self, result_batch: InferenceResultBatch) -> InferenceResultBatch:
     predictions = result_batch.get_predictions(self.prediction_subscription_key)
     binarized_outputs = torch.zeros_like(predictions).int()
     binarized_outputs[predictions > self.threshold] = 1
     result_batch.add_predictions(key=self.prediction_publication_key, predictions=binarized_outputs)
     return result_batch
Beispiel #39
0
def get_mel_banks(num_bins, window_length_padded, sample_freq,
                  low_freq, high_freq, vtln_low, vtln_high, vtln_warp_factor):
    # type: (int, int, float, float, float, float, float)
    """
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The tuple consists of ``bins`` (which is
        melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
        center frequencies of bins of size (``num_bins``)).
    """
    assert num_bins > 3, 'Must have at least 3 mel bins'
    assert window_length_padded % 2 == 0
    num_fft_bins = window_length_padded / 2
    nyquist = 0.5 * sample_freq

    if high_freq <= 0.0:
        high_freq += nyquist

    assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
        ('Bad values in options: low-freq %f and high-freq %f vs. nyquist %f' % (low_freq, high_freq, nyquist))

    # fft-bin width [think of it as Nyquist-freq / half-window-length]
    fft_bin_width = sample_freq / window_length_padded
    mel_low_freq = mel_scale_scalar(low_freq)
    mel_high_freq = mel_scale_scalar(high_freq)

    # divide by num_bins+1 in next line because of end-effects where the bins
    # spread out to the sides.
    mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)

    if vtln_high < 0.0:
        vtln_high += nyquist

    assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
                                       (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
        ('Bad values in options: vtln-low %f and vtln-high %f, versus low-freq %f and high-freq %f' %
            (vtln_low, vtln_high, low_freq, high_freq))

    bin = torch.arange(num_bins, dtype=torch.get_default_dtype()).unsqueeze(1)
    left_mel = mel_low_freq + bin * mel_freq_delta  # size(num_bins, 1)
    center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta  # size(num_bins, 1)
    right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta  # size(num_bins, 1)

    if vtln_warp_factor != 1.0:
        left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
        center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
        right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)

    center_freqs = inverse_mel_scale(center_mel)  # size (num_bins)
    # size(1, num_fft_bins)
    mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins, dtype=torch.get_default_dtype())).unsqueeze(0)

    # size (num_bins, num_fft_bins)
    up_slope = (mel - left_mel) / (center_mel - left_mel)
    down_slope = (right_mel - mel) / (right_mel - center_mel)

    if vtln_warp_factor == 1.0:
        # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
        bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
    else:
        # warping can move the order of left_mel, center_mel, right_mel anywhere
        bins = torch.zeros_like(up_slope)
        up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel)  # left_mel < mel <= center_mel
        down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel)  # center_mel < mel < right_mel
        bins[up_idx] = up_slope[up_idx]
        bins[down_idx] = down_slope[down_idx]

    return bins, center_freqs
Beispiel #40
0
def compute_loss(p, targets, model):  # predictions, targets, model
    device = p[0].device
    lcls = torch.zeros(1, device=device)  # Tensor(0)
    lbox = torch.zeros(1, device=device)  # Tensor(0)
    lobj = torch.zeros(1, device=device)  # Tensor(0)
    tcls, tbox, indices, anchors = build_targets(p, targets, model)  # targets
    h = model.hyp  # hyperparameters
    red = 'mean'  # Loss reduction (sum or mean)

    # Define criteria
    BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']],
                                                          device=device),
                                  reduction=red)
    BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']],
                                                          device=device),
                                  reduction=red)

    # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
    cp, cn = smooth_BCE(eps=0.0)

    # focal loss
    g = h['fl_gamma']  # focal loss gamma
    if g > 0:
        BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

    # per output
    nt = 0  # targets
    for i, pi in enumerate(p):  # layer index, layer predictions
        b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
        tobj = torch.zeros_like(pi[..., 0], device=device)  # target obj

        nb = b.shape[0]  # number of targets
        if nb:
            # 对应匹配到正样本的预测信息
            ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets

            # GIoU
            pxy = ps[:, :2].sigmoid()
            pwh = ps[:, 2:4].exp().clamp(max=1E3) * anchors[i]
            pbox = torch.cat((pxy, pwh), 1)  # predicted box
            giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False,
                            GIoU=True)  # giou(prediction, target)
            lbox += (1.0 - giou).mean()  # giou loss

            # Obj
            tobj[b, a, gj,
                 gi] = (1.0 -
                        model.gr) + model.gr * giou.detach().clamp(0).type(
                            tobj.dtype)  # giou ratio

            # Class
            if model.nc > 1:  # cls loss (only if multiple classes)
                t = torch.full_like(ps[:, 5:], cn, device=device)  # targets
                t[range(nb), tcls[i]] = cp
                lcls += BCEcls(ps[:, 5:], t)  # BCE

            # Append targets to text file
            # with open('targets.txt', 'a') as file:
            #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]

        lobj += BCEobj(pi[..., 4], tobj)  # obj loss

    # 乘上每种损失的对应权重
    lbox *= h['giou']
    lobj *= h['obj']
    lcls *= h['cls']

    # loss = lbox + lobj + lcls
    return {"box_loss": lbox, "obj_loss": lobj, "class_loss": lcls}
Beispiel #41
0
def evaluate(model, data_loader, device, use_squad_v2):
    """ Evaluate on dev questions
    @param model (Module): Question Generation Model
    @param data_loader (DataLoader): DataLoader to load dev examples in batches
    @param device (string): 'cuda:0' or 'cpu'
    @param use_squad_v2 (bool): boolean flag to indicate whether to use SQuAD 2.0 
    @returns results (dictionary of on dev questions)
    """
    nll_meter = util.AverageMeter()

    was_training = model.training
    model.eval()

    total_loss = 0.
    total_words = 0.

    criterion = nn.NLLLoss(ignore_index=PAD, reduction='sum')
    loss_compute = SimpleLossCompute(criterion)

    # no_grad() signals backend to throw away all gradients
    with torch.no_grad():
        for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in data_loader:

            cw_idxs = cw_idxs.to(device)
            qw_idxs = qw_idxs.to(device)
            batch_size = cw_idxs.size(0)

            # Setup for forward
            src_idxs = cw_idxs
            src_idxs = torch.cat(
                (torch.zeros((batch_size, 1), device=device,
                             dtype=torch.long), src_idxs,
                 torch.zeros(
                     (batch_size, 1), device=device, dtype=torch.long)),
                dim=-1)
            src_idxs[:, 0] = SOS
            src_idxs[:, -1] = EOS
            tgt_idxs = qw_idxs[:, :-1]
            tgt_idxs_y = qw_idxs[:, 1:]

            src_mask = src_idxs == PAD
            tgt_mask = tgt_idxs == PAD

            # Forward
            log_p = model(src_idxs, tgt_idxs, src_mask,
                          tgt_mask)  #(batch_size, q_len, vocab_size)
            log_p = log_p.contiguous().view(-1, log_p.size(-1))

            tgt_idxs_y = tgt_idxs_y.contiguous().view(-1)

            tgt_no_pad = torch.zeros_like(tgt_idxs) != tgt_idxs
            tgt_len = tgt_no_pad.sum(-1)
            batch_words = torch.sum(tgt_len).item()
            #loss = F.nll_loss(log_p, qw_idxs_target, ignore_index=0, reduction='sum')

            batch_loss = loss_compute(log_p, tgt_idxs_y, batch_words,
                                      model.training)
            loss_val = batch_loss / batch_words

            nll_meter.update(loss_val, batch_words)

            # Calculate perplexity
            total_loss += batch_loss
            total_words += batch_words

        ppl = np.exp(total_loss / total_words)
        avg_loss = nll_meter.avg

    results_list = [('NLL', avg_loss), \
                ('PPL', ppl)]
    results = OrderedDict(results_list)

    if was_training:
        model.train()

    return results
def process_train_MAM_data(spec, config=None):
    """Process training data for the masked acoustic model"""

    dr = config['downsample_rate'] if config is not None else DR
    hidden_size = config['hidden_size'] if config is not None else HIDDEN_SIZE
    mask_proportion = config['mask_proportion'] if config is not None else MASK_PROPORTION
    mask_consecutive_min = config['mask_consecutive_min'] if config is not None else MASK_CONSECUTIVE
    mask_consecutive_max = config['mask_consecutive_max'] if config is not None else MASK_CONSECUTIVE
    mask_allow_overlap = config['mask_allow_overlap'] if config is not None else True
    mask_bucket_ratio = config['mask_bucket_ratio'] if config is not None else MASK_BUCKET_RATIO
    mask_frequency = config['mask_frequency'] if config is not None else MASK_FREQUENCY
    noise_proportion = config['noise_proportion'] if config is not None else NOISE_PROPORTION
    test_reconstruct = False

    with torch.no_grad():
        if len(spec) == 2: # if self.duo_feature: dataloader will output `source_spec` and `target_spec`
            source_spec = spec[0]
            target_spec = spec[1]
        elif len(spec) == 1:
            source_spec = spec[0]
            target_spec = copy.deepcopy(spec[0])
        else:
            raise NotImplementedError('Input spec sould be either (spec,) or (target_spec, source_spec), where `spec` has shape BxTxD.')

        # Down sample
        spec_masked = down_sample_frames(source_spec, dr) # (batch_size, seq_len, mel_dim * dr)
        spec_stacked = down_sample_frames(target_spec, dr) # (batch_size, seq_len, mel_dim * dr)
        assert(spec_masked.shape[1] == spec_stacked.shape[1]), 'Input and output spectrogram should have the same shape'

        # Record length for each uttr
        spec_len = (spec_stacked.sum(dim=-1) != 0).long().sum(dim=-1).tolist()
        batch_size = spec_stacked.shape[0]
        seq_len = spec_stacked.shape[1]
        
        pos_enc = fast_position_encoding(seq_len, hidden_size) # (seq_len, hidden_size)
        mask_label = torch.zeros_like(spec_stacked, dtype=torch.uint8)
        attn_mask = torch.ones((batch_size, seq_len)) # (batch_size, seq_len)

        for idx in range(batch_size):
            # zero vectors for padding dimension
            attn_mask[idx, spec_len[idx]:] = 0

            if test_reconstruct:
                mask_label[idx, :, :] = 1
                continue

            def starts_to_intervals(starts, consecutive):
                tiled = starts.expand(consecutive, starts.size(0)).permute(1, 0)
                offset = torch.arange(consecutive).expand_as(tiled)
                intervals = tiled + offset
                return intervals.view(-1)
            
            # time masking
            mask_consecutive = random.randint(mask_consecutive_min, mask_consecutive_max)
            valid_start_max = max(spec_len[idx] - mask_consecutive - 1, 0) # compute max valid start point for a consecutive mask
            proportion = round(spec_len[idx] * mask_proportion / mask_consecutive)
            if mask_allow_overlap:
                # draw `proportion` samples from the range (0, valid_index_range) and without replacement
                chosen_starts = torch.randperm(valid_start_max + 1)[:proportion]
            else:
                mask_bucket_size = round(mask_consecutive * mask_bucket_ratio)
                rand_start = random.randint(0, min(mask_consecutive, valid_start_max))
                valid_starts = torch.arange(rand_start, valid_start_max + 1, mask_bucket_size)
                chosen_starts = valid_starts[torch.randperm(len(valid_starts))[:proportion]]
            chosen_intervals = starts_to_intervals(chosen_starts, mask_consecutive)
            
            # determine whether to mask / random / or do nothing to the frame
            dice = random.random()
            # mask to zero
            if dice < 0.8:
                spec_masked[idx, chosen_intervals, :] = 0
            # replace to random frames
            elif dice >= 0.8 and dice < 0.9:
                random_starts = torch.randperm(valid_start_max + 1)[:proportion]
                random_intervals = starts_to_intervals(random_starts, mask_consecutive)
                spec_masked[idx, chosen_intervals, :] = spec_masked[idx, random_intervals, :]
            # do nothing
            else:
                pass

            # the gradients will be calculated on chosen frames
            mask_label[idx, chosen_intervals, :] = 1

            # frequency masking
            if mask_frequency > 0:
                rand_bandwidth = random.randint(0, mask_frequency)
                chosen_starts = torch.randperm(spec_masked.shape[2] - rand_bandwidth)[:1]
                chosen_intervals = starts_to_intervals(chosen_starts, rand_bandwidth)
                spec_masked[idx, :, chosen_intervals] = 0
                
                # the gradients will be calculated on chosen frames
                mask_label[idx, :, chosen_intervals] = 1   

        if not test_reconstruct:
            # noise augmentation
            dice = random.random()
            if dice < noise_proportion:
                noise_sampler = torch.distributions.Normal(0, 0.2)
                spec_masked += noise_sampler.sample(spec_masked.shape)
        
        valid_batchid = mask_label.view(batch_size, -1).sum(dim=-1).nonzero().view(-1)
        batch_is_valid = len(valid_batchid) > 0
        spec_masked = spec_masked.to(dtype=torch.float32)[valid_batchid]
        pos_enc = pos_enc.to(dtype=torch.float32)
        mask_label = mask_label.to(dtype=torch.uint8)[valid_batchid]
        attn_mask = attn_mask.to(dtype=torch.float32)[valid_batchid]
        spec_stacked = spec_stacked.to(dtype=torch.float32)[valid_batchid]

    return batch_is_valid, spec_masked, pos_enc, mask_label, attn_mask, spec_stacked
    def forward(self, predictions, targets):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net.
                conf shape: torch.size(batch_size,num_priors,num_classes)
                loc shape: torch.size(batch_size,num_priors,4)
                priors shape: torch.size(num_priors,4)

            targets (tensor): Ground truth boxes and labels for a batch,
                shape: [batch_size,num_objs,5] (last idx is the label).
        """
        loc_data, conf_data, priors = predictions
        num = loc_data.size(0)
        priors = priors[:loc_data.size(1), :]
        num_priors = priors.size(0)
        num_classes = self.num_classes

        # match priors (default boxes) and ground truth boxes
        loc_t = torch.Tensor(num, num_priors, 4)
        conf_t = torch.LongTensor(num, num_priors)
        for idx in range(num):
            truths = targets[idx][:, :-1].data
            labels = targets[idx][:, -1].data
            defaults = priors.cpu().data  ## add cpu
            match(self.threshold, truths, defaults, self.variance, labels,
                  loc_t, conf_t, idx)
        if self.use_gpu:
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()
        # wrap targets
        loc_t = Variable(loc_t, requires_grad=False)
        conf_t = Variable(conf_t, requires_grad=False)

        pos = conf_t > 0
        # num_pos = pos.sum()

        # Localization Loss (Smooth L1)
        # Shape: [batch,num_priors,4]
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
        loc_p = loc_data[pos_idx].view(-1, 4)
        loc_t = loc_t[pos_idx].view(-1, 4)
        loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
        # Avoid over stack
        loss_l = torch.clamp(loss_l, min=-9999.0, max=9999.0)

        # Compute max conf across batch for hard negative mining
        batch_conf = conf_data.view(-1, self.num_classes)
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))

        # Avoid over stack
        loss_c = torch.clamp(loss_c, min=-9999.0, max=9999.0)

        # Hard Negative Mining
        loss_c[pos.view(-1,1)] = 0  # filter out pos boxes for now
        loss_c = loss_c.view(num, -1)
        _, loss_idx = loss_c.sort(1, descending=True)
        _, idx_rank = loss_idx.sort(1)
        num_pos = pos.long().sum(1, keepdim=True)
        num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
        neg = idx_rank < num_neg.expand_as(idx_rank)

        # Confidence Loss Including Positive and Negative Examples
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
        targets_weighted = conf_t[(pos+neg).gt(0)]
        if self.lm:
            targets_onehot = torch.zeros_like(conf_p)
            targets_onehot.zero_().scatter_(1, targets_weighted.unsqueeze(-1), 1)
            # Label Smoothing
            targets_onehot = targets_onehot * 0.99 + targets_onehot * 0.01 * 20.0 / 21.0
            outputs = F.softmax(conf_p, dim=1)
            loss_c = -targets_onehot.float() * torch.log(outputs)

            N = num_pos.data.sum().float() * 10.0
            loss_l /= N
            loss_c /= N
            loss_c = loss_c.sum()
            # assert loss_l.size(0) == loss_c.size(0), "Should be same???"
            return loss_l, loss_c

        loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)

        # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N

        N = max(num_pos.data.sum().float(),1)
        loss_l /= N
        loss_c /= N
        return loss_l, loss_c
Beispiel #44
0
    def forward(self,
                sender,
                receiver,
                loss,
                sender_input,
                labels,
                receiver_input=None):
        message, log_prob_s, entropy_s = sender(sender_input)
        message_length = find_lengths(message)
        receiver_output, log_prob_r, entropy_r = receiver(
            message, receiver_input, message_length)

        loss, aux_info = loss(sender_input, message, receiver_input,
                              receiver_output, labels)

        # the entropy of the outputs of S before and including the eos symbol - as we don't care about what's after
        effective_entropy_s = torch.zeros_like(entropy_r)

        # the log prob of the choices made by S before and including the eos symbol - again, we don't
        # care about the rest
        effective_log_prob_s = torch.zeros_like(log_prob_r)

        for i in range(message.size(1)):
            not_eosed = (i < message_length).float()
            effective_entropy_s += entropy_s[:, i] * not_eosed
            effective_log_prob_s += log_prob_s[:, i] * not_eosed
        effective_entropy_s = effective_entropy_s / message_length.float()
        effective_log_prob_s = effective_log_prob_s / message_length.float()

        weighted_entropy = (
            effective_entropy_s.mean() * self.sender_entropy_coeff +
            entropy_r.mean() * self.receiver_entropy_coeff)

        log_prob = effective_log_prob_s + log_prob_r

        length_loss = message_length.float() * self.length_cost

        policy_length_loss = (
            (length_loss - self.baselines["length"].predict(length_loss)) *
            effective_log_prob_s).mean()
        policy_loss = (
            (loss.detach() - self.baselines["loss"].predict(loss.detach())) *
            log_prob).mean()

        optimized_loss = policy_length_loss + policy_loss - weighted_entropy
        # if the receiver is deterministic/differentiable, we apply the actual loss
        optimized_loss += loss.mean()

        if self.training:
            self.baselines["loss"].update(loss)
            self.baselines["length"].update(length_loss)

        aux_info["sender_entropy"] = entropy_s.detach()
        aux_info["receiver_entropy"] = entropy_r.detach()
        aux_info["length"] = message_length.float()  # will be averaged

        logging_strategy = (self.train_logging_strategy
                            if self.training else self.test_logging_strategy)
        interaction = logging_strategy.filtered_interaction(
            sender_input=sender_input,
            labels=labels,
            receiver_input=receiver_input,
            message=message.detach(),
            receiver_output=receiver_output.detach(),
            message_length=message_length,
            aux=aux_info,
        )

        return optimized_loss, interaction
def match_priors(prior_bboxes, gt_bboxes, gt_labels, iou_threshold=0.5):
    """
    Match the ground-truth boxes with the priors.
    Note: Use this function in your ''cityscape_dataset.py', see the SSD paper page 5 for reference.

    :param gt_bboxes: ground-truth bounding boxes, dim:(n_samples, 4)
    :param gt_labels: ground-truth classification labels, negative (background) = 0, dim: (n_samples)
    :param prior_bboxes: prior bounding boxes on different levels, dim:(num_priors, 4)
    :param iou_threshold: matching criterion
    :return matched_boxes: real matched bounding box, dim: (num_priors, 4)
    :return matched_labels: real matched classification label, dim: (num_priors)
    """
    # [DEBUG] Check if input is the desire shape
    assert gt_bboxes.dim() == 2
    assert gt_bboxes.shape[1] == 4
    assert gt_labels.dim() == 1
    assert gt_labels.shape[0] == gt_bboxes.shape[0]
    assert prior_bboxes.dim() == 2
    assert prior_bboxes.shape[1] == 4

    matched_boxes = torch.zeros_like(prior_bboxes)
    matched_labels = torch.zeros(len(prior_bboxes), dtype=torch.long)

    # TODO: implement prior matching
    # Compute the Jaccard's Similarity using the iou function.
    jaccard_sim = iou(gt_bboxes, center2corner(prior_bboxes))

    # Find the best prior matching gt and vice versa along with the
    # corresponding indices.
    best_prior_sim, best_prior_idx = jaccard_sim.max(1)
    best_gt_sim, best_gt_idx = jaccard_sim.max(0)

    # Make sure every best prior that was selected for a ground truth
    # is not eliminated.
    best_gt_sim.index_fill_(0, best_prior_idx, 1)

    # Make sure the ground truth that needs the prior more is selected.
    for i in range(len(best_prior_idx)):
        best_gt_idx[best_prior_idx[i]] = i

    # Remove the priors that do no meet the threshold overlap
    best_prior_idx = torch.arange(len(best_gt_idx))
    best_prior_idx = (best_prior_idx + 1) * (best_gt_sim > iou_threshold).long()
    best_prior_idx = best_prior_idx[best_prior_idx != 0]
    best_prior_idx -= 1

    # Remove the ground truth corresponding to the removed prior
    best_gt_idx = best_gt_idx[best_gt_sim > iou_threshold]

    # Convert the ground truth bounding boxes to center format
    matched_boxes[best_prior_idx] = corner2center(gt_bboxes[best_gt_idx])

    # Convert the ground truth bounding boxes to ssd locations.
    matched_boxes[best_prior_idx] = bbox2loc(matched_boxes[best_prior_idx],
                                             prior_bboxes[best_prior_idx])

    # Extract the ground truth labels for each prior
    matched_labels[best_prior_idx] = gt_labels[best_gt_idx]

    # [DEBUG] Check if output is the desire shape
    assert matched_boxes.dim() == 2
    assert matched_boxes.shape[1] == 4
    assert matched_labels.dim() == 1
    assert matched_labels.shape[0] == matched_boxes.shape[0]

    return matched_boxes, matched_labels
Beispiel #46
0
    def build_target(self, pred, target):
        target_num = target.size(0)
        tcls, tbox, indices, anch = [], [], [], []
        # normalized to gridspace gain
        gain = torch.ones(7, device=self.device)
        # targets: (N, 6) -> (3, N, 7), at last append anchor index
        target = [
            torch.cat(
                [target, i.repeat(target_num).view(target_num, 1)], dim=1)
            for i in torch.arange(self.anchor_num, device=self.device)
        ]
        target = torch.stack(target, dim=0)

        off = torch.tensor([[0, 0], [1, 0], [0, 1], [-1, 0], [0, -1]],
                           device=self.device) * 0.5
        for i in range(self.detect_layer_num):
            anchors = self.anchors[i]
            # pred: [(n, c, h1, w1, 9), (n, c, h2, w2, 9), (n, c, h3, w3, 9)] gain: [1, 1, w, h, w, h, 1]
            gain[2:6] = torch.tensor(pred[i].shape)[[3, 2, 3, 2]]
            # target: [[[id, cls, cx, cy, w, h, anchor_id], ...n...]]
            # normalize by w, h -> origin size
            t = target * gain
            if target_num:
                # anchors: (3, 2) -> (3, 1, 2)
                wh_radio = t[:, :, 4:6] / anchors.unsqueeze(dim=1)
                # index: (3, N) t: (3, N, 7) -> (n, 7)
                index = torch.max(wh_radio,
                                  1 / wh_radio).max(dim=2)[0] < self.anchor_t
                t = t[index]
                # cxcy: [[cx, xy], ...n...]
                cxcy = t[:, 2:4]
                # inverse_cxcy: [[w-cx, h-cy], ...n...]
                inverse_cxcy = gain[[2, 3]] - cxcy
                # cx_index: x_index  cy_index: y_index
                cx_index, cy_index = ((cxcy.fmod(1.) < 0.5) & (cxcy > 1.)).T
                # inverse_cx_index: x_index  inverse_cy_index: y_index
                inverse_cx_index, inverse_cy_index = ((inverse_cxcy % 1. < 0.5)
                                                      & (inverse_cxcy > 1.)).T
                # cx_index: (n) -> (5, n)
                cx_index = torch.stack(
                    (torch.ones_like(cx_index), cx_index, cy_index,
                     inverse_cx_index, inverse_cy_index))
                # t: (n, 7) -> (5, n, 7) -> (n', 7)
                t = t.unsqueeze(dim=0).repeat((5, 1, 1))[cx_index]
                offsets = (torch.zeros_like(cxcy.unsqueeze(dim=0)) +
                           off.unsqueeze(dim=1))[cx_index]
                # offsets = (torch.zeros_like(cxcy)[None] + off[:, None])[cx_index]
            else:
                t = target[0]
                offsets = 0

            img_id, cls = t[:, :2].long().T
            cxcy = t[:, 2:4]
            wh = t[:, 4:6]
            cxcy_index = (cxcy - offsets).long()
            cx_index, cy_index = cxcy_index.T

            anchor_index = t[:, 6].long()
            indices.append(
                (img_id, anchor_index, cy_index.clamp_(0, gain[3] - 1),
                 cx_index.clamp_(0, gain[2] - 1)))
            tbox.append(torch.cat((cxcy - cxcy_index, wh), 1))
            anch.append(anchors[anchor_index])
            tcls.append(cls)
        return tcls, tbox, indices, anch
 def label_compress(this, label, labels_in_task, nidx):
     compact_ids = torch.zeros_like(label) + this.label_dict["[UNK]"];  # fill with unk
     for i in range(len(labels_in_task)):
         # print("here");
         compact_ids[label == labels_in_task[i]] = nidx[i];
     return compact_ids;
Beispiel #48
0
    def do_one_epoch(self, epoch, episodes):
        mode = "train" if self.encoder.training and self.classifier1.training else "val"
        epoch_loss, accuracy, steps = 0., 0., 0
        accuracy1, accuracy2 = 0., 0.
        epoch_loss1, epoch_loss2 = 0., 0.
        data_generator = self.generate_batch(episodes)
        for x_t, x_tprev, x_that, ts, thats in data_generator:
            f_t_maps, f_t_prev_maps = self.encoder(
                x_t, fmaps=True), self.encoder(x_tprev, fmaps=True)
            f_t_hat_maps = self.encoder(x_that, fmaps=True)

            # Loss 1: Global at time t, f5 patches at time t-1
            f_t, f_t_prev = f_t_maps['out'], f_t_prev_maps['f5']
            f_t_hat = f_t_hat_maps['f5']
            f_t = f_t.unsqueeze(1).unsqueeze(1).expand(
                -1, f_t_prev.size(1), f_t_prev.size(2),
                self.encoder.hidden_size)

            target = torch.cat((torch.ones_like(
                f_t[:, :, :, 0]), torch.zeros_like(f_t[:, :, :, 0])),
                               dim=0).to(self.device)

            x1, x2 = torch.cat([f_t, f_t],
                               dim=0), torch.cat([f_t_prev, f_t_hat], dim=0)
            shuffled_idxs = torch.randperm(len(target))
            x1, x2, target = x1[shuffled_idxs], x2[shuffled_idxs], target[
                shuffled_idxs]
            self.optimizer.zero_grad()
            loss1 = self.loss_fn(self.classifier1(x1, x2).squeeze(), target)

            # Loss 2: f5 patches at time t, with f5 patches at time t-1
            f_t = f_t_maps['f5']
            x1_p, x2_p = torch.cat([f_t, f_t],
                                   dim=0), torch.cat([f_t_prev, f_t_hat],
                                                     dim=0)
            x1_p, x2_p = x1_p[shuffled_idxs], x2_p[shuffled_idxs]
            loss2 = self.loss_fn(
                self.classifier2(x1_p, x2_p).squeeze(), target)

            loss = loss1 + loss2
            if mode == "train":
                loss.backward()
                self.optimizer.step()

            epoch_loss += loss.detach().item()
            epoch_loss1 += loss1.detach().item()
            epoch_loss2 += loss2.detach().item()
            preds1 = torch.sigmoid(self.classifier1(x1, x2).squeeze())
            accuracy1 += calculate_accuracy(preds1, target)
            preds2 = torch.sigmoid(self.classifier2(x1_p, x2_p).squeeze())
            accuracy2 += calculate_accuracy(preds2, target)
            steps += 1
        self.log_results(epoch,
                         epoch_loss1 / steps,
                         epoch_loss2 / steps,
                         epoch_loss / steps,
                         accuracy1 / steps,
                         accuracy2 / steps, (accuracy1 + accuracy2) / steps,
                         prefix=mode)
        if mode == "val":
            self.early_stopper((accuracy1 + accuracy2) / steps, self.encoder)
Beispiel #49
0
def decode_dataset(
    model,
    data_loader,
    bos_token,
    eos_token,
    num_samples,
    max_steps,
    mode,
    device,
    prefix_length=0,
    temperature=1.0,
    progress_bar=False,
    consistent_sampling=False,
):
    with torch.no_grad():
        xs = []
        prefixes = []
        iterator = (
            tqdm(enumerate(data_loader), total=len(data_loader) if num_samples == -1 else num_samples//data_loader.batch_size)
            if progress_bar
            else enumerate(data_loader)
        )
        for minibatch_id, (inp, target) in iterator:
            inp = inp.to(device)
            # encode the prefix
            hidden = None
            p_eos_prev = None
            prefix = inp[:, : prefix_length + 1]  # +1 for <bos>
            if isinstance(model, RNNLanguageModelST):
                output, hidden, p_eos_prev = model(prefix, return_extra=True)
            else:
                output, hidden = model.step(prefix, hidden)

            if 'beam' in mode:
                beam_size = int(mode.split("_")[1])  # e.g. beam_4
                batch_size = inp.size(0)
                max_timestep = max_steps

                count_finished = torch.zeros(batch_size, device=device).long()
                finished_hypotheses = {
                    i:[] for i in range(batch_size)
                }
                finished_scores = {
                    i:[] for i in range(batch_size)
                }

                # first beam iteration is out of the loop here
                log_probs = torch.log_softmax(output[:, -1, :], dim=-1)  # (batch_size, vocab_size)
                vocab_size = log_probs.size(-1)

                top_scores, top_tokens = torch.topk(log_probs, beam_size, dim=-1, largest=True, sorted=True)

                # we add to finished even now when eos is selected too
                current_eos_mask = (top_tokens == eos_token)
                count_finished = count_finished + current_eos_mask.sum(1).long()

                # we need this loop... ?
                for beam_id, beam_eos_mask in enumerate(current_eos_mask):
                    if any(beam_eos_mask):
                        finished_in_this_beam = top_tokens[beam_id, beam_eos_mask]
                        finished_scores_in_this_beam = top_scores[beam_id, beam_eos_mask]
                        finished_hypotheses[beam_id].extend([finished_in_this_beam.tolist()])
                        finished_scores[beam_id].extend(finished_scores_in_this_beam.tolist())
                
                hypotheses = [
                    (
                        top_tokens[:,:,None],
                        top_scores,
                        torch.zeros_like(top_tokens)
                    )
                ]

                # expanding the hidden tuple up to the beam_size
                if isinstance(hidden, tuple):  # LSTM
                    expanded_hidden = [None,None]
                    for i in range(2):
                        expanded_hidden[i] = hidden[i][:,:,None,:].expand(-1,-1,beam_size,-1).reshape(2,batch_size*beam_size,-1)
                    expanded_hidden = tuple(expanded_hidden)
                else:
                    expanded_hidden = hidden[:,:,None,:].expand(-1,-1,beam_size,-1).reshape(2,batch_size*beam_size,-1)

                # input for the first beam timestep
                expanded_input = top_tokens.view(batch_size*beam_size,1)
                for timestep in range(1, max_timestep):
                    # change below should be enough for the STRNN
                    expanded_output, expanded_hidden = model.step(expanded_input, expanded_hidden)

                    # reshaping back as batch_size * beam_size
                    decoupled_output = expanded_output[:, None, :,:].view(batch_size, beam_size, 1, -1)  # (batch, beam, 1, vocab)
                    # -> log_softmax
                    decoupled_output = torch.log_softmax(decoupled_output, dim=-1)

                    partial_from_prev_timestep = hypotheses[timestep-1][0]  # index 0 is partial
                    scores_from_prev_timestep = hypotheses[timestep-1][1]  # index 1 is scores
                    
                    # check for eos, do not select anything after eos
                    eos_mask = partial_from_prev_timestep[:,:,-1] == eos_token
                    scores_from_prev_timestep[eos_mask] = -10e15

                    extended_scores = decoupled_output.add(scores_from_prev_timestep[:,:,None,None])

                    # coupling it beam*vocab for topk
                    coupled_extended_scores = extended_scores.view(batch_size, beam_size*vocab_size)
                    top_scores, top_ids = torch.topk(coupled_extended_scores, beam_size, dim=-1, largest=True, sorted=True)
                    
                    actual_word_ids = top_ids % vocab_size
                    
                    # make a new input for next iteration

                    expanded_input = actual_word_ids.view(batch_size*beam_size, -1)

                    prev_hyp_id_per_sample = top_ids // vocab_size

                    prev_hyp_id_flat = ((torch.arange(batch_size, device=device) * beam_size)[:,None] + prev_hyp_id_per_sample).view(-1)
                    reordered_prev_hypotheses = torch.index_select(partial_from_prev_timestep.view(batch_size*beam_size,-1), dim=0, index=prev_hyp_id_flat).view(batch_size, beam_size, -1)
                    extended_current_hypotheses = torch.cat([reordered_prev_hypotheses, actual_word_ids[:,:,None]], dim=2)

                    # check currently extended hyps for eos
                    current_eos_mask = (actual_word_ids == eos_token)
                    count_finished = count_finished + current_eos_mask.sum(1).long()

                    # we need this loop... ?
                    for beam_id, beam_eos_mask in enumerate(current_eos_mask):
                        if any(beam_eos_mask):
                            finished_in_this_beam = extended_current_hypotheses[beam_id, beam_eos_mask, :]
                            finished_scores_in_this_beam = top_scores[beam_id, beam_eos_mask]
                            finished_hypotheses[beam_id].extend(finished_in_this_beam.tolist())
                            finished_scores[beam_id].extend(finished_scores_in_this_beam.tolist())

                    # reorder the hidden state
                    if isinstance(expanded_hidden, tuple):  # LSTM
                        new_expanded_hidden = [None,None]
                        num_layers = expanded_hidden[0].size(0)
                        for i in range(num_layers):
                            
                            new_expanded_hidden[i] = torch.index_select(expanded_hidden[i], dim=1, index=prev_hyp_id_flat)
                        new_expanded_hidden = tuple(expanded_hidden)
                    else:
                        new_expanded_hidden = torch.index_select(expanded_hidden, dim=1, index=prev_hyp_id_flat)
                    expanded_hidden = new_expanded_hidden

                    # add new hypotheses to beam
                    hypotheses.append(
                        (extended_current_hypotheses, top_scores, prev_hyp_id_per_sample)
                    )
                    # check if we have enough ( at least 1) finished for each sample in mini batch
                    # ideally one would do at least beam size, with 1 avg len might be shorter
                    if all(count_finished > 0):
                        break
                
                # now we check what hypotheses are finished
                best_finished_seqs = []
                for beam_id in range(batch_size):
                    if count_finished[beam_id].item() == 0:
                        # non-terminated here
                        # take the first seq from the beam
                        seq = hypotheses[-1][0][beam_id][0].cpu().tolist()
                    else:
                        # find the best one w.r.t score
                        finished_here = finished_scores[beam_id]
                        best_finished_id = np.array(finished_here).argmax()
                        seq = finished_hypotheses[beam_id][best_finished_id]
                    best_finished_seqs.append(seq) 
                x = best_finished_seqs

            else:
                # decode
                x = []
                p_eoss = []
                output = output[:, -1, :].unsqueeze(1)
                for t in range(max_steps):

                    if mode == "greedy":
                        xt = output.argmax(-1)
                    elif mode == "sample":
                        if isinstance(model, RNNLanguageModelST) and temperature != 1.0:
                            raise NotImplementedError
                        elif isinstance(model, RNNLanguageModelST):
                            xt = output.exp().squeeze(1).multinomial(1)
                        else:
                            xt = (
                                torch.softmax(output / temperature, -1)
                                .squeeze(1)
                                .multinomial(1)
                            )
                    elif isinstance(mode, tuple):
                        if mode[0] == "topk":
                            output = top_k_top_p_filtering(output.squeeze(1), top_k=mode[1], 
                                consistent_sampling=consistent_sampling, eos_idx=model.eos_idx)
                        elif mode[0] == "topp":
                            output = top_k_top_p_filtering(output.squeeze(1), top_p=mode[1], 
                                consistent_sampling=consistent_sampling, eos_idx=model.eos_idx)
                        xt = torch.softmax(output, -1).multinomial(1)

                    if isinstance(model, RNNLanguageModelST):
                        output, hidden, p_eos_prev = model.step(xt, hidden, p_eos_prev)
                        p_eoss.append(p_eos_prev)
                    else:
                        output, hidden = model.step(xt, hidden)

                    x.append(xt)
                x = torch.cat(x, 1)
            prefixes.append(prefix)
            if isinstance(x, torch.Tensor):
                xs.append(x)
            else:
                xs.extend(x)
            if num_samples >= 0 and (minibatch_id + 1) * inp.size(0) > num_samples:
                break
        if isinstance(xs[0], torch.Tensor):
            xs = torch.cat(xs, 0).tolist()
        prefixes = torch.cat(prefixes, 0).tolist()
        for i, x in enumerate(xs):
            if eos_token in x:
                xs[i] = x[: x.index(eos_token)]

        return xs, prefixes
Beispiel #50
0
    def compute_loss_for_batch(self, data, model, K=K, test=False):
        # data = (B, 1, H, W)
        B, _, H, W = data.shape

        # Generate K copies of each observation. Each will get sampled once according to the generated distribution to generate a total of K observation samples
        data_k_vec = data.repeat((1, K, 1, 1)).view(-1, H * W)

        # Retrieve the estimated mean and log(standard deviation) estimates from the posterior approximator
        mu, logstd = model.encode(data_k_vec)

        # Use the reparametrization trick to generate (mean)+(epsilon)*(standard deviation) for each sample of each observation
        z = model.reparameterize(mu, logstd)

        # Calculate log q(z|x) - how likely are the importance samples given the distribution that generated them?
        log_q = compute_log_probabitility_gaussian(z, mu, logstd)

        # Calculate log p(z) - how likely are the importance samples under the prior N(0,1) assumption?
        log_p_z = compute_log_probabitility_gaussian(
            z, torch.zeros_like(z, requires_grad=False),
            torch.zeros_like(z, requires_grad=False))

        # Hand the samples to the decoder network and get a reconstruction of each sample.
        decoded = model.decode(z)

        # Calculate log p(x|z) with a bernoulli distribution - how likely are the recreations given the latents that generated them?
        log_p = compute_log_probabitility_bernoulli(decoded, data_k_vec)

        # Begin calculating L_alpha depending on the (a) model type, and (b) optimization method
        # log_p_z + log_p - log_q = log(p(z_i)p(x|z_i)/q(z_i|x)) = log(p(x,z_i)/q(z_i|x)) = L_VI
        #   (for each importance sample i out of K for each observation)
        if model_type == 'iwae' or test:
            # Re-order the entries so that each row holds the K importance samples for each observation
            log_w_matrix = (log_p_z + log_p - log_q).view(B, K)

        elif model_type == 'vae':
            # Don't reorder, and divide by K in anticipation of taking a batch sum of (1/K)*SUM(log(p(x,z)/q(z|x)))
            log_w_matrix = (log_p_z + log_p - log_q).view(B * K, 1) * 1 / K

        elif model_type == 'general_alpha' or model_type == 'vralpha':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Multiply by (1-alpha) because (1-alpha)* log(p(x,z_i)/q(z_i|x)) =  log([p(x,z_i)/q(z_i|x)]^(1-alpha))
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K) * (1 - alpha)

        elif model_type == 'vrmax':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Take the max in each row, representing the maximum-weighted sample
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K).max(
                axis=1, keepdim=True).values

            # immediately return loss = -sum(L_alpha) over each observation
            return -torch.sum(log_w_matrix)

        # Begin using the "max trick". Subtract the maximum log(*) sample value for each observation.
        # log_w_minus_max = log([p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]))
        log_w_minus_max = log_w_matrix - torch.max(
            log_w_matrix, 1, keepdim=True)[0]

        # Exponentiate so that each term is [p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]) (no log)
        ws_matrix = torch.exp(log_w_minus_max)

        # Calculate normalized weights in each row. Max denominators cancel out!
        # ws_norm = [p(z_i,x)/q(z_i|x)]/SUM([p(z_k,x)/q(z_k|x)])
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if model_type == 'vralpha' and not test:
            # If we're specifically using a VR-alpha model, we want to choose a sample to backprop according to the values in ws_norm above
            # So we make a distribution in each row
            sample_dist = Multinomial(1, ws_norm)

            # Then we choose a sample in each row acccording to this distribution
            ws_sum_per_datapoint = log_w_matrix.gather(
                1,
                sample_dist.sample().argmax(1, keepdim=True))
        else:
            # For any other model, we're taking the full sum at this point
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if model_type in ["general_alpha", "vralpha"] and not test:
            # For both VR-alpha and directly estimating L_alpha with a sum, we have to renormalize the sum with 1-alpha
            ws_sum_per_datapoint /= (1 - alpha)

        # Return a value of loss = -L_alpha as the batch sum.
        loss = -torch.sum(ws_sum_per_datapoint)

        return loss
Beispiel #51
0
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'DiffMod does not support sparse gradients')
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    # Exponential moving average of actual learning rates
                    state['exp_avg_lr'] = torch.zeros_like(p.data)
                    # Previous gradient
                    state['previous_grad'] = torch.zeros_like(p.data)                    

                exp_avg, exp_avg_sq, exp_avg_lr = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_lr']
                previous_grad = state['previous_grad']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
                
                # compute diffgrad coefficient (dfc)
                if self.version==0:
                    diff = abs(previous_grad - grad)
                    
                elif self.version ==1:
                    diff = previous_grad-grad
               
                if self.version==0 or self.version==1:    
                    dfc = 1. / (1. + torch.exp(-diff))
                    
                state['previous_grad'] = grad                

                if group['weight_decay'] != 0:
                    p.data.add_(-group['weight_decay'] * group['lr'], p.data)

                # create long term memory of actual learning rates (from AdaMod)
                step_size = torch.full_like(denom, step_size)
                step_size.div_(denom)
                exp_avg_lr.mul_(group['beta3']).add_(1 - group['beta3'], step_size)
                
                if self.debug_print:
                    print(f"batch step size {step_size} and exp_avg_step {exp_avg_lr}")
                    
                #Blend the mini-batch step size with long term memory
                step_size = step_size.add(exp_avg_lr)
                step_size = step_size.div(2.)

                # update momentum with dfc
                exp_avg1 = exp_avg * dfc
                
                step_size.mul_(exp_avg1)

                p.data.add_(-step_size)

        return loss
Beispiel #52
0
def r2c(x):
    return torch.stack([x, torch.zeros_like(x)], -1)
def train(epoch):
    model.train()
    train_loss = 0
    ACT = 0
    ACTn = 0

    for batch_idx, (data, gt) in enumerate(train_loader):
        lr = args.lr
        if epoch>=0:
            lr = max(args.lr * 0.8**(np.floor(epoch/5.0)), 0.0001)
            #if epoch > 100:
            #    lr = lr * ((np.cos(((epoch-100)/(math.pi*2)))+1)/2)
            #lr = args.lr * ((np.sin((epoch/(math.pi*2)))+1)/2)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        gmm_lr = args.gmmlr

        data = data.to(device)
        gt = gt.to(device)

        optimizer.zero_grad()
        recon_batch, z, gamma, mu, logvar, gamma_l, a, b = model(data, lr=lr, gmm_lr = gmm_lr )

        #print(np.any(np.isnan(recon_batch.detach().cpu().numpy())), np.any(np.isnan(z.detach().cpu().numpy())), np.any(np.isnan(gamma.detach().cpu().numpy())), np.any(np.isnan(mu.detach().cpu().numpy())), np.any(np.isnan(logvar.detach().cpu().numpy())), np.any(np.isnan(gamma_l.detach().cpu().numpy())))

        # BCE = F.binary_cross_entropy(recon_batch, data, reduction='mean') * 561
        # BCE = F.mse_loss(recon_batch, data, reduction='mean')* 561

        BCE = torch.sum(torch.mean((recon_batch - data).pow(2), 0),-1)

        KLD = torch.sum(0.5 * gamma.unsqueeze(-1) * (
                        torch.log((torch.zeros_like(gamma.unsqueeze(-1)) + 2) * math.pi+1e-10) + torch.log(
                    model.var.unsqueeze(0)+1e-10) + torch.exp(logvar.unsqueeze(-2)) / (model.var.unsqueeze(0)+1e-10) + (
                                    mu.unsqueeze(-2) - model.mu.unsqueeze(0)).pow(2) / (model.var.unsqueeze(0)+1e-10)),
                            [-1, -2])
        KLD -= 0.5 * torch.sum(logvar + 1, -1)
        KLD += torch.sum(torch.log(gamma+1e-10) * gamma, -1)
        KLD -= torch.sum(torch.log(gamma_l+1e-10) * gamma, -1)
        KLD = torch.mean(KLD)

        prior_alpha = torch.Tensor(1).zero_().cuda() + 1
        prior_beta = torch.Tensor(1).zero_().cuda() + 2
        SBKL = calc_kl_divergence(a, b, prior_alpha, prior_beta)/ data.shape[0]
        loss = BCE + KLD + SBKL*0.005

        if batch_idx==0:
            gt_all = gt
            ret_all = torch.max(gamma, 1)[1]
        else:
            gt_all = torch.cat((gt_all, gt))
            ret_all = torch.cat((ret_all, torch.max(gamma, 1)[1]))

        # ACT += cluster_acc(torch.max(gamma, 1)[1].detach().cpu().numpy(), gt.detach().cpu().numpy())[0]
        # ACTn += 1
        if batch_idx % args.log_interval == 0:
            acc = cluster_acc(torch.max(gamma, 1)[1].detach().cpu().numpy(), gt.detach().cpu().numpy())
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tBCE: {:.6f}\tKLD: {:.6f}\tKL: {:.6f}\tACC: {:.2f}\tLR: {:.6f}\tgmmLR: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item(), BCE, KLD, SBKL, acc[0], lr, gmm_lr))


        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    if epoch % 10 == 0:
        print('Saving..')
        state = {
            'net': model.state_dict(),
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt8-9rts4.t8-9')

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / (len(train_loader.dataset))))
    acc = cluster_acc(ret_all.detach().cpu().numpy(), gt_all.detach().cpu().numpy())[0]

    global ac
    print('current:', acc ,'best: ', ac)

    if (acc>ac):
        ac = acc
        print('Saving best..')

        state = {
            'net': model.state_dict(),
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt8-9rts4_best.t8-9')
Beispiel #54
0
    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = group['lr'] / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                # more conservative since it's an approximated value
                if N_sma >= 5:            
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
                else:
                    p_data_fp32.add_(-step_size, exp_avg)

                p.data.copy_(p_data_fp32)

        return loss
Beispiel #55
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['next_m'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['next_v'] = torch.zeros_like(p.data)

                next_m, next_v = state['next_m'], state['next_v']
                beta1, beta2 = group['b1'], group['b2']

                # Add grad clipping
                if group['max_grad_norm'] > 0:
                    clip_grad_norm_(p, group['max_grad_norm'])

                # Decay the first and second moment running average coefficient
                # In-place operations to update the averages at the same time
                next_m.mul_(beta1).add_(1 - beta1, grad)
                next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                update = next_m / (next_v.sqrt() + group['e'])

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                if group['weight_decay_rate'] > 0.0:
                    update += group['weight_decay_rate'] * p.data

                if group['t_total'] != -1:
                    schedule_fct = SCHEDULES[group['schedule']]
                    lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
                else:
                    lr_scheduled = group['lr']

                update_with_lr = lr_scheduled * update
                p.data.add_(-update_with_lr)

                state['step'] += 1

                # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
                # No bias correction
                # bias_correction1 = 1 - beta1 ** state['step']
                # bias_correction2 = 1 - beta2 ** state['step']

        return loss
Beispiel #56
0
def reconstruct_weight_from_k_means_result(centroids, labels):
    weight = torch.zeros_like(labels).float().cuda()
    for i, c in enumerate(centroids.cpu().numpy().squeeze()):
        weight[labels == i] = c.item()
    return weight
Beispiel #57
0
    def build_target(self, pds, gts):
        self.device = 'cuda' if pds.is_cuda else 'cpu'
        nB, nA, nH, nW, _ = pds.shape
        assert nH == nW
        nC = self.cls_num
        #threshold = th
        nGts = len(gts)
        obj_mask = torch.zeros(nB,
                               nA,
                               nH,
                               nW,
                               dtype=torch.bool,
                               device=self.device)
        noobj_mask = torch.ones(nB,
                                nA,
                                nH,
                                nW,
                                dtype=torch.bool,
                                device=self.device)
        tbboxes = torch.zeros(nB,
                              nA,
                              nH,
                              nW,
                              4,
                              dtype=torch.float,
                              device=self.device)
        tcls = torch.zeros(nB,
                           nA,
                           nH,
                           nW,
                           nC,
                           dtype=torch.float,
                           device=self.device)
        if nGts == 0:
            return obj_mask, noobj_mask, tbboxes, tcls, obj_mask.float()
        #convert target
        gt_boxes = gts[:, 2:6]
        gws = gt_boxes[:, 2]
        ghs = gt_boxes[:, 3]

        ious = torch.stack(
            [iou_wo_center(gws, ghs, w, h) for (w, h) in self.scaled_anchors])
        vals, best_n = ious.max(0)
        ind = torch.arange(vals.shape[0], device=self.device)
        '''ind = torch.argsort(vals)
        # so that obj with bigger iou will cover the smaller one 
        # useful for crowed scenes
        idx = torch.argsort(gts[ind,-1],descending=True)#sort as match num,then gt has not matched will be matched first
        ind = ind[idx]
        #discard the gts below the match threshold and has been matched
        best_n =best_n[ind]
        gts = gts[ind,:]
        gt_boxes = gt_boxes[ind,:]
        ious = ious[:,ind]
        '''

        batch = gts[:, 0].long()
        labels = gts[:, 1].long()
        gxs, gys = gt_boxes[:, 0] * nW, gt_boxes[:, 1] * nH
        gis, gjs = gxs.long(), gys.long()
        #calculate bbox ious with anchors
        obj_mask[batch, best_n, gjs, gis] = 1
        noobj_mask[batch, best_n, gjs, gis] = 0
        ious = ious.t()
        #ignore big overlap but not the best
        for i, iou in enumerate(ious):
            noobj_mask[batch[i], iou > self.ignore_threshold, gjs[i],
                       gis[i]] = 0

        selected = torch.zeros_like(obj_mask, dtype=torch.long).fill_(-1)

        tbboxes[batch, best_n, gjs, gis] = gt_boxes
        tcls[batch, best_n, gjs, gis, labels] = 1
        selected[batch, best_n, gjs, gis] = ind

        selected = torch.unique(selected[selected >= 0])
        gts[selected, -1] += 1  #marked as matched

        return obj_mask, noobj_mask, tbboxes, tcls, obj_mask.float()
Beispiel #58
0
    def compute_ptps(self):

        t_range = torch.arange(-(self.n_times // 2),
                               self.n_times // 2 + 1).cuda()

        ptps_raw = torch.zeros(self.spike_index.shape[0]).float().cuda()
        if self.denoiser is not None:
            ptps_denoised = torch.zeros(
                self.spike_index.shape[0]).float().cuda()
        else:
            ptps_denoised = None

        # batch offsets
        offsets = torch.from_numpy(self.reader.idx_list[:, 0] -
                                   self.reader.buffer).cuda().long()

        with tqdm(total=self.reader.n_batches) as pbar:

            for batch_id in range(self.reader.n_batches):

                # load residual data
                dat = self.reader.read_data_batch(batch_id, add_buffer=True)
                dat = torch.from_numpy(dat).cuda()

                # relevant idx
                idx_in = torch.nonzero((
                    self.spike_index[:, 0] > self.reader.idx_list[batch_id][0])
                                       & (self.spike_index[:, 0] < self.reader.
                                          idx_list[batch_id][1]))[:, 0]

                spike_index_batch = self.spike_index[idx_in]
                spike_index_batch[:, 0] -= offsets[batch_id]

                # skip if no spikes
                if len(spike_index_batch) == 0:
                    continue

                # get residual snippets
                t_index = spike_index_batch[:, 0][:, None] + t_range
                c_index = spike_index_batch[:, 1].long()

                dat = torch.cat((dat, torch.zeros((dat.shape[0], 1)).cuda()),
                                1)
                wfs = dat[t_index, c_index[:, None]]
                ptps_raw[idx_in] = (torch.max(wfs, 1)[0] -
                                    torch.min(wfs, 1)[0])

                if self.denoiser is not None:
                    n_sample_run = 1000

                    idx_list = np.hstack(
                        (np.arange(0, wfs.shape[0],
                                   n_sample_run), wfs.shape[0]))
                    denoised_wfs = torch.zeros_like(wfs).cuda()
                    #print ("denoised_wfs; ", denoised_wfs.shape)
                    #print ("wfs; ", wfs.shape)
                    for j in range(len(idx_list) - 1):
                        #print ("idx_list[j], j+1: ", idx_list[j], idx_list[j+1])
                        denoised_wfs[
                            idx_list[j]:idx_list[j + 1]] = self.denoiser(
                                wfs[idx_list[j]:idx_list[j + 1]])[0].data
                    ptps_denoised[idx_in] = (torch.max(denoised_wfs, 1)[0] -
                                             torch.min(denoised_wfs, 1)[0])

                pbar.update()

        ptps_raw_cpu = ptps_raw.cpu().numpy()

        del dat, idx_in, spike_index_batch, t_index, c_index, wfs, ptps_raw

        if self.denoiser is not None:
            ptps_denoised_cpu = ptps_denoised.cpu().numpy()
            del denoised_wfs, ptps_denoised
        else:
            ptps_denoised_cpu = np.copy(ptps_raw_cpu)

        torch.cuda.empty_cache()

        return ptps_raw_cpu, ptps_denoised_cpu
Beispiel #59
0
 def init_zero_sm_weight(self,sm_weight):
     self.update_sm_weight = torch.zeros_like(sm_weight).detach()
Beispiel #60
0
def calculate_variance_tensor(embeddings, spatial_separation=False):
    '''
    Calculates the mean distance of all embeddings relative to their average
    over all paths, independently for every index within a mini-batch.
    embeddings: (B, paths, pred_step, F, S, S).
    spatial_separation: Whether to avoid averaging over H, W.
    
    If False, returns: (B, 6, pred_step) where the 6 metrics =
    (L2, pooled L2, normalized L2, pooled normalized L2, cosine, pooled cosine).
    NOTE: pooling is done before everything else, followed by L2 normalization.

    If True, returns: (B, 3, pred_step, S, S) where the 3 metrics =
    (L2, normalized L2, cosine).
    '''

    embeddings = embeddings.cpu()
    (B, paths, pred_step, _, H, W) = embeddings.shape

    # Construct auxiliary variables
    embs_pool = embeddings.mean(dim=[4, 5])  # (B, paths, pred_step, F)
    embs_norm = torch.zeros_like(embeddings)  # (B, paths, pred_step, F, S, S)
    embs_norm_pool = torch.zeros_like(embs_pool)  # (B, paths, pred_step, F)
    avg = embeddings.mean(dim=1)  # (B, pred_step, F, S, S)
    avg_pool = avg.mean(dim=[3, 4])  # (B, pred_step, F)
    avg_norm = torch.zeros_like(avg)  # (B, pred_step, F, S, S)
    avg_norm_pool = torch.zeros_like(avg_pool)  # (B, pred_step, F)

    if spatial_separation:
        sum_metrics = torch.zeros(B, 3, pred_step, H,
                                  W)  # (B, metric, pred_step, S, S)
    else:
        sum_metrics = torch.zeros(B, 6, pred_step)  # (B, metric, pred_step)

        # Get normalized & pooled embeddings first
        for i in range(B):
            for t in range(pred_step):
                for j in range(paths):
                    embs_norm[i, j, t] = embeddings[i, j, t] / torch.norm(
                        embeddings[i, j, t], 2)
                    embs_norm_pool[i, j, t] = embs_pool[i, j, t] / torch.norm(
                        embs_pool[i, j, t], 2)
                avg_norm[i, t] = avg[i, t] / torch.norm(avg[i, t], 2)
                avg_norm_pool[i, t] = avg_pool[i, t] / torch.norm(
                    avg_pool[i, t], 2)

    # Calculate variance metrics (distances from average over all paths)
    for i in range(B):
        for t in range(pred_step):
            for j in range(paths):

                if spatial_separation:
                    # Process every position independently; pooling is not relevant here
                    for y in range(H):
                        for x in range(W):
                            # L2-distance between raw spatio-temporal blocks
                            sum_metrics[i, 0, t, y, x] += torch.norm(
                                embeddings[i, j, t, :, y, x] -
                                avg[i, t, :, y, x], 2)
                            # L2-distance between L2-normalized spatio-temporal blocks
                            sum_metrics[i, 1, t, y, x] += torch.norm(
                                embs_norm[i, j, t, :, y, x] -
                                avg_norm[i, t, :, y, x], 2)
                            # Cosine similarity between raw spatio-temporal blocks
                            sum_metrics[i, 2, t, y, x] += F.cosine_similarity(
                                embeddings[i, j, t, :, y, x],
                                avg[i, t, :, y, x],
                                dim=0)

                else:
                    # L2-distance between raw embeddings
                    sum_metrics[i, 0, t] += torch.norm(
                        embeddings[i, j, t] - avg[i, t], 2)
                    # L2-distance between spatially averaged embeddings
                    sum_metrics[i, 1, t] += torch.norm(
                        embs_pool[i, j, t] - avg_pool[i, t], 2)
                    # L2-distance between L2-normalized embeddings
                    sum_metrics[i, 2, t] += torch.norm(
                        embs_norm[i, j, t] - avg_norm[i, t], 2)
                    # L2-distance between L2-normalized spatially averaged embeddings
                    sum_metrics[i, 3, t] += torch.norm(
                        embs_norm_pool[i, j, t] - avg_norm_pool[i, t], 2)
                    # Cosine similarity between raw embeddings, mean afterwards
                    sum_metrics[i, 4,
                                t] += F.cosine_similarity(embeddings[i, j, t],
                                                          avg[i, t],
                                                          dim=0).mean()
                    # Cosine similarity between spatially averaged embeddings
                    sum_metrics[i, 5, t] += F.cosine_similarity(embs_pool[i, j,
                                                                          t],
                                                                avg_pool[i, t],
                                                                dim=0)

    result = sum_metrics / paths  # average over number of paths
    return result