コード例 #1
0
    def apply(module, name, dim):
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, WeightNorm) and hook.name == name:
                raise RuntimeError("Cannot register two weight_norm hooks on "
                                   "the same parameter {}".format(name))

        if dim is None:
            dim = -1

        fn = WeightNorm(name, dim)

        weight = getattr(module, name)

        # remove w from parameter list
        del module._parameters[name]

        # add g and v as new parameters and express w as g/||v|| * v
        module.register_parameter(
            name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
        module.register_parameter(name + '_v', Parameter(weight.data))
        setattr(module, name, fn.compute_weight(module))

        # recompute weight before every forward()
        module.register_forward_pre_hook(fn)

        return fn
コード例 #2
0
ファイル: weight_norm.py プロジェクト: thomascong121/NCRF
    def apply(module, name: str, dim: int) -> 'WeightNorm':
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, WeightNorm) and hook.name == name:
                raise RuntimeError("Cannot register two weight_norm hooks on "
                                   "the same parameter {}".format(name))

        if dim is None:
            dim = -1

        fn = WeightNorm(name, dim)

        weight = getattr(module, name)
        if isinstance(weight, UninitializedParameter):
            raise ValueError(
                'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
                'Make sure to run the dummy forward before applying weight normalization'
            )
        # remove w from parameter list
        del module._parameters[name]

        # add g and v as new parameters and express w as g/||v|| * v
        module.register_parameter(
            name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
        module.register_parameter(name + '_v', Parameter(weight.data))
        setattr(module, name, fn.compute_weight(module))

        # recompute weight before every forward()
        module.register_forward_pre_hook(fn)

        return fn
コード例 #3
0
ファイル: weight_norm.py プロジェクト: victorcampos7/pytorch
    def apply(module, name, dim, init, gamma):
        assert init in ['default', 'norm_preserving'], \
            "Invalid init for WeightNorm ({}). It must be one of ['default', 'norm_preserving']".format(init)

        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, WeightNorm) and hook.name == name:
                raise RuntimeError("Cannot register two weight_norm hooks on "
                                   "the same parameter {}".format(name))

        if dim is None:
            dim = -1

        fn = WeightNorm(name, dim)

        weight = getattr(module, name)

        # remove w from parameter list
        del module._parameters[name]

        # initial value for g
        if init == 'default':
            g_init = norm_except_dim(weight, 2, dim)
        elif init == 'norm_preserving':
            fan_in, fan_out = _calculate_fan_in_and_fan_out(weight)
            g_init = full_like(norm_except_dim(weight, 2, dim),
                               math.sqrt(gamma * fan_in / fan_out))

        # add g and v as new parameters and express w as g/||v|| * v
        module.register_parameter(name + '_g', Parameter(g_init.data))
        module.register_parameter(name + '_v', Parameter(weight.data))
        setattr(module, name, fn.compute_weight(module))

        # recompute weight before every forward()
        module.register_forward_pre_hook(fn)

        return fn
コード例 #4
0
    def apply(module, name, dim):
        fn = WeightNorm(name, dim)

        weight = getattr(module, name)

        # remove w from parameter list
        del module._parameters[name]

        # add g and v as new parameters and express w as g/||v|| * v
        module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
        module.register_parameter(name + '_v', Parameter(weight.data))
        setattr(module, name, fn.compute_weight(module))

        # recompute weight before every forward()
        module.register_forward_pre_hook(fn)

        return fn
コード例 #5
0
def _norm_except_dim(w, norm_type, dim):
    if norm_type == 1 or norm_type == 2:
        return torch.norm_except_dim(w, norm_type, dim)
    elif norm_type == float('inf'):
        return _max_except_dim(w, dim)
コード例 #6
0
ファイル: pplm.py プロジェクト: AdarshKumar712/PPCM
def perturb_past(past, model, prev, args, classifier, good_index=None, stepsize=0.01, vocab_size=50257,
                 original_probs=None, accumulated_hidden=None, current_output=None, true_past=None, grad_norms=None,
                 knowledge_to_ent=None):
    window_length = args.window_length
    gm_scale, kl_scale = args.gm_scale, args.kl_scale
    one_hot_vectors = []
    for good_list in good_index:
        good_list = list(filter(lambda x: len(x) <= 1, good_list))
        good_list = torch.tensor(good_list).cuda()
        num_good = good_list.shape[0]
        one_hot_good = torch.zeros(num_good, vocab_size).cuda()
        one_hot_good.scatter_(1, good_list, 1)
        one_hot_vectors.append(one_hot_good)


    # Generate inital perturbed past
    past_perturb_orig = [(np.random.uniform(0.0, 0.0, p.shape).astype('float32'))
                         for p in past]
    
    
    if accumulated_hidden is None:
        accumulated_hidden = 0

    if args.decay:
        decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0/(window_length))[1:]
    else:
        decay_mask = 1.0

    # Generate a mask is gradient perturbated is based on a past window
    _, batch_size, _, current_length, _ = past[0].shape
    if current_length > window_length and window_length > 0:
        ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(
            past[0].shape[-1:])

        zeros_key_val_shape = tuple(past[0].shape[:-2]) + tuple([current_length - window_length]) + tuple(
            past[0].shape[-1:])

        ones_mask = torch.ones(ones_key_val_shape)
        ones_mask = decay_mask*ones_mask.permute(0, 1, 2, 4, 3)
        ones_mask = ones_mask.permute(0, 1, 2, 4, 3)

        window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).cuda()
    else:
        window_mask = torch.ones_like(past[0]).cuda()

    loss_per_iter = []
    loss_barrier = 1.0
    for current_iter in range(args.num_iterations):
        # print("Iteration ", i + 1)
        past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
        past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]

        perturbed_past = list(map(add, past, past_perturb))

        _, _, _, current_length, _ = past_perturb[0].shape

        # Compute hidden using perturbed past
        logits, future_past = model(prev, past=perturbed_past)
        hidden = model.hidden_states
        new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()

        # TODO: Check the layer-norm consistency of this with trained discriminator
        logits = logits[:, -1, :]
        probabs = F.softmax(logits, dim=-1)
        loss = 0.0

        ## BOW
        if args.loss_type == 1 or args.loss_type == 3:

            for one_hot_good in one_hot_vectors:
                good_logits = torch.mm(probabs, torch.t(one_hot_good))
                loss_word = good_logits
                loss_word = torch.sum(loss_word, dim=1)
                loss_word = -torch.log(loss_word)
                loss += loss_word.sum()
            loss_per_iter.append(loss_word.detach().tolist())

        ## DISCRIMINATOR
        if args.loss_type == 2 or args.loss_type == 3:
            ce_loss = torch.nn.CrossEntropyLoss(reduction='sum')
            new_true_past = true_past
            for i in range(args.horizon_length):

                future_probabs = F.softmax(logits, dim=-1)  # Get softmax
                future_probabs = torch.unsqueeze(future_probabs, dim=1)

                _, new_true_past = model(future_probabs, past=new_true_past)
                future_hidden = model.hidden_states  # Get expected hidden states
                new_accumulated_hidden = new_accumulated_hidden + torch.sum(future_hidden, dim=1)

            predicted_sentiment = classifier(new_accumulated_hidden / (current_length + 1 + args.horizon_length))

            label = torch.tensor([args.label_class], device='cuda', dtype=torch.long).repeat(batch_size)
            discrim_loss = ce_loss(predicted_sentiment, label)
            loss += discrim_loss

            ## LOGGING 
            ce_loss_logging = torch.nn.CrossEntropyLoss(reduction='none')
            loss_logging = ce_loss_logging(predicted_sentiment, label).detach().tolist()
            loss_per_iter.append(loss_logging)

        if args.loss_type == 4: # Enteiltment loss
            _ = model(torch.tensor([knowledge_to_ent], device='cuda', dtype=torch.long))
            hidden_p = model.hidden_states.repeat(batch_size,1,1) #torch.mean(model.hidden_states,dim=1).repeat(batch_size,1) 

            ce_loss = torch.nn.CrossEntropyLoss(reduction='sum')
            new_true_past = true_past
            for i in range(args.horizon_length):

                future_probabs = F.softmax(logits, dim=-1)  # Get softmax
                future_probabs = torch.unsqueeze(future_probabs, dim=1)

                _, new_true_past = model(future_probabs, past=new_true_past)
                future_hidden = model.hidden_states  # Get expected hidden states
                new_accumulated_hidden = new_accumulated_hidden + torch.sum(future_hidden, dim=1)

            if current_output.size(1)!=0: 
                hidden_h = torch.cat((current_output,future_hidden), dim=1)
            else:
                hidden_h = future_hidden

            predicted_NLI = classifier(hidden_p,hidden_h)

            label = torch.tensor([args.label_class], device='cuda', dtype=torch.long).repeat(batch_size)
            discrim_loss = ce_loss(predicted_NLI, label)
            loss += discrim_loss

            ## LOGGING 
            ce_loss_logging = torch.nn.CrossEntropyLoss(reduction='none')
            loss_per_iter.append(ce_loss_logging(predicted_NLI, label).detach().tolist())

        if args.loss_type == 5:
            bce_loss = torch.nn.BCEWithLogitsLoss(reduction='sum')
            new_true_past = true_past
            for i in range(args.horizon_length):

                future_probabs = F.softmax(logits, dim=-1)  # Get softmax
                future_probabs = torch.unsqueeze(future_probabs, dim=1)

                _, new_true_past = model(future_probabs, past=new_true_past)
                future_hidden = model.hidden_states  # Get expected hidden states
                new_accumulated_hidden = new_accumulated_hidden + torch.sum(future_hidden, dim=1)

            predicted_sentiment = classifier(new_accumulated_hidden / (current_length + 1 + args.horizon_length))

            label = torch.tensor([1], device='cuda', dtype=torch.float).repeat(batch_size)
            discrim_loss = bce_loss(predicted_sentiment, label.unsqueeze(-1))
            loss += discrim_loss

            ## LOGGING 
            bce_loss_logging = torch.nn.BCEWithLogitsLoss(reduction='none')
            loss_per_iter.append(bce_loss_logging(predicted_sentiment, label.unsqueeze(-1)).detach().tolist())

        kl_loss = 0.0
        if kl_scale > 0.0:
            p = (F.softmax(original_probs[:, -1, :], dim=-1))
            p = p + SmallConst * (p <= SmallConst).type(torch.FloatTensor).cuda().detach()
            correction = SmallConst * (probabs <= SmallConst).type(torch.FloatTensor).cuda().detach()
            corrected_probabs = probabs + correction.detach()
            kl_loss = kl_scale * ((corrected_probabs * (corrected_probabs / p).log()).sum())
         
            ## TODO
            # print(' kl_loss', (kl_loss).data.cpu().numpy())
            loss += kl_loss  # + discrim_loss
        
        ## TODO
        # print(f'pplm_loss {current_iter}', ((loss/batch_size) - kl_loss).data.cpu().numpy())
        # print(f'pplm_min_loss {current_iter}', min(loss_logging))
        # print()

        loss.backward(retain_graph=True)
        if grad_norms is not None and args.loss_type == 1:
            grad_norms = [torch.max(grad_norms[index], 
                torch.norm_except_dim(p_.grad * window_mask, dim=1)) 
                    for index, p_ in enumerate(past_perturb)]
        else:
            grad_norms = [(torch.norm_except_dim(p_.grad * window_mask, dim=1) + SmallConst) for index, p_ in enumerate(past_perturb)]

        grad = [
            -stepsize * (p_.grad * window_mask / grad_norms[index] ** args.gamma).data.cpu().numpy()
            for index, p_ in enumerate(past_perturb)]
        past_perturb_orig = list(map(add, grad, past_perturb_orig))

        for p_ in past_perturb:
            p_.grad.data.zero_()

        new_past = []
        for p in past:
            new_past.append(p.detach())

        past = new_past

    past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
    past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
    perturbed_past = list(map(add, past, past_perturb))

    return perturbed_past, new_accumulated_hidden, grad_norms, loss_per_iter
コード例 #7
0
ファイル: models.py プロジェクト: zeta1999/GTN
 def forward(self, x):
     x = super().forward(x)
     x = x * (self.weight_g / torch.norm_except_dim(self.weight, 2, 0)).transpose(1, 0)
     return x
コード例 #8
0
ファイル: models.py プロジェクト: zeta1999/GTN
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.weight_g = nn.Parameter(torch.norm_except_dim(self.weight, 2, 0).data)
コード例 #9
0
def l_norm(p, x, y):
    return torch.norm_except_dim(x - y, p)