示例#1
0
class LearnedNorm(DataDepInitModule):
    def __init__(self, shape, init_scale=1.0):
        super().__init__()
        self.init_scale = init_scale
        self.g = Parameter(torch.ones(*shape))
        self.b = Parameter(torch.zeros(*shape))

    def _init(self, x, *, inverse):
        assert not inverse
        assert x.shape[1:] == self.g.shape == self.b.shape
        m_init = x.mean(dim=0)
        scale_init = self.init_scale / (x.std(dim=0) + _SMALL)
        self.g.copy_(scale_init)
        self.b.copy_(-m_init * scale_init)
        return self._forward(x, inverse=inverse)

    def get_gain(self):
        return torch.clamp(self.g, min=1e-10)

    def _forward(self, x, *, inverse):
        """
        inverse == False to normalize; inverse == True to unnormalize
        """
        assert x.shape[1:] == self.g.shape == self.b.shape
        assert x.dtype == self.g.dtype == self.b.dtype
        g = self.get_gain()
        if not inverse:
            return x * g[None] + self.b[None]
        else:
            return (x - self.b[None]) / g[None]
示例#2
0
def ircn(w: nn.Parameter, scale_factor, init=None):
    """
    Initialization for convolution in upsampling block to prevent artifacts.
    Paper: https://arxiv.org/abs/1707.02937
    """

    # new_shape = [size // (scale_factor ** 2) if idx == 2 else size for idx, size in enumerate(w.shape)]
    new_shape = [
        size // (scale_factor**2) if idx == 1 else size
        for idx, size in enumerate(w.shape)
    ]
    x = zeros(new_shape)

    if init is not None:
        init(x)
    else:
        stdv = 1. / sqrt(x.size(1))
        x.data.uniform_(-stdv, stdv)

    # N H W C
    x = nn.functional.interpolate(x, scale_factor=scale_factor)

    # backwards PixelShuffle operation
    # source: https://github.com/pytorch/pytorch/issues/2456
    out_channel = x.shape[1] * (scale_factor**2)
    out_h, out_w = (x.shape[i] // scale_factor for i in [2, 3])
    x = x.contiguous().view(x.shape[0], x.shape[1], out_h, scale_factor, out_w,
                            scale_factor)
    x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(x.shape[0], out_channel,
                                                      out_h, out_w)

    with no_grad():
        w.copy_(x)
示例#3
0
class EmbeddingToWord(nn.Module):
    """
    This module is used to restore actual word from embedding.
    """
    def __init__(self, embedding_size, words_count):
        super(EmbeddingToWord, self).__init__()

        self.embedding_size = embedding_size
        self.words_count = words_count

        self.norm_weights = Parameter(
            torch.FloatTensor(self.words_count, self.embedding_size))

        self.norm_weights.requires_grad = False

    @classmethod
    def _norm_tensor(cls, vectors: torch.FloatTensor):
        """
        Normalize each vector in batch. Length of all vectors will be 1

        :param vectors: [batch_size, vector_length]
        :return:
        """
        return vectors / vectors.norm(p=2, dim=1, keepdim=True)

    def init_from_embeddings(self, embeddings: torch.FloatTensor):
        """

        :param embeddings: word2vec embeddings in shape: [word_count, embedding_size]
        :return:
        """
        self.norm_weights.copy_(self._norm_tensor(embeddings))

    def forward(self, vectors: torch.FloatTensor):
        """

        :param vectors: [batch_size x embedding_size]
        :return: [batch_size x word_count] in range (-1, 1)
        """

        # shape: [vector_size, batch_size]
        norm_vectors = self._norm_tensor(vectors).transpose(1, 0)

        # shape: [batch_size x word_count]
        return torch.mm(self.norm_weights, norm_vectors).transpose(0, 1)
示例#4
0
class Dense(DataDepInitModule):
    def __init__(self, in_features, out_features, init_scale=1.0):
        super().__init__()
        self.in_features, self.out_features, self.init_scale = in_features, out_features, init_scale

        self.w = Parameter(torch.Tensor(out_features, in_features))
        self.b = Parameter(torch.Tensor(out_features))

        init.normal_(self.w, 0, _WN_INIT_STDV)
        init.zeros_(self.b)

    def _init(self, x):
        y = self._forward(x)
        m = y.mean(dim=0)
        s = self.init_scale / (y.std(dim=0) + _SMALL)
        assert m.shape == s.shape == self.b.shape
        self.w.copy_(self.w * s[:, None])
        self.b.copy_(-m * s)
        return self._forward(x)

    def _forward(self, x):
        return F.linear(x, self.w, self.b[None, :])
示例#5
0
class Conv2d(DataDepInitModule):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 init_scale=1.0):
        super().__init__()
        self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.init_scale = \
            in_channels, out_channels, kernel_size, stride, padding, dilation, init_scale

        self.w = Parameter(
            torch.Tensor(out_channels, in_channels, self.kernel_size,
                         self.kernel_size))
        self.b = Parameter(torch.Tensor(out_channels))

        init.normal_(self.w, 0, _WN_INIT_STDV)
        init.ones_(self.b)

    def _init(self, x):
        # x.shape == (batch, channels, h, w)
        y = self._forward(x)  # (batch, out_channels, h, w)
        m = y.transpose(0, 1).reshape(y.shape[1],
                                      -1).mean(dim=1)  # out_channels
        s = self.init_scale / (y.transpose(0, 1).reshape(
            y.shape[1], -1).std(dim=1) + _SMALL)  # out_channels
        self.w.copy_(self.w *
                     s[:, None, None, None])  # (out, in, k, k) * (ou))
        self.b.copy_(-m * s)
        return self._forward(x)

    def _forward(self, x):
        return F.conv2d(x, self.w, self.b, self.stride, self.padding,
                        self.dilation, 1)
示例#6
0
class Third_Level_Agent(Second_Level_Agent):
    def __init__(self,
                 n_concepts,
                 n_actions,
                 concept_architecture,
                 second_level_architecture,
                 first_level_actor,
                 noop_action,
                 min_entropy_factor=0.1,
                 lr=1e-4,
                 lr_Alpha=1e-4,
                 entropy_update_rate=0.05,
                 target_update_rate=5e-3,
                 init_Epsilon=1.0,
                 delta_Epsilon=7.5e-4,
                 temporal_ratio=5):
        super().__init__(n_actions, second_level_architecture,
                         first_level_actor, noop_action, temporal_ratio)

        self.concept_architecture = concept_architecture
        freeze(self.concept_architecture)

        self.Q_table = Parameter(torch.Tensor(n_concepts, self._n_actions),
                                 requires_grad=False)
        nn.init.constant_(self.Q_table, 0.0)
        self.Q_target = Parameter(torch.Tensor(n_concepts, self._n_actions),
                                  requires_grad=False)
        nn.init.constant_(self.Q_target, 0.0)
        self.C_table = Parameter(torch.Tensor(n_concepts, self._n_actions),
                                 requires_grad=False)
        nn.init.constant_(self.C_table, 0.0)
        self.Pi_table = Parameter(torch.Tensor(n_concepts, self._n_actions),
                                  requires_grad=False)
        nn.init.constant_(self.Pi_table, 1.0 / self._n_actions)
        # self.Q_table2 = Parameter(torch.Tensor(n_concepts, self._n_actions))
        # nn.init.constant_(self.Q_table2, 0.0)
        # self.Q_table1_target = Parameter(torch.Tensor(n_concepts, self._n_actions), requires_grad=False)
        # nn.init.constant_(self.Q_table1_target, 0.0)
        # self.Q_table2_target = Parameter(torch.Tensor(n_concepts, self._n_actions), requires_grad=False)
        # nn.init.constant_(self.Q_table2_target, 0.0)
        self.log_Alpha = Parameter(torch.Tensor(1), requires_grad=False)
        nn.init.constant_(self.log_Alpha, 1.0)
        self.Epsilon = Parameter(torch.Tensor(1), requires_grad=False)
        nn.init.constant_(self.Epsilon, init_Epsilon)
        self.H_mean = Parameter(torch.Tensor(1), requires_grad=False)
        nn.init.constant_(self.H_mean, -1.0)

        self._n_concepts = n_concepts
        self.H_min = np.log(self._n_actions)
        self.min_Epsilon = min_entropy_factor
        self.delta_Epsilon = delta_Epsilon
        self.lr_Alpha = lr_Alpha
        self.entropy_update_rate = entropy_update_rate
        self.target_update_rate = target_update_rate

        # self.Q_optimizer = Adam([self.Q_table1, self.Q_table2], lr=lr)

    def save(self, save_path, best=False):
        if best:
            model_path = save_path + 'best_agent_3l_' + self._id
        else:
            model_path = save_path + 'last_agent_3l_' + self._id
        torch.save(self.state_dict(), model_path)

    def load(self, load_directory_path, model_id, device='cuda'):
        dev = torch.device(device)
        self.load_state_dict(
            torch.load(load_directory_path + 'agent_3l_' + model_id,
                       map_location=dev))

    # def PA_S(self, target=True):
    #     #Q_min = torch.min(self.Q_table1, self.Q_table2)
    #     if not target:
    #         Q = self.Q_table
    #     else:
    #         Q = self.Q_target
    #     Alpha = self.log_Alpha.exp().item()
    #     Z = torch.logsumexp(Q/(Alpha + 1e-6), dim=1, keepdim=True)
    #     log_PA_S = Q/Alpha - Z
    #     PA_S = log_PA_S.exp() + 1e-6
    #     PA_S = PA_S / PA_S.sum(1, keepdim=True)
    #     log_PA_S = torch.log(PA_S)
    #     return PA_S, log_PA_S

    def PA_S(self):
        PA_S = self.Pi_table
        log_PA_S = torch.log(PA_S)
        return PA_S, log_PA_S

    def update_Alpha(self, HA_S):  #, n_updates=1):
        # for update in range(0,n_updates):
        #     PA_S, log_PA_S = self.PA_S()
        #     HA_gS = -(PA_S * log_PA_S).sum(1)
        #     HA_S = (PS * HA_gS).sum()
        error = HA_S.item() - self.H_min * self.Epsilon
        new_log_Alpha = self.log_Alpha - self.lr_Alpha * error
        self.log_Alpha.copy_(new_log_Alpha)

        new_Epsilon = torch.max(
            self.Epsilon - self.delta_Epsilon,
            self.min_Epsilon * torch.ones_like(self.Epsilon))
        self.Epsilon.copy_(new_Epsilon)

    def update_mean_entropy(self, H):
        if self.H_mean < 0.0:
            self.H_mean.copy_(H.detach())
        else:
            H_new = self.H_mean * (
                1.0 - self.entropy_update_rate) + H * self.entropy_update_rate
            self.H_mean.copy_(H_new.detach())

    def update_Q(self, Q, C):
        self.Q_table.copy_(Q)
        self.C_table.copy_(C)

    def update_policy(self, Pi):
        self.Pi_table.copy_(Pi)

    def update_target(self, rate=1.0):
        Q_target = self.Q_target * (1. - rate) + self.Q_table * rate
        self.Q_target.copy_(Q_target)

    def sample_action_from_concept(self, state, explore=True):
        inner_state, outer_state = self.observe_second_level_state(state)
        with torch.no_grad():
            PS_s = self.concept_architecture(inner_state.view(1, -1),
                                             outer_state.unsqueeze(0))[0]
            concept = PS_s.argmax(1).item()
            dist = self.Pi_table[concept, :]
            if explore:
                action = Categorical(probs=dist).sample().item()
            else:
                tie_breaking_dist = torch.isclose(dist, dist.max()).float()
                tie_breaking_dist /= tie_breaking_dist.sum()
                action = Categorical(probs=tie_breaking_dist).sample().item()
            return action, dist.detach().cpu().numpy()
class MarkovFlow(nn.Module):
    def __init__(self, args, num_dims):
        super(MarkovFlow, self).__init__()

        self.args = args
        self.device = args.device

        # Gaussian Variance
        self.var = Parameter(torch.zeros(num_dims, dtype=torch.float32))

        if not args.train_var:
            self.var.requires_grad = False

        self.num_state = args.num_state
        self.num_dims = num_dims
        self.couple_layers = args.couple_layers
        self.cell_layers = args.cell_layers
        self.hidden_units = num_dims // 2
        self.lstm_hidden_units = self.num_dims

        # transition parameters in log space
        self.tparams = Parameter(torch.Tensor(self.num_state, self.num_state))

        self.prior_group = [self.tparams]

        # Gaussian means
        self.means = Parameter(torch.Tensor(self.num_state, self.num_dims))

        if args.mode == "unsupervised" and args.freeze_prior:
            self.tparams.requires_grad = False

        if args.mode == "unsupervised" and args.freeze_mean:
            self.means.requires_grad = False

        if args.model == 'nice':
            self.proj_layer = NICETrans(self.couple_layers, self.cell_layers,
                                        self.hidden_units, self.num_dims,
                                        self.device)
        elif args.model == "lstmnice":
            self.proj_layer = LSTMNICE(self.args.lstm_layers,
                                       self.args.couple_layers,
                                       self.args.cell_layers,
                                       self.lstm_hidden_units,
                                       self.hidden_units, self.num_dims,
                                       self.device)

        if args.mode == "unsupervised" and args.freeze_proj:
            for param in self.proj_layer.parameters():
                param.requires_grad = False

        if args.model == "gaussian":
            self.proj_group = [self.means, self.var]
        else:
            self.proj_group = list(
                self.proj_layer.parameters()) + [self.means, self.var]

        # prior
        self.pi = torch.zeros(self.num_state,
                              dtype=torch.float32,
                              requires_grad=False,
                              device=self.device).fill_(1.0 / self.num_state)

        self.pi = torch.log(self.pi)

    def init_params(self, train_data):
        """
        init_seed:(sents, masks)
        sents: (seq_length, batch_size, features)
        masks: (seq_length, batch_size)

        """

        # initialize transition matrix params
        # self.tparams.data.uniform_().add_(1)
        self.tparams.data.uniform_()

        # load pretrained model
        if self.args.load_nice != '':
            self.load_state_dict(torch.load(self.args.load_nice), strict=True)

            self.means_init = self.means.clone()
            self.tparams_init = self.tparams.clone()
            self.proj_init = [
                param.clone() for param in self.proj_layer.parameters()
            ]

            if self.args.init_var:
                self.init_var(train_data)

            if self.args.init_var_one:
                self.var.fill_(0.01)

            # self.means_init.requires_grad = False
            # self.tparams_init.requires_grad = False
            # for tensor in self.proj_init:
            #     tensor.requires_grad = False

            return

        # load pretrained Gaussian baseline
        if self.args.load_gaussian != '':
            self.load_state_dict(torch.load(self.args.load_gaussian),
                                 strict=False)

        # fully unsupervised training
        if self.args.mode == "unsupervised" and self.args.load_nice == "":
            with torch.no_grad():
                for iter_obj in train_data.data_iter(self.args.batch_size):
                    sents = iter_obj.embed
                    masks = iter_obj.mask
                    sents, _ = self.transform(sents, iter_obj.mask)
                    seq_length, _, features = sents.size()
                    flat_sents = sents.view(-1, features)
                    seed_mean = torch.sum(
                        masks.view(-1, 1).expand_as(flat_sents) * flat_sents,
                        dim=0) / masks.sum()
                    seed_var = torch.sum(
                        masks.view(-1, 1).expand_as(flat_sents) *
                        ((flat_sents - seed_mean.expand_as(flat_sents))**2),
                        dim=0) / masks.sum()
                    self.var.copy_(seed_var)
                    # self.var.fill_(0.02)

                    # add noise to the pretrained Gaussian mean
                    if self.args.load_gaussian != '' and self.args.model == 'nice':
                        self.means.data.add_(
                            seed_mean.data.expand_as(self.means.data))
                    elif self.args.load_gaussian == '' and self.args.load_nice == '':
                        self.means.data.normal_().mul_(0.04)
                        self.means.data.add_(
                            seed_mean.data.expand_as(self.means.data))

                    return

        self.init_mean(train_data)
        self.var.fill_(1.0)
        self.init_var(train_data)

        if self.args.init_var_one:
            self.var.fill_(1.0)

    def init_mean(self, train_data):
        emb_dict = {}
        cnt_dict = Counter()
        for iter_obj in train_data.data_iter(self.args.batch_size):
            sents_t = iter_obj.embed
            sents_t, _ = self.transform(sents_t, iter_obj.mask)
            sents_t = sents_t.transpose(0, 1)
            pos_t = iter_obj.pos.transpose(0, 1)
            mask_t = iter_obj.mask.transpose(0, 1)

            for emb_s, tagid_s, mask_s in zip(sents_t, pos_t, mask_t):
                for tagid, emb, mask in zip(tagid_s, emb_s, mask_s):
                    tagid = tagid.item()
                    mask = mask.item()
                    if tagid in emb_dict:
                        emb_dict[tagid] = emb_dict[tagid] + emb * mask
                    else:
                        emb_dict[tagid] = emb * mask

                    cnt_dict[tagid] += mask

        for tagid in emb_dict:
            self.means[tagid] = emb_dict[tagid] / cnt_dict[tagid]

    def init_var(self, train_data):
        cnt = 0
        mean_sum = 0.
        var_sum = 0.
        for iter_obj in train_data.data_iter(batch_size=self.args.batch_size):
            sents, masks = iter_obj.embed, iter_obj.mask
            sents, _ = self.transform(sents, masks)
            seq_length, _, features = sents.size()
            flat_sents = sents.view(-1, features)
            mean_sum = mean_sum + torch.sum(
                masks.view(-1, 1).expand_as(flat_sents) * flat_sents, dim=0)
            cnt += masks.sum().item()

        mean = mean_sum / cnt

        for iter_obj in train_data.data_iter(batch_size=self.args.batch_size):
            sents, masks = iter_obj.embed, iter_obj.mask
            sents, _ = self.transform(sents, masks)
            seq_length, _, features = sents.size()
            flat_sents = sents.view(-1, features)
            var_sum = var_sum + torch.sum(
                masks.view(-1, 1).expand_as(flat_sents) *
                ((flat_sents - mean.expand_as(flat_sents))**2),
                dim=0)
        var = var_sum / cnt
        self.var.copy_(var)

    def _calc_log_density_c(self):
        # return -self.num_dims/2.0 * (math.log(2) + \
        #         math.log(np.pi)) - 0.5 * self.num_dims * (torch.log(self.var))

        return -self.num_dims/2.0 * (math.log(2) + \
                math.log(np.pi)) - 0.5 * torch.sum(torch.log(self.var))

    def transform(self, x, masks=None):
        """
        Args:
            x: (sent_length, batch_size, num_dims)
        """
        jacobian_loss = torch.zeros(1, device=self.device, requires_grad=False)

        if self.args.model != 'gaussian':
            x, jacobian_loss_new = self.proj_layer(x, masks)
            jacobian_loss = jacobian_loss + jacobian_loss_new

        return x, jacobian_loss

    def MSE_loss(self):
        # diff1 = ((self.means - self.means_init) ** 2).sum()
        diff_prior = ((self.tparams - self.tparams_init)**2).sum()

        # diff = diff1 + diff2
        diff_proj = 0.

        for i, param in enumerate(self.proj_layer.parameters()):
            diff_proj = diff_proj + ((self.proj_init[i] - param)**2).sum()

        diff_mean = ((self.means_init - self.means)**2).sum()

        return 0.5 * (self.args.beta_prior * diff_prior + self.args.beta_proj *
                      diff_proj + self.args.beta_mean * diff_mean)

    def unsupervised_loss(self, sents, masks):
        """
        Args:
            sents: (sent_length, batch_size, self.num_dims)
            masks: (sent_length, batch_size)

        Returns: Tensor1, Tensor2
            Tensor1: negative log likelihood, shape ([])
            Tensor2: jacobian loss, shape ([])


        """
        max_length, batch_size, _ = sents.size()
        sents, jacobian_loss = self.transform(sents, masks)

        assert self.var.data.min() > 0

        self.logA = self._calc_logA()
        self.log_density_c = self._calc_log_density_c()

        alpha = self.pi + self._eval_density(sents[0])
        for t in range(1, max_length):
            density = self._eval_density(sents[t])
            mask_ep = masks[t].expand(self.num_state, batch_size) \
                      .transpose(0, 1)
            alpha = torch.mul(mask_ep,
                              self._forward_cell(alpha, density)) + \
                    torch.mul(1-mask_ep, alpha)

        # calculate objective from log space
        objective = torch.sum(log_sum_exp(alpha, dim=1))

        return -objective, jacobian_loss

    def supervised_loss(self, sents, tags, masks):
        """
        Args:
            sents: (sent_length, batch_size, num_dims)
            masks: (sent_length, batch_size)
            tags:  (sent_length, batch_size)

        Returns: Tensor1, Tensor2
            Tensor1: negative log likelihood, shape ([])
            Tensor2: jacobian loss, shape ([])

        """

        sent_len, batch_size, _ = sents.size()

        # (sent_length, batch_size, num_dims)
        sents, jacobian_loss = self.transform(sents, masks)

        # ()
        log_density_c = self._calc_log_density_c()

        # (1, 1, num_state, num_dims)
        means = self.means.view(1, 1, self.num_state, self.num_dims)
        means = means.expand(sent_len, batch_size, self.num_state,
                             self.num_dims)
        tag_id = tags.view(*tags.size(), 1, 1).expand(sent_len, batch_size, 1,
                                                      self.num_dims)

        # (sent_len, batch_size, num_dims)
        means = torch.gather(means, dim=2, index=tag_id).squeeze(2)

        var = self.var.view(1, 1, self.num_dims)

        # (sent_len, batch_size)
        log_emission_prob = log_density_c - \
                       0.5 * torch.sum((means-sents) ** 2 / var, dim=-1)

        log_emission_prob = torch.mul(masks, log_emission_prob).sum()

        # (num_state, num_state)
        log_trans = self._calc_logA()

        # (sent_len, batch_size, num_state, num_state)
        log_trans_prob = log_trans.view(1, 1, *log_trans.size()).expand(
            sent_len, batch_size, *log_trans.size())

        # (sent_len-1, batch_size, 1, num_state)
        tag_id = tags.view(*tags.size(), 1, 1).expand(sent_len, batch_size, 1,
                                                      self.num_state)[:-1]

        # (sent_len-1, batch_size, 1, num_state)
        log_trans_prob = torch.gather(log_trans_prob[:-1], dim=2, index=tag_id)

        # (sent_len-1, batch_size, 1, 1)
        tag_id = tags.view(*tags.size(), 1, 1)[1:]

        # (sent_len-1, batch_size)
        log_trans_prob = torch.gather(log_trans_prob, dim=3,
                                      index=tag_id).squeeze()

        log_trans_prob = torch.mul(masks[1:], log_trans_prob)

        log_trans_prior = self.pi.expand(batch_size, self.num_state)
        tag_id = tags[0].unsqueeze(dim=1)

        # (batch_size)
        log_trans_prior = torch.gather(log_trans_prior, dim=1,
                                       index=tag_id).sum()

        log_trans_prob = log_trans_prior + log_trans_prob.sum()

        return -(log_trans_prob + log_emission_prob), jacobian_loss

    def _calc_alpha(self, sents, masks):
        """
        sents: (sent_length, batch_size, self.num_dims)
        masks: (sent_length, batch_size)

        Returns:
            output: (batch_size, sent_length, num_state)

        """
        max_length, batch_size, _ = sents.size()

        alpha_all = []
        alpha = self.pi + self._eval_density(sents[0])
        alpha_all.append(alpha.unsqueeze(1))
        for t in range(1, max_length):
            density = self._eval_density(sents[t])
            mask_ep = masks[t].expand(self.num_state, batch_size) \
                      .transpose(0, 1)
            alpha = torch.mul(mask_ep, self._forward_cell(alpha, density)) + \
                    torch.mul(1-mask_ep, alpha)
            alpha_all.append(alpha.unsqueeze(1))

        return torch.cat(alpha_all, dim=1)

    def _forward_cell(self, alpha, density):
        batch_size = len(alpha)
        ep_size = torch.Size([batch_size, self.num_state, self.num_state])
        alpha = log_sum_exp(alpha.unsqueeze(dim=2).expand(ep_size) +
                            self.logA.expand(ep_size) +
                            density.unsqueeze(dim=1).expand(ep_size),
                            dim=1)

        return alpha

    def _backward_cell(self, beta, density):
        """
        density: (batch_size, num_state)
        beta: (batch_size, num_state)

        """
        batch_size = len(beta)
        ep_size = torch.Size([batch_size, self.num_state, self.num_state])
        beta = log_sum_exp(self.logA.expand(ep_size) +
                           density.unsqueeze(dim=1).expand(ep_size) +
                           beta.unsqueeze(dim=1).expand(ep_size),
                           dim=2)

        return beta

    def _eval_density(self, words):
        """
        Args:
            words: (batch_size, self.num_dims)

        Returns: Tensor1
            Tensor1: the density tensor with shape (batch_size, num_state)
        """

        batch_size = words.size(0)
        ep_size = torch.Size([batch_size, self.num_state, self.num_dims])
        words = words.unsqueeze(dim=1).expand(ep_size)
        means = self.means.expand(ep_size)
        var = self.var.expand(ep_size)

        return self.log_density_c - \
               0.5 * torch.sum((means-words) ** 2 / var, dim=2)

    def _calc_logA(self):
        return (self.tparams - \
                log_sum_exp(self.tparams, dim=1, keepdim=True) \
                .expand(self.num_state, self.num_state))

    def _calc_log_mul_emit(self):
        return self.emission - \
                log_sum_exp(self.emission, dim=1, keepdim=True) \
                .expand(self.num_state, self.vocab_size)

    def _viterbi(self, sents_var, masks):
        """
        Args:
            sents_var: (sent_length, batch_size, num_dims)
            masks: (sent_length, batch_size)
        """

        self.log_density_c = self._calc_log_density_c()
        self.logA = self._calc_logA()

        length, batch_size = masks.size()

        # (batch_size, num_state)
        delta = self.pi + self._eval_density(sents_var[0])

        ep_size = torch.Size([batch_size, self.num_state, self.num_state])
        index_all = []

        # forward calculate delta
        for t in range(1, length):
            density = self._eval_density(sents_var[t])
            delta_new = self.logA.expand(ep_size) + \
                    density.unsqueeze(dim=1).expand(ep_size) + \
                    delta.unsqueeze(dim=2).expand(ep_size)
            mask_ep = masks[t].view(-1, 1, 1).expand(ep_size)
            delta = mask_ep * delta_new + \
                    (1 - mask_ep) * delta.unsqueeze(dim=1).expand(ep_size)

            # index: (batch_size, num_state)
            delta, index = torch.max(delta, dim=1)
            index_all.append(index)

        assign_all = []
        # assign: (batch_size)
        _, assign = torch.max(delta, dim=1)
        assign_all.append(assign.unsqueeze(dim=1))

        # backward retrieve path
        # len(index_all) = length-1
        for t in range(length - 2, -1, -1):
            assign_new = torch.gather(index_all[t],
                                      dim=1,
                                      index=assign.view(-1, 1)).squeeze(dim=1)

            assign_new = assign_new.float()
            assign = assign.float()
            assign = masks[t + 1] * assign_new + (1 - masks[t + 1]) * assign
            assign = assign.long()

            assign_all.append(assign.unsqueeze(dim=1))

        assign_all = assign_all[-1::-1]

        return torch.cat(assign_all, dim=1)

    def test_supervised(self, test_data):
        """Evaluate tagging performance with
        token-level supervised accuracy

        Args:
            test_data: ConlluData object

        Returns: a scalar accuracy value

        """
        total = 0.0
        correct = 0.0

        index_all = []
        eval_tags = []

        for iter_obj in test_data.data_iter(batch_size=self.args.batch_size,
                                            shuffle=False):
            sents_t = iter_obj.embed
            masks = iter_obj.mask
            tags_t = iter_obj.pos

            sents_t, _ = self.transform(sents_t, masks)

            # index: (batch_size, seq_length)
            index = self._viterbi(sents_t, masks)

            for index_s, tag_s, mask_s in zip(index, tags_t.transpose(0, 1),
                                              masks.transpose(0, 1)):
                for i in range(int(mask_s.sum().item())):
                    if index_s[i].item() == tag_s[i].item():
                        correct += 1
                    total += 1

        return correct / total

    def test_unsupervised(self,
                          test_data,
                          sentences=None,
                          tagging=False,
                          path=None,
                          null_index=None):
        """Evaluate tagging performance with
        many-to-1 metric, VM score and 1-to-1
        accuracy

        Args:
            test_data: ConlluData object
            tagging: output the predicted tags if True
            path: The output tag file path
            null_index: the null element location in Penn
                        Treebank, only used for writing unsupervised
                        tags for downstream parsing task

        Returns:
            Tuple1: (M1, VM score, 1-to-1 accuracy)

        """

        total = 0.0
        correct = 0.0
        cnt_stats = {}
        match_dict = {}

        index_all = []
        eval_tags = []

        gold_vm = []
        model_vm = []

        for i in range(self.num_state):
            cnt_stats[i] = Counter()

        for iter_obj in test_data.data_iter(batch_size=self.args.batch_size,
                                            shuffle=False):
            total += iter_obj.mask.sum().item()
            sents_t = iter_obj.embed
            tags_t = iter_obj.pos
            masks = iter_obj.mask

            sents_t, _ = self.transform(sents_t, masks)

            # index: (batch_size, seq_length)
            index = self._viterbi(sents_t, masks)

            index_all += list(index)

            tags = [
                tags_t[:int(masks[:, i].sum().item()), i]
                for i in range(index.size(0))
            ]
            eval_tags += tags

            # count
            for (seq_gold_tags, seq_model_tags) in zip(tags, index):
                for (gold_tag, model_tag) in zip(seq_gold_tags,
                                                 seq_model_tags):
                    model_tag = model_tag.item()
                    gold_tag = gold_tag.item()
                    gold_vm += [gold_tag]
                    model_vm += [model_tag]
                    cnt_stats[model_tag][gold_tag] += 1

        # evaluate one-to-one accuracy
        cost_matrix = np.zeros((self.num_state, self.num_state))
        for i in range(self.num_state):
            for j in range(self.num_state):
                cost_matrix[i][j] = -cnt_stats[j][i]

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        for (seq_gold_tags, seq_model_tags) in zip(eval_tags, index_all):
            for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags):
                model_tag = model_tag.item()
                gold_tag = gold_tag.item()

                if col_ind[gold_tag] == model_tag:
                    correct += 1

        one2one = correct / total

        correct = 0.

        # match
        for tag in cnt_stats:
            if len(cnt_stats[tag]) != 0:
                match_dict[tag] = cnt_stats[tag].most_common(1)[0][0]

        # eval many2one
        for (seq_gold_tags, seq_model_tags) in zip(eval_tags, index_all):
            for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags):
                model_tag = model_tag.item()
                gold_tag = gold_tag.item()
                if match_dict[model_tag] == gold_tag:
                    correct += 1

        if tagging:
            write_conll(path, sentences, index_all, null_index)

        return correct / total, v_measure_score(gold_vm, model_vm), one2one