예제 #1
0
def _get_random_mt_data(**tkwargs):
    train_x = torch.linspace(0, 0.95, 10, **tkwargs) + 0.05 * torch.rand(10, **tkwargs)
    train_y1 = torch.sin(train_x * (2 * math.pi)) + torch.randn_like(train_x) * 0.2
    train_y2 = torch.cos(train_x * (2 * math.pi)) + torch.randn_like(train_x) * 0.2
    train_i_task1 = torch.full_like(train_x, dtype=torch.long, fill_value=0)
    train_i_task2 = torch.full_like(train_x, dtype=torch.long, fill_value=1)
    full_train_x = torch.cat([train_x, train_x])
    full_train_i = torch.cat([train_i_task1, train_i_task2])
    full_train_y = torch.cat([train_y1, train_y2])
    train_X = torch.stack([full_train_x, full_train_i.type_as(full_train_x)], dim=-1)
    train_Y = full_train_y
    return train_X, train_Y
예제 #2
0
def _get_fixed_noise_model_single_output(**tkwargs):
    train_X, train_Y = _get_random_mt_data(**tkwargs)
    train_Yvar = torch.full_like(train_Y, 0.05)
    model = FixedNoiseMultiTaskGP(
        train_X, train_Y, train_Yvar, task_feature=1, output_tasks=[1]
    )
    return model.to(**tkwargs)
예제 #3
0
 def _get_model(self, batch_shape, num_outputs, n, **tkwargs):
     train_x, train_y = _get_random_data(
         batch_shape=batch_shape, num_outputs=num_outputs, n=n, **tkwargs
     )
     train_yvar = torch.full_like(train_y, 0.01)
     model = FixedNoiseGP(train_X=train_x, train_Y=train_y, train_Yvar=train_yvar)
     return model.to(**tkwargs)
예제 #4
0
 def _expand(bounds: Union[float, Tensor], X: Tensor, lower: bool) -> Tensor:
     if bounds is None:
         ebounds = torch.full_like(X, float("-inf" if lower else "inf"))
     else:
         if not torch.is_tensor(bounds):
             bounds = torch.tensor(bounds)
         ebounds = bounds.expand_as(X)
     return _arrayify(ebounds).flatten()
예제 #5
0
    def __init__(self, probs=None, logits=None, validate_args=None):
        if probs is not None:
            new_probs = torch.zeros_like(probs, dtype=torch.float)
            new_prob[torch.argmax(probs, dim=0)] = 1.0
            probs = new_probs
        elif logits is not None:
            new_logits = torch.full_like(logits, -1e8, dtype=torch.float)
            max_idx = torch.argmax(logits, dim=0)
            new_logits[max_idx] = logits[max_idx]
            logits = new_logits

        super(Argmax, self).__init__(probs=probs, logits=logits, validate_args=validate_args)
예제 #6
0
def standardize(X: Tensor) -> Tensor:
    r"""Standardize a tensor by dim=0.

    Args:
        X: A `n` or `n x d`-dim tensor

    Returns:
        The standardized `X`.

    Example:
        >>> X = torch.rand(4, 3)
        >>> X_standardized = standardize(X)
    """
    X_std = X.std(dim=0)
    X_std = X_std.where(X_std >= 1e-9, torch.full_like(X_std, 1.0))
    return (X - X.mean(dim=0)) / X_std
예제 #7
0
    def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= lr_decay:
            raise ValueError("Invalid lr_decay value: {}".format(lr_decay))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if not 0.0 <= initial_accumulator_value:
            raise ValueError("Invalid initial_accumulator_value value: {}".format(initial_accumulator_value))

        defaults = dict(lr=lr, lr_decay=lr_decay, weight_decay=weight_decay,
                        initial_accumulator_value=initial_accumulator_value)
        super(Adagrad, self).__init__(params, defaults)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['sum'] = torch.full_like(p.data, initial_accumulator_value)
예제 #8
0
    def test_single_task_batch_cv(self, cuda=False):
        n = 10
        for batch_shape in (torch.Size([]), torch.Size([2])):
            for num_outputs in (1, 2):
                for double in (False, True):
                    tkwargs = {
                        "device": torch.device("cuda") if cuda else torch.device("cpu"),
                        "dtype": torch.double if double else torch.float,
                    }
                    train_X, train_Y = _get_random_data(
                        batch_shape=batch_shape, num_outputs=num_outputs, n=n, **tkwargs
                    )
                    train_Yvar = torch.full_like(train_Y, 0.01)
                    noiseless_cv_folds = gen_loo_cv_folds(
                        train_X=train_X, train_Y=train_Y
                    )
                    # Test SingleTaskGP
                    cv_results = batch_cross_validation(
                        model_cls=SingleTaskGP,
                        mll_cls=ExactMarginalLogLikelihood,
                        cv_folds=noiseless_cv_folds,
                        fit_args={"options": {"maxiter": 1}},
                    )
                    expected_shape = batch_shape + torch.Size([n, 1, num_outputs])
                    self.assertEqual(cv_results.posterior.mean.shape, expected_shape)
                    self.assertEqual(cv_results.observed_Y.shape, expected_shape)

                    # Test FixedNoiseGP
                    noisy_cv_folds = gen_loo_cv_folds(
                        train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar
                    )
                    cv_results = batch_cross_validation(
                        model_cls=FixedNoiseGP,
                        mll_cls=ExactMarginalLogLikelihood,
                        cv_folds=noisy_cv_folds,
                        fit_args={"options": {"maxiter": 1}},
                    )
                    self.assertEqual(cv_results.posterior.mean.shape, expected_shape)
                    self.assertEqual(cv_results.observed_Y.shape, expected_shape)
                    self.assertEqual(cv_results.observed_Y.shape, expected_shape)
예제 #9
0
def _get_noiseless_fantasy_model(
    model: FixedNoiseGP, batch_X_observed: Tensor, Y_fantasized: Tensor
) -> FixedNoiseGP:
    r"""Construct a fantasy model from a fitted model and provided fantasies.

    The fantasy model uses the hyperparameters from the original fitted model and
    assumes the fantasies are noiseless.

    Args:
        model: a fitted FixedNoiseGP
        batch_X_observed: A `b x m x d` tensor of inputs where `b` is the number of
            fantasies.
        Y_fantasized: A `b x m` tensor of fantasized targets where `b` is the number of
            fantasies.

    Returns:
        The fantasy model.
    """
    # initialize a copy of FixedNoiseGP on the original training inputs
    # this makes FixedNoiseGP a non-batch GP, so that the same hyperparameters
    # are used across all batches (by default, a GP with batched training data
    # uses independent hyperparameters for each batch).
    fantasy_model = FixedNoiseGP(
        train_X=model.train_inputs[0],
        train_Y=model.train_targets,
        train_Yvar=model.likelihood.noise_covar.noise,
    )
    # update training inputs/targets to be batch mode fantasies
    fantasy_model.set_train_data(
        inputs=batch_X_observed, targets=Y_fantasized, strict=False
    )
    # use noiseless fantasies
    fantasy_model.likelihood.noise_covar.noise = torch.full_like(Y_fantasized, 1e-7)
    # load hyperparameters from original model
    state_dict = deepcopy(model.state_dict())
    fantasy_model.load_state_dict(state_dict)
    return fantasy_model
예제 #10
0
 def _get_model(self, cuda=False, dtype=torch.float):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     state_dict = {
         "mean_module.constant": torch.tensor([-0.0066]),
         "covar_module.raw_outputscale": torch.tensor(1.0143),
         "covar_module.base_kernel.raw_lengthscale": torch.tensor([[-0.99]]),
         "covar_module.base_kernel.lengthscale_prior.concentration": torch.tensor(
             3.0
         ),
         "covar_module.base_kernel.lengthscale_prior.rate": torch.tensor(6.0),
         "covar_module.outputscale_prior.concentration": torch.tensor(2.0),
         "covar_module.outputscale_prior.rate": torch.tensor(0.1500),
     }
     train_x = torch.linspace(0, 1, 10, device=device, dtype=dtype)
     train_y = torch.sin(train_x * (2 * math.pi))
     noise = torch.tensor(NEI_NOISE, device=device, dtype=dtype)
     train_y += noise
     train_yvar = torch.full_like(train_y, 0.25 ** 2)
     train_x = train_x.view(-1, 1)
     model = FixedNoiseGP(train_X=train_x, train_Y=train_y, train_Yvar=train_yvar)
     model.load_state_dict(state_dict)
     model.to(train_x)
     model.eval()
     return model
예제 #11
0
def build_targets(p, targets, model):
    # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
    det = model.module.model[-1] if is_parallel(model) else model.model[-1]  # Detect() module
    na, nt = det.na, targets.shape[0]  # number of anchors, targets
    tcls, tbox, indices, anch, landmarks, lmks_mask = [], [], [], [], [], []
    #gain = torch.ones(7, device=targets.device)  # normalized to gridspace gain
    gain = torch.ones(17, device=targets.device)
    ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)
    targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  # append anchor indices

    g = 0.5  # bias
    off = torch.tensor([[0, 0],
                        [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
                        # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
                        ], device=targets.device).float() * g  # offsets

    for i in range(det.nl):
        anchors = det.anchors[i]
        gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]  # xyxy gain
        #landmarks 10
        gain[6:16] = torch.tensor(p[i].shape)[[3, 2, 3, 2, 3, 2, 3, 2, 3, 2]]  # xyxy gain

        # Match targets to anchors
        t = targets * gain
        if nt:
            # Matches
            r = t[:, :, 4:6] / anchors[:, None]  # wh ratio
            j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t']  # compare
            # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
            t = t[j]  # filter

            # Offsets
            gxy = t[:, 2:4]  # grid xy
            gxi = gain[[2, 3]] - gxy  # inverse
            j, k = ((gxy % 1. < g) & (gxy > 1.)).T
            l, m = ((gxi % 1. < g) & (gxi > 1.)).T
            j = torch.stack((torch.ones_like(j), j, k, l, m))
            t = t.repeat((5, 1, 1))[j]
            offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
        else:
            t = targets[0]
            offsets = 0

        # Define
        b, c = t[:, :2].long().T  # image, class
        gxy = t[:, 2:4]  # grid xy
        gwh = t[:, 4:6]  # grid wh
        gij = (gxy - offsets).long()
        gi, gj = gij.T  # grid xy indices

        # Append
        a = t[:, 16].long()  # anchor indices
        indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1)))  # image, anchor, grid indices
        tbox.append(torch.cat((gxy - gij, gwh), 1))  # box
        anch.append(anchors[a])  # anchors
        tcls.append(c)  # class

        #landmarks
        lks = t[:,6:16]
        #lks_mask = lks > 0
        #lks_mask = lks_mask.float()
        lks_mask = torch.where(lks < 0, torch.full_like(lks, 0.), torch.full_like(lks, 1.0))

        #应该是关键点的坐标除以anch的宽高才对,便于模型学习。使用gwh会导致不同关键点的编码不同,没有统一的参考标准

        lks[:, [0, 1]] = (lks[:, [0, 1]] - gij)
        lks[:, [2, 3]] = (lks[:, [2, 3]] - gij)
        lks[:, [4, 5]] = (lks[:, [4, 5]] - gij)
        lks[:, [6, 7]] = (lks[:, [6, 7]] - gij)
        lks[:, [8, 9]] = (lks[:, [8, 9]] - gij)

        '''
        #anch_w = torch.ones(5, device=targets.device).fill_(anchors[0][0])
        #anch_wh = torch.ones(5, device=targets.device)
        anch_f_0 = (a == 0).unsqueeze(1).repeat(1, 5)
        anch_f_1 = (a == 1).unsqueeze(1).repeat(1, 5)
        anch_f_2 = (a == 2).unsqueeze(1).repeat(1, 5)
        lks[:, [0, 2, 4, 6, 8]] = torch.where(anch_f_0, lks[:, [0, 2, 4, 6, 8]] / anchors[0][0], lks[:, [0, 2, 4, 6, 8]])
        lks[:, [0, 2, 4, 6, 8]] = torch.where(anch_f_1, lks[:, [0, 2, 4, 6, 8]] / anchors[1][0], lks[:, [0, 2, 4, 6, 8]])
        lks[:, [0, 2, 4, 6, 8]] = torch.where(anch_f_2, lks[:, [0, 2, 4, 6, 8]] / anchors[2][0], lks[:, [0, 2, 4, 6, 8]])

        lks[:, [1, 3, 5, 7, 9]] = torch.where(anch_f_0, lks[:, [1, 3, 5, 7, 9]] / anchors[0][1], lks[:, [1, 3, 5, 7, 9]])
        lks[:, [1, 3, 5, 7, 9]] = torch.where(anch_f_1, lks[:, [1, 3, 5, 7, 9]] / anchors[1][1], lks[:, [1, 3, 5, 7, 9]])
        lks[:, [1, 3, 5, 7, 9]] = torch.where(anch_f_2, lks[:, [1, 3, 5, 7, 9]] / anchors[2][1], lks[:, [1, 3, 5, 7, 9]])

        #new_lks = lks[lks_mask>0]
        #print('new_lks:   min --- ', torch.min(new_lks), '  max --- ', torch.max(new_lks))
        
        lks_mask_1 = torch.where(lks < -3, torch.full_like(lks, 0.), torch.full_like(lks, 1.0))
        lks_mask_2 = torch.where(lks > 3, torch.full_like(lks, 0.), torch.full_like(lks, 1.0))

        lks_mask_new = lks_mask * lks_mask_1 * lks_mask_2
        lks_mask_new[:, 0] = lks_mask_new[:, 0] * lks_mask_new[:, 1]
        lks_mask_new[:, 1] = lks_mask_new[:, 0] * lks_mask_new[:, 1]
        lks_mask_new[:, 2] = lks_mask_new[:, 2] * lks_mask_new[:, 3]
        lks_mask_new[:, 3] = lks_mask_new[:, 2] * lks_mask_new[:, 3]
        lks_mask_new[:, 4] = lks_mask_new[:, 4] * lks_mask_new[:, 5]
        lks_mask_new[:, 5] = lks_mask_new[:, 4] * lks_mask_new[:, 5]
        lks_mask_new[:, 6] = lks_mask_new[:, 6] * lks_mask_new[:, 7]
        lks_mask_new[:, 7] = lks_mask_new[:, 6] * lks_mask_new[:, 7]
        lks_mask_new[:, 8] = lks_mask_new[:, 8] * lks_mask_new[:, 9]
        lks_mask_new[:, 9] = lks_mask_new[:, 8] * lks_mask_new[:, 9]
        '''
        lks_mask_new = lks_mask
        lmks_mask.append(lks_mask_new)
        landmarks.append(lks)
        #print('lks: ',  lks.size())

    return tcls, tbox, indices, anch, landmarks, lmks_mask
예제 #12
0
def compute_loss(p, targets, model):  # predictions, targets, model
    ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
    lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
    tcls, tbox, indices, anchors = build_targets(p, targets, model)  # targets
    h = model.hyp  # hyperparameters
    red = 'mean'  # Loss reduction (sum or mean)

    # Define criteria
    BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red)
    BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red)

    # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
    cp, cn = smooth_BCE(eps=0.0)

    # focal loss
    g = h['fl_gamma']  # focal loss gamma
    if g > 0:
        BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

    # per output
    nt = 0  # targets
    for i, pi in enumerate(p):  # layer index, layer predictions
        b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
        tobj = torch.zeros_like(pi[..., 0])  # target obj

        nb = b.shape[0]  # number of targets
        if nb:
            nt += nb  # cumulative targets
            ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets

            # GIoU
            pxy = ps[:, :2].sigmoid()
            pwh = ps[:, 2:4].exp().clamp(max=1E3) * anchors[i]
            pbox = torch.cat((pxy, pwh), 1)  # predicted box
            giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False,
                            GIoU=True)  # giou(prediction, target)
            lbox += (1.0 - giou).sum() if red == 'sum' else (
                1.0 - giou).mean()  # giou loss

            # Obj
            tobj[b, a, gj,
                 gi] = (1.0 -
                        model.gr) + model.gr * giou.detach().clamp(0).type(
                            tobj.dtype)  # giou ratio

            # Class
            if model.nc > 1:  # cls loss (only if multiple classes)
                t = torch.full_like(ps[:, 5:], cn)  # targets
                t[range(nb), tcls[i]] = cp
                lcls += BCEcls(ps[:, 5:], t)  # BCE

            # Append targets to text file
            # with open('targets.txt', 'a') as file:
            #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]

        lobj += BCEobj(pi[..., 4], tobj)  # obj loss

    lbox *= h['giou']
    lobj *= h['obj']
    lcls *= h['cls']
    if red == 'sum':
        bs = tobj.shape[0]  # batch size
        g = 3.0  # loss gain
        lobj *= g / bs
        if nt:
            lcls *= g / nt / model.nc
            lbox *= g / nt

    loss = lbox + lobj + lcls
    return loss, torch.cat((lbox, lobj, lcls, loss)).detach()
예제 #13
0
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group, base_lr in zip(self.param_groups, self.base_lrs):
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'Adam does not support sparse gradients, please consider SparseAdam instead'
                    )
                amsbound = group['amsbound']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    if amsbound:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsbound:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']

                if group['eta'] != 0:
                    size = list(grad.size())
                    noise = np.random.normal(
                        0,
                        math.sqrt(group['eta']) /
                        math.pow(1 + state['step'], 0.55), size)
                    noise = torch.tensor(noise, dtype=torch.float32)
                    grad.mul_(group['noise_coff']).add_(
                        1 - group['noise_coff'], noise)

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                if amsbound:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                else:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1**state['step']
                bias_correction2 = 1 - beta2**state['step']
                step_size = group['lr'] * math.sqrt(
                    bias_correction2) / bias_correction1

                # Applies bounds on actual learning rate
                # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
                final_lr = group['final_lr'] * group['lr'] / base_lr
                lower_bound = final_lr * (1 - 1 /
                                          (group['gamma'] * state['step'] + 1))
                upper_bound = final_lr * (1 + 1 /
                                          (group['gamma'] * state['step']))
                step_size = torch.full_like(denom, step_size)
                step_size.div_(denom).clamp_(lower_bound,
                                             upper_bound).mul_(exp_avg)

                p.data.add_(-step_size)

        return loss
예제 #14
0
 def test_full_like(self):
     x = torch.randn(3, 4, requires_grad=True)
     self.assertONNX(lambda x: torch.full_like(x, 2), x)
예제 #15
0
    def train_step_gen(self, training_batch: rlt.DiscreteDqnInput,
                       batch_idx: int):
        rewards = self.boost_rewards(training_batch.reward,
                                     training_batch.action)
        discount_tensor = torch.full_like(rewards, self.gamma)
        possible_next_actions_mask = training_batch.possible_next_actions_mask.float(
        )
        possible_actions_mask = training_batch.possible_actions_mask.float()

        not_terminal = training_batch.not_terminal.float()

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.time_diff.float())
        if self.multi_steps is not None:
            assert training_batch.step is not None
            # pyre-fixme[16]: Optional type has no attribute `float`.
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.step.float())

        next_dist = self.q_network_target.log_dist(
            training_batch.next_state).exp()

        if self.maxq_learning:
            # Select distribution corresponding to max valued action
            if self.double_q_learning:
                next_q_values = (
                    self.q_network.log_dist(training_batch.next_state).exp() *
                    self.support).sum(2)
            else:
                next_q_values = (next_dist * self.support).sum(2)

            next_action = self.argmax_with_mask(next_q_values,
                                                possible_next_actions_mask)
            next_dist = next_dist[range(rewards.shape[0]),
                                  next_action.reshape(-1)]
        else:
            next_dist = (next_dist *
                         training_batch.next_action.unsqueeze(-1)).sum(1)

        # Build target distribution
        target_Q = rewards + discount_tensor * not_terminal * self.support
        target_Q = target_Q.clamp(self.qmin, self.qmax)

        # rescale to indicies [0, 1, ..., N-1]
        b = (target_Q - self.qmin) / self.scale_support
        lo = b.floor().to(torch.int64)
        up = b.ceil().to(torch.int64)

        # handle corner cases of l == b == u
        # without the following, it would give 0 signal, whereas we want
        # m to add p(s_t+n, a*) to index l == b == u.
        # So we artificially adjust l and u.
        # (1) If 0 < l == u < N-1, we make l = l-1, so b-l = 1
        # (2) If 0 == l == u, we make u = 1, so u-b=1
        # (3) If l == u == N-1, we make l = N-2, so b-1 = 1
        # This first line handles (1) and (3).
        lo[(up > 0) * (lo == up)] -= 1
        # Note: l has already changed, so the only way l == u is possible is
        # if u == 0, in which case we let u = 1
        # I don't even think we need the first condition in the next line
        up[(lo < (self.num_atoms - 1)) * (lo == up)] += 1

        # distribute the probabilities
        # m_l = m_l + p(s_t+n, a*)(u - b)
        # m_u = m_u + p(s_t+n, a*)(b - l)
        m = torch.zeros_like(next_dist)
        # pyre-fixme[16]: `Tensor` has no attribute `scatter_add_`.
        m.scatter_add_(dim=1, index=lo, src=next_dist * (up.float() - b))
        m.scatter_add_(dim=1, index=up, src=next_dist * (b - lo.float()))
        log_dist = self.q_network.log_dist(training_batch.state)

        # for reporting only
        all_q_values = (log_dist.exp() * self.support).sum(2).detach()
        model_action_idxs = self.argmax_with_mask(
            all_q_values,
            possible_actions_mask
            if self.maxq_learning else training_batch.action,
        )

        log_dist = (log_dist * training_batch.action.unsqueeze(-1)).sum(1)

        loss = -(m * log_dist).sum(1).mean()

        if batch_idx % self.trainer.log_every_n_steps == 0:
            self.reporter.log(
                td_loss=loss,
                logged_actions=torch.argmax(training_batch.action,
                                            dim=1,
                                            keepdim=True),
                logged_propensities=training_batch.extras.action_probability,
                logged_rewards=rewards,
                model_values=all_q_values,
                model_action_idxs=model_action_idxs,
            )
            self.log("td_loss", loss, prog_bar=True)

        yield loss
        result = self.soft_update_result()
        yield result
예제 #16
0
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'AdaMod does not support sparse gradients')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    # Exponential moving average of actual learning rates
                    state['exp_avg_lr'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq, exp_avg_lr = state['exp_avg'], state[
                    'exp_avg_sq'], state['exp_avg_lr']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1**state['step']
                bias_correction2 = 1 - beta2**state['step']
                step_size = group['lr'] * math.sqrt(
                    bias_correction2) / bias_correction1

                if group['weight_decay'] != 0:
                    p.data.add_(-group['weight_decay'] * group['lr'], p.data)

                # Applies momental bounds on actual learning rates
                step_size = torch.full_like(denom, step_size)
                step_size.div_(denom)
                exp_avg_lr.mul_(group['beta3']).add_(1 - group['beta3'],
                                                     step_size)
                step_size = torch.min(step_size, exp_avg_lr)
                step_size.mul_(exp_avg)

                p.data.add_(-step_size)

        return loss
 def get_dropout_mask(prob, x):
     return Bernoulli(torch.full_like(x, 1 - prob)).sample() / (1 - prob)
예제 #18
0
def train_model(model, dataloaders, dataset_sizes, criterion_1, criterion_2,\
                optimizer, scheduler, gt_norm_dict, num_epoches=200):

    criterion_3 = nn.MSELoss()

    train_log_prefix = 'train_' + time.strftime("%m-%d-%H-%M",
                                                time.localtime())

    test_log_prefix = 'test_' + time.strftime("%m-%d-%H-%M", time.localtime())

    tensor_board_dir = train_log_prefix
    tensor_board_dir = os.path.join("./new_train_logs/", tensor_board_dir)
    os.makedirs(tensor_board_dir)

    best_model = None
    bast_acc = 0
    best_pr, best_pr_pre, best_pr_rec = .0, .0, .0

    batch_size = 1

    tensorboard_writer = SummaryWriter(tensor_board_dir)

    nor_rev_std = torch.Tensor(np.diag(gt_norm_dict[:,
                                                    0].T)).double().to(device)
    nor_rev_men = torch.Tensor(gt_norm_dict[:, 1].T).double().to(device)

    loss_loc_param = 1.0
    loss_cls_param = 1.0

    min_loss = 1e4

    for epoch in tqdm(range(num_epoches)):
        print('-' * 70)
        print('Epoch: {}/{}'.format(epoch + 1, num_epoches))
        for phase in dataloaders.keys():
            if phase == 'train':
                model.train()
            elif phase == 'val' or phase == 'test':
                model.eval()
            else:
                raise ValueError("WRONG DATALOADER PHASE: {}".format(phase))

            running_loss = 0.0
            running_loss_cls = 0.0
            running_loss_loc = 0.0

            sample_num = 0
            dis_thresh = 0.45
            theta_thersh = np.pi / 9

            acc_den, acc_num = 0, 0
            iter_idx = 0

            for pcd_id, data in tqdm(dataloaders[phase]):
                iter_idx += 1
                inputs = data['input_data'].double().to(device)
                props = data['proposal'].double().to(device)
                gts = data['ground_truth'].double().to(device).view(1, -1, 13)

                if props.size(1) == 0:
                    continue

                optimizer.zero_grad()
                if torch.sum(gts[:, :, 0]).item() == 0:
                    continue

                with torch.set_grad_enabled(phase == 'train'):

                    inputs.transpose_(1, 2)

                    loc_pre, cls_pre = model(inputs, props)

                    cls_ = torch.sigmoid(cls_pre)

                    cls_ = torch.clamp(cls_, min=0.0001, max=1.0)
                    cls_ = cls_.view(1, -1).squeeze(0)
                    loc_ = loc_pre.view(-1, 4)
                    loc_ = torch.mm(loc_, nor_rev_std) + nor_rev_men

                    with torch.no_grad():

                        cls_np = cls_.cpu().numpy()
                        if not (np.logical_and(
                                cls_np > np.zeros_like(cls_np),
                                cls_np < np.ones_like(cls_np))).all():
                            print(cls_)
                            continue

                        pre_cls = (cls_ > 0.5)
                        pre_cls = pre_cls.long().to(device)
                        gt_cls = gts[:, :, 0].long().view(1, -1).squeeze(0)

                        theta_ = torch.atan2(loc_[:, 3], loc_[:,
                                                              2]).view(1, -1)

                        comp = (pre_cls == gt_cls)
                        all_truth = torch.full_like(gt_cls, 1).to(device)
                        all_false = torch.full_like(gt_cls, 0).to(device)

                        dis_err = (gts[:, :, 1:3] - loc_[:, 0:2]).view(
                            -1, 2).to(device)
                        dis_err = torch.sqrt(dis_err[:, 0]**2 +
                                             dis_err[:, 1]**2)

                        dis_positive = (dis_err < dis_thresh)
                        theta_err = (gts[:, :, 9] - theta_).squeeze()

                        the_positive_1 = torch.abs(theta_err) <= theta_thersh
                        the_positive_2 = torch.abs(theta_err +
                                                   np.pi) <= theta_thersh
                        the_positive_3 = torch.abs(theta_err -
                                                   np.pi) <= theta_thersh

                        theta_positive = (the_positive_1 | the_positive_2
                                          | the_positive_3)

                        dis_negative = ~dis_positive
                        theta_negative = ~theta_positive

                        cls_positive = (pre_cls == all_truth)
                        cls_negative = ~cls_positive

                        pre_positive = (cls_positive & dis_positive
                                        & theta_positive)
                        pre_negative = ~pre_positive

                        gt_positive = (gt_cls == all_truth)

                        acc_den += torch.sum(all_truth).item()
                        acc_num += torch.sum(gt_cls == pre_cls).item()

                    loss_cls_param = 1.5
                    loss_loc_param = 1.5

                    loc_.unsqueeze_(0)
                    if cls_.size() == torch.Size([]):
                        cls_ = cls_.unsqueeze(0)

                    loss_cls = criterion_1(cls_, gt_cls.double())

                    if torch.sum(gt_positive).item() > 0:

                        loss_loc = criterion_2(loc_[:, gt_positive, 0:4], gts[:, gt_positive, 1:5]) + \
                                   criterion_3(loc_[:, gt_positive, 0:4], gts[:, gt_positive, 1:5])
                    else:
                        loss_loc = 0
                    loss = loss_loc_param * loss_loc + loss_cls_param * loss_cls

                    if torch.isnan(loss):
                        print(loss_cls)
                        print(loss_loc)
                        print(loc_[:, gt_positive, 0:6])
                        print(gts[:, gt_positive, 1:7])

                    if phase == 'train':

                        loss.backward()

                        clip_grad_value_(model.parameters(), 5)

                        optimizer.step()

                running_loss += loss.item()
                running_loss_cls += loss_cls.item()
                running_loss_loc += loss_loc.item()
                sample_num += gts.size(1)

            # print(sample_num)
            epoch_loss = running_loss / dataset_sizes[phase]

            print('')
            print("Binary Classification Acc: {:.4f}".format(acc_num /
                                                             (acc_den + 1e-3)))
            print("Loss: {:.4f}".format(epoch_loss))

            if phase == 'train':
                tensorboard_writer.add_scalar(
                    "{}/runnning_loss".format(train_log_prefix),
                    running_loss / dataset_sizes[phase], epoch)
                tensorboard_writer.add_scalar(
                    "{}/loc loss".format(train_log_prefix),
                    running_loss_loc / dataset_sizes[phase], epoch)
                tensorboard_writer.add_scalar(
                    "{}/cls loss".format(train_log_prefix),
                    running_loss_cls / dataset_sizes[phase], epoch)
                tensorboard_writer.flush()

            elif phase == 'test':
                tensorboard_writer.add_scalar(
                    "{}/runnning_loss".format(test_log_prefix),
                    running_loss / dataset_sizes[phase], epoch)
                tensorboard_writer.add_scalar(
                    "{}/loc loss".format(test_log_prefix),
                    running_loss_loc / dataset_sizes[phase], epoch)
                tensorboard_writer.add_scalar(
                    "{}/cls loss".format(test_log_prefix),
                    running_loss_cls / dataset_sizes[phase], epoch)
                tensorboard_writer.flush()

            if phase == 'test' and epoch_loss < min_loss:

                min_loss = epoch_loss
                best_ws = copy.deepcopy(model.state_dict())
                best_epoch = epoch + 1
                torch.save(
                    best_ws, './' + tensor_board_dir + '/' + train_log_prefix +
                    'model_train_best_check.pth')
                print("best_check_point {} saved.".format(epoch))

        scheduler.step()

    print('Best epoch num: {}'.format(best_epoch))
    model.load_state_dict(best_ws)
    return model
예제 #19
0
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)

    list_dloss = []
    list_gloss = []
    list_rloss = []
    for epoch in range(5):
        netG.train()
        netD.train()
        for i, (imgs, masks) in enumerate(dataloader):
            # Optimize Discriminator
            optD.zero_grad(), netD.zero_grad(), netG.zero_grad(
            ), optG.zero_grad()
            # mask is 1 on masked region
            imgs, masks = imgs.to(device), masks.to(device)

            coarse_imgs, recon_imgs = netG(imgs, masks)
            # vutils.save_image(imgs,
            #                   '%s/real_samples.png' % 'output',
            #                   normalize=True)
            # time.sleep(10000)

            complete_imgs = recon_imgs * masks + imgs * (1 - masks)

            pos_imgs = torch.cat(
                [imgs, masks, torch.full_like(masks, 1.)], dim=1)
            neg_imgs = torch.cat(
                [complete_imgs, masks,
                 torch.full_like(masks, 1.)], dim=1)
            pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

            pred_pos_neg = netD(pos_neg_imgs)
            pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
            d_loss = DLoss(pred_pos, pred_neg)
            d_loss.backward(retain_graph=True)

            optD.step()

            # Optimize Generator
            optD.zero_grad(), netD.zero_grad(), optG.zero_grad(
            ), netG.zero_grad()
            pred_neg = netD(neg_imgs)
            # pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
            g_loss = GANLoss(pred_neg)
            r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

            whole_loss = g_loss + r_loss

            # Update the recorder for losses
            whole_loss.backward()

            optG.step()

            if i % 20 == 0:
                print('[%d/%d][%d/%d] d_loss: %.4f g_loss: %.4f r_loss: %.4f' %
                      (epoch, 3, i, len(dataloader_train), d_loss.item(),
                       g_loss.item(), r_loss.item()))

                list_dloss.append(d_loss.item())
                list_gloss.append(g_loss.item())
                list_rloss.append(r_loss.item())

            if i % 100 == 0:
                vutils.save_image(imgs,
                                  '%s/real_samples.png' % 'output',
                                  normalize=True)
                output_trans = torch.ones_like(recon_imgs)
                for i in range(args.batchsize):
                    # output_trans[i] = in_transform(output[i])
                    output_trans[i][
                        0] = recon_imgs[i][0] * opt.STD[0] + opt.MEAN[0]
                    output_trans[i][
                        1] = recon_imgs[i][1] * opt.STD[1] + opt.MEAN[1]
                    output_trans[i][
                        2] = recon_imgs[i][2] * opt.STD[2] + opt.MEAN[2]

                masks = 1 - masks
                # aaa = in_transform(imgs[0])
                # result = aaa.cpu() * masks[0].cpu() + output_trans[0].cpu() * (1 - masks[0].cpu())
                # result = transforms.ToPILImage()(result.cpu()).convert('RGB')
                # result.show()
                # aaa = in_transform(imgs[1])
                # result = aaa.cpu() * masks[1].cpu() + output_trans[1].cpu() * (1 - masks[1].cpu())
                # result = transforms.ToPILImage()(result.cpu()).convert('RGB')
                # result.show()
                # aaa = in_transform(imgs[2])
                # result = aaa.cpu() * masks[2].cpu() + output_trans[2].cpu() * (1 - masks[2].cpu())
                # result = transforms.ToPILImage()(result.cpu()).convert('RGB')
                # result.show()

                vutils.save_image(output_trans,
                                  '%s/output%03d.png' %
                                  ('output', epoch + 585),
                                  normalize=True)
                vutils.save_image(1 - masks,
                                  '%s/mask%03d.png' % ('output', epoch + 585),
                                  normalize=True)

    torch.save(netG, 'netG.pth')
    torch.save(netD, 'netD.pth')

    f = open("txt/dis_loss.txt", "a+")
    for a in list_dloss:
        f.write(str(a) + '\n')
    f.close()

    f = open("txt/gan_loss.txt", "a+")
    for a in list_gloss:
        f.write(str(a) + '\n')
    f.close()

    f = open("txt/recon_loss.txt", "a+")
    for a in list_rloss:
        f.write(str(a) + '\n')
    f.close()
예제 #20
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                state = self.state[p]

                # state initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['sum'] = torch.full_like(
                        p.data, group['initial_accumulator_value'])
                    if self._is_share_memory:
                        state['sum'].share_memory_()

                state['step'] += 1

                if group['weight_decay'] != 0:
                    if p.grad.data.is_sparse:
                        raise RuntimeError(
                            "weight_decay option is not compatible with sparse gradients"
                        )
                    grad = grad.add(group['weight_decay'], p.data)

                clr = group['lr'] / (1 +
                                     (state['step'] - 1) * group['lr_decay'])

                if grad.is_sparse:
                    grad = grad.coalesce(
                    )  # the update is non-linear so indices must be unique
                    grad_indices = grad._indices()
                    grad_values = grad._values()
                    size = grad.size()

                    def make_sparse(values):
                        constructor = grad.new
                        if grad_indices.dim() == 0 or values.dim() == 0:
                            return constructor().resize_as_(grad)
                        return constructor(grad_indices, values, size)

                    state['sum'].add_(make_sparse(grad_values.pow(2)))
                    std = state['sum'].sparse_mask(grad)
                    std_values = std._values().sqrt_().add_(1e-10)
                    p.data.add_(-clr, make_sparse(grad_values / std_values))
                else:
                    state['sum'].addcmul_(1, grad, grad)
                    std = state['sum'].sqrt().add_(1e-10)
                    p.data.addcdiv_(-clr, grad, std)

        return loss
예제 #21
0
import torch

a = torch.Tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])

print(a.shape)
print(a[:, ::2])
print(a > 0)
print(a[0][1])
print(a > 0)
print(a[a > 2])

indexes = torch.nonzero(a, as_tuple=True)
print(indexes)
exit()

print(indexes.shape)
print(indexes)
#
for index in indexes:
    # 获取到了索引
    print(a[index[0]][index[1]][index[2]])
print(torch.full_like(a, 1))

print(a.shape)

i, j, k = torch.where(a > 2)
print(a[i][j][k])
예제 #22
0
    def forward_pyramid(self, pos, anchor,
                        mask):  # pos [1, 1536, 52, 52]anchor[1, 1536, 13, 13]
        # 复现scale金字塔,多尺度cat[x2,x3],mask_resize
        ###################先将特征进行卷积操作,然后对mask下采样与特征相乘后,然后计算邻接矩阵,然后在进行图卷积等操作
        ###################
        b, dim, w, h = mask.size()
        b1, dim1, w1, h1 = pos.size()  # [1,1536,52,52]
        _, _, w2, h2 = anchor.size()
        ######################theta_s,phi_q, g_s,g_q#####################################
        theta_s = self.theta(pos)  # [1, 512, 52, 52]
        phi_q = self.phi(anchor)  # [1, 512, 13, 13] [1, 512, 52, 52]

        theta_s_1 = theta_s.view(1, -1, w1 * h1)  # [1, 512, 2704]
        phi_q_1 = phi_q.view(1, w2 * h2, -1)  # [1, 169, 512]

        g_s = self.g(pos)  # [1, 512, 13, 13]
        g_q = self.g(anchor)  # [1, 512, 13, 13][1, 512, 52, 52]

        ##############均值pooling part################################
        #####相当于不进行尺度降低时,就计算该特征map的均值相似性就行了

        ####################################################################

        # #################计算权重e_s ############################################
        e_s = torch.matmul(phi_q_1, theta_s_1)  # [1, 169,2704] [1, 2704, 2704]
        #########################归一化#########################################
        pos_mask = F.interpolate(mask, [w1, h1],
                                 mode='bilinear',
                                 align_corners=False)  # [1, 1, 52, 52]
        e_s_0 = e_s.reshape(w2 * h2, 1, w1 * h1)  # [169, 1, 2704]
        pos_mask_0 = pos_mask.reshape(1, w1 * h1)  # [1,2704]
        e_s_masked = torch.where(pos_mask_0 == 1, e_s_0,
                                 torch.full_like(
                                     e_s_0, float('-inf')))  # [169, 1, 2704]
        e_s = F.softmax(e_s_masked, dim=-1).reshape(1, w2 * h2,
                                                    w1 * h1)  # [169, 1, 2704]
        #######################################################################
        g_s_0 = g_s.view(1, w1 * h1, -1)  # [1, 2704, 512]
        v_q = torch.matmul(e_s, g_s_0)  # [1, 169, 512][1,2704,512]
        v_q = v_q.view(1, -1, w2, h2)  # [1, 512, 13, 13]

        finall_q = self.eth(torch.cat([g_q, v_q], 1))  # [1, 512, 13, 13]
        if w2 == w1:
            pos_node = F.interpolate(theta_s,
                                     size=(w, h),
                                     mode='bilinear',
                                     align_corners=False)  # [1, 512, 416, 416]
            vec_pos = torch.sum(torch.sum(pos_node * mask, dim=3),
                                dim=2) / torch.sum(mask)  # [1, 512]
            vec_pos = vec_pos.unsqueeze(dim=2).repeat(1, 1,
                                                      w1 * h1)  # [1,512,52*52]
            #####计算权重
            e_s_1 = torch.matmul(phi_q_1, vec_pos)  # [1,16,16]
            e_s_1 = e_s_1 / torch.sum(e_s_1, 2)
            ###########计算g_s
            g_s_1 = F.interpolate(g_s,
                                  size=(w, h),
                                  mode='bilinear',
                                  align_corners=False)
            g_pos = torch.sum(torch.sum(g_s_1 * mask, dim=3),
                              dim=2) / torch.sum(mask)  # [1, 512]
            g_pos = g_pos.unsqueeze(dim=2).repeat(1, 1,
                                                  w1 * h1)  # [1,512,52*52]
            v_q_1 = torch.matmul(e_s_1, g_pos.reshape(1, w1 * h1, -1))
            v_q_1 = v_q_1.view(1, -1, w2, h2)
            finall_q_1 = self.eth(torch.cat([g_q, v_q_1],
                                            1))  # [1, 512, 52, 52]
            return finall_q_1 + finall_q
        # ##############均值pooling part################################
        # pos_node = F.interpolate(theta_s, size=(w, h), mode='bilinear', align_corners=False)  # [1, 512, 416, 416]
        # vec_pos = torch.sum(torch.sum(pos_node * mask, dim=3), dim=2) / torch.sum(mask)  # [1, 512]
        # vec_pos = vec_pos.unsqueeze(dim=2).repeat(1, 1, w1 * h1)  # [1,512,52*52]
        # # vec_pos = torch.sum(torch.sum(theta_s, dim=3), dim=2)
        # #####计算权重
        # e_s_1 = torch.matmul(phi_q_1, vec_pos)
        # ###########计算g_s
        # g_s_1 = F.interpolate(g_s, size=(w, h), mode='bilinear', align_corners=False)
        # g_pos = torch.sum(torch.sum(g_s_1 * mask, dim=3), dim=2) / torch.sum(mask)  # [1, 512]
        # g_pos = g_pos.unsqueeze(dim=2).repeat(1, 1, w1 * h1)  # [1,512,52*52]
        # v_q_1 = torch.matmul(e_s_1, g_pos.reshape(1, w1*h1, -1))
        # v_q_1 = v_q_1.view(1, -1, w2, h2)
        # finall_q_1 = self.eth(torch.cat([g_q, v_q_1], 1))  # [1, 512, 52, 52]
        # ######################################################################################
        # return finall_q + finall_q_1

        # if w2 == h2 == 52:
        #     pos_node = F.interpolate(theta_s, size=(w, h), mode='bilinear', align_corners=False)  # [1, 512, 416, 416]
        #     vec_pos = torch.sum(torch.sum(pos_node * mask, dim=3), dim=2) / torch.sum(mask)  # [1, 512]
        #     vec_pos = vec_pos.unsqueeze(dim=2).unsqueeze(dim=3).repeat(1, 1, w2, h2)  # [1,512,52,52]
        #     #####每个权重都一样,然后在做加权求和,其实就相当于他自己
        #     # vec_q = self.g(vec_pos)
        #     vec_q = self.eth(torch.cat([g_q, vec_pos], 1))  # [1, 512, 52, 52]
        #     ######################
        #     return finall_q + vec_q
        return finall_q
예제 #23
0
    def forward_resize_feature(
            self, pos, anchor,
            mask):  # pos [1, 1536, 52, 52]anchor[1, 1536, 13, 13]
        # 复现scale金字塔,多尺度cat[x2,x3]
        ###################先将特征进行卷积操作,然后对mask下采样与特征相乘后,然后计算邻接矩阵,然后在进行图卷积等操作
        ###################
        b, dim, w, h = mask.size()
        b1, dim1, w1, h1 = pos.size()  # [1,1536,52,52]
        _, _, w2, h2 = anchor.size()
        theta_s = self.theta(pos)  # [1, 512, 47, 63]
        ####################支持图片先上次采样与mask相乘,再下采样################
        theta_s = F.interpolate(theta_s, [w, h],
                                mode='bilinear',
                                align_corners=False)  # [1, 512, 374, 500]
        theta_s = theta_s * mask  # [1, 512, 374, 500]
        theta_s = F.interpolate(theta_s, [w1, h1],
                                mode='bilinear',
                                align_corners=False)  # [1, 512, 47, 63]
        inf_mask = torch.where(theta_s[0][0] == 0,
                               torch.full_like(theta_s[0][0], 1),
                               torch.full_like(theta_s[0][0],
                                               0))  # [1, 512, 47, 63]
        #################计算权重e_s ###############################################################
        theta_s_1 = theta_s.view(1, -1,
                                 w1 * h1)  # [1, 512, 2961][1, 512, 2704]
        phi_q = self.phi(anchor)  # [1, 512, 13, 13] [1, 512, 52, 52]
        phi_q_1 = phi_q.view(1, w2 * h2, -1)  # [1, 169, 512]
        e_s = torch.matmul(phi_q_1, theta_s_1)  # [1, 169,2704] [1, 2704, 2704]
        ################对权重进行归一化#########
        # import matplotlib.pyplot as plt
        # plt.imshow(mask[0][0].cpu().detach().numpy())
        # plt.show()
        e_s_0 = e_s.reshape(w2 * h2, 1, w1 * h1)  # [169, 1, 2704]
        inf_mask = inf_mask.reshape(1, w1 * h1)  # [1,2704]
        e_s_masked = torch.where(inf_mask == 1,
                                 torch.full_like(e_s_0, float('-inf')),
                                 e_s_0)  # [169, 1, 2704]
        e_s = F.softmax(e_s_masked, dim=-1).reshape(1, w2 * h2,
                                                    w1 * h1)  # [169, 1, 2704]
        print(e_s.sum())
        ######################计算最终的q#######################################
        g_s = self.g(pos).view(1, w1 * h1, -1)  # [1, 2704, 512]
        g_q = self.g(anchor)  # [1, 512, 13, 13][1, 512, 52, 52]

        v_q = torch.matmul(e_s, g_s)  # [1, 169, 512][1,2704,512]
        v_q = v_q.view(1, -1, w2, h2)  # [1, 512, 13, 13]
        finall_q = self.eth(torch.cat([g_q, v_q], 1))  # [1, 512, 13, 13]

        ################均值pooling part###############################
        if w2 == h2 == 52:
            pos_node = F.interpolate(theta_s,
                                     size=(w, h),
                                     mode='bilinear',
                                     align_corners=False)  # [1, 512, 416, 416]
            vec_pos = torch.sum(torch.sum(pos_node * mask, dim=3),
                                dim=2) / torch.sum(mask)  # [1, 512]
            vec_pos = vec_pos.unsqueeze(dim=2).unsqueeze(dim=3).repeat(
                1, 1, w2, h2)  # [1,512,52,52]
            #####每个权重都一样,然后在做加权求和,其实就相当于他自己
            # vec_q = self.g(vec_pos)
            vec_q = self.eth(torch.cat([g_q, vec_pos], 1))  # [1, 512, 52, 52]
            ######################
            return finall_q + vec_q
        return finall_q
예제 #24
0
                    def run(self, test_idx, test_at_steps=None):
                        assert test_at_steps is None

                        logging.info(
                            'Test #{} nearest neighbor classification with k={} and {}-norm'
                            .format(test_idx, k, p))

                        state = self.state

                        with state.pretend(distill_epochs=1):
                            ref_data_label = tuple(get_data_label(state))

                        ref_flat_data = torch.cat(
                            [d for d, _ in ref_data_label], 0).flatten(1)
                        #print(ref_flat_data.shape)
                        ref_label = torch.cat([l for _, l in ref_data_label],
                                              0)

                        assert k <= ref_label.size(0), (
                            'k={} is greater than the number of data {}. '
                            'Set k to the latter').format(
                                k, ref_label.size(0))

                        total = np.array(0, dtype=np.int64)
                        corrects = np.array(0, dtype=np.int64)
                        for example in state.test_loader:
                            if state.textdata:
                                data = example.text[0]
                                target = example.label
                            else:
                                (data, target) = example
                            data = data.to(state.device, non_blocking=True)
                            target = target.to(state.device, non_blocking=True)
                            #print(data.shape)
                            if state.textdata:
                                data = encode(data, state)
                                data.unsqueeze_(1)
                                dists = torch.norm(
                                    data.flatten(1)[:, None, ...] -
                                    ref_flat_data,
                                    dim=2,
                                    p=p)
                            else:
                                dists = torch.norm(
                                    data.flatten(1)[:, None, ...] -
                                    ref_flat_data,
                                    dim=2,
                                    p=p)
                            if k == 1:
                                argmin_dist = dists.argmin(dim=1)
                                if state.num_classes == 2:
                                    pred = (ref_label[argmin_dist] > 0.5).to(
                                        target.dtype).view(-1)
                                else:
                                    pred = ref_label[argmin_dist].argmax(1)
                                del argmin_dist
                            else:
                                _, argmink_dist = torch.topk(dists,
                                                             k,
                                                             dim=1,
                                                             largest=False,
                                                             sorted=False)
                                labels = ref_label[argmink_dist]
                                #print(labels.shape)
                                #print(labels[0].argmax(1).shape)
                                #print(labels)
                                counts = [
                                    torch.bincount(torch.as_tensor(
                                        list(map(int, l > 0.5)),
                                        device=state.device)
                                                   if state.num_classes == 2
                                                   else l.argmax(1),
                                                   minlength=state.num_classes)
                                    for l in labels
                                ]
                                counts = torch.stack(counts, 0)
                                #print(counts.shape)
                                pred = counts.argmax(dim=1)
                                del argmink_dist, labels, counts
                            corrects += (pred == target).sum().item()
                            total += data.size(0)

                        at_steps = torch.ones(1,
                                              dtype=torch.long,
                                              device=state.device)
                        acc = torch.as_tensor(corrects / total,
                                              device=state.device).view(
                                                  1, 1)  # STEP x MODEL
                        loss = torch.full_like(acc, utils.nan)  # STEP x MODEL
                        return (at_steps, acc, loss)
예제 #25
0
    def test_FixedNoiseMultiTaskGP(self):
        bounds = torch.tensor([[-1.0, 0.0], [1.0, 1.0]])
        for dtype, use_intf in itertools.product((torch.float, torch.double),
                                                 (False, True)):
            tkwargs = {"device": self.device, "dtype": dtype}
            intf = (Normalize(
                d=2, bounds=bounds.to(
                    **tkwargs), transform_on_train=True) if use_intf else None)
            model, train_X, _, _ = _get_fixed_noise_model_and_training_data(
                input_transform=intf, **tkwargs)
            self.assertIsInstance(model, FixedNoiseMultiTaskGP)
            self.assertEqual(model.num_outputs, 2)
            self.assertIsInstance(model.likelihood,
                                  FixedNoiseGaussianLikelihood)
            self.assertIsInstance(model.mean_module, ConstantMean)
            self.assertIsInstance(model.covar_module, ScaleKernel)
            matern_kernel = model.covar_module.base_kernel
            self.assertIsInstance(matern_kernel, MaternKernel)
            self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
            self.assertIsInstance(model.task_covar_module, IndexKernel)
            self.assertEqual(model._rank, 2)
            self.assertEqual(model.task_covar_module.covar_factor.shape[-1],
                             model._rank)
            if use_intf:
                self.assertIsInstance(model.input_transform, Normalize)

            # test model fitting
            mll = ExactMarginalLogLikelihood(model.likelihood, model)
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=OptimizationWarning)
                mll = fit_gpytorch_model(mll,
                                         options={"maxiter": 1},
                                         max_retries=1)

            # check that training data has input transform applied
            # check that the train inputs have been transformed and set on the model
            if use_intf:
                self.assertTrue(
                    torch.equal(model.train_inputs[0],
                                model.input_transform(train_X)))

            # test posterior
            test_x = torch.rand(2, 1, **tkwargs)
            posterior_f = model.posterior(test_x)
            self.assertIsInstance(posterior_f, GPyTorchPosterior)
            self.assertIsInstance(posterior_f.mvn, MultitaskMultivariateNormal)
            self.assertEqual(posterior_f.mean.shape, torch.Size([2, 2]))
            self.assertEqual(posterior_f.variance.shape, torch.Size([2, 2]))

            # test that posterior w/ observation noise raises appropriate error
            with self.assertRaises(NotImplementedError):
                model.posterior(test_x, observation_noise=True)
            with self.assertRaises(NotImplementedError):
                model.posterior(test_x,
                                observation_noise=torch.rand(2, **tkwargs))

            # test posterior w/ single output index
            posterior_f = model.posterior(test_x, output_indices=[0])
            self.assertIsInstance(posterior_f, GPyTorchPosterior)
            self.assertIsInstance(posterior_f.mvn, MultivariateNormal)
            self.assertEqual(posterior_f.mean.shape, torch.Size([2, 1]))
            self.assertEqual(posterior_f.variance.shape, torch.Size([2, 1]))

            # test posterior w/ bad output index
            with self.assertRaises(ValueError):
                model.posterior(test_x, output_indices=[2])

            # test posterior (batch eval)
            test_x = torch.rand(3, 2, 1, **tkwargs)
            posterior_f = model.posterior(test_x)
            self.assertIsInstance(posterior_f, GPyTorchPosterior)
            self.assertIsInstance(posterior_f.mvn, MultitaskMultivariateNormal)

            # test that unsupported batch shape MTGPs throw correct error
            with self.assertRaises(ValueError):
                FixedNoiseMultiTaskGP(torch.rand(2, 2, 2), torch.rand(2, 2, 1),
                                      torch.rand(2, 2, 1), 0)

            # test that bad feature index throws correct error
            train_X, train_Y = _get_random_mt_data(**tkwargs)
            train_Yvar = torch.full_like(train_Y, 0.05)
            with self.assertRaises(ValueError):
                FixedNoiseMultiTaskGP(train_X, train_Y, train_Yvar, 2)

            # test that bad output task throws correct error
            with self.assertRaises(RuntimeError):
                FixedNoiseMultiTaskGP(train_X,
                                      train_Y,
                                      train_Yvar,
                                      0,
                                      output_tasks=[2])
예제 #26
0
def compute_loss(p, targets, model):  # predictions, targets, model
    device = targets.device
    lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(
        1, device=device), torch.zeros(1, device=device)
    tcls, tbox, indices, anchors = build_targets(p, targets, model)  # targets
    h = model.hyp  # hyperparameters

    # Define criteria
    BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(
        [h['cls_pw']], device=device))  # weight=model.class_weights)
    BCEobj = nn.BCEWithLogitsLoss(
        pos_weight=torch.tensor([h['obj_pw']], device=device))

    # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
    cp, cn = smooth_BCE(eps=0.0)

    # Focal loss
    g = h['fl_gamma']  # focal loss gamma
    if g > 0:
        BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

    # Losses
    balance = [4.0, 1.0, 0.4, 0.1]  # P3-P6
    for i, pi in enumerate(p):  # layer index, layer predictions
        b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
        tobj = torch.zeros_like(pi[..., 0], device=device)  # target obj

        n = b.shape[0]  # number of targets
        if n:
            ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets

            # Regression
            pxy = ps[:, :2].sigmoid() * 2. - 0.5
            pwh = (ps[:, 2:4].sigmoid() * 2)**2 * anchors[i]
            pbox = torch.cat((pxy, pwh), 1)  # predicted box
            iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False,
                           CIoU=True)  # iou(prediction, target)
            lbox += (1.0 - iou).mean()  # iou loss

            # Objectness
            tobj[b, a, gj,
                 gi] = (1.0 -
                        model.gr) + model.gr * iou.detach().clamp(0).type(
                            tobj.dtype)  # iou ratio

            # Classification
            if model.nc > 1:  # cls loss (only if multiple classes)
                t = torch.full_like(ps[:, 5:], cn, device=device)  # targets
                t[range(n), tcls[i]] = cp
                lcls += BCEcls(ps[:, 5:], t)  # BCE

            # Append targets to text file
            # with open('targets.txt', 'a') as file:
            #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]

        lobj += BCEobj(pi[..., 4], tobj) * balance[i]  # obj loss

    lbox *= h['box']
    lobj *= h['obj']
    lcls *= h['cls']
    bs = tobj.shape[0]  # batch size

    loss = lbox + lobj + lcls
    return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
def aa_center_target(gt_bboxes_list,
                     featmap_sizes,
                     anchor_scale,
                     anchor_strides,
                     center_ratio=0.2,
                     ignore_ratio=0.5):
    """
    Args:
        gt_bboxes_list (list[Tensor]): Gt bboxes of each image.
        featmap_sizes (list[tuple]): Multi level sizes of each feature maps.
        anchor_scale (int): Anchor scale.
        anchor_strides ([list[int]]): Multi level anchor strides.
        center_ratio (float): Ratio of center region.
        ignore_ratio (float): Ratio of ignore region.

    Returns:
        tuple
    """
    img_per_gpu = len(gt_bboxes_list)
    num_lvls = len(featmap_sizes)
    r1 = (1 - center_ratio) / 2
    r2 = (1 - ignore_ratio) / 2
    all_center_targets = []
    all_center_weights = []
    all_ignore_map = []
    for lvl_id in range(num_lvls):
        h, w = featmap_sizes[lvl_id]
        center_targets = torch.zeros(img_per_gpu,
                                     1,
                                     h,
                                     w,
                                     device=gt_bboxes_list[0].device,
                                     dtype=torch.float32)
        center_weights = torch.full_like(center_targets, -1)
        ignore_map = torch.zeros_like(center_targets)
        all_center_targets.append(center_targets)
        all_center_weights.append(center_weights)
        all_ignore_map.append(ignore_map)
    for img_id in range(img_per_gpu):
        gt_bboxes = gt_bboxes_list[img_id]
        scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) *
                           (gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1))
        min_anchor_size = scale.new_full(
            (1, ), float(anchor_scale * anchor_strides[0]))
        # assign gt bboxes to different feature levels w.r.t. their scales
        target_lvls = torch.floor(
            torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
        target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
        for gt_id in range(gt_bboxes.size(0)):
            lvl = target_lvls[gt_id].item()
            # rescaled to corresponding feature map
            gt_ = gt_bboxes[gt_id, :4] / anchor_strides[lvl]
            # calculate ignore regions
            ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
                gt_, r2, featmap_sizes[lvl])
            # calculate positive (center) regions
            ctr_x1, ctr_y1, ctr_x2, ctr_y2 = calc_region(
                gt_, r1, featmap_sizes[lvl])
            all_center_targets[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
                                    ctr_x1:ctr_x2 + 1] = 1
            all_center_weights[lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
                                    ignore_x1:ignore_x2 + 1] = 0
            all_center_weights[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
                                    ctr_x1:ctr_x2 + 1] = 1
            # calculate ignore map on nearby low level feature
            if lvl > 0:
                d_lvl = lvl - 1
                # rescaled to corresponding feature map
                gt_ = gt_bboxes[gt_id, :4] / anchor_strides[d_lvl]
                ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
                    gt_, r2, featmap_sizes[d_lvl])
                all_ignore_map[d_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
                                      ignore_x1:ignore_x2 + 1] = 1
            # calculate ignore map on nearby high level feature
            if lvl < num_lvls - 1:
                u_lvl = lvl + 1
                # rescaled to corresponding feature map
                gt_ = gt_bboxes[gt_id, :4] / anchor_strides[u_lvl]
                ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
                    gt_, r2, featmap_sizes[u_lvl])
                all_ignore_map[u_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
                                      ignore_x1:ignore_x2 + 1] = 1
    for lvl_id in range(num_lvls):
        # ignore negative regions w.r.t. ignore map
        all_center_weights[lvl_id][(all_center_weights[lvl_id] < 0)
                                   & (all_ignore_map[lvl_id] > 0)] = 0
        # set negative regions with weight 0.1
        all_center_weights[lvl_id][all_center_weights[lvl_id] < 0] = 0.1
    # loc average factor to balance loss
    center_avg_factor = sum(
        [t.size(0) * t.size(-1) * t.size(-2)
         for t in all_center_targets]) / 200
    return all_center_targets, all_center_weights, center_avg_factor
예제 #28
0
 def _get_src_permutation_idx(self, indices):
     # permute predictions following indices
     batch_idx = torch.cat(
         [torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
     src_idx = torch.cat([src for (src, _) in indices])
     return batch_idx, src_idx
예제 #29
0
    def forward(self, images, debug_percentile=None):
        assert isinstance(images, torch.Tensor) and images.ndim == 4
        batch_size, num_channels, height, width = images.shape
        device = images.device
        if debug_percentile is not None:
            debug_percentile = torch.as_tensor(debug_percentile,
                                               dtype=torch.float32,
                                               device=device)

        # -------------------------------------
        # Select parameters for pixel blitting.
        # -------------------------------------

        # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
        I_3 = torch.eye(3, device=device)
        G_inv = I_3

        # Apply x-flip with probability (xflip * strength).
        if self.xflip > 0:
            i = torch.floor(torch.rand([batch_size], device=device) * 2)
            i = torch.where(
                torch.rand([batch_size], device=device) < self.xflip * self.p,
                i, torch.zeros_like(i))
            if debug_percentile is not None:
                i = torch.full_like(i, torch.floor(debug_percentile * 2))
            G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)

        # Apply 90 degree rotations with probability (rotate90 * strength).
        if self.rotate90 > 0:
            i = torch.floor(torch.rand([batch_size], device=device) * 4)
            i = torch.where(
                torch.rand([batch_size], device=device) <
                self.rotate90 * self.p, i, torch.zeros_like(i))
            if debug_percentile is not None:
                i = torch.full_like(i, torch.floor(debug_percentile * 4))
            G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)

        # Apply integer translation with probability (xint * strength).
        if self.xint > 0:
            t = (torch.rand([batch_size, 2], device=device) * 2 -
                 1) * self.xint_max
            t = torch.where(
                torch.rand([batch_size, 1], device=device) <
                self.xint * self.p, t, torch.zeros_like(t))
            if debug_percentile is not None:
                t = torch.full_like(t,
                                    (debug_percentile * 2 - 1) * self.xint_max)
            G_inv = G_inv @ translate2d_inv(torch.round(t[:, 0] * width),
                                            torch.round(t[:, 1] * height))

        # --------------------------------------------------------
        # Select parameters for general geometric transformations.
        # --------------------------------------------------------

        # Apply isotropic scaling with probability (scale * strength).
        if self.scale > 0:
            s = torch.exp2(
                torch.randn([batch_size], device=device) * self.scale_std)
            s = torch.where(
                torch.rand([batch_size], device=device) < self.scale * self.p,
                s, torch.ones_like(s))
            if debug_percentile is not None:
                s = torch.full_like(
                    s,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.scale_std))
            G_inv = G_inv @ scale2d_inv(s, s)

        # Apply pre-rotation with probability p_rot.
        p_rot = 1 - torch.sqrt(
            (1 - self.rotate * self.p).clamp(0, 1))  # P(pre OR post) = p
        if self.rotate > 0:
            theta = (torch.rand([batch_size], device=device) * 2 -
                     1) * np.pi * self.rotate_max
            theta = torch.where(
                torch.rand([batch_size], device=device) < p_rot, theta,
                torch.zeros_like(theta))
            if debug_percentile is not None:
                theta = torch.full_like(theta, (debug_percentile * 2 - 1) *
                                        np.pi * self.rotate_max)
            G_inv = G_inv @ rotate2d_inv(-theta)  # Before anisotropic scaling.

        # Apply anisotropic scaling with probability (aniso * strength).
        if self.aniso > 0:
            s = torch.exp2(
                torch.randn([batch_size], device=device) * self.aniso_std)
            s = torch.where(
                torch.rand([batch_size], device=device) < self.aniso * self.p,
                s, torch.ones_like(s))
            if debug_percentile is not None:
                s = torch.full_like(
                    s,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.aniso_std))
            G_inv = G_inv @ scale2d_inv(s, 1 / s)

        # Apply post-rotation with probability p_rot.
        if self.rotate > 0:
            theta = (torch.rand([batch_size], device=device) * 2 -
                     1) * np.pi * self.rotate_max
            theta = torch.where(
                torch.rand([batch_size], device=device) < p_rot, theta,
                torch.zeros_like(theta))
            if debug_percentile is not None:
                theta = torch.zeros_like(theta)
            G_inv = G_inv @ rotate2d_inv(-theta)  # After anisotropic scaling.

        # Apply fractional translation with probability (xfrac * strength).
        if self.xfrac > 0:
            t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
            t = torch.where(
                torch.rand([batch_size, 1], device=device) <
                self.xfrac * self.p, t, torch.zeros_like(t))
            if debug_percentile is not None:
                t = torch.full_like(
                    t,
                    torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
            G_inv = G_inv @ translate2d_inv(t[:, 0] * width, t[:, 1] * height)

        # ----------------------------------
        # Execute geometric transformations.
        # ----------------------------------

        # Execute if the transform is not identity.
        if G_inv is not I_3:

            # Calculate padding.
            cx = (width - 1) / 2
            cy = (height - 1) / 2
            cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1],
                        device=device)  # [idx, xyz]
            cp = G_inv @ cp.t()  # [batch, xyz, idx]
            Hz_pad = self.Hz_geom.shape[0] // 4
            margin = cp[:, :2, :].permute(1, 0,
                                          2).flatten(1)  # [xy, batch * idx]
            margin = torch.cat([-margin,
                                margin]).max(dim=1).values  # [x0, y0, x1, y1]
            margin = margin + misc.constant(
                [Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
            margin = margin.max(misc.constant([0, 0] * 2, device=device))
            margin = margin.min(
                misc.constant([width - 1, height - 1] * 2, device=device))
            mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)

            # Pad image and adjust origin.
            images = torch.nn.functional.pad(input=images,
                                             pad=[mx0, mx1, my0, my1],
                                             mode='reflect')
            G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv

            # Upsample.
            images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
            G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(
                2, 2, device=device)
            G_inv = translate2d(-0.5, -0.5,
                                device=device) @ G_inv @ translate2d_inv(
                                    -0.5, -0.5, device=device)

            # Execute transformation.
            shape = [
                batch_size, num_channels, (height + Hz_pad * 2) * 2,
                (width + Hz_pad * 2) * 2
            ]
            G_inv = scale2d(2 / images.shape[3],
                            2 / images.shape[2],
                            device=device) @ G_inv @ scale2d_inv(
                                2 / shape[3], 2 / shape[2], device=device)
            grid = torch.nn.functional.affine_grid(theta=G_inv[:, :2, :],
                                                   size=shape,
                                                   align_corners=False)
            images = grid_sample_gradfix.grid_sample(images, grid)

            # Downsample and crop.
            images = upfirdn2d.downsample2d(x=images,
                                            f=self.Hz_geom,
                                            down=2,
                                            padding=-Hz_pad * 2,
                                            flip_filter=True)

        # --------------------------------------------
        # Select parameters for color transformations.
        # --------------------------------------------

        # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
        I_4 = torch.eye(4, device=device)
        C = I_4

        # Apply brightness with probability (brightness * strength).
        if self.brightness > 0:
            b = torch.randn([batch_size], device=device) * self.brightness_std
            b = torch.where(
                torch.rand([batch_size], device=device) <
                self.brightness * self.p, b, torch.zeros_like(b))
            if debug_percentile is not None:
                b = torch.full_like(
                    b,
                    torch.erfinv(debug_percentile * 2 - 1) *
                    self.brightness_std)
            C = translate3d(b, b, b) @ C

        # Apply contrast with probability (contrast * strength).
        if self.contrast > 0:
            c = torch.exp2(
                torch.randn([batch_size], device=device) * self.contrast_std)
            c = torch.where(
                torch.rand([batch_size], device=device) <
                self.contrast * self.p, c, torch.ones_like(c))
            if debug_percentile is not None:
                c = torch.full_like(
                    c,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.contrast_std))
            C = scale3d(c, c, c) @ C

        # Apply luma flip with probability (lumaflip * strength).
        v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3),
                          device=device)  # Luma axis.
        if self.lumaflip > 0:
            i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
            i = torch.where(
                torch.rand([batch_size, 1, 1], device=device) <
                self.lumaflip * self.p, i, torch.zeros_like(i))
            if debug_percentile is not None:
                i = torch.full_like(i, torch.floor(debug_percentile * 2))
            C = (I_4 - 2 * v.ger(v) * i) @ C  # Householder reflection.

        # Apply hue rotation with probability (hue * strength).
        if self.hue > 0 and num_channels > 1:
            theta = (torch.rand([batch_size], device=device) * 2 -
                     1) * np.pi * self.hue_max
            theta = torch.where(
                torch.rand([batch_size], device=device) < self.hue * self.p,
                theta, torch.zeros_like(theta))
            if debug_percentile is not None:
                theta = torch.full_like(theta, (debug_percentile * 2 - 1) *
                                        np.pi * self.hue_max)
            C = rotate3d(v, theta) @ C  # Rotate around v.

        # Apply saturation with probability (saturation * strength).
        if self.saturation > 0 and num_channels > 1:
            s = torch.exp2(
                torch.randn([batch_size, 1, 1], device=device) *
                self.saturation_std)
            s = torch.where(
                torch.rand([batch_size, 1, 1], device=device) <
                self.saturation * self.p, s, torch.ones_like(s))
            if debug_percentile is not None:
                s = torch.full_like(
                    s,
                    torch.exp2(
                        torch.erfinv(debug_percentile * 2 - 1) *
                        self.saturation_std))
            C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C

        # ------------------------------
        # Execute color transformations.
        # ------------------------------

        # Execute if the transform is not identity.

        # Does this transformation really only work for 3 or 1 channel ?
        if C is not I_4:
            images = images.reshape([batch_size, num_channels, height * width])
            if num_channels == 3:
                images = C[:, :3, :3] @ images + C[:, :3, 3:]
            elif num_channels == 1:
                C = C[:, :3, :].mean(dim=1, keepdims=True)
                images = images * C[:, :, :3].sum(dim=2,
                                                  keepdims=True) + C[:, :, 3:]
            elif num_channels == 6:
                images[:, :3] = C[:, :3, :3] @ images[:, :3] + C[:, :3, 3:]
            else:
                raise ValueError(
                    'Image must be RGB (3 channels) or L (1 channel)')
            images = images.reshape([batch_size, num_channels, height, width])

        # ----------------------
        # Image-space filtering.
        # ----------------------

        if self.imgfilter > 0:
            num_bands = self.Hz_fbank.shape[0]
            assert len(self.imgfilter_bands) == num_bands
            expected_power = misc.constant(
                np.array([10, 1, 1, 1]) / 13,
                device=device)  # Expected power spectrum (1/f).

            # Apply amplification for each band with probability (imgfilter * strength * band_strength).
            g = torch.ones([batch_size, num_bands],
                           device=device)  # Global gain vector (identity).
            for i, band_strength in enumerate(self.imgfilter_bands):
                t_i = torch.exp2(
                    torch.randn([batch_size], device=device) *
                    self.imgfilter_std)
                t_i = torch.where(
                    torch.rand([batch_size], device=device) <
                    self.imgfilter * self.p * band_strength, t_i,
                    torch.ones_like(t_i))
                if debug_percentile is not None:
                    t_i = torch.full_like(
                        t_i,
                        torch.exp2(
                            torch.erfinv(debug_percentile * 2 - 1) *
                            self.imgfilter_std)
                    ) if band_strength > 0 else torch.ones_like(t_i)
                t = torch.ones([batch_size, num_bands],
                               device=device)  # Temporary gain vector.
                t[:, i] = t_i  # Replace i'th element.
                t = t / (expected_power * t.square()).sum(
                    dim=-1, keepdims=True).sqrt()  # Normalize power.
                g = g * t  # Accumulate into global gain.

            # Construct combined amplification filter.
            Hz_prime = g @ self.Hz_fbank  # [batch, tap]
            Hz_prime = Hz_prime.unsqueeze(1).repeat(
                [1, num_channels, 1])  # [batch, channels, tap]
            Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1,
                                         -1])  # [batch * channels, 1, tap]

            # Apply filter.
            p = self.Hz_fbank.shape[1] // 2
            images = images.reshape(
                [1, batch_size * num_channels, height, width])
            images = torch.nn.functional.pad(input=images,
                                             pad=[p, p, p, p],
                                             mode='reflect')
            images = conv2d_gradfix.conv2d(input=images,
                                           weight=Hz_prime.unsqueeze(2),
                                           groups=batch_size * num_channels)
            images = conv2d_gradfix.conv2d(input=images,
                                           weight=Hz_prime.unsqueeze(3),
                                           groups=batch_size * num_channels)
            images = images.reshape([batch_size, num_channels, height, width])

        # ------------------------
        # Image-space corruptions.
        # ------------------------

        # Apply additive RGB noise with probability (noise * strength).
        if self.noise > 0:
            sigma = torch.randn([batch_size, 1, 1, 1],
                                device=device).abs() * self.noise_std
            sigma = torch.where(
                torch.rand([batch_size, 1, 1, 1], device=device) <
                self.noise * self.p, sigma, torch.zeros_like(sigma))
            if debug_percentile is not None:
                sigma = torch.full_like(
                    sigma,
                    torch.erfinv(debug_percentile) * self.noise_std)
            images = images + torch.randn(
                [batch_size, num_channels, height, width],
                device=device) * sigma

        # Apply cutout with probability (cutout * strength).
        if self.cutout > 0:
            size = torch.full([batch_size, 2, 1, 1, 1],
                              self.cutout_size,
                              device=device)
            size = torch.where(
                torch.rand([batch_size, 1, 1, 1, 1], device=device) <
                self.cutout * self.p, size, torch.zeros_like(size))
            center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
            if debug_percentile is not None:
                size = torch.full_like(size, self.cutout_size)
                center = torch.full_like(center, debug_percentile)
            coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
            coord_y = torch.arange(height,
                                   device=device).reshape([1, 1, -1, 1])
            mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >=
                      size[:, 0] / 2)
            mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >=
                      size[:, 1] / 2)
            mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
            images = images * mask

        return images
예제 #30
0
 def _get_tgt_permutation_idx(self, indices):
     # permute targets following indices
     batch_idx = torch.cat(
         [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
     tgt_idx = torch.cat([tgt for (_, tgt) in indices])
     return batch_idx, tgt_idx
예제 #31
0
 def forward(self, x):
     value = 42 if x.dtype == torch.int32 else float_test_num
     return x.fill_(value), torch.full_like(x, value), torch.full(
         input_shapes, value, dtype=x.dtype)
예제 #32
0
def distiller_qparams_to_pytorch(scale, zp, num_bits, distiller_mode, dest_dtype, reduce_range=False):
    """
    Convert quantization parameters (scale and zero-point) calculated by Distiller APIs to quantization parameters
    compatible with PyTorch quantization APIs.

    By "calculated with Distiller APIs" we mean calculated using either of:
      * distiller.quantization.symmetric_linear_quantization_params
      * distiller.quantization.asymmetric_linear_quantization_params

    The main differences between quantization parameters as calculated by Distiller and PyTorch:
      * pytorch_scale = 1 / distiller_scale
      * pytorch_zero_point = -distiller_zero_point

    Args:
        scale (torch.Tensor): Scale factor calcualted by Distiller
        zp (torch.Tensor): Zero point calcualted by Distiller
        num_bits (int): Number of bits used for quantization in Distiller
        distiller_mode (distiller.quantization.LinearQuantMode): The quantization mode used in Distiller
        dest_dtype (torch.dtype): PyTorch quantized dtype to convert to. Must be one of: torch.quint8, torch.qint8
        reduce_range (bool): Reduces the range of the quantized data type by 1 bit. This should mainly be used for
          quantized activations with the "fbgemm" PyTorch backend - it prevents overflows. See:
          https://github.com/pytorch/pytorch/blob/fde94e75568b527b424b108c272793e096e8e471/torch/quantization/observer.py#L294

    Returns:
        Tuple of (scale, zero_point) which are compatible with PyTorch quantization API
    """
    assert dest_dtype in (torch.qint8, torch.quint8), 'Must specify one of the quantized PyTorch dtypes'

    distiller_symmetric = is_linear_quant_mode_symmetric(distiller_mode)
    if distiller_symmetric and dest_dtype == torch.quint8:
        reduce_range = False

    distiller_asym_signed = distiller_mode == LinearQuantMode.ASYMMETRIC_SIGNED

    if reduce_range:
        assert num_bits == 8, 'reduce_range needed only when num_bits == 8'
        if distiller_symmetric and dest_dtype == torch.quint8:
            raise NotImplementedError('reduce_range + symmetric + quint8 not supported in PyTorch')
        num_bits = 7
        if distiller_symmetric:
            ratio = 63. / 127.
        else:
            ratio = 127. / 255.
            zp_offset = 128 if distiller_asym_signed else 0
            zp = ((zp - zp_offset) * ratio + zp_offset / 2).round()
        scale = scale * ratio

    scale = scale.cpu().squeeze()
    zp = zp.cpu().squeeze().long()

    # Distiller scale is the reciprocal of PyTorch scale
    scale_torch = 1. / scale

    n_bins_half = 2 ** (num_bits - 1)

    if distiller_symmetric:
        # In Distiller symmetric is always signed with zero-point = 0, but in PyTorch it can be
        # unsigned in which case we offset the zero-point to the middle of the quantized range
        zp_torch = zp if dest_dtype == torch.qint8 else torch.full_like(zp, n_bins_half)
    else:
        pytorch_signed = dest_dtype == torch.qint8
        if distiller_asym_signed and not pytorch_signed:
            zp = zp - n_bins_half
        elif not distiller_asym_signed and pytorch_signed:
            zp = zp + n_bins_half
        # Distiller subtracts the zero-point when quantizing, PyTorch adds it.
        # So we negate the zero-point calculated in Distiller
        zp_torch = -zp
    return scale_torch, zp_torch
def compute_loss(p, targets, model):  # predictions, targets, model
    ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
    lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
    tcls, tbox, indices, anchor_vec = build_targets(p, targets, model)
    h = model.hyp  # hyperparameters
    red = 'mean'  # Loss reduction (sum or mean)

    # Define criteria
    BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red)
    BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red)

    # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
    cp, cn = smooth_BCE(eps=0.0)

    # focal loss
    g = h['fl_gamma']  # focal loss gamma
    if g > 0:
        BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

    # Compute losses
    np, ng = 0, 0  # number grid points, targets
    for i, pi in enumerate(p):  # layer index, layer predictions
        b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
        tobj = torch.zeros_like(pi[..., 0])  # target obj
        np += tobj.numel()

        # Compute losses
        nb = len(b)
        if nb:  # number of targets
            ng += nb
            ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets
            # ps[:, 2:4] = torch.sigmoid(ps[:, 2:4])  # wh power loss (uncomment)

            # GIoU
            pxy = torch.sigmoid(ps[:, 0:2])  # pxy = pxy * s - (s - 1) / 2,  s = 1.5  (scale_xy)
            pwh = torch.exp(ps[:, 2:4]).clamp(max=1E3) * anchor_vec[i]
            pbox = torch.cat((pxy, pwh), 1)  # predicted box
            giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)  # giou computation
            lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean()  # giou loss
            tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype)  # giou ratio

            if model.nc > 1:  # cls loss (only if multiple classes)
                t = torch.full_like(ps[:, 5:], cn)  # targets
                t[range(nb), tcls[i]] = cp
                lcls += BCEcls(ps[:, 5:], t)  # BCE
                # lcls += CE(ps[:, 5:], tcls[i])  # CE

            # Append targets to text file
            # with open('targets.txt', 'a') as file:
            #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]

        lobj += BCEobj(pi[..., 4], tobj)  # obj loss

    lbox *= h['giou']
    lobj *= h['obj']
    lcls *= h['cls']
    if red == 'sum':
        bs = tobj.shape[0]  # batch size
        lobj *= 3 / (6300 * bs) * 2  # 3 / np * 2
        if ng:
            lcls *= 3 / ng / model.nc
            lbox *= 3 / ng

    loss = lbox + lobj + lcls
    return loss, torch.cat((lbox, lobj, lcls, loss)).detach()
예제 #34
0
    def forward(ctx,
                input,
                input_lengths,
                num_graphs,
                den_graphs,
                leaky_coefficient=1e-5):
        try:
            import pychain_C
        except ImportError:
            raise ImportError(
                "Please install OpenFST and PyChain by `make openfst pychain` "
                "after entering espresso/tools")

        input = input.clamp(
            -30, 30)  # clamp for both the denominator and the numerator
        B = input.size(0)
        if B != num_graphs.batch_size or B != den_graphs.batch_size:
            raise ValueError(
                "input batch size ({}) does not equal to num graph batch size ({}) "
                "or den graph batch size ({})".format(B, num_graphs.batch_size,
                                                      den_graphs.batch_size))
        packed_data = torch.nn.utils.rnn.pack_padded_sequence(
            input,
            input_lengths,
            batch_first=True,
        )
        batch_sizes = packed_data.batch_sizes
        input_lengths = input_lengths.cpu()

        exp_input = input.exp()
        den_objf, input_grad, denominator_ok = pychain_C.forward_backward(
            den_graphs.forward_transitions,
            den_graphs.forward_transition_indices,
            den_graphs.forward_transition_probs,
            den_graphs.backward_transitions,
            den_graphs.backward_transition_indices,
            den_graphs.backward_transition_probs,
            den_graphs.leaky_probs,
            den_graphs.initial_probs,
            den_graphs.final_probs,
            den_graphs.start_state,
            exp_input,
            batch_sizes,
            input_lengths,
            den_graphs.num_states,
            leaky_coefficient,
        )
        denominator_ok = denominator_ok.item()

        assert num_graphs.log_domain
        num_objf, log_probs_grad, numerator_ok = pychain_C.forward_backward_log_domain(
            num_graphs.forward_transitions,
            num_graphs.forward_transition_indices,
            num_graphs.forward_transition_probs,
            num_graphs.backward_transitions,
            num_graphs.backward_transition_indices,
            num_graphs.backward_transition_probs,
            num_graphs.initial_probs,
            num_graphs.final_probs,
            num_graphs.start_state,
            input,
            batch_sizes,
            input_lengths,
            num_graphs.num_states,
        )
        numerator_ok = numerator_ok.item()

        loss = -num_objf + den_objf

        if (loss - loss) != 0.0 or not denominator_ok or not numerator_ok:
            default_loss = 10
            input_grad = torch.zeros_like(input)
            logger.warning(
                f"Loss is {loss} and denominator computation "
                f"(if done) returned {denominator_ok} "
                f"and numerator computation returned {numerator_ok} "
                f", setting loss to {default_loss} per frame")
            loss = torch.full_like(num_objf,
                                   default_loss * input_lengths.sum())
        else:
            num_grad = log_probs_grad.exp()
            input_grad -= num_grad

        ctx.save_for_backward(input_grad)
        return loss
예제 #35
0
파일: utils.py 프로젝트: saschwan/botorch
def _fix_feature(Z: Tensor, value: Optional[float]) -> Tensor:
    r"""Helper function returns a Tensor like `Z` filled with `value` if provided."""
    if value is None:
        return Z.detach()
    return torch.full_like(Z, value)
예제 #36
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = rlt.FeatureVector(
                rescale_torch_tensor(
                    action.float_features,
                    new_min=self.min_action_range_tensor_training,
                    new_max=self.max_action_range_tensor_training,
                    prev_min=self.min_action_range_tensor_serving,
                    prev_max=self.max_action_range_tensor_serving,
                )
            )

        current_state_action = rlt.StateAction(state=state, action=action)

        q1_value = self.q1_network(current_state_action).q_value
        min_q_value = q1_value

        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
            min_q_value = torch.min(q1_value, q2_value)

        # Use the minimum as target, ensure no gradient going through
        min_q_value = min_q_value.detach()

        #
        # First, optimize value network; minimizing MSE between
        # V(s) & Q(s, a) - log(pi(a|s))
        #

        state_value = self.value_network(state.float_features)  # .q_value

        if self.logged_action_uniform_prior:
            log_prob_a = torch.zeros_like(min_q_value)
            target_value = min_q_value
        else:
            with torch.no_grad():
                log_prob_a = self.actor_network.get_log_prob(
                    state, action.float_features
                )
                log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                target_value = min_q_value - self.entropy_temperature * log_prob_a

        value_loss = F.mse_loss(state_value, target_value)
        self.value_network_optimizer.zero_grad()
        value_loss.backward()
        self.value_network_optimizer.step()

        #
        # Second, optimize Q networks; minimizing MSE between
        # Q(s, a) & r + discount * V'(next_s)
        #

        with torch.no_grad():
            next_state_value = (
                self.value_network_target(learning_input.next_state.float_features)
                * not_done_mask.float()
            )

            if self.minibatch < self.reward_burnin:
                target_q_value = reward
            else:
                target_q_value = reward + discount * next_state_value

        q1_loss = F.mse_loss(q1_value, target_q_value)
        self.q1_network_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_network_optimizer.step()
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            self.q2_network_optimizer.zero_grad()
            q2_loss.backward()
            self.q2_network_optimizer.step()

        #
        # Lastly, optimize the actor; minimizing KL-divergence between action propensity
        # & softmax of value. Due to reparameterization trick, it ends up being
        # log_prob(actor_action) - Q(s, actor_action)
        #

        actor_output = self.actor_network(rlt.StateInput(state=state))

        state_actor_action = rlt.StateAction(
            state=state, action=rlt.FeatureVector(float_features=actor_output.action)
        )
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = (
            self.entropy_temperature * actor_output.log_prob - min_q_actor_value
        )
        # Do this in 2 steps so we can log histogram of actor loss
        actor_loss_mean = actor_loss.mean()
        self.actor_network_optimizer.zero_grad()
        actor_loss_mean.backward()
        self.actor_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.value_network, self.value_network_target, 1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.value_network, self.value_network_target, self.tau)

        # Logging at the end to schedule all the cuda operations first
        if (
            self.tensorboard_logging_freq is not None
            and self.minibatch % self.tensorboard_logging_freq == 0
        ):
            SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value)

            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            SummaryWriterContext.add_histogram("value_network/target", target_value)
            SummaryWriterContext.add_histogram(
                "q_network/next_state_value", next_state_value
            )
            SummaryWriterContext.add_histogram(
                "q_network/target_q_value", target_q_value
            )
            SummaryWriterContext.add_histogram(
                "actor/min_q_actor_value", min_q_actor_value
            )
            SummaryWriterContext.add_histogram(
                "actor/action_log_prob", actor_output.log_prob
            )
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
            model_propensities=actor_output.log_prob.exp(),
            model_values=min_q_actor_value,
        )
예제 #37
0
    def test_FixedNoiseMultiTaskGP(self, cuda=False):
        for double in (False, True):
            tkwargs = {
                "device": torch.device("cuda") if cuda else torch.device("cpu"),
                "dtype": torch.double if double else torch.float,
            }
            model = _get_fixed_noise_model(**tkwargs)
            self.assertIsInstance(model, FixedNoiseMultiTaskGP)
            self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
            self.assertIsInstance(model.mean_module, ConstantMean)
            self.assertIsInstance(model.covar_module, ScaleKernel)
            matern_kernel = model.covar_module.base_kernel
            self.assertIsInstance(matern_kernel, MaternKernel)
            self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
            self.assertIsInstance(model.task_covar_module, IndexKernel)
            self.assertEqual(model._rank, 2)
            self.assertEqual(
                model.task_covar_module.covar_factor.shape[-1], model._rank
            )

            # test model fitting
            mll = ExactMarginalLogLikelihood(model.likelihood, model)
            mll = fit_gpytorch_model(mll, options={"maxiter": 1})

            # test posterior
            test_x = torch.rand(2, 1, **tkwargs)
            posterior_f = model.posterior(test_x)
            self.assertIsInstance(posterior_f, GPyTorchPosterior)
            self.assertIsInstance(posterior_f.mvn, MultitaskMultivariateNormal)
            self.assertEqual(posterior_f.mean.shape, torch.Size([2, 2]))
            self.assertEqual(posterior_f.variance.shape, torch.Size([2, 2]))

            # TODO: test posterior w/ observation noise

            # test posterior w/ single output index
            posterior_f = model.posterior(test_x, output_indices=[0])
            self.assertIsInstance(posterior_f, GPyTorchPosterior)
            self.assertIsInstance(posterior_f.mvn, MultivariateNormal)
            self.assertEqual(posterior_f.mean.shape, torch.Size([2, 1]))
            self.assertEqual(posterior_f.variance.shape, torch.Size([2, 1]))

            # test posterior w/ bad output index
            with self.assertRaises(ValueError):
                model.posterior(test_x, output_indices=[2])

            # test posterior (batch eval)
            test_x = torch.rand(3, 2, 1, **tkwargs)
            posterior_f = model.posterior(test_x)
            self.assertIsInstance(posterior_f, GPyTorchPosterior)
            self.assertIsInstance(posterior_f.mvn, MultitaskMultivariateNormal)

            # test that unsupported batch shape MTGPs throw correct error
            with self.assertRaises(ValueError):
                FixedNoiseMultiTaskGP(
                    torch.rand(2, 2, 2), torch.rand(2, 1), torch.rand(2, 1), 0
                )

            # test that bad feature index throws correct error
            train_X, train_Y = _get_random_mt_data(**tkwargs)
            train_Yvar = torch.full_like(train_Y, 0.05)
            with self.assertRaises(ValueError):
                FixedNoiseMultiTaskGP(train_X, train_Y, train_Yvar, 2)

            # test that bad output task throws correct error
            with self.assertRaises(RuntimeError):
                FixedNoiseMultiTaskGP(train_X, train_Y, train_Yvar, 0, output_tasks=[2])
예제 #38
0
파일: utils.py 프로젝트: yiminzme/botorch
def _fix_feature(Z: Tensor, value: Optional[float]) -> Tensor:
    r"""Helper function returns a Tensor like `Z` filled with `value` if provided."""
    if value is None:
        return Z.detach()
    return torch.full_like(Z, value)
예제 #39
0
    def train(self, training_batch) -> None:
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        reward = learning_input.reward
        if self.multi_steps is not None:
            discount_tensor = torch.pow(self.gamma, learning_input.step.float())
        else:
            discount_tensor = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            # TODO: Implement this in another diff
            raise NotImplementedError

        if self.maxq_learning:
            all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
                learning_input.tiled_next_state, learning_input.possible_next_actions
            )
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, _ = self.get_max_q_values_with_target(
                all_next_q_values.q_value,
                all_next_q_values_target.q_value,
                learning_input.possible_next_actions_mask.float(),
            )
        else:
            # SARSA
            next_q_values, _ = self.get_detached_q_values(
                learning_input.next_state, learning_input.next_action
            )
            next_q_values = next_q_values.q_value

        filtered_max_q_vals = next_q_values * not_done_mask.float()

        if self.minibatch < self.reward_burnin:
            target_q_values = reward
        else:
            target_q_values = reward + (discount_tensor * filtered_max_q_vals)

        # Get Q-value of action taken
        current_state_action = rlt.StateAction(
            state=learning_input.state, action=learning_input.action
        )
        q_values = self.q_network(current_state_action).q_value
        self.all_action_scores = q_values.detach()

        value_loss = self.q_network_loss(q_values, target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        # TODO: Maybe soft_update should belong to the target network
        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(current_state_action).q_value
        reward_loss = F.mse_loss(reward_estimates, reward)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            model_values_on_logged_actions=self.all_action_scores,
        )
예제 #40
0
def _fmt_box_list(box_tensor, batch_index: int):
    repeated_index = torch.full_like(box_tensor[:, :1],
                                     batch_index,
                                     dtype=box_tensor.dtype,
                                     device=box_tensor.device)
    return cat((repeated_index, box_tensor), dim=1)
예제 #41
0
def train(netG,
          netD,
          GANLoss,
          ReconLoss,
          DLoss,
          optG,
          optD,
          dataloader,
          epoch,
          device=cuda0,
          val_datas=None):
    """
    Train Phase, for training and spectral normalization patch gan in
    Free-Form Image Inpainting with Gated Convolution (snpgan)

    """
    netG.to(device)
    netD.to(device)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        'd_loss': AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    for i, (imgs, masks) in enumerate(dataloader):
        data_time.update(time.time() - end)
        masks = masks['random_free_form']

        # Optimize Discriminator
        optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad()
        guide = []
        transform = transforms.Compose([transforms.ToPILImage()])
        for k in range(imgs.shape[0]):
            im = transform(imgs[k])
            im = np.array(im)
            # cv2.imwrite('test.jpg', im)

            im = cv2.Canny(image=im, threshold1=20, threshold2=220)
            # cv2.imwrite('test1.jpg', im)

            guide.append(im)
        guide = torch.FloatTensor(guide)
        guide = guide[:, None, :, :]
        imgs, masks, guide = imgs.to(device), masks.to(device), guide.to(
            device)

        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        guide = guide / 255.0

        coarse_imgs, recon_imgs, attention = netG(imgs, masks, guide)
        # print(attention.size(), )
        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)
        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        d_loss.backward(retain_graph=True)

        optD.step()

        # Optimize Generator
        optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad()
        pred_neg = netD(neg_imgs)
        # pred_pos, pred_neg = torch.chunk(pred_pos_neg,  2, dim=0)
        g_loss = GANLoss(pred_neg)
        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))
        whole_loss.backward()

        optG.step()

        # Update time recorder
        batch_time.update(time.time() - end)

        if (i + 1) % config.SUMMARY_FREQ == 0:
            # Logger logging
            logger.info(
                "Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t,"
                "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \
                .format(epoch, i + 1, len(dataloader), batch_time=batch_time, data_time=data_time,
                        whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        , g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            # Tensorboard logger for scaler and images
            info_terms = {
                'WGLoss': whole_loss.item(),
                'ReconLoss': r_loss.item(),
                "GANLoss": g_loss.item(),
                "DLoss": d_loss.item()
            }

            for tag, value in info_terms.items():
                tensorboardlogger.scalar_summary(tag, value,
                                                 epoch * len(dataloader) + i)

            for tag, value in losses.items():
                tensorboardlogger.scalar_summary('avg_' + tag, value.avg,
                                                 epoch * len(dataloader) + i)

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'train/ori_imgs':img2photo(imgs),
            #          'train/coarse_imgs':img2photo(coarse_imgs),
            #          'train/recon_imgs':img2photo(recon_imgs),
            #          'train/comp_imgs':img2photo(complete_imgs),
            info = {
                'train/whole_imgs':
                img2photo(
                    torch.cat([
                        imgs * (1 - masks), coarse_imgs, recon_imgs, imgs,
                        complete_imgs
                    ],
                              dim=3))
            }

            for tag, images in info.items():
                tensorboardlogger.image_summary(tag, images,
                                                epoch * len(dataloader) + i)
        if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None:
            validate(netG,
                     netD,
                     GANLoss,
                     ReconLoss,
                     DLoss,
                     optG,
                     optD,
                     val_datas,
                     epoch,
                     device,
                     batch_n=i)
            netG.train()
            netD.train()
        end = time.time()
예제 #42
0
def validate(netG,
             netD,
             GANLoss,
             ReconLoss,
             DLoss,
             optG,
             optD,
             dataloader,
             epoch,
             device=cuda0,
             batch_n="whole"):
    """
    validate phase
    """
    netG.to(device)
    netD.to(device)
    netG.eval()
    netD.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = {
        "g_loss": AverageMeter(),
        "r_loss": AverageMeter(),
        "whole_loss": AverageMeter(),
        "d_loss": AverageMeter()
    }

    netG.train()
    netD.train()
    end = time.time()
    val_save_dir = os.path.join(
        result_dir, "val_{}_{}".format(
            epoch, batch_n if isinstance(batch_n, str) else batch_n + 1))
    val_save_real_dir = os.path.join(val_save_dir, "real")
    val_save_gen_dir = os.path.join(val_save_dir, "gen")
    val_save_inf_dir = os.path.join(val_save_dir, "inf")
    if not os.path.exists(val_save_real_dir):
        os.makedirs(val_save_real_dir)
        os.makedirs(val_save_gen_dir)
        os.makedirs(val_save_inf_dir)
    info = {}
    for i, (imgs, masks) in enumerate(dataloader):

        data_time.update(time.time() - end)
        masks = masks['val']
        # masks = (masks > 0).type(torch.FloatTensor)

        imgs, masks = imgs.to(device), masks.to(device)
        imgs = (imgs / 127.5 - 1)
        # mask is 1 on masked region
        # forward
        coarse_imgs, recon_imgs, attention = netG.forward(imgs, masks)

        complete_imgs = recon_imgs * masks + imgs * (1 - masks)

        pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        neg_imgs = torch.cat(
            [complete_imgs, masks,
             torch.full_like(masks, 1.)], dim=1)
        pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0)

        pred_pos_neg = netD(pos_neg_imgs)
        pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0)

        g_loss = GANLoss(pred_neg)

        r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks)

        whole_loss = g_loss + r_loss

        # Update the recorder for losses
        losses['g_loss'].update(g_loss.item(), imgs.size(0))
        losses['r_loss'].update(r_loss.item(), imgs.size(0))
        losses['whole_loss'].update(whole_loss.item(), imgs.size(0))

        d_loss = DLoss(pred_pos, pred_neg)
        losses['d_loss'].update(d_loss.item(), imgs.size(0))
        # Update time recorder
        batch_time.update(time.time() - end)

        # Logger logging

        if i + 1 < config.STATIC_VIEW_SIZE:

            def img2photo(imgs):
                return ((imgs + 1) * 127.5).transpose(1, 2).transpose(
                    2, 3).detach().cpu().numpy()

            # info = { 'val/ori_imgs':img2photo(imgs),
            #          'val/coarse_imgs':img2photo(coarse_imgs),
            #          'val/recon_imgs':img2photo(recon_imgs),
            #          'val/comp_imgs':img2photo(complete_imgs),
            info['val/whole_imgs/{}'.format(i)] = img2photo(
                torch.cat([
                    imgs *
                    (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs
                ],
                          dim=3))

        else:
            logger.info(
                "Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t,"
                "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}"
                .format(epoch, i + 1, len(dataloader), batch_time=batch_time, data_time=data_time,
                        whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \
                        , g_loss=losses['g_loss'], d_loss=losses['d_loss']))
            j = 0
            for tag, images in info.items():
                h, w = images.shape[1], images.shape[2] // 5
                for val_img in images:
                    real_img = val_img[:, (3 * w):(4 * w), :]
                    gen_img = val_img[:, (4 * w):, :]
                    real_img = Image.fromarray(real_img.astype(np.uint8))
                    gen_img = Image.fromarray(gen_img.astype(np.uint8))
                    real_img.save(
                        os.path.join(val_save_real_dir, "{}.png".format(j)))
                    gen_img.save(
                        os.path.join(val_save_gen_dir, "{}.png".format(j)))
                    j += 1
                tensorboardlogger.image_summary(tag, images, epoch)
            path1, path2 = val_save_real_dir, val_save_gen_dir
            fid_score = metrics['fid']([path1, path2], cuda=False)
            ssim_score = metrics['ssim']([path1, path2])
            tensorboardlogger.scalar_summary('val/fid', fid_score.item(),
                                             epoch * len(dataloader) + i)
            tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(),
                                             epoch * len(dataloader) + i)
            break

        end = time.time()
예제 #43
0
    def __call__(self, y_pred, y_true):
        na, nt = self.na, y_true.shape[0]  # number of anchors, targets
        t_cls, t_box, indices, anchors = [], [], [], []
        gain = torch.ones(7, device=y_true.device)  # normalized to grid-space gain
        ai = torch.arange(na, device=y_true.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)
        true = torch.cat((y_true.repeat(na, 1, 1), ai[:, :, None]), 2)  # append anchor indices

        g = 0.5  # bias
        off = torch.tensor([[0, 0], [1, 0], [0, 1], [-1, 0], [0, -1], ], device=true.device).float() * g  # offsets

        for i in range(self.nl):
            anchor = self.anchors[i]
            gain[2:6] = torch.tensor(y_pred[i].shape)[[3, 2, 3, 2]]  # xy-xy gain

            # Match targets to anchors
            t = true * gain
            if nt:
                # Matches
                r = t[:, :, 4:6] / anchor[:, None]  # wh ratio
                j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t']  # compare
                # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
                t = t[j]  # filter

                # Offsets
                gxy = t[:, 2:4]  # grid xy
                gxi = gain[[2, 3]] - gxy  # inverse
                j, k = ((gxy % 1. < g) & (gxy > 1.)).T
                l, m = ((gxi % 1. < g) & (gxi > 1.)).T
                j = torch.stack((torch.ones_like(j), j, k, l, m))
                t = t.repeat((5, 1, 1))[j]
                offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
            else:
                t = true[0]
                offsets = 0

            # Define
            b, c = t[:, :2].long().T  # image, class
            gxy = t[:, 2:4]  # grid xy
            gwh = t[:, 4:6]  # grid wh
            gij = (gxy - offsets).long()
            gi, gj = gij.T  # grid xy indices

            # Append
            a = t[:, 6].long()  # anchor indices
            indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1)))  # image, anchor, grid indices
            t_box.append(torch.cat((gxy - gij, gwh), 1))  # box
            anchors.append(anchor[a])  # anchors
            t_cls.append(c)  # class

        device = y_true.device
        l_cls = torch.zeros(1, device=device)
        l_box = torch.zeros(1, device=device)
        l_obj = torch.zeros(1, device=device)
        t_obj = None
        # Losses
        for i, pred in enumerate(y_pred):  # layer index, layer predictions
            b, a, gj, gi = indices[i]  # image, anchor, grid_y, grid_x
            t_obj = torch.zeros_like(pred[..., 0], device=device)  # target obj

            n = b.shape[0]  # number of targets
            if n:
                ps = pred[b, a, gj, gi]  # prediction subset corresponding to targets

                # Regression
                p_xy = ps[:, :2].sigmoid() * 2. - 0.5
                p_wh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
                p_box = torch.cat((p_xy, p_wh), 1)  # predicted box
                iou = compute_iou(p_box.T, t_box[i], c_iou=True)  # iou(prediction, target)
                l_box += (1.0 - iou).mean()  # iou loss

                # Object-ness
                t_obj[b, a, gj, gi] = iou.detach().clamp(0).type(t_obj.dtype)  # iou ratio

                # Classification
                if self.nc > 1:  # cls loss (only if multiple classes)
                    t = torch.full_like(ps[:, 5:], self.cn, device=device)  # targets
                    t[range(n), t_cls[i]] = self.cp
                    l_cls += self.bce_cls(ps[:, 5:], t)  # BCE

            l_obj += self.bce_obj(pred[..., 4], t_obj) * self.balance[i]  # obj loss

        l_box *= self.hyp['box']
        l_obj *= self.hyp['obj']
        l_cls *= self.hyp['cls']
        bs = t_obj.shape[0]  # batch size

        loss = l_box + l_obj + l_cls
        return loss * bs, loss.detach()
예제 #44
0
def train_with_grad_control(model, trainloader, criterion, optimizer, HufuSize,
                            ModelSize, prevGH, prevGM, alpha, gamma, prevRatio,
                            epo):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    grad_H_R = torch.tensor(0)
    grad_M_R = torch.tensor(0)
    Ratio = 0
    cnt = 0

    Hgrad_epo = torch.tensor([0]).to(device)
    Mgrad_epo = torch.tensor([0]).to(device)

    for i, (input, target) in enumerate(trainloader):

        Grad_H = torch.tensor(0)
        Grad_M = torch.tensor(0)
        ratio = torch.tensor(0)

        input = torch.squeeze(input)
        target = torch.squeeze(target)
        # input = torch.unsqueeze(input, 1)
        input, target = input.to(device), target.to(device)

        # compute output
        output = model(input)
        if (epo == 0):
            loss_func = nn.CrossEntropyLoss().cuda()
            loss = loss_func(output, target)
        else:
            loss = criterion(output, target, prevGH, prevGM, alpha, gamma,
                             i % 100, prevRatio)

        # measure accuracy and record loss
        err1, err5 = get_error(output.detach(), target, topk=(1, 5))

        losses.update(loss.item(), input.size(0))
        top1.update(err1.item(), input.size(0))
        top5.update(err5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()

        Mgrad_nonzeroes = 0
        Hgrad_nonzeroes = 0
        Hgrad_list = torch.tensor([0]).to(device)
        Mgrad_list = torch.tensor([0]).to(device)

        Hgrad_list_i = torch.tensor([0]).to(device)
        Mgrad_list_i = torch.tensor([0]).to(device)

        for name, param in model.named_parameters():
            if 'conv' in name and 'weight' in name:
                name = name.rstrip('.weight')  # 'conv1_1'
                name_sp = name.split('.')
                print(name_sp)
                torch.autograd.set_detect_anomaly(True)

                current_grad = getattr(model, name_sp[0]).weight.grad
                #current_grad = getattr(getattr(getattr(model, name_sp[0]),name_sp[1]),name_sp[2]).weight.grad

                hufu_mask = getattr(model, name_sp[0].replace('conv',
                                                              'grad')).weight

                b = torch.where(
                    torch.abs(current_grad) <= 0.00001,
                    torch.full_like(current_grad, 0),
                    torch.full_like(current_grad, 1))
                Mgrad_nonzeroes += torch.sum(b)

                Mgrad_list = torch.cat(
                    (Mgrad_list, torch.mul(current_grad, b).view(-1)), 0)

                c = torch.where(hufu_mask == 0, torch.full_like(hufu_mask, 1),
                                torch.full_like(hufu_mask, 0))
                d = torch.mul(c, b)
                Hz = torch.sum(d)
                hufu_grad = torch.mul(d, current_grad)

                Hgrad_list_i = torch.cat(
                    (Hgrad_list_i, torch.mul(c, current_grad).view(-1)), 0)
                Mgrad_list_i = torch.cat((Mgrad_list_i, current_grad.view(-1)),
                                         0)

                Hgrad_list = torch.cat((Hgrad_list, hufu_grad.view(-1)), 0)

                Hgrad_nonzeroes += Hz

                Grad_M = Grad_M + torch.sum(
                    torch.abs(torch.mul(current_grad, b))).item()
                Grad_H = Grad_H + torch.sum(
                    torch.abs(torch.mul(current_grad, d))).item()

                #unfiltered_grad = torch.mul(hufu_mask,current_grad)
                #getattr(model.module, name_sp[1]).weight.grad = current_grad
                getattr(model, name_sp[0]).weight.grad = current_grad

                # validate the gradient don't change in some piece of one single layer
                #getattr(model.module, name).weight.grad[:10,0,0,0] = torch.zeros(getattr(model.module, name).weight.grad[:10,0,0,0].size())

        optimizer.step()
        Hmean = torch.sum(Hgrad_list) / Hgrad_nonzeroes
        c = torch.where(Hgrad_list == 0,
                        torch.full_like(Hgrad_list, Hmean.item()), Hgrad_list)
        Mmean = torch.sum(Mgrad_list) / Mgrad_nonzeroes
        d = torch.where(Mgrad_list == 0,
                        torch.full_like(Mgrad_list, Mmean.item()), Mgrad_list)
        VarH = torch.sum(pow(c - Hmean, 2)) / Hgrad_nonzeroes
        VarM = torch.sum(pow(d - Mmean, 2)) / Mgrad_nonzeroes
        ratio = VarM / VarH
        Ratio += ratio

        grad_M_R = Grad_M / Mgrad_nonzeroes.item()
        grad_H_R = Grad_H / Hgrad_nonzeroes.item()
        cnt += 1

        if (i == 0):
            Hgrad_epo = Hgrad_list_i
            Mgrad_epo = Mgrad_list_i
        else:
            Hgrad_epo += Hgrad_list_i
            Hgrad_epo += Hgrad_list_i
    print(cnt)
    Ratio = Ratio / cnt

    return [grad_H_R, grad_M_R, Ratio, Mgrad_epo, Hgrad_epo]