class WN_ConvTranspose2d(nn.ConvTranspose2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, train_scale=False, init_stdv=1.0):
        super(WN_ConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias)
        if train_scale:
            self.weight_scale = Parameter(torch.Tensor(out_channels))
        else:
            self.register_buffer('weight_scale', torch.Tensor(out_channels))
        
        self.train_scale = train_scale 
        self.init_mode = False
        self.init_stdv = init_stdv

        self._reset_parameters()

    def _reset_parameters(self):
        self.weight.data.normal_(std=0.05)
        if self.bias is not None:
            self.bias.data.zero_()
        if self.train_scale:
            self.weight_scale.data.fill_(1.)
        else:
            self.weight_scale.fill_(1.)

    def forward(self, input, output_size=None):
        if self.train_scale:
            weight_scale = self.weight_scale
        else:
            weight_scale = Variable(self.weight_scale)

        print(weight_scale[None, :, None, None].shape)
        print(self.weight.shape)
        print(torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(1) + 1e-6).shape)
        print((weight_scale[None, :, None, None] // torch.sqrt((self.weight ** 2).sum(3).sum(1).sum(0) + 1e-6)).shape)
        # normalize weight matrix and linear projection [in x out x h x w]
        # for each output dimension, normalize through (in, h, w)  = (0, 2, 3) dims
        norm_weight = self.weight * (weight_scale[None, :, None, None] / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(0) + 1e-6)).expand_as(
            self.weight)
        output_padding = self._output_padding(input, output_size)
        activation = F.conv_transpose2d(input, norm_weight, bias=None, 
                                        stride=self.stride, padding=self.padding, 
                                        output_padding=output_padding, groups=self.groups)

        if self.init_mode == True:
            mean_act = activation.mean(3).mean(2).mean(0).squeeze()
            activation = activation - mean_act[None,:,None,None].expand_as(activation)

            inv_stdv = self.init_stdv / torch.sqrt((activation ** 2).mean(3).mean(2).mean(0) + 1e-6).squeeze()
            activation = activation * inv_stdv[None,:,None,None].expand_as(activation)

            if self.train_scale:
                self.weight_scale.data = self.weight_scale.data * inv_stdv.data
            else:
                self.weight_scale = self.weight_scale * inv_stdv.data
            self.bias.data = - mean_act.data * inv_stdv.data

        else:
            if self.bias is not None:
                activation = activation + self.bias[None,:,None,None].expand_as(activation)

        return activation
Esempio n. 2
0
class WN_Linear(nn.Linear):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 train_scale=False,
                 init_stdv=1.0):
        super(WN_Linear, self).__init__(in_features, out_features, bias=bias)
        if train_scale:
            self.weight_scale = Parameter(torch.ones(self.out_features))
        else:
            self.register_buffer('weight_scale', torch.Tensor(out_features))

        self.train_scale = train_scale
        self.init_mode = False
        self.init_stdv = init_stdv

        self._reset_parameters()

    def _reset_parameters(self):
        self.weight.data.normal_(0, std=0.05)
        if self.bias is not None:
            self.bias.data.zero_()
        if self.train_scale:
            self.weight_scale.data.fill_(1.)
        else:
            self.weight_scale.fill_(1.)

    def forward(self, input):
        if self.train_scale:
            weight_scale = self.weight_scale
        else:
            weight_scale = Variable(self.weight_scale)

        # normalize weight matrix and linear projection
        # norm_weight = self.weight * (weight_scale.unsqueeze(1) / torch.sqrt((self.weight ** 2).sum(1) + 1e-6)).expand_as(self.weight)
        norm_weight = self.weight * (weight_scale / torch.sqrt(
            (self.weight**2).sum(1) + 1e-6)).unsqueeze(1).expand_as(
                self.weight)
        activation = F.linear(input, norm_weight)

        if self.init_mode == True:
            mean_act = activation.mean(0).squeeze(0)
            activation = activation - mean_act.expand_as(activation)

            inv_stdv = self.init_stdv / torch.sqrt((activation**2).mean(0) +
                                                   1e-6).squeeze(0)
            activation = activation * inv_stdv.expand_as(activation)

            if self.train_scale:
                self.weight_scale.data = self.weight_scale.data * inv_stdv.data
            else:
                self.weight_scale = self.weight_scale * inv_stdv.data
            self.bias.data = -mean_act.data * inv_stdv.data

        else:
            if self.bias is not None:
                activation = activation + self.bias.expand_as(activation)

        return activation
Esempio n. 3
0
class GroupNormMoving(nn.Module):
    def __init__(self, num_features, num_groups=32, eps=1e-5,
                 momentum=0.1, affine=True,
                 track_running_stats=True
                 ):
        super(GroupNormMoving, self).__init__()

        self.num_features = num_features
        self.num_groups = num_groups
        self.eps = eps

        self.momentum = momentum
        self.affine = affine

        self.track_running_stats = track_running_stats

        tensor_shape = (1, num_features, 1, 1)

        if self.affine:
            self.weight = Parameter(torch.Tensor(*tensor_shape))
            self.bias = Parameter(torch.Tensor(*tensor_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        if self.track_running_stats:
            #     self.register_buffer('running_mean', torch.zeros(*tensor_shape))
            #     self.register_buffer('running_var', torch.ones(*tensor_shape))
            # else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()

    def forward(self, x):
        N, C, H, W = x.size()
        G = self.num_groups
        assert C % G == 0, "Channel must be divided by groups"

        x = x.view(N, G, -1)
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True)

        if self.running_mean is None or self.running_mean.size() != mean.size():
            # self.running_mean = Parameter(torch.Tensor(mean.data.clone()))
            # self.running_var = Parameter(torch.Tensor(var.data.clone()))
            self.running_mean = Parameter(torch.Tensor(mean.data))
            self.running_var = Parameter(torch.Tensor(mean.data))

        if self.training and self.track_running_stats:
            self.running_mean.data = mean * self.momentum + \
                                     self.running_mean.data * (1 - self.momentum)
            self.running_var.data = var * self.momentum + \
                                    self.running_var.data * (1 - self.momentum)

        # mean = self.running_mean
        # var = self.running_var

        x = (x - self.running_mean) / (self.running_var + self.eps).sqrt()
        x = x.view(N, C, H, W)
        return x * self.weight + self.bias

    def reset_parameters(self):
        if self.track_running_stats:
            if self.running_mean is not None and self.running_var is not None:
                self.running_mean.zero_()
                self.running_var.fill_(1)
        if self.affine:
            self.weight.data.uniform_()
            self.bias.data.zero_()

    def __repr__(self):
        return ('{name}({num_features}, eps={eps}, momentum={momentum},'
                ' affine={affine}, track_running_stats={track_running_stats})'
                .format(name=self.__class__.__name__, **self.__dict__))
class MarkovFlow(nn.Module):
    def __init__(self, args, num_dims):
        super(MarkovFlow, self).__init__()

        self.args = args
        self.device = args.device

        # Gaussian Variance
        self.var = Parameter(torch.zeros(num_dims, dtype=torch.float32))

        if not args.train_var:
            self.var.requires_grad = False

        self.num_state = args.num_state
        self.num_dims = num_dims
        self.couple_layers = args.couple_layers
        self.cell_layers = args.cell_layers
        self.hidden_units = num_dims // 2
        self.lstm_hidden_units = self.num_dims

        # transition parameters in log space
        self.tparams = Parameter(torch.Tensor(self.num_state, self.num_state))

        self.prior_group = [self.tparams]

        # Gaussian means
        self.means = Parameter(torch.Tensor(self.num_state, self.num_dims))

        if args.mode == "unsupervised" and args.freeze_prior:
            self.tparams.requires_grad = False

        if args.mode == "unsupervised" and args.freeze_mean:
            self.means.requires_grad = False

        if args.model == 'nice':
            self.proj_layer = NICETrans(self.couple_layers, self.cell_layers,
                                        self.hidden_units, self.num_dims,
                                        self.device)
        elif args.model == "lstmnice":
            self.proj_layer = LSTMNICE(self.args.lstm_layers,
                                       self.args.couple_layers,
                                       self.args.cell_layers,
                                       self.lstm_hidden_units,
                                       self.hidden_units, self.num_dims,
                                       self.device)

        if args.mode == "unsupervised" and args.freeze_proj:
            for param in self.proj_layer.parameters():
                param.requires_grad = False

        if args.model == "gaussian":
            self.proj_group = [self.means, self.var]
        else:
            self.proj_group = list(
                self.proj_layer.parameters()) + [self.means, self.var]

        # prior
        self.pi = torch.zeros(self.num_state,
                              dtype=torch.float32,
                              requires_grad=False,
                              device=self.device).fill_(1.0 / self.num_state)

        self.pi = torch.log(self.pi)

    def init_params(self, train_data):
        """
        init_seed:(sents, masks)
        sents: (seq_length, batch_size, features)
        masks: (seq_length, batch_size)

        """

        # initialize transition matrix params
        # self.tparams.data.uniform_().add_(1)
        self.tparams.data.uniform_()

        # load pretrained model
        if self.args.load_nice != '':
            self.load_state_dict(torch.load(self.args.load_nice), strict=True)

            self.means_init = self.means.clone()
            self.tparams_init = self.tparams.clone()
            self.proj_init = [
                param.clone() for param in self.proj_layer.parameters()
            ]

            if self.args.init_var:
                self.init_var(train_data)

            if self.args.init_var_one:
                self.var.fill_(0.01)

            # self.means_init.requires_grad = False
            # self.tparams_init.requires_grad = False
            # for tensor in self.proj_init:
            #     tensor.requires_grad = False

            return

        # load pretrained Gaussian baseline
        if self.args.load_gaussian != '':
            self.load_state_dict(torch.load(self.args.load_gaussian),
                                 strict=False)

        # fully unsupervised training
        if self.args.mode == "unsupervised" and self.args.load_nice == "":
            with torch.no_grad():
                for iter_obj in train_data.data_iter(self.args.batch_size):
                    sents = iter_obj.embed
                    masks = iter_obj.mask
                    sents, _ = self.transform(sents, iter_obj.mask)
                    seq_length, _, features = sents.size()
                    flat_sents = sents.view(-1, features)
                    seed_mean = torch.sum(
                        masks.view(-1, 1).expand_as(flat_sents) * flat_sents,
                        dim=0) / masks.sum()
                    seed_var = torch.sum(
                        masks.view(-1, 1).expand_as(flat_sents) *
                        ((flat_sents - seed_mean.expand_as(flat_sents))**2),
                        dim=0) / masks.sum()
                    self.var.copy_(seed_var)
                    # self.var.fill_(0.02)

                    # add noise to the pretrained Gaussian mean
                    if self.args.load_gaussian != '' and self.args.model == 'nice':
                        self.means.data.add_(
                            seed_mean.data.expand_as(self.means.data))
                    elif self.args.load_gaussian == '' and self.args.load_nice == '':
                        self.means.data.normal_().mul_(0.04)
                        self.means.data.add_(
                            seed_mean.data.expand_as(self.means.data))

                    return

        self.init_mean(train_data)
        self.var.fill_(1.0)
        self.init_var(train_data)

        if self.args.init_var_one:
            self.var.fill_(1.0)

    def init_mean(self, train_data):
        emb_dict = {}
        cnt_dict = Counter()
        for iter_obj in train_data.data_iter(self.args.batch_size):
            sents_t = iter_obj.embed
            sents_t, _ = self.transform(sents_t, iter_obj.mask)
            sents_t = sents_t.transpose(0, 1)
            pos_t = iter_obj.pos.transpose(0, 1)
            mask_t = iter_obj.mask.transpose(0, 1)

            for emb_s, tagid_s, mask_s in zip(sents_t, pos_t, mask_t):
                for tagid, emb, mask in zip(tagid_s, emb_s, mask_s):
                    tagid = tagid.item()
                    mask = mask.item()
                    if tagid in emb_dict:
                        emb_dict[tagid] = emb_dict[tagid] + emb * mask
                    else:
                        emb_dict[tagid] = emb * mask

                    cnt_dict[tagid] += mask

        for tagid in emb_dict:
            self.means[tagid] = emb_dict[tagid] / cnt_dict[tagid]

    def init_var(self, train_data):
        cnt = 0
        mean_sum = 0.
        var_sum = 0.
        for iter_obj in train_data.data_iter(batch_size=self.args.batch_size):
            sents, masks = iter_obj.embed, iter_obj.mask
            sents, _ = self.transform(sents, masks)
            seq_length, _, features = sents.size()
            flat_sents = sents.view(-1, features)
            mean_sum = mean_sum + torch.sum(
                masks.view(-1, 1).expand_as(flat_sents) * flat_sents, dim=0)
            cnt += masks.sum().item()

        mean = mean_sum / cnt

        for iter_obj in train_data.data_iter(batch_size=self.args.batch_size):
            sents, masks = iter_obj.embed, iter_obj.mask
            sents, _ = self.transform(sents, masks)
            seq_length, _, features = sents.size()
            flat_sents = sents.view(-1, features)
            var_sum = var_sum + torch.sum(
                masks.view(-1, 1).expand_as(flat_sents) *
                ((flat_sents - mean.expand_as(flat_sents))**2),
                dim=0)
        var = var_sum / cnt
        self.var.copy_(var)

    def _calc_log_density_c(self):
        # return -self.num_dims/2.0 * (math.log(2) + \
        #         math.log(np.pi)) - 0.5 * self.num_dims * (torch.log(self.var))

        return -self.num_dims/2.0 * (math.log(2) + \
                math.log(np.pi)) - 0.5 * torch.sum(torch.log(self.var))

    def transform(self, x, masks=None):
        """
        Args:
            x: (sent_length, batch_size, num_dims)
        """
        jacobian_loss = torch.zeros(1, device=self.device, requires_grad=False)

        if self.args.model != 'gaussian':
            x, jacobian_loss_new = self.proj_layer(x, masks)
            jacobian_loss = jacobian_loss + jacobian_loss_new

        return x, jacobian_loss

    def MSE_loss(self):
        # diff1 = ((self.means - self.means_init) ** 2).sum()
        diff_prior = ((self.tparams - self.tparams_init)**2).sum()

        # diff = diff1 + diff2
        diff_proj = 0.

        for i, param in enumerate(self.proj_layer.parameters()):
            diff_proj = diff_proj + ((self.proj_init[i] - param)**2).sum()

        diff_mean = ((self.means_init - self.means)**2).sum()

        return 0.5 * (self.args.beta_prior * diff_prior + self.args.beta_proj *
                      diff_proj + self.args.beta_mean * diff_mean)

    def unsupervised_loss(self, sents, masks):
        """
        Args:
            sents: (sent_length, batch_size, self.num_dims)
            masks: (sent_length, batch_size)

        Returns: Tensor1, Tensor2
            Tensor1: negative log likelihood, shape ([])
            Tensor2: jacobian loss, shape ([])


        """
        max_length, batch_size, _ = sents.size()
        sents, jacobian_loss = self.transform(sents, masks)

        assert self.var.data.min() > 0

        self.logA = self._calc_logA()
        self.log_density_c = self._calc_log_density_c()

        alpha = self.pi + self._eval_density(sents[0])
        for t in range(1, max_length):
            density = self._eval_density(sents[t])
            mask_ep = masks[t].expand(self.num_state, batch_size) \
                      .transpose(0, 1)
            alpha = torch.mul(mask_ep,
                              self._forward_cell(alpha, density)) + \
                    torch.mul(1-mask_ep, alpha)

        # calculate objective from log space
        objective = torch.sum(log_sum_exp(alpha, dim=1))

        return -objective, jacobian_loss

    def supervised_loss(self, sents, tags, masks):
        """
        Args:
            sents: (sent_length, batch_size, num_dims)
            masks: (sent_length, batch_size)
            tags:  (sent_length, batch_size)

        Returns: Tensor1, Tensor2
            Tensor1: negative log likelihood, shape ([])
            Tensor2: jacobian loss, shape ([])

        """

        sent_len, batch_size, _ = sents.size()

        # (sent_length, batch_size, num_dims)
        sents, jacobian_loss = self.transform(sents, masks)

        # ()
        log_density_c = self._calc_log_density_c()

        # (1, 1, num_state, num_dims)
        means = self.means.view(1, 1, self.num_state, self.num_dims)
        means = means.expand(sent_len, batch_size, self.num_state,
                             self.num_dims)
        tag_id = tags.view(*tags.size(), 1, 1).expand(sent_len, batch_size, 1,
                                                      self.num_dims)

        # (sent_len, batch_size, num_dims)
        means = torch.gather(means, dim=2, index=tag_id).squeeze(2)

        var = self.var.view(1, 1, self.num_dims)

        # (sent_len, batch_size)
        log_emission_prob = log_density_c - \
                       0.5 * torch.sum((means-sents) ** 2 / var, dim=-1)

        log_emission_prob = torch.mul(masks, log_emission_prob).sum()

        # (num_state, num_state)
        log_trans = self._calc_logA()

        # (sent_len, batch_size, num_state, num_state)
        log_trans_prob = log_trans.view(1, 1, *log_trans.size()).expand(
            sent_len, batch_size, *log_trans.size())

        # (sent_len-1, batch_size, 1, num_state)
        tag_id = tags.view(*tags.size(), 1, 1).expand(sent_len, batch_size, 1,
                                                      self.num_state)[:-1]

        # (sent_len-1, batch_size, 1, num_state)
        log_trans_prob = torch.gather(log_trans_prob[:-1], dim=2, index=tag_id)

        # (sent_len-1, batch_size, 1, 1)
        tag_id = tags.view(*tags.size(), 1, 1)[1:]

        # (sent_len-1, batch_size)
        log_trans_prob = torch.gather(log_trans_prob, dim=3,
                                      index=tag_id).squeeze()

        log_trans_prob = torch.mul(masks[1:], log_trans_prob)

        log_trans_prior = self.pi.expand(batch_size, self.num_state)
        tag_id = tags[0].unsqueeze(dim=1)

        # (batch_size)
        log_trans_prior = torch.gather(log_trans_prior, dim=1,
                                       index=tag_id).sum()

        log_trans_prob = log_trans_prior + log_trans_prob.sum()

        return -(log_trans_prob + log_emission_prob), jacobian_loss

    def _calc_alpha(self, sents, masks):
        """
        sents: (sent_length, batch_size, self.num_dims)
        masks: (sent_length, batch_size)

        Returns:
            output: (batch_size, sent_length, num_state)

        """
        max_length, batch_size, _ = sents.size()

        alpha_all = []
        alpha = self.pi + self._eval_density(sents[0])
        alpha_all.append(alpha.unsqueeze(1))
        for t in range(1, max_length):
            density = self._eval_density(sents[t])
            mask_ep = masks[t].expand(self.num_state, batch_size) \
                      .transpose(0, 1)
            alpha = torch.mul(mask_ep, self._forward_cell(alpha, density)) + \
                    torch.mul(1-mask_ep, alpha)
            alpha_all.append(alpha.unsqueeze(1))

        return torch.cat(alpha_all, dim=1)

    def _forward_cell(self, alpha, density):
        batch_size = len(alpha)
        ep_size = torch.Size([batch_size, self.num_state, self.num_state])
        alpha = log_sum_exp(alpha.unsqueeze(dim=2).expand(ep_size) +
                            self.logA.expand(ep_size) +
                            density.unsqueeze(dim=1).expand(ep_size),
                            dim=1)

        return alpha

    def _backward_cell(self, beta, density):
        """
        density: (batch_size, num_state)
        beta: (batch_size, num_state)

        """
        batch_size = len(beta)
        ep_size = torch.Size([batch_size, self.num_state, self.num_state])
        beta = log_sum_exp(self.logA.expand(ep_size) +
                           density.unsqueeze(dim=1).expand(ep_size) +
                           beta.unsqueeze(dim=1).expand(ep_size),
                           dim=2)

        return beta

    def _eval_density(self, words):
        """
        Args:
            words: (batch_size, self.num_dims)

        Returns: Tensor1
            Tensor1: the density tensor with shape (batch_size, num_state)
        """

        batch_size = words.size(0)
        ep_size = torch.Size([batch_size, self.num_state, self.num_dims])
        words = words.unsqueeze(dim=1).expand(ep_size)
        means = self.means.expand(ep_size)
        var = self.var.expand(ep_size)

        return self.log_density_c - \
               0.5 * torch.sum((means-words) ** 2 / var, dim=2)

    def _calc_logA(self):
        return (self.tparams - \
                log_sum_exp(self.tparams, dim=1, keepdim=True) \
                .expand(self.num_state, self.num_state))

    def _calc_log_mul_emit(self):
        return self.emission - \
                log_sum_exp(self.emission, dim=1, keepdim=True) \
                .expand(self.num_state, self.vocab_size)

    def _viterbi(self, sents_var, masks):
        """
        Args:
            sents_var: (sent_length, batch_size, num_dims)
            masks: (sent_length, batch_size)
        """

        self.log_density_c = self._calc_log_density_c()
        self.logA = self._calc_logA()

        length, batch_size = masks.size()

        # (batch_size, num_state)
        delta = self.pi + self._eval_density(sents_var[0])

        ep_size = torch.Size([batch_size, self.num_state, self.num_state])
        index_all = []

        # forward calculate delta
        for t in range(1, length):
            density = self._eval_density(sents_var[t])
            delta_new = self.logA.expand(ep_size) + \
                    density.unsqueeze(dim=1).expand(ep_size) + \
                    delta.unsqueeze(dim=2).expand(ep_size)
            mask_ep = masks[t].view(-1, 1, 1).expand(ep_size)
            delta = mask_ep * delta_new + \
                    (1 - mask_ep) * delta.unsqueeze(dim=1).expand(ep_size)

            # index: (batch_size, num_state)
            delta, index = torch.max(delta, dim=1)
            index_all.append(index)

        assign_all = []
        # assign: (batch_size)
        _, assign = torch.max(delta, dim=1)
        assign_all.append(assign.unsqueeze(dim=1))

        # backward retrieve path
        # len(index_all) = length-1
        for t in range(length - 2, -1, -1):
            assign_new = torch.gather(index_all[t],
                                      dim=1,
                                      index=assign.view(-1, 1)).squeeze(dim=1)

            assign_new = assign_new.float()
            assign = assign.float()
            assign = masks[t + 1] * assign_new + (1 - masks[t + 1]) * assign
            assign = assign.long()

            assign_all.append(assign.unsqueeze(dim=1))

        assign_all = assign_all[-1::-1]

        return torch.cat(assign_all, dim=1)

    def test_supervised(self, test_data):
        """Evaluate tagging performance with
        token-level supervised accuracy

        Args:
            test_data: ConlluData object

        Returns: a scalar accuracy value

        """
        total = 0.0
        correct = 0.0

        index_all = []
        eval_tags = []

        for iter_obj in test_data.data_iter(batch_size=self.args.batch_size,
                                            shuffle=False):
            sents_t = iter_obj.embed
            masks = iter_obj.mask
            tags_t = iter_obj.pos

            sents_t, _ = self.transform(sents_t, masks)

            # index: (batch_size, seq_length)
            index = self._viterbi(sents_t, masks)

            for index_s, tag_s, mask_s in zip(index, tags_t.transpose(0, 1),
                                              masks.transpose(0, 1)):
                for i in range(int(mask_s.sum().item())):
                    if index_s[i].item() == tag_s[i].item():
                        correct += 1
                    total += 1

        return correct / total

    def test_unsupervised(self,
                          test_data,
                          sentences=None,
                          tagging=False,
                          path=None,
                          null_index=None):
        """Evaluate tagging performance with
        many-to-1 metric, VM score and 1-to-1
        accuracy

        Args:
            test_data: ConlluData object
            tagging: output the predicted tags if True
            path: The output tag file path
            null_index: the null element location in Penn
                        Treebank, only used for writing unsupervised
                        tags for downstream parsing task

        Returns:
            Tuple1: (M1, VM score, 1-to-1 accuracy)

        """

        total = 0.0
        correct = 0.0
        cnt_stats = {}
        match_dict = {}

        index_all = []
        eval_tags = []

        gold_vm = []
        model_vm = []

        for i in range(self.num_state):
            cnt_stats[i] = Counter()

        for iter_obj in test_data.data_iter(batch_size=self.args.batch_size,
                                            shuffle=False):
            total += iter_obj.mask.sum().item()
            sents_t = iter_obj.embed
            tags_t = iter_obj.pos
            masks = iter_obj.mask

            sents_t, _ = self.transform(sents_t, masks)

            # index: (batch_size, seq_length)
            index = self._viterbi(sents_t, masks)

            index_all += list(index)

            tags = [
                tags_t[:int(masks[:, i].sum().item()), i]
                for i in range(index.size(0))
            ]
            eval_tags += tags

            # count
            for (seq_gold_tags, seq_model_tags) in zip(tags, index):
                for (gold_tag, model_tag) in zip(seq_gold_tags,
                                                 seq_model_tags):
                    model_tag = model_tag.item()
                    gold_tag = gold_tag.item()
                    gold_vm += [gold_tag]
                    model_vm += [model_tag]
                    cnt_stats[model_tag][gold_tag] += 1

        # evaluate one-to-one accuracy
        cost_matrix = np.zeros((self.num_state, self.num_state))
        for i in range(self.num_state):
            for j in range(self.num_state):
                cost_matrix[i][j] = -cnt_stats[j][i]

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        for (seq_gold_tags, seq_model_tags) in zip(eval_tags, index_all):
            for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags):
                model_tag = model_tag.item()
                gold_tag = gold_tag.item()

                if col_ind[gold_tag] == model_tag:
                    correct += 1

        one2one = correct / total

        correct = 0.

        # match
        for tag in cnt_stats:
            if len(cnt_stats[tag]) != 0:
                match_dict[tag] = cnt_stats[tag].most_common(1)[0][0]

        # eval many2one
        for (seq_gold_tags, seq_model_tags) in zip(eval_tags, index_all):
            for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags):
                model_tag = model_tag.item()
                gold_tag = gold_tag.item()
                if match_dict[model_tag] == gold_tag:
                    correct += 1

        if tagging:
            write_conll(path, sentences, index_all, null_index)

        return correct / total, v_measure_score(gold_vm, model_vm), one2one
class WN_Conv3d(nn.Conv3d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 train_scale=False,
                 init_stdv=1.0):
        super(WN_Conv3d,
              self).__init__(in_channels, out_channels, kernel_size, stride,
                             padding, dilation, groups, bias)
        if train_scale:
            self.weight_scale = Parameter(torch.Tensor(out_channels))
        else:
            self.register_buffer('weight_scale', torch.Tensor(out_channels))

        self.train_scale = train_scale
        self.init_mode = False
        self.init_stdv = init_stdv
        self.kernel_size = kernel_size

        self._reset_parameters()

    def _reset_parameters(self):
        self.weight.data.normal_(std=0.05)
        if self.bias is not None:
            self.bias.data.zero_()
        if self.train_scale:
            self.weight_scale.data.fill_(1.)
        else:
            self.weight_scale.fill_(1.)

    def forward(self, input):
        if self.train_scale:
            weight_scale = self.weight_scale
        else:
            weight_scale = Variable(self.weight_scale)
        # normalize weight matrix and linear projection [out x in x h x w x z]
        # for each output dimension, normalize through (in, h, w, z) = (1, 2, 3, 4) dims

        # This is done to ensure padding as "SAME"
        pad_h = math.ceil((self.kernel_size[0] - input.shape[2] *
                           (1 - self.stride[0]) - self.stride[0]) / 2)
        pad_w = math.ceil((self.kernel_size[1] - input.shape[3] *
                           (1 - self.stride[1]) - self.stride[1]) / 2)
        pad_z = math.ceil((self.kernel_size[2] - input.shape[4] *
                           (1 - self.stride[2]) - self.stride[2]) / 2)
        padding = (pad_h, pad_w, pad_z)

        norm_weight = self.weight * (
            weight_scale[:, None, None, None, None] /
            torch.sqrt((self.weight**2).sum(4).sum(3).sum(2).sum(1) +
                       1e-6).reshape([-1, 1, 1, 1, 1])).expand_as(self.weight)
        activation = F.conv3d(input,
                              norm_weight,
                              bias=None,
                              stride=self.stride,
                              padding=padding,
                              dilation=self.dilation,
                              groups=self.groups)

        if self.init_mode == True:
            mean_act = activation.mean(4).mean(3).mean(2).mean(0).squeeze()
            activation = activation - mean_act[None, :, None,
                                               None].expand_as(activation)

            inv_stdv = self.init_stdv / torch.sqrt(
                (activation**2).mean(4).mean(3).mean(2).mean(0) +
                1e-6).squeeze()
            activation = activation * inv_stdv[None, :, None, None,
                                               None].expand_as(activation)

            if self.train_scale:
                self.weight_scale.data = self.weight_scale.data * inv_stdv.data
            else:
                self.weight_scale = self.weight_scale * inv_stdv.data
            self.bias.data = -mean_act.data * inv_stdv.data

        else:
            if self.bias is not None:
                activation = activation + self.bias[None, :, None, None,
                                                    None].expand_as(activation)

        return activation
class WN_Conv2d(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 train_scale=False,
                 init_stdv=1.0,
                 momentum=0.999):
        super(WN_Conv2d,
              self).__init__(in_channels, out_channels, kernel_size, stride,
                             padding, dilation, groups, bias)
        if train_scale:
            self.g = Parameter(torch.ones(out_channels))
        else:
            self.register_buffer('g', torch.ones(out_channels))

        self.train_scale = train_scale
        self.init_stdv = init_stdv

        self.has_init = False

        self.register_buffer('avg_mean', torch.zeros(out_channels))
        self.momentum = momentum

        self._reset_parameters()

    def _reset_parameters(self):
        self.weight.data.normal_(std=0.05)
        if self.bias is not None:
            self.bias.data.zero_()
        if self.train_scale:
            self.g.data.fill_(1.)
        else:
            self.g.fill_(1.)
        self.has_init = False

    def forward(self, input, moving_average=True, init_mode=False):
        if self.train_scale:
            g = self.g
        else:
            g = Variable(self.g)
        # normalize weight matrix and linear projection [out x in x h x w]
        # for each output dimension, normalize through (in, h, w) = (1, 2, 3) dims
        norm_weight = self.weight * (g[:, None, None, None] / torch.sqrt(
            (self.weight**2).sum(3).sum(2).sum(1)).view(
                -1, 1, 1, 1)).expand_as(self.weight)
        activation = F.conv2d(input,
                              norm_weight,
                              bias=None,
                              stride=self.stride,
                              padding=self.padding,
                              dilation=self.dilation,
                              groups=self.groups)

        if self.training:
            mean_act = activation.mean(3).mean(2).mean(0).squeeze()
            activation = activation - mean_act[None, :, None,
                                               None].expand_as(activation)

            if init_mode or self.has_init == False:
                inv_stdv = self.init_stdv / torch.sqrt(
                    (activation**2).mean(3).mean(2).mean(0)).squeeze()
                activation = activation * inv_stdv[None, :, None,
                                                   None].expand_as(activation)

                if self.train_scale:
                    self.g.data = self.g.data * inv_stdv.data
                else:
                    self.g = self.g * inv_stdv.data

                self.has_init = True
            elif moving_average:
                self.avg_mean.mul_(self.momentum).add_(1. - self.momentum,
                                                       mean_act.data)
        else:
            avg_mean = Variable(self.avg_mean)
            assert avg_mean.requires_grad == False
            activation = activation - avg_mean[None, :, None,
                                               None].expand_as(activation)

        if self.bias is not None:
            activation = activation + self.bias[None, :, None,
                                                None].expand_as(activation)

        return activation
class WN_Linear(nn.Linear):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 train_scale=True,
                 init_stdv=1.0,
                 momentum=0.999):
        super(WN_Linear, self).__init__(in_features, out_features, bias=bias)
        if train_scale:
            self.g = Parameter(torch.ones(self.out_features))
        else:
            self.register_buffer('g', torch.ones(out_features))

        self.train_scale = train_scale
        self.init_stdv = init_stdv
        self.has_init = False
        self.register_buffer('avg_mean', torch.zeros(out_features))
        self.momentum = momentum
        self._reset_parameters()

    def _reset_parameters(self):
        self.weight.data.normal_(0, std=0.05)
        if self.bias is not None:
            self.bias.data.zero_()
        if self.train_scale:
            self.g.data.fill_(1.)
        else:
            self.g.fill_(1.)
        self.has_init = False

    def forward(self, input, moving_average=True, init_mode=False):
        assert self.avg_mean.requires_grad == False
        if self.train_scale:
            g = self.g
        else:
            g = Variable(self.g)

        # normalize weight matrix and linear projection
        norm_weight = self.weight * (g.unsqueeze(1) / torch.sqrt(
            torch.sum(
                (self.weight**2), dim=1, keepdim=True))).expand_as(self.weight)
        activation = F.linear(input, norm_weight)

        if self.training:
            mean_act = activation.mean(0).squeeze(0)
            activation = activation - mean_act.expand_as(activation)
            if init_mode or self.has_init == False:
                inv_stdv = self.init_stdv / torch.sqrt(
                    (activation**2).mean(0)).squeeze(0)
                activation = activation * inv_stdv.expand_as(activation)

                if self.train_scale:
                    self.g.data = self.g.data * inv_stdv.data
                else:
                    self.g = self.g * inv_stdv.data
                self.has_init = True
            elif moving_average:
                self.avg_mean.mul_(self.momentum).add_(1 - self.momentum,
                                                       mean_act.data)
        else:
            avg_mean = Variable(self.avg_mean)
            assert avg_mean.requires_grad == False
            activation = activation - avg_mean.expand_as(activation)

        if self.bias is not None:
            activation = activation + self.bias.expand_as(activation)

        return activation
class WN_Conv2d_Mean_Only_BN(nn.Conv2d):
    """Weight norm combined with mean-only batch norm for 2d ConvNet"""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
                 train_scale=False, init_stdv=1.0, bn_momentum=0.001):
        super(WN_Conv2d_Mean_Only_BN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        if train_scale:
            self.weight_scale = Parameter(torch.Tensor(out_channels))
        else:
            self.register_buffer('weight_scale', torch.Tensor(out_channels))

        self.train_scale = train_scale
        self.init_mode = False
        self.init_stdv = init_stdv

        # mean-only batch norm params
        self.register_buffer('running_mean', torch.zeros(out_channels))
        self.bn_momentum = bn_momentum

        self._reset_parameters()

    def _reset_parameters(self):
        self.weight.data.normal_(std=0.05)
        if self.bias is not None:
            self.bias.data.zero_()
        if self.train_scale:
            self.weight_scale.data.fill_(1.)
        else:
            self.weight_scale.fill_(1.)
        self.running_mean.zero_()

    def forward(self, input):
        if self.train_scale:
            weight_scale = self.weight_scale
        else:
            weight_scale = Variable(self.weight_scale)
        # normalize weight matrix and linear projection [out x in x h x w]
        # for each output dimension, normalize through (in, h, w) = (1, 2, 3) dims
        norm_weight = self.weight * (weight_scale / torch.sqrt((self.weight ** 2).sum(3).sum(2).sum(1) + 1e-8))\
            .unsqueeze(1).unsqueeze(2).unsqueeze(3)
        activation = F.conv2d(input, norm_weight, bias=None,
                              stride=self.stride, padding=self.padding,
                              dilation=self.dilation, groups=self.groups)

        if self.init_mode == True:
            mean_act = activation.mean(3).mean(2).mean(0).squeeze()
            activation = activation - mean_act.unsqueeze(0).unsqueeze(2).unsqueeze(3)

            inv_stdv = self.init_stdv / torch.sqrt((activation ** 2).mean(3).mean(2).mean(0) + 1e-8).squeeze()
            activation = activation * inv_stdv.unsqueeze(0).unsqueeze(2).unsqueeze(3)

            if self.train_scale:
                self.weight_scale.data = self.weight_scale.data * inv_stdv.data
            else:
                self.weight_scale = self.weight_scale * inv_stdv.data
            self.bias.data = - mean_act.data * inv_stdv.data

        else:
            training_mean = activation.mean(3).mean(2).mean(0).squeeze()
            if self.training:
                mean = training_mean
                self.running_mean = self.running_mean * (1 - self.bn_momentum) + training_mean.data * self.bn_momentum
            else:
                mean = Variable(self.running_mean)

            activation = activation - mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)

            if self.bias is not None:
                activation = activation + self.bias[None, :, None, None].expand_as(activation)

        return activation
class WN_Linear_Mean_Only_BN(nn.Linear):
    """Weight norm combined with mean-only batch norm for linear layer"""
    def __init__(self, in_features, out_features, bias=True, train_scale=False, init_stdv=1.0, bn_momentum=0.001):
        super(WN_Linear_Mean_Only_BN, self).__init__(in_features, out_features, bias=bias)
        if train_scale:
            self.weight_scale = Parameter(torch.ones(self.out_features))
        else:
            self.register_buffer('weight_scale', torch.Tensor(out_features))

        self.train_scale = train_scale
        self.init_mode = False
        self.init_stdv = init_stdv

        # mean-only batch norm params
        self.register_buffer('running_mean', torch.zeros(out_features))
        self.bn_momentum = bn_momentum

        self._reset_parameters()

    def _reset_parameters(self):
        self.weight.data.normal_(0, std=0.05)
        if self.bias is not None:
            self.bias.data.zero_()
        if self.train_scale:
            self.weight_scale.data.fill_(1.)
        else:
            self.weight_scale.fill_(1.)
        self.running_mean.zero_()

    def forward(self, input):
        if self.train_scale:
            weight_scale = self.weight_scale
        else:
            weight_scale = Variable(self.weight_scale)

        # normalize weight matrix and linear projection
        norm_weight = self.weight * (
        weight_scale / torch.sqrt((self.weight ** 2).sum(1) + 1e-8)).unsqueeze(1)
        activation = F.linear(input, norm_weight)

        if self.init_mode == True:
            mean_act = activation.mean(0).squeeze(0)
            activation = activation - mean_act.unsqueeze(0)

            inv_stdv = self.init_stdv / torch.sqrt((activation ** 2).mean(0) + 1e-8).squeeze(0)
            activation = activation * inv_stdv.unsqueeze(0)

            if self.train_scale:
                self.weight_scale.data = self.weight_scale.data * inv_stdv.data
            else:
                self.weight_scale = self.weight_scale * inv_stdv.data
            self.bias.data = - mean_act.data * inv_stdv.data
        else:
            training_mean = activation.mean(0).squeeze(0)
            if self.training:
                mean = training_mean
                self.running_mean = self.running_mean * (1 - self.bn_momentum) + training_mean.data * self.bn_momentum
            else:
                mean = Variable(self.running_mean)

            activation = activation - mean.unsqueeze(0)

            if self.bias is not None:
                activation = activation + self.bias.expand_as(activation)

        return activation