示例#1
0
 def call_Encoder(self,nlayer,\
                  x_lt,R_lt,first,epoch):
     E_lt = None
     Encode_lt = None
     # return hidden layers size:
     h_l_down_in,  h_l_top_out,\
     h_l_down_out, h_Elt  = self.hidden_layers_selctor(nlayer)
     Encode_lt            = Encoder([h_l_down_in,h_l_down_out],\
                                    h_l_down_out,self.kernel_size,\
                                    self.image_size).cuda()
     if first is True:
         E_lt = Encode_lt(x_lt, R_lt, True)
     else:
         E_lt = Encode_lt(x_lt, R_lt, False)
     if self.saveModel == True:
         if epoch % self.numSaveIter == 0:
             self.save_models(Encode_lt, epoch, "Encoder")
     return E_lt, Encode_lt.parameters()
# training: loss functions, learning rates, parameter optimization

#Reconstruction phase parameter
# define the optimization criterion / loss function
reconstruction_criterion_categorical = nn.BCELoss(reduction='mean')
reconstruction_criterion_numeric = nn.MSELoss(reduction='mean')
# push to cuda if cudnn is available
if (torch.backends.cudnn.version() != None and USE_CUDA == True):
    reconstruction_criterion_categorical = reconstruction_criterion_categorical.cuda(
    )
    reconstruction_criterion_numeric = reconstruction_criterion_numeric.cuda()
# define encoder and decoded learning rate
learning_rate_enc = 1e-3
learning_rate_dec = 1e-3
# define encoder and decoder optimization strategy
encoder_optimizer = optim.Adam(encoder_train.parameters(),
                               lr=learning_rate_enc)
decoder_optimizer = optim.Adam(decoder_train.parameters(),
                               lr=learning_rate_dec)

# Regularization phase parameter
# init the discriminator losses
discriminator_criterion = nn.BCELoss()
# push to cuda if cudnn is available
if (torch.backends.cudnn.version() != None and USE_CUDA == True):
    discriminator_criterion = discriminator_criterion.cuda()
# define generator and discriminator learning rate
learning_rate_dis_z = 1e-5
# define generator and discriminator optimization strategy
discriminator_optimizer = optim.Adam(discriminator_train.parameters(),
                                     lr=learning_rate_dis_z)
示例#3
0
    print('test word size:', text_alpha.m_size)
    print('label word size:', label_alpha.m_size)
    # print(label_alpha.id2string)
    '''
        seqs to id
    '''
    text_id_list = seq2id(text_alpha, text_sent_list)
    label_id_list = seq2id(label_alpha, label_sent_list)

    encoder = Encoder(text_alpha.m_size, config)
    decoder = AttnDecoderRNN(label_alpha.m_size, config)

    # print(encoder)
    # print(decoder)
    lr = config.lr
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=lr)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=lr)
    criterion = nn.NLLLoss()

    n_epochs = 1000
    plot_every = 200
    print_every = 10

    start = time.time()
    plot_losses = []
    print_loss_total = 0
    plot_loss_total = 0
    '''
        start...
    '''
    for epoch in range(n_epochs):
示例#4
0
class Mem2SeqRunner(ExperimentRunnerBase):
    def __init__(self, args):
        super(Mem2SeqRunner, self).__init__(args)

        # Model parameters
        self.gru_size = 128
        self.emb_size = 128
        #TODO: Try hops 4 with task 3
        self.hops = 3
        self.dropout = 0.2

        self.encoder = Encoder(self.hops, self.nwords, self.gru_size)
        self.decoder = Decoder(self.emb_size, self.hops, self.gru_size,
                               self.nwords)

        self.optim_enc = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
        self.optim_dec = torch.optim.Adam(self.decoder.parameters(), lr=0.001)
        if self.loss_weighting:
            self.optim_loss_weights = torch.optim.Adam([self.loss_weights],
                                                       lr=0.0001)
        self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim_dec,
                                                        mode='max',
                                                        factor=0.5,
                                                        patience=1,
                                                        min_lr=0.0001,
                                                        verbose=True)

        if self.use_cuda:
            self.cross_entropy = self.cross_entropy.cuda()
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            if self.loss_weighting:
                self.loss_weights = self.loss_weights.cuda()

    def train_batch_wrapper(self, batch, new_epoch, clip_grads):
        context = batch[0].transpose(0, 1)
        responses = batch[1].transpose(0, 1)
        index = batch[2].transpose(0, 1)
        sentinel = batch[3].transpose(0, 1)
        context_lengths = batch[4]
        target_lengths = batch[5]
        return self.train_batch(context, responses, index, sentinel, new_epoch,
                                context_lengths, target_lengths, clip_grads)

    def train_batch(self, context, responses, index, sentinel, new_epoch,
                    context_lengths, target_lengths, clip_grads):

        # (TODO): remove transpose
        if new_epoch:  # (TODO): Change this part
            self.loss = 0
            self.ploss = 0
            self.vloss = 0
            self.n = 1

        context = context.type(self.TYPE)
        responses = responses.type(self.TYPE)
        index = index.type(self.TYPE)
        sentinel = sentinel.type(self.TYPE)

        self.optim_enc.zero_grad()
        self.optim_dec.zero_grad()
        if self.loss_weighting:
            self.optim_loss_weights.zero_grad()

        h = self.encoder(context.transpose(0, 1))
        self.decoder.load_memory(context.transpose(0, 1))
        y = torch.from_numpy(np.array([2] * context.size(1),
                                      dtype=int)).type(self.TYPE)
        y_len = 0

        h = h.unsqueeze(0)
        output_vocab = torch.zeros(max(target_lengths), context.size(1),
                                   self.nwords)
        output_ptr = torch.zeros(max(target_lengths), context.size(1),
                                 context.size(0))
        if self.use_cuda:
            output_vocab = output_vocab.cuda()
            output_ptr = output_ptr.cuda()
        while y_len < responses.size(0):  # TODO: Add EOS condition
            p_ptr, p_vocab, h = self.decoder(context, y, h)
            output_vocab[y_len] = p_vocab
            output_ptr[y_len] = p_ptr
            #TODO: Add teqacher forcing ratio
            y = responses[y_len].type(self.TYPE)
            y_len += 1

        # print(loss)
        mask_v = torch.ones(output_vocab.size())
        mask_p = torch.ones(output_ptr.size())
        if self.use_cuda:
            mask_p = mask_p.cuda()
            mask_v = mask_v.cuda()
        for i in range(responses.size(1)):
            mask_v[target_lengths[i]:, i, :] = 0
            mask_p[target_lengths[i]:, i, :] = 0

        loss_v = self.cross_entropy(
            output_vocab.contiguous().view(-1, self.nwords),
            responses.contiguous().view(-1))

        loss_ptr = self.cross_entropy(
            output_ptr.contiguous().view(-1, context.size(0)),
            index.contiguous().view(-1))
        if self.loss_weighting:
            loss = loss_ptr/(2*self.loss_weights[0]*self.loss_weights[0]) + loss_v/(2*self.loss_weights[1]*self.loss_weights[1]) + \
               torch.log(self.loss_weights[0] * self.loss_weights[1])
            loss_ptr = loss_ptr / (2 * self.loss_weights[0] *
                                   self.loss_weights[0])
            loss_v = loss_v / (2 * self.loss_weights[1] * self.loss_weights[1])
        else:
            loss = loss_ptr + loss_v
        loss.backward()
        ec = torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), 10.0)
        dc = torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), 10.0)
        self.optim_enc.step()
        self.optim_dec.step()
        if self.loss_weighting:
            self.optim_loss_weights.step()

        self.loss += loss.item()
        self.vloss += loss_v.item()
        self.ploss += loss_ptr.item()

        return loss.item(), loss_v.item(), loss_ptr.item()

    def evaluate_batch(self,
                       batch_size,
                       input_batches,
                       input_lengths,
                       target_batches,
                       target_lengths,
                       target_index,
                       target_gate,
                       src_plain,
                       profile_memory=None):

        # Set to not-training mode to disable dropout
        self.encoder.train(False)
        self.decoder.train(False)
        # Run words through encoder
        decoder_hidden = self.encoder(input_batches.transpose(0,
                                                              1)).unsqueeze(0)
        self.decoder.load_memory(input_batches.transpose(0, 1))

        # Prepare input and output variables
        decoder_input = Variable(torch.LongTensor([2] * batch_size))

        decoded_words = []
        all_decoder_outputs_vocab = Variable(
            torch.zeros(max(target_lengths), batch_size, self.nwords))
        all_decoder_outputs_ptr = Variable(
            torch.zeros(max(target_lengths), batch_size,
                        input_batches.size(0)))
        # all_decoder_outputs_gate = Variable(torch.zeros(self.max_r, batch_size))
        # Move new Variables to CUDA

        if self.use_cuda:
            all_decoder_outputs_vocab = all_decoder_outputs_vocab.cuda()
            all_decoder_outputs_ptr = all_decoder_outputs_ptr.cuda()
            # all_decoder_outputs_gate = all_decoder_outputs_gate.cuda()
            decoder_input = decoder_input.cuda()

        p = []
        for elm in src_plain:
            elm_temp = [word_triple[0] for word_triple in elm]
            p.append(elm_temp)

        self.from_whichs = []
        acc_gate, acc_ptr, acc_vac = 0.0, 0.0, 0.0
        # Run through decoder one time step at a time
        for t in range(max(target_lengths)):
            decoder_ptr, decoder_vacab, decoder_hidden = self.decoder(
                input_batches, decoder_input, decoder_hidden)
            all_decoder_outputs_vocab[t] = decoder_vacab
            topv, topvi = decoder_vacab.data.topk(1)
            all_decoder_outputs_ptr[t] = decoder_ptr
            topp, toppi = decoder_ptr.data.topk(1)
            top_ptr_i = torch.gather(input_batches[:, :, 0], 0,
                                     Variable(toppi.view(1,
                                                         -1))).transpose(0, 1)
            next_in = [
                top_ptr_i[i].item() if
                (toppi[i].item() < input_lengths[i] - 1) else topvi[i].item()
                for i in range(batch_size)
            ]
            # if next_in in self.kb_entry.keys():
            #     ptr_distr.append([next_in, decoder_vacab.data])

            decoder_input = Variable(
                torch.LongTensor(next_in))  # Chosen word is next input
            if self.use_cuda: decoder_input = decoder_input.cuda()

            temp = []
            from_which = []
            for i in range(batch_size):
                if (toppi[i].item() < len(p[i]) - 1):
                    temp.append(p[i][toppi[i].item()])
                    from_which.append('p')
                else:
                    if target_index[t][i] != toppi[i].item():
                        self.incorrect_sentinel += 1
                    ind = topvi[i].item()
                    if ind == 3:
                        temp.append('<eos>')
                    else:
                        temp.append(self.i2w[ind])
                    from_which.append('v')
            decoded_words.append(temp)
            self.from_whichs.append(from_which)
        self.from_whichs = np.array(self.from_whichs)

        loss_v = self.cross_entropy(
            all_decoder_outputs_vocab.contiguous().view(-1, self.nwords),
            target_batches.contiguous().view(-1))
        loss_ptr = self.cross_entropy(
            all_decoder_outputs_ptr.contiguous().view(-1,
                                                      input_batches.size(0)),
            target_index.contiguous().view(-1))

        if self.loss_weighting:
            loss = loss_ptr/(2*self.loss_weights[0]*self.loss_weights[0]) + loss_v/(2*self.loss_weights[1]*self.loss_weights[1]) + \
               torch.log(self.loss_weights[0] * self.loss_weights[1])
        else:
            loss = loss_ptr + loss_v

        self.loss += loss.item()
        self.vloss += loss_v.item()
        self.ploss += loss_ptr.item()
        self.n += 1

        # Set back to training mode
        self.encoder.train(True)
        self.decoder.train(True)
        return decoded_words, self.from_whichs  # , acc_ptr, acc_vac

    def save_models(self, path):
        torch.save(self.encoder.state_dict(),
                   os.path.join(path, 'encoder.pth'))
        torch.save(self.decoder.state_dict(),
                   os.path.join(path, 'decoder.pth'))

    def load_models(self, path: str = '.'):
        self.encoder.load_state_dict(
            torch.load(os.path.join(path, 'encoder.pth')))
        self.decoder.load_state_dict(
            torch.load(os.path.join(path, 'decoder.pth')))
示例#5
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()

        self.config = config

        # 情感嵌入层
        self.affect_embedding = AffectEmbedding(config.num_vocab,
                                                config.affect_embedding_size,
                                                config.pad_id)

        # 定义嵌入层
        self.embedding = WordEmbedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id)  # pad_id

        # post编码器
        self.post_encoder = Encoder(
            config.post_encoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.post_encoder_output_size,  # 输出维度
            config.post_encoder_num_layers,  # rnn层数
            config.post_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(
            config.response_encoder_cell_type,
            config.embedding_size,  # 输入维度
            config.response_encoder_output_size,  # 输出维度
            config.response_encoder_num_layers,  # rnn层数
            config.response_encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(
            config.post_encoder_output_size,  # post输入维度
            config.latent_size,  # 潜变量维度
            config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(
            config.post_encoder_output_size,  # post输入维度
            config.response_encoder_output_size,  # response输入维度
            config.latent_size,  # 潜变量维度
            config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(
            config.post_encoder_output_size + config.latent_size,
            config.decoder_cell_type, config.decoder_output_size,
            config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(
            config.decoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.decoder_output_size,  # 输出维度
            config.decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1))

    def forward(
            self,
            input,
            inference=False,  # 是否测试
            use_true=False,
            max_len=60):  # 解码的最大长度

        if not inference:  # 训练
            if use_true:
                id_posts = input['posts']  # [batch, seq]
                len_posts = input['len_posts']  # [batch]
                id_responses = input['responses']  # [batch, seq]
                len_responses = input['len_responses']  # [batch, seq]
                sampled_latents = input[
                    'sampled_latents']  # [batch, latent_size]

                embed_posts = self.embedding(
                    id_posts)  # [batch, seq, embed_size]
                embed_responses = self.embedding(
                    id_responses)  # [batch, seq, embed_size]

                # 解码器的输入为回复去掉end_id
                decoder_input = embed_responses[:, :-1, :].transpose(
                    0, 1)  # [seq-1, batch, embed_size]
                len_decoder = decoder_input.size()[0]  # 解码长度 seq-1
                decoder_input = decoder_input.split(
                    [1] * len_decoder,
                    0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]

                # state = [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                outputs = []

                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state  # 解码器初始状态
                    input = decoder_input[
                        idx]  # 当前时间步输入 [1, batch, embed_size]
                    # output: [1, batch, dim_out]
                    # state: [num_layer, batch, dim_out]
                    output, state = self.decoder(input, state)
                    outputs.append(output)

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, _mu, _logvar, mu, logvar

            else:
                id_posts = input['posts']  # [batch, seq]
                len_posts = input['len_posts']  # [batch]
                id_responses = input['responses']  # [batch, seq]
                len_responses = input['len_responses']  # [batch]
                sampled_latents = input[
                    'sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size()[1] - 1
                batch_size = id_posts.size()[0]
                device = id_posts.device.type

                embed_posts = self.embedding(
                    id_posts)  # [batch, seq, embed_size]
                embed_responses = self.embedding(
                    id_responses)  # [batch, seq, embed_size]

                # state = [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                                   len_posts)
                _, state_responses = self.response_encoder(
                    embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]

                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                z = mu + (0.5 *
                          logvar).exp() * sampled_latents  # [batch, latent]

                first_state = self.prepare_state(torch.cat(
                    [z, x], 1))  # [num_layer, batch, dim_out]
                first_input_id = (torch.ones(
                    (1, batch_size)) * self.config.start_id).long()
                if device == 'cuda':
                    first_input_id = first_input_id.cuda()
                outputs = []

                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state
                        input = self.embedding(first_input_id)
                    else:
                        input = self.embedding(
                            next_input_id)  # 当前时间步输入 [1, batch, embed_size]
                    output, state = self.decoder(input, state)
                    outputs.append(output)

                    vocab_prob = self.projector(
                        output)  # [1, batch, num_vocab]
                    next_input_id = torch.argmax(
                        vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                outputs = torch.cat(outputs,
                                    0).transpose(0,
                                                 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(
                    outputs)  # [batch, seq-1, num_vocab]

                return output_vocab, _mu, _logvar, mu, logvar

        else:  # 测试

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            sampled_latents = input['sampled_latents']  # [batch, latent_size]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]

            # state = [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1),
                                               len_posts)
            if isinstance(state_posts, tuple):  # 如果是lstm则取h
                state_posts = state_posts[0]  # [layers, batch, dim]
            x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]

            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            z = _mu + (0.5 *
                       _logvar).exp() * sampled_latents  # [batch, latent]

            first_state = self.prepare_state(torch.cat(
                [z, x], 1))  # [num_layer, batch, dim_out]
            outputs = []

            done = torch.BoolTensor([0] * batch_size)
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            if device == 'cuda':
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = first_state  # 解码器初始状态
                    input = self.embedding(
                        first_input_id)  # 解码器初始输入 [1, batch, embed_size]

                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break
                else:
                    input = self.embedding(
                        next_input_id)  # [1, batch, embed_size]

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, _mu, _logvar, None, None

    # 统计参数
    def print_parameters(self):
        def statistic_param(params):
            total_num = 0  # 参数总数
            for param in params:
                num = 1
                if param.requires_grad:
                    size = param.size()
                    for dim in size:
                        num *= dim
                total_num += num
            return total_num

        print("嵌入层参数个数: %d" % statistic_param(self.embedding.parameters()))
        print("post编码器参数个数: %d" %
              statistic_param(self.post_encoder.parameters()))
        print("response编码器参数个数: %d" %
              statistic_param(self.response_encoder.parameters()))
        print("先验网络参数个数: %d" % statistic_param(self.prior_net.parameters()))
        print("识别网络参数个数: %d" %
              statistic_param(self.recognize_net.parameters()))
        print("解码器初始状态参数个数: %d" %
              statistic_param(self.prepare_state.parameters()))
        print("解码器参数个数: %d" % statistic_param(self.decoder.parameters()))
        print("输出层参数个数: %d" % statistic_param(self.projector.parameters()))
        print("参数总数: %d" % statistic_param(self.parameters()))

    # 保存模型
    def save_model(self, epoch, global_step, path):

        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'post_encoder': self.post_encoder.state_dict(),
                'response_encoder': self.response_encoder.state_dict(),
                'prior_net': self.prior_net.state_dict(),
                'recognize_net': self.recognize_net.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'decoder': self.decoder.state_dict(),
                'projector': self.projector.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    # 载入模型
    def load_model(self, path):

        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['affect_embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.post_encoder.load_state_dict(checkpoint['post_encoder'])
        self.response_encoder.load_state_dict(checkpoint['response_encoder'])
        self.prior_net.load_state_dict(checkpoint['prior_net'])
        self.recognize_net.load_state_dict(checkpoint['recognize_net'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']

        return epoch, global_step
示例#6
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()

        self.config = config

        # 情感嵌入层
        self.affect_embedding = AffectEmbedding(config.num_vocab,
                                                config.affect_embedding_size,
                                                config.pad_id)

        # 定义嵌入层
        self.embedding = WordEmbedding(
            config.num_vocab,  # 词汇表大小
            config.embedding_size,  # 嵌入层维度
            config.pad_id)  # pad_id

        # 情感编码器
        self.affect_encoder = Encoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.affect_embedding_size,  # 输入维度
            config.affect_encoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # 层数
            config.encoder_bidirectional,  # 是否双向
            config.dropout)

        # 编码器
        self.encoder = Encoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.embedding_size,  # 输入维度
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.encoder_bidirectional,  # 是否双向
            config.dropout)  # dropout概率

        self.attention = Attention(config.encoder_decoder_output_size,
                                   config.affect_encoder_output_size,
                                   config.attention_type,
                                   config.attention_size)

        self.prepare_state = PrepareState(
            config.encoder_decoder_cell_type,
            config.encoder_decoder_output_size +
            config.affect_encoder_output_size,
            config.encoder_decoder_output_size)

        self.linear_prepare_input = nn.Linear(
            config.embedding_size + config.affect_encoder_output_size +
            config.attention_size, config.decoder_input_size)

        # 解码器
        self.decoder = Decoder(
            config.encoder_decoder_cell_type,  # rnn类型
            config.decoder_input_size,  # 输入维度
            config.encoder_decoder_output_size,  # 输出维度
            config.encoder_decoder_num_layers,  # rnn层数
            config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(
                config.encoder_decoder_output_size + config.attention_size,
                config.num_vocab), nn.Softmax(-1))

    def forward(
            self,
            input,
            inference=False,  # 是否测试
            max_len=60):  # 解码的最大长度

        if not inference:  # 训练

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            id_responses = input['responses']  # [batch, seq]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            embed_responses = self.embedding(
                id_responses)  # [batch, seq, embed_size]
            affect_posts = self.affect_embedding(id_posts)

            # 解码器的输入为回复去掉end_id
            decoder_input = embed_responses[:, :-1, :].transpose(
                0, 1)  # [seq-1, batch, embed_size]
            len_decoder = decoder_input.size()[0]  # 解码长度 seq-1
            decoder_input = decoder_input.split(
                [1] * len_decoder, 0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]

            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            # output_affect: [seq, batch, dim]
            output_affect, state_affect_encoder = self.affect_encoder(
                affect_posts.transpose(0, 1), len_posts)
            context_affect = output_affect[-1, :, :].unsqueeze(
                0)  # [1, batch, dim]

            outputs = []
            weights = []

            init_attn = torch.zeros(
                [1, batch_size, self.config.attention_size])
            if device == 'cuda':
                init_attn = init_attn.cuda()

            for idx in range(len_decoder):
                if idx == 0:
                    state = self.prepare_state(state_encoder,
                                               state_affect_encoder)  # 解码器初始状态
                    input = torch.cat(
                        [decoder_input[idx], context_affect, init_attn], 2)  #
                else:
                    input = torch.cat([
                        decoder_input[idx], context_affect,
                        attn.transpose(0, 1)
                    ], 2)  #

                input = self.linear_prepare_input(input)

                # output: [1, batch, dim_out]
                # state: [num_layer, batch, dim_out]
                output, state = self.decoder(input, state)
                # attn: [batch, 1, attention_size]
                # weights: [batch, 1, encoder_len]
                attn, weight = self.attention(output.transpose(0, 1),
                                              output_affect.transpose(0, 1))
                outputs.append(torch.cat([output, attn.transpose(0, 1)], 2))
                weights.append(weight)

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq-1, dim_out]
            weights = torch.cat(weights, 1)  # [batch, seq-1, dim_out]

            output_vocab = self.projector(outputs)  # [batch, seq-1, num_vocab]

            return output_vocab, weights

        else:  # 测试

            id_posts = input['posts']  # [batch, seq]
            len_posts = input['len_posts']  # [batch]
            batch_size = id_posts.size()[0]
            device = id_posts.device.type

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            affect_posts = self.affect_embedding(id_posts)

            # state: [layers, batch, dim]
            _, state_encoder = self.encoder(embed_posts.transpose(0, 1),
                                            len_posts)
            output_affect, state_affect_encoder = self.affect_encoder(
                affect_posts.transpose(0, 1), len_posts)
            context_affect = output_affect[-1, :, :].unsqueeze(
                0)  # [1, batch, dim]

            outputs = []
            weights = []

            done = torch.BoolTensor([0] * batch_size)
            first_input_id = (torch.ones(
                (1, batch_size)) * self.config.start_id).long()
            init_attn = torch.zeros(
                [1, batch_size, self.config.attention_size])
            if device == 'cuda':
                done = done.cuda()
                first_input_id = first_input_id.cuda()
                init_attn = init_attn.cuda()

            for idx in range(max_len):

                if idx == 0:  # 第一个时间步
                    state = self.prepare_state(state_encoder,
                                               state_affect_encoder)  # 解码器初始状态
                    input = torch.cat([
                        self.embedding(first_input_id), context_affect,
                        init_attn
                    ], 2)  # 解码器初始输入 [1, batch, embed_size]
                else:
                    input = torch.cat(
                        [input, context_affect,
                         attn.transpose(0, 1)], 2)

                input = self.linear_prepare_input(input)

                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(input, state)
                # attn: [batch, 1, attention_size]
                # weights: [batch, 1, encoder_len]
                attn, weight = self.attention(output.transpose(0, 1),
                                              output_affect.transpose(0, 1))
                outputs.append(torch.cat([output, attn.transpose(0, 1)], 2))
                weights.append(weight)

                vocab_prob = self.projector(
                    torch.cat([output, attn.transpose(0, 1)],
                              2))  # [1, batch, num_vocab]
                next_input_id = torch.argmax(
                    vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                _done = next_input_id.squeeze(
                    0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的

                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break
                else:
                    input = self.embedding(
                        next_input_id)  # [1, batch, embed_size]

            outputs = torch.cat(outputs,
                                0).transpose(0, 1)  # [batch, seq, dim_out]
            weights = torch.cat(weights, 1)
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, weights

    # 统计参数
    def print_parameters(self):
        def statistic_param(params):
            total_num = 0  # 参数总数
            for param in params:
                num = 1
                if param.requires_grad:
                    size = param.size()
                    for dim in size:
                        num *= dim
                total_num += num
            return total_num

        print("嵌入层参数个数: %d" % statistic_param(self.embedding.parameters()))
        print("编码器参数个数: %d" % statistic_param(self.encoder.parameters()))
        print("情感编码器参数个数: %d" %
              statistic_param(self.affect_encoder.parameters()))
        print("准备状态参数个数: %d" %
              statistic_param(self.prepare_state.parameters()))
        print("准备输入参数个数: %d" %
              statistic_param(self.linear_prepare_input.parameters()))
        print("注意力参数个数: %d" % statistic_param(self.attention.parameters()))
        print("解码器参数个数: %d" % statistic_param(self.decoder.parameters()))
        print("输出层参数个数: %d" % statistic_param(self.projector.parameters()))
        print("参数总数: %d" % statistic_param(self.parameters()))

    # 保存模型
    def save_model(self, epoch, global_step, path):

        torch.save(
            {
                'affect_embedding': self.affect_embedding.state_dict(),
                'embedding': self.embedding.state_dict(),
                'encoder': self.encoder.state_dict(),
                'affect_encoder': self.affect_encoder.state_dict(),
                'prepare_state': self.prepare_state.state_dict(),
                'linear_prepare_input': self.linear_prepare_input.state_dict(),
                'attention': self.attention.state_dict(),
                'projector': self.projector.state_dict(),
                'decoder': self.decoder.state_dict(),
                'epoch': epoch,
                'global_step': global_step
            }, path)

    # 载入模型
    def load_model(self, path):

        checkpoint = torch.load(path)
        self.affect_embedding.load_state_dict(checkpoint['embedding'])
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.affect_encoder.load_state_dict(checkpoint['affect_encoder'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.linear_prepare_input.load_state_dict(
            checkpoint['linear_prepare_input'])
        self.attention.load_state_dict(checkpoint['attention'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']

        return epoch, global_step
示例#7
0
class TrainBatch():
    def __init__(self,
                 input_size,
                 hidden_size,
                 batch_size,
                 learning_rate,
                 method,
                 num_layers=1):
        dataset = Seq2SeqDataset()
        self.data_loader = DataLoader(dataset=dataset,
                                      batch_size=batch_size,
                                      shuffle=True)
        self.vocab = dataset.vocab
        self.output_size = len(self.vocab)
        self.char2index, self.index2char = self.data_index()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_layers = 1
        self.method = method

        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.attn = Attn(method, hidden_size)
        self.encoder = Encoder(input_size, hidden_size, self.output_size,
                               self.num_layers)
        self.decoder = Decoder(hidden_size, self.output_size, method,
                               self.num_layers)

        self.attn = self.attn.to(self.device)
        self.encoder = self.encoder.to(self.device)
        self.decoder = self.decoder.to(self.device)

        self.loss_function = NLLLoss()
        self.encoder_optim = torch.optim.Adam(self.encoder.parameters(),
                                              lr=self.learning_rate)
        self.decoder_optim = torch.optim.Adam(self.decoder.parameters(),
                                              lr=self.learning_rate)

    def word_to_index(self, word):
        char_index = [self.char2index[w] for w in list(word)]
        return torch.LongTensor(char_index).to(self.device)

    # return batch_indedx _ after softed
    def create_batch_tensor(self, batch_word, batch_len):
        batch_size = len(batch_word)
        seq_len = max(batch_len)
        seq_tensor = torch.zeros([batch_size, seq_len]).long().to(self.device)
        for i in range(batch_size):
            seq_tensor[i, :batch_len[i]] = self.word_to_index(batch_word[i])
        return seq_tensor

    def create_batch(self, input, target):
        input_seq = [list(w) for w in list(input)]
        target_seq = [list(w) for w in list(target)]

        seq_pairs = sorted(zip(input_seq, target_seq),
                           key=lambda p: len(p[0]),
                           reverse=True)
        input_seq, target_seq = zip(*seq_pairs)
        input_len = [len(w) for w in input_seq]
        target_len = [len(w) for w in target_seq]

        input_seq = self.create_batch_tensor(input_seq, input_len)
        input_len = torch.LongTensor(input_len).to(self.device)
        target_seq = self.create_batch_tensor(target_seq, target_len)
        return self.create_tensor(input_seq), \
               self.create_tensor(input_len), self.create_tensor(target_seq)

    def get_len(self, input):
        input_seq = [list(w) for w in list(input)]
        input_len = [len(w) for w in input_seq]
        input_len = torch.LongTensor(input_len).to(self.device)
        return input_len

    def data_index(self):
        char2index = {}
        char2index.update({w: i for i, w in enumerate(self.vocab)})
        index2char = {w[1]: w[0] for w in char2index.items()}
        return char2index, index2char

    def create_tensor(self, tensor):
        return Variable(tensor.to(self.device))

    def create_mask(self, tensor):
        return self.create_tensor(
            torch.gt(tensor,
                     torch.LongTensor([0]).to(self.device)))

    def mask_NLLLoss(self, inp, target, mask):
        nTotal = mask.sum()
        crossEntropy = -torch.log(torch.gather(inp, 2, target))
        loss = crossEntropy.masked_select(mask).mean()
        loss = loss.to(self.device)
        return loss, nTotal.item()

    def sequence_mask(self, sequence_length, max_len=None):
        if max_len is None:
            max_len = sequence_length.data.max()
        batch_size = sequence_length.size(0)
        seq_range = torch.range(0, max_len - 1).long()
        seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
        seq_range_expand = Variable(seq_range_expand)
        if sequence_length.is_cuda:
            seq_range_expand = seq_range_expand.cuda()
        seq_length_expand = (
            sequence_length.unsqueeze(1).expand_as(seq_range_expand))
        return seq_range_expand < seq_length_expand

    def masked_cross_entropy(self, output, target, length):
        output = output.view(-1, output.size(2))
        log_output = F.log_softmax(output, 1)
        target = target.view(-1, 1)
        losses_flat = -torch.gather(log_output, 1, target)
        losses = losses_flat.view(*target.size())

        # mask = self.sequence_mask(sequence_length=length, max_len=target.size(1))
        mask = target.gt(torch.LongTensor([0]).to('cuda:0'))
        losses = losses * mask.float()
        loss = losses.sum() / length.float().sum()
        return loss

    def step(self, input, target):
        input_seq, input_len, target_seq = self.create_batch(input, target)
        # encoder_output: (batch, max_len, hidden) (5,8,64)
        # hidden (1, batch, 64)
        encoder_output, (hidden_state,
                         cell_state) = self.encoder(input_seq, input_len)
        # SOS_index = torch.LongTensor(self.char2index[SOS]).to(self.device)
        batch_size = input_seq.size(0)
        max_len = target_seq.size(1)
        decoder_output = torch.zeros([batch_size, max_len,
                                      self.output_size]).to(self.device)
        # start of sentence
        decoder_input = torch.tensor((), dtype=torch.long)
        decoder_input = decoder_input.new_ones([batch_size, 1]).to(self.device)
        decoder_input = self.create_tensor(decoder_input *
                                           self.char2index['_'])

        output_tensor = torch.zeros([batch_size, max_len])
        # use schedule sampling
        for i in range(max_len):
            output, (hidden_state,
                     cell_state) = self.decoder(decoder_input,
                                                (hidden_state, cell_state),
                                                encoder_output)
            if rd.random() > 0.5:
                decoder_input = target_seq[:, i].unsqueeze(1)
            else:
                decoder_input = output.topk(1)[1]
            # decoder_input = target_seq[:, i].unsqueeze(1)
            output_index = output.topk(1)[1]
            output_tensor[:, i] = output_index.squeeze(1)
            decoder_output[:, i] = output
        target_len = self.get_len(target)
        loss_ = self.masked_cross_entropy(decoder_output, target_seq,
                                          target_len)
        decoder_output = decoder_output.view(-1, self.output_size)
        target_seq = target_seq.view(-1)

        # loss = self.loss_function(decoder_output, target_seq)

        self.encoder_optim.zero_grad()
        self.decoder_optim.zero_grad()

        # loss.backward()
        loss_.backward()

        self.encoder_optim.step()
        self.decoder_optim.step()

        return loss_.item(), output_tensor.cpu().tolist()
示例#8
0
class LSTM_CTE_Model(nn.Module):
    """
    Implementation of a seq2seq model.
    Architecture:
        Encoder/decoder
        2 LTSM layers
    """
    def __init__(self, w2i, i2w, embs=None, title_emb=None):
        """
        Args:
            args: parameters of the model
            textData: the dataset object
        """
        super(LSTM_CTE_Model, self).__init__()
        print("Model creation...")

        self.word2index = w2i
        self.index2word = i2w
        self.max_length = args['maxLengthDeco']

        self.NLLloss = torch.nn.NLLLoss(ignore_index=0)
        self.CEloss = torch.nn.CrossEntropyLoss(ignore_index=0)

        if embs is not None:
            self.embedding = nn.Embedding.from_pretrained(embs)
        else:
            self.embedding = nn.Embedding(args['vocabularySize'],
                                          args['embeddingSize'])

        if title_emb is not None:
            self.field_embedding = nn.Embedding.from_pretrained(title_emb)
        else:
            self.field_embedding = nn.Embedding(args['TitleNum'],
                                                args['embeddingSize'])

        self.encoder = Encoder(w2i, i2w, bidirectional=True)
        # self.encoder_answer_only = Encoder(w2i, i2w)
        self.encoder_no_answer = Encoder(w2i, i2w)
        self.encoder_pure_answer = Encoder(w2i, i2w)

        self.decoder_answer = Decoder(w2i,
                                      i2w,
                                      self.embedding,
                                      copy='pure',
                                      max_dec_len=10)
        self.decoder_no_answer = Decoder(w2i,
                                         i2w,
                                         self.embedding,
                                         input_dim=args['embeddingSize'] * 2,
                                         copy='semi')

        self.ansmax2state_h = nn.Linear(args['embeddingSize'],
                                        args['hiddenSize'] * 2,
                                        bias=False)
        self.ansmax2state_c = nn.Linear(args['embeddingSize'],
                                        args['hiddenSize'] * 2,
                                        bias=False)
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

        self.att_size_r = 60
        # self.grm = GaussianOrthogonalRandomMatrix()
        # self.att_projection_matrix = Parameter(self.grm.get_2d_array(args['embeddingSize'], self.att_size_r))
        self.M = Parameter(
            torch.randn([args['embeddingSize'], args['hiddenSize'] * 2, 2]))

        self.shrink_copy_input = nn.Linear(args['hiddenSize'] * 2,
                                           args['hiddenSize'],
                                           bias=False)
        self.emb2hid = nn.Linear(args['embeddingSize'],
                                 args['hiddenSize'],
                                 bias=False)
        # self.z_logit2prob = nn.Sequential(
        #     nn.Linear(args['hiddenSize'], 2)
        # )

        # self.z_to_fea = nn.Linear(args['hiddenSize'], args['hiddenSize']).to(args['device'])
        # self.SEClassifier = nn.Sequential(
        #     nn.Linear(args['hiddenSize'], 2),
        #     nn.Sigmoid()
        # )
        #
        # self.SentenceClassifier = nn.Sequential(
        #     nn.Linear(args['hiddenSize'], 1),
        #     nn.Sigmoid()
        # )

    def sample_z(self, mu, log_var, batch_size):
        eps = Variable(
            torch.randn(batch_size, args['style_len'] * 2 *
                        args['numLayers'])).to(args['device'])
        return mu + torch.einsum('ba,ba->ba', torch.exp(log_var / 2), eps)

    def cos(self, x1, x2):
        '''
        :param x1: batch seq emb
        :param x2:
        :return:
        '''
        xx = torch.einsum('bse,bte->bst', x1, x2)
        x1n = torch.norm(x1, dim=-1, keepdim=True)
        x2n = torch.norm(x2, dim=-1, keepdim=True)
        xd = torch.einsum('bse,bte->bst', x1n, x2n).clamp(min=0.0001)
        return xx / xd

    def sample_gumbel(self, shape, eps=1e-20):
        U = torch.rand(shape).to(args['device'])
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature):
        y = logits + self.sample_gumbel(logits.size())
        return F.softmax(y / temperature, dim=-1)

    def gumbel_softmax(self, logits, temperature=args['temperature']):
        """
        ST-gumple-softmax
        input: [*, n_class]
        return: flatten --> [*, n_class] an one-hot vector
        """
        y = self.gumbel_softmax_sample(logits, temperature)
        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        y_hard = (y_hard - y).detach() + y
        return y_hard, y

    def get_pretrain_parameters(self):
        return list(self.embedding.parameters()) + list(
            self.encoder.parameters()) + list(
                self.decoder_no_answer.parameters())

    def build(self, x, mode, eps=1e-6):
        '''
        :param encoderInputs: [batch, enc_len]
        :param decoderInputs: [batch, dec_len]
        :param decoderTargets: [batch, dec_len]
        :return:
        '''

        # D,Q -> s: P(s|D,Q)
        context_inputs = torch.LongTensor(x.contextSeqs).to(args['device'])
        field = torch.LongTensor(x.field).to(args['device'])
        answer_dec = torch.LongTensor(x.decoderSeqs).to(args['device'])
        answer_tar = torch.LongTensor(x.targetSeqs).to(args['device'])
        context_dec = torch.LongTensor(x.ContextDecoderSeqs).to(args['device'])
        context_tar = torch.LongTensor(x.ContextTargetSeqs).to(args['device'])
        pure_answer = torch.LongTensor(x.answerSeqs).to(args['device'])
        # context_mask = torch.FloatTensor(x.context_mask).to(args['device'])  # batch sentence
        # sentence_mask = torch.FloatTensor(x.sentence_mask).to(args['device'])  # batch sennum contextlen
        # start_positions = torch.FloatTensor(x.starts).to(args['device'])
        # end_positions = torch.FloatTensor(x.ends).to(args['device'])
        # ans_context_input = torch.LongTensor(x.ans_contextSeqs).to(args['device'])
        # ans_context_mask = torch.LongTensor(x.ans_con_mask).to(args['device'])

        pure_answer_embs = self.embedding(pure_answer)
        # print(' context_inputs: ', context_inputs[0])
        # print(' context_dec: ', context_dec[0])
        # print(' context_tar: ', context_tar[0])

        mask = torch.sign(context_inputs).float()
        mask_pure_answer = torch.sign(pure_answer).float()

        batch_size = context_inputs.size()[0]
        seq_len = context_inputs.size()[1]
        context_inputs_embs = self.embedding(context_inputs)
        q_inputs_embs = self.field_embedding(field)  #.unsqueeze(1) # batch emb
        #
        # attentioned_context = dot_product_attention(query=q_inputs_embs.unsqueeze(1), key=context_inputs_embs, value=context_inputs_embs,
        #                                             projection_matrix=self.att_projection_matrix)  # b s h

        en_context_output, en_context_state = self.encoder(
            context_inputs_embs)  # b s e

        # print(q_inputs_embs.size(), en_context_output.size())
        att1 = self.tanh(torch.einsum('be,ehc->bhc', q_inputs_embs, self.M))
        # print(att1.size(), en_context_output.size())
        z_logit = torch.einsum('bhc,bsh->bsc', att1, en_context_output)
        # z_embs = self.tanh(self.q_att_layer(q_inputs_embs) + self.c_att_layer(en_context_output)) # b s h
        # z_logit = self.z_logit2prob(z_embs).squeeze() # b s 2
        # z_logit = torch.cat([1-z_logit_1, z_logit_1], dim = 2)
        z_logit_fla = z_logit.reshape((batch_size * seq_len, 2))
        z_prob = self.softmax(z_logit)
        if mode == 'train':
            sampled_seq, sampled_seq_soft = self.gumbel_softmax(z_logit_fla)
            sampled_seq = sampled_seq.reshape((batch_size, seq_len, 2))
            sampled_seq_soft = sampled_seq_soft.reshape(
                (batch_size, seq_len, 2))
            sampled_seq = sampled_seq * mask.unsqueeze(2)
            sampled_seq_soft = sampled_seq_soft * mask.unsqueeze(2)
        else:
            sampled_seq = (z_prob > 0.5).float() * mask.unsqueeze(2)

        gold_ans_mask, _ = (
            context_inputs.unsqueeze(2) == pure_answer.unsqueeze(1)).max(2)

        if mode == 'train':
            ans_mask, _ = (
                context_inputs.unsqueeze(2) == pure_answer.unsqueeze(1)).max(2)
            noans_mask = 1 - ans_mask.int()
            # print(noans_mask)
        else:
            ans_mask = sampled_seq[:, :, 1].int()
            noans_mask = sampled_seq[:, :, 0].int()

        answer_only_sequence = context_inputs_embs * sampled_seq[:, :,
                                                                 1].unsqueeze(
                                                                     2)
        no_answer_sequence = context_inputs_embs * noans_mask.unsqueeze(
            2)  #.detach()
        # ANS_END = torch.LongTensor([5] * batch_size).to(args['device'])
        # ANS_END = ANS_END.unsqueeze(1)
        # ANS_END_emb = self.embedding(ANS_END)
        # no_answer_sequence = torch.cat([pure_answer_embs, ANS_END_emb, no_answer_sequence], dim = 1)

        answer_only_logp_z0 = torch.log(z_prob[:, :, 0].clamp(
            eps, 1.0))  # [B,T], log P(z = 0 | x)
        answer_only_logp_z1 = torch.log(z_prob[:, :, 1].clamp(
            eps, 1.0))  # [B,T], log P(z = 1 | x)
        # answer_only_logpz = (1-sampled_seq[:, :, 1]) * answer_only_logp_z0 + sampled_seq[:, :, 1] * answer_only_logp_z1
        answer_only_logpz = torch.where(sampled_seq[:, :, 1] == 0,
                                        answer_only_logp_z0,
                                        answer_only_logp_z1)
        # no_answer_logpz = torch.where(sampled_seq[:, :, 1] == 0,answer_only_logp_z1, answer_only_logp_z0)
        answer_only_logpz = mask * answer_only_logpz
        # no_answer_logpz = mask * no_answer_logpz

        # answer_only_output, answer_only_state = self.encoder_answer_only(answer_only_sequence)
        answer_only_info, _ = torch.max(answer_only_sequence, dim=1)
        # print(answer_only_info.size())
        answer_only_state = (self.ansmax2state_h(answer_only_info).reshape([
            batch_size, args['numLayers'], args['hiddenSize']
        ]), self.ansmax2state_c(answer_only_info).reshape(
            [batch_size, args['numLayers'], args['hiddenSize']]))
        answer_only_state = (answer_only_state[0].transpose(0, 1).contiguous(),
                             answer_only_state[1].transpose(0, 1).contiguous())

        no_answer_output, no_answer_state = self.encoder_no_answer(
            no_answer_sequence)
        # no_answer_output, no_answer_state = self.encoder_no_answer(context_inputs_embs)

        en_context_output_shrink = self.shrink_copy_input(
            en_context_output)  # bsh
        # answer_latent_emb,_ = torch.max(answer_only_output)
        enc_onehot = F.one_hot(context_inputs,
                               num_classes=args['vocabularySize'])
        answer_de_output = self.decoder_answer(
            answer_only_state,
            answer_dec,
            answer_tar,
            enc_embs=en_context_output_shrink,
            enc_mask=mask,
            enc_onehot=enc_onehot)
        answer_recon_loss = self.NLLloss(
            torch.transpose(answer_de_output, 1, 2), answer_tar)
        # answer_mask = torch.sign(answer_tar.float())
        # answer_recon_loss = torch.squeeze(answer_recon_loss) * answer_mask
        answer_recon_loss_mean = answer_recon_loss  #torch.mean(answer_recon_loss, dim = 1)
        #
        ######################## no_answer do not contain answer #####################
        pred_no_answer_seq = context_inputs_embs * sampled_seq[:, :,
                                                               0].unsqueeze(2)
        cross_len = torch.abs(
            torch.einsum('bse,bae->bsa', pred_no_answer_seq, pure_answer_embs)
            * mask.unsqueeze(2)) / (
                torch.norm(pred_no_answer_seq, dim=2).unsqueeze(2) +
                eps) / (torch.norm(pure_answer_embs, dim=2).unsqueeze(1) + eps)
        # print(cross_len)
        # print(torch.max(cross_len))
        # print(torch.mean(cross_len))
        # exit()
        cross_sim = torch.mean(cross_len)

        # ################### no-answer context  + answer info -> origin context  #############
        # pure_answer_output, pure_answer_state = self.encoder_pure_answer(pure_answer_embs)
        # pure_answer_output = torch.mean(pure_answer_embs, dim = 1, keepdim=True)
        pure_answer_mask = torch.sign(pure_answer).float()
        pure_answer_output = torch.sum(
            pure_answer_embs * pure_answer_mask.unsqueeze(2),
            dim=1,
            keepdim=True) / torch.sum(pure_answer_mask, dim=1,
                                      keepdim=True).unsqueeze(2)
        # no_ans_plus_pureans_state = (torch.cat([no_answer_state[0], pure_answer_state[0]], dim = 2),
        #                              torch.cat([no_answer_state[1], pure_answer_state[1]], dim=2))
        en_context_output_plus = torch.cat([
            en_context_output_shrink * noans_mask.unsqueeze(2),
            self.emb2hid(pure_answer_embs)
        ],
                                           dim=1)
        mask_plus = torch.cat([mask, mask_pure_answer], dim=1)
        enc_onehot_plus = F.one_hot(torch.cat(
            [context_inputs * noans_mask, pure_answer], dim=1),
                                    num_classes=args['vocabularySize'])
        pa_mask, _ = F.one_hot(pure_answer,
                               num_classes=args['vocabularySize']).max(
                                   1)  # batch voc
        pa_mask[:, 0] = 0
        pa_mask = pa_mask.detach()
        # print(pa_mask)
        context_de_output = self.decoder_no_answer(
            no_answer_state,
            context_dec,
            context_tar,
            cat=pure_answer_output,
            enc_embs=en_context_output_plus,
            enc_mask=mask_plus,
            enc_onehot=enc_onehot_plus,
            lstm_mask=pa_mask)
        # context_de_output = self.decoder_no_answer(no_answer_state, context_dec, context_tar, cat=None, enc_embs = en_context_output_plus, enc_mask=mask_plus, enc_onehot = enc_onehot_plus, lstm_mask = pa_mask)
        # context_de_output = self.decoder_no_answer(en_context_state, context_dec, context_tar)#, cat=torch.max(pure_answer_output, dim = 1, keepdim=True)[0])
        context_recon_loss = self.NLLloss(
            torch.transpose(context_de_output, 1, 2), context_tar)
        # context_mask = torch.sign(context_tar.float())
        # context_recon_loss = torch.squeeze(context_recon_loss) * context_mask
        context_recon_loss_mean = context_recon_loss  #torch.mean(context_recon_loss, dim = 1)

        I_x_z = torch.abs(
            torch.mean(-torch.log(z_prob[:, :, 0] + eps), 1) + np.log(0.5))
        # I_x_z = torch.abs(torch.mean(torch.log(z_prob[:, :, 1]+eps), 1) -np.log(0.1))

        loss = 10 * I_x_z.mean() + answer_recon_loss_mean.mean(
        ) + context_recon_loss_mean + cross_sim * 100  #+ ((answer_recon_loss_mean.detach() )* answer_only_logpz.mean(1)).mean()
        # print(loss, 100 * I_x_z.mean(), answer_recon_loss_mean.mean(), context_recon_loss_mean, cross_sim)
        #    + context_recon_loss_mean.detach() * no_answer_logpz.mean(1)).mean()
        # loss = context_recon_loss_mean.mean()
        self.tt = [
            answer_recon_loss_mean.mean(), context_recon_loss_mean,
            (sampled_seq[:, :, 1].sum(1) * 1.0 / mask.sum(1)).mean(), cross_sim
        ]
        # self.tt = [context_recon_loss_mean.mean(),]
        # self.tt = [answer_recon_loss_mean.mean() , (sampled_seq[:,:,1].sum(1)*1.0/ mask.sum(1)).mean()]
        # return loss, answer_only_state, no_answer_state, pure_answer_output, (sampled_seq[:,:,1].sum(1)*1.0/ mask.sum(1)).mean(), sampled_seq[:,:,1], \
        #        en_context_output_shrink, mask,  enc_onehot, en_context_output_plus, mask_plus, enc_onehot_plus
        return {
            'loss': loss,
            'answer_only_state': answer_only_state,
            'no_answer_state': no_answer_state,
            'pure_answer_output': pure_answer_output,
            'closs': (sampled_seq[:, :, 1].sum(1) * 1.0 / mask.sum(1)).mean(),
            'sampled_words': sampled_seq[:, :, 1],
            'en_context_output': en_context_output_shrink,
            'mask': mask,
            'enc_onehot': enc_onehot,
            'en_context_output_plus': en_context_output_plus,
            'mask_plus': mask_plus,
            'enc_onehot_plus': enc_onehot_plus,
            'context_inputs': context_inputs,
            'context_inputs_embs': context_inputs_embs,
            'ans_mask': ans_mask,
            'gold_ans_mask': gold_ans_mask
        }

    def forward(self, x):
        data = self.build(x, mode='train')
        return data['loss'], data['closs']

    def predict(self, x):
        data = self.build(x, mode='train')
        de_words_answer = []
        if data['answer_only_state'] is not None:
            de_words_answer = self.decoder_answer.generate(
                data['answer_only_state'],
                enc_embs=data['en_context_output'],
                enc_mask=data['mask'],
                enc_onehot=data['enc_onehot'])
        de_words_context = self.decoder_no_answer.generate(
            data['no_answer_state'],
            cat=data['pure_answer_output'],
            enc_embs=data['en_context_output_plus'],
            enc_mask=data['mask_plus'],
            enc_onehot=data['enc_onehot_plus'])
        # de_words_context = self.decoder_no_answer.generate(no_answer_state, cat = None, enc_embs = en_context_output_plus, enc_mask=mask_plus, enc_onehot = enc_onehot_plus)

        return data['loss'], de_words_answer, de_words_context, data[
            'sampled_words'], data['gold_ans_mask'], data['mask']

    def pre_training_forward(self, x, eps=1e-6):
        context_inputs = torch.LongTensor(x.contextSeqs).to(args['device'])
        context_dec = torch.LongTensor(x.ContextDecoderSeqs).to(args['device'])
        context_tar = torch.LongTensor(x.ContextTargetSeqs).to(args['device'])
        context_inputs_embs = self.embedding(context_inputs)
        en_context_output, en_context_state = self.encoder(
            context_inputs_embs)  # b s e

        mask = torch.sign(context_inputs).float()
        enc_onehot = F.one_hot(context_inputs,
                               num_classes=args['vocabularySize'])
        batch_size = context_inputs.size()[0]

        context_de_output = self.decoder_no_answer(
            en_context_state,
            context_dec,
            context_tar,
            cat=torch.zeros([batch_size, 1,
                             args['embeddingSize']]).to(args['device']),
            enc_embs=en_context_output,
            enc_mask=mask,
            enc_onehot=enc_onehot)
        context_recon_loss = self.NLLloss(
            torch.transpose(context_de_output, 1, 2), context_tar)
        context_mask = torch.sign(context_tar.float())
        context_recon_loss = torch.squeeze(context_recon_loss) * context_mask
        context_recon_loss_mean = torch.mean(context_recon_loss, dim=1)
        return context_recon_loss_mean
示例#9
0
文件: train.py 项目: sztudy/poetry
decoder = Decoder(2 * e_hidden_size, wv_size, d_hidden_size, len(ch_index),
                  d_linear_size)

try:
    encoder.load_state_dict(torch.load('model/encoder.params.pkl'))
    decoder.load_state_dict(torch.load('model/decoder.params.pkl'))
    print('load')
except:
    pass

if use_cuda:
    encoder = encoder.cuda()
    decoder = decoder.cuda()

criterion = nn.NLLLoss()
encoder_optimizer = optim.SGD(encoder.parameters(),
                              lr=encoder_lr,
                              momentum=momentum_rate)
decoder_optimizer = optim.SGD(decoder.parameters(),
                              lr=decoder_lr,
                              momentum=momentum_rate)
#encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
#decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)


def train(keywords, sentences):
    poem_type = len(sentences[0]) - 1

    for index, word in enumerate(keywords):
        if index == 0:
            inputs = torch.from_numpy(glove[word]).view(1, 1, -1)
示例#10
0
class Train():
    def __init__(self, input_size, hidden_size, batch_size, learning_rate,
                 num_epoch, method):
        dataset = Seq2SeqDataset()

        self.vocab = sorted(set(dataset.full_text))
        self.vocab_size = len(self.vocab)
        self.char2ind, self.ind2char = self.get_vocab()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = self.vocab_size
        self.method = method
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

        self.dataloader = DataLoader(dataset=dataset,
                                     batch_size=batch_size,
                                     shuffle=True)

        self.encoder = Encoder(input_size, hidden_size, self.vocab_size)
        self.decoder = Decoder(hidden_size, self.output_size, method)

        self.encoder = self.encoder.to(self.device)
        self.decoder = self.decoder.to(self.device)

        self.loss_function = NLLLoss()

        self.encoder_optim = optim.Adam(self.encoder.parameters(),
                                        lr=self.learning_rate)
        self.decoder_optim = optim.Adam(self.decoder.parameters(),
                                        lr=self.learning_rate)

    def step(self, input, output):
        input_tensor = self.convert2indx(input)
        target_tensor = self.convert2indx(output).squeeze(0)

        encoder_output, (hidden_state, cell_state) = self.encoder(input_tensor)

        target_len = target_tensor.size(0)
        SOS_tensor = self.convert2indx(SOS)

        decoder_input = SOS_tensor
        decoder_output = torch.zeros([target_len,
                                      self.output_size]).to(self.device)
        output_index = torch.zeros(target_len)
        # use teacher forcing
        for i in range(target_len):
            output, (hidden_state,
                     cell_state) = self.decoder(decoder_input,
                                                (hidden_state, cell_state),
                                                encoder_output)
            decoder_output[i] = output
            output_index[i] = output.topk(1)[1]
            decoder_input = target_tensor[i].unsqueeze(0).unsqueeze(0)

        loss = self.loss_function(decoder_output, target_tensor)
        self.encoder_optim.zero_grad()
        self.decoder_optim.zero_grad()

        loss.backward()

        self.encoder_optim.step()
        self.decoder_optim.step()

        return loss.data[0], decoder_output, output_index

    def train_batch(self):
        total_loss = 0
        for i, (x_data, y_data) in enumerate(self.dataloader):
            loss, _, output = self.step(x_data[0], y_data[0])
            total_loss += loss
        return total_loss

    def train(self):
        for i in range(self.num_epoch):
            loss = self.train_batch()
            print('Epoch : ', i, ' -->>>--> loss', loss)
            print('output ', self.step('We lost.', 'Nous fûmes défaites.')[2])
            print('output ', self.convert2indx('Nous fûmes défaites.'))

    def convert2indx(self, input):
        input_tensor = torch.LongTensor(
            [[self.char2ind[w] for w in list(input)]])
        return input_tensor.to(self.device)

    def get_vocab(self):
        char2ind = {'_': 1}
        char2ind.update({w: i + 1 for i, w in enumerate(self.vocab)})
        ind2char = {w[1]: w[0] for w in char2ind.items()}
        return char2ind, ind2char
示例#11
0
def main(args):
    print(device)

    # Path to save the trained models
    saving_model_path = args.saving_model_path

    # If path is not empty, set check_out = True
    check_point = True if args.encoder_model_path and args.decoder_model_path else False

    curr_epoch = int(args.encoder_model_path.split('/')[-1].split('.')[0]
                     [8:]) if check_point else 0

    # Load all vocabulary in the data set
    vocab = vocab_loader("vocab.txt")

    # Build the models
    encoder = Encoder(base_model=args.cnn_model,
                      embed_size=embedding_size,
                      init=not check_point,
                      train_cnn=args.train_cnn).to(device)
    decoder = Decoder(vocab_size=vocal_size,
                      input_size=embedding_size,
                      hidden_size=args.hidden_size).to(device)

    # Transform image size to 224 or 299
    size_of_image = 299 if args.cnn_model == "inception" else 224
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size_of_image),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    val_transform = transforms.Compose([
        transforms.Resize((size_of_image, size_of_image)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load training data
    train_loader = get_train_loader(vocab, train_image_file,
                                    train_captions_json, train_transform,
                                    args.batch_size, True)

    # Load validation data
    val_loader = get_val_loader(val_image_file, val_captions_json,
                                val_transform)

    # load model from a check point
    if check_point:
        encoder.load_state_dict(
            torch.load(args.encoder_model_path, map_location=device))
        decoder.load_state_dict(
            torch.load(args.decoder_model_path, map_location=device))

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(filter(
        lambda p: p.requires_grad,
        list(encoder.parameters()) + list(decoder.parameters())),
                                 lr=args.learning_rate)

    while curr_epoch < args.num_epochs:
        curr_epoch += 1
        train(epoch=curr_epoch,
              num_epochs=args.num_epochs,
              vocab=vocab,
              train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              optimizer=optimizer,
              criterion=criterion,
              saving_model_path=saving_model_path)
        validation(vocab=vocab,
                   val_loader=val_loader,
                   encoder=encoder,
                   decoder=decoder,
                   beam_width=args.beam_width)