Ejemplo n.º 1
0
def da_rnn(train_data: TrainData, n_targs: int, encoder_hidden_size=64, decoder_hidden_size=64,
           T=10, learning_rate=0.01, batch_size=128):

    train_cfg = TrainConfig(T, int(train_data.feats.shape[0] * 0.7), batch_size, nn.MSELoss())
    logger.info(f"Training size: {train_cfg.train_size:d}.")

    enc_kwargs = {"input_size": train_data.feats.shape[1], "hidden_size": encoder_hidden_size, "T": T}
    encoder = Encoder(**enc_kwargs).to(device)
    with open(os.path.join("data", "enc_kwargs.json"), "w") as fi:
        json.dump(enc_kwargs, fi, indent=4)

    dec_kwargs = {"encoder_hidden_size": encoder_hidden_size,
                  "decoder_hidden_size": decoder_hidden_size, "T": T, "out_feats": n_targs}
    decoder = Decoder(**dec_kwargs).to(device)
    with open(os.path.join("data", "dec_kwargs.json"), "w") as fi:
        json.dump(dec_kwargs, fi, indent=4)

    encoder_optimizer = optim.Adam(
        params=[p for p in encoder.parameters() if p.requires_grad],
        lr=learning_rate)
    decoder_optimizer = optim.Adam(
        params=[p for p in decoder.parameters() if p.requires_grad],
        lr=learning_rate)
    da_rnn_net = DaRnnNet(encoder, decoder, encoder_optimizer, decoder_optimizer)

    return train_cfg, da_rnn_net
Ejemplo n.º 2
0
class RNN(object):
    def __init__(self, input_size, output_size):
        super(RNN, self).__init__()

        self.encoder = Encoder(input_size)
        self.decoder = Decoder(output_size)

        self.loss = nn.CrossEntropyLoss()
        self.encoder_optimizer = optim.Adam(self.encoder.parameters())
        self.decoder_optimizer = optim.Adam(self.decoder.parameters())

        sos, eos = torch.LongTensor(1, 1).zero_(), torch.LongTensor(1,
                                                                    1).zero_()
        sos[0, 0], eos[0, 0] = 0, 1

        self.sos, self.eos = sos, eos

    def train(self, input, target):
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        hidden_state = self.encoder.first_hidden()

        # Encoder
        for ivec in input:
            _, hidden_state = self.encoder.forward(Variable(ivec),
                                                   hidden_state)

        # Decoder
        target.insert(0, self.sos)
        target.append(self.eos)
        total_loss = 0
        for i in range(len(target) - 1):
            _, softmax, hidden_state = self.decoder.forward(
                target[i], hidden_state)
            total_loss += self.loss(softmax, Variable(target[i + 1][0]))

        total_loss.backward()

        self.decoder_optimizer.step()
        self.encoder_optimizer.step()

        return total_loss

    def eval(self, input):
        hidden_state = self.encoder.first_hidden()

        # Encoder
        for ivec in input:
            _, hidden_state = self.encoder.forward(ivec, hidden_state)

        outputs = []
        output = self.sos
        # Decoder
        while output is not self.eos:
            output, _, hidden_state = self.decoder.forward(
                output, hidden_state)
            outputs += output

        return outputs
Ejemplo n.º 3
0
def da_rnn(train_data,
           n_targs: int,
           encoder_hidden_size=64,
           decoder_hidden_size=64,
           T=10,
           learning_rate=0.01,
           batch_size=128):

    train_cfg = TrainConfig(T, int(train_data.feats.shape[0] * 0.7),
                            batch_size, nn.MSELoss())
    logging.info(f"Training size: {train_cfg.train_size:d}.")

    enc_params = pd.DataFrame([{
        'input_size': train_data.feats.shape[1],
        'hidden_size': encoder_hidden_size,
        'T': T
    }])
    enc_params.to_csv(os.path.join('results', save_name, 'enc_params.csv'))

    encoder = Encoder(input_size=enc_params['input_size'][0].item(),
                      hidden_size=enc_params['hidden_size'][0].item(),
                      T=enc_params['T'][0].item()).cuda()

    dec_params = pd.DataFrame([{
        'encoder_hidden_size': encoder_hidden_size,
        'decoder_hidden_size': decoder_hidden_size,
        'T': T,
        'out_feats': n_targs
    }])
    dec_params.to_csv(os.path.join('results', save_name, 'dec_params.csv'))

    decoder = Decoder(
        encoder_hidden_size=dec_params['encoder_hidden_size'][0].item(),
        decoder_hidden_size=dec_params['decoder_hidden_size'][0].item(),
        T=dec_params['T'][0].item(),
        out_feats=dec_params['out_feats'][0].item()).cuda()

    encoder_optimizer = optim.Adam(
        params=[p for p in encoder.parameters() if p.requires_grad],
        lr=learning_rate,
        weight_decay=args.wdecay)

    decoder_optimizer = optim.Adam(
        params=[p for p in decoder.parameters() if p.requires_grad],
        lr=learning_rate,
        weight_decay=args.wdecay)

    encoder_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        encoder_optimizer, train_data.feats.shape[0], eta_min=args.min_lr)
    decoder_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        decoder_optimizer, train_data.feats.shape[0], eta_min=args.min_lr)

    model = DaRnnNet(encoder, decoder, encoder_optimizer, decoder_optimizer,
                     encoder_scheduler, decoder_scheduler)

    return train_cfg, model
Ejemplo n.º 4
0
def da_rnn(train_data,
           n_targs: int,
           encoder_hidden_size=64,
           decoder_hidden_size=64,
           T=10,
           learning_rate=0.01,
           batch_size=128):

    train_cfg = TrainConfig(T, int(train_data.feats.shape[0] * 0.7),
                            batch_size, nn.MSELoss())
    logging.info(f"Training size: {train_cfg.train_size:d}.")

    enc_kwargs = {
        "input_size": train_data.feats.shape[1],
        "hidden_size": encoder_hidden_size,
        "T": T
    }
    encoder = Encoder(**enc_kwargs).cuda()
    with open(os.path.join("data", "enc_kwargs.json"), "w") as fi:
        json.dump(enc_kwargs, fi, indent=4)

    dec_kwargs = {
        "encoder_hidden_size": encoder_hidden_size,
        "decoder_hidden_size": decoder_hidden_size,
        "T": T,
        "out_feats": n_targs
    }
    decoder = Decoder(**dec_kwargs).cuda()
    with open(os.path.join("data", "dec_kwargs.json"), "w") as fi:
        json.dump(dec_kwargs, fi, indent=4)

    encoder_optimizer = optim.Adam(
        params=[p for p in encoder.parameters() if p.requires_grad],
        lr=learning_rate,
        weight_decay=args.wdecay)

    decoder_optimizer = optim.Adam(
        params=[p for p in decoder.parameters() if p.requires_grad],
        lr=learning_rate,
        weight_decay=args.wdecay)

    encoder_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        encoder_optimizer, args.epochs, eta_min=args.min_lr)
    decoder_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        decoder_optimizer, args.epochs, eta_min=args.min_lr)

    da_rnn_net = DaRnnNet(encoder, decoder, encoder_optimizer,
                          decoder_optimizer, encoder_scheduler,
                          decoder_scheduler)

    return train_cfg, da_rnn_net
Ejemplo n.º 5
0
def darnn(train_data: TrainingData, 
           n_targets: int, 
           encoder_hidden_size: int, 
           decoder_hidden_size: int,
           T: int, 
           learning_rate=0.002, 
           batch_size=32):
    train_cfg = TrainingConfig(T, int(train_data.features.shape[0] * 0.7), batch_size, nn.MSELoss())
    print(f"Training size: {train_cfg.train_size:d}.")
    enc_kwargs = {"input_size": train_data.features.shape[1], "hidden_size": encoder_hidden_size, "T": T}
    encoder = Encoder(**enc_kwargs).to(device)
    dec_kwargs = {"encoder_hidden_size": encoder_hidden_size,"decoder_hidden_size": decoder_hidden_size, "T": T, "out_features": n_targets}
    decoder = Decoder(**dec_kwargs).to(device)
    encoder_optimizer = optim.Adam(params=[p for p in encoder.parameters() if p.requires_grad],lr=learning_rate)
    decoder_optimizer = optim.Adam(params=[p for p in decoder.parameters() if p.requires_grad],lr=learning_rate)
    da_rnn_net = Darnn_Net(encoder, decoder, encoder_optimizer, decoder_optimizer)
    return train_cfg, da_rnn_net
Ejemplo n.º 6
0
def TCHA(train_data: TrainData, n_targs: int, bidirec=False, num_layer=1, encoder_hidden_size=64, decoder_hidden_size=64,
         T=10, learning_rate=0.01, batch_size=128, interval=1, split=0.7, isMean=False):
    train_cfg = TrainConfig(T, int(train_data.feats.shape[0] * split), batch_size, nn.MSELoss(), interval, T, isMean)
    logger.info(f"Training size: {train_cfg.train_size:d}.")

    enc_args = {"input_size": train_data.feats.shape[1], "hidden_size": encoder_hidden_size, "T": T,
                  "bidirec": bidirec, "num_layer": num_layer}
    encoder = Encoder(**enc_args).to(device)

    dec_args = {"encoder_hidden_size": encoder_hidden_size, "decoder_hidden_size": decoder_hidden_size, "T": T,
                  "out_feats": n_targs, "bidirec": bidirec, "num_layer": num_layer}
    decoder = Decoder(**dec_args).to(device)

    encoder_optimizer = optim.Adam(
        params=[p for p in encoder.parameters() if p.requires_grad],
        lr=learning_rate)
    decoder_optimizer = optim.Adam(
        params=[p for p in decoder.parameters() if p.requires_grad],
        lr=learning_rate)
    tcha = TCHA_Net(encoder, decoder, encoder_optimizer, decoder_optimizer)

    return train_cfg, tcha
Ejemplo n.º 7
0
        def set_params(train_data, device, **da_rnn_kwargs):
            train_configs = TrainConfig(da_rnn_kwargs["time_step"],
                                        int(train_data.shape[0] * 0.95),
                                        da_rnn_kwargs["batch_size"],
                                        nn.MSELoss())

            enc_kwargs = {
                "input_size": train_data.shape[1],
                "hidden_size": da_rnn_kwargs["en_hidden_size"],
                "time_step":
                int(da_rnn_kwargs["time_step"] / self.predict_size)
            }
            dec_kwargs = {
                "encoder_hidden_size": da_rnn_kwargs["en_hidden_size"],
                "decoder_hidden_size": da_rnn_kwargs["de_hidden_size"],
                "time_step":
                int(da_rnn_kwargs["time_step"] / self.predict_size),
                "out_feats": da_rnn_kwargs["target_cols"]
            }
            encoder = Encoder(**enc_kwargs).to(device)
            decoder = Decoder(**dec_kwargs).to(device)

            encoder_optimizer = optim.Adam(
                params=[p for p in encoder.parameters() if p.requires_grad],
                lr=da_rnn_kwargs["learning_rate"],
                betas=(0.9, 0.999),
                eps=1e-08)
            decoder_optimizer = optim.Adam(
                params=[p for p in decoder.parameters() if p.requires_grad],
                lr=da_rnn_kwargs["learning_rate"],
                betas=(0.9, 0.999),
                eps=1e-08)
            da_rnn_net = DaRnnNet(encoder, decoder, encoder_optimizer,
                                  decoder_optimizer)

            return train_configs, da_rnn_net
Ejemplo n.º 8
0
class DeepAPI(nn.Module):
    ''' model. '''
    def __init__(self, config, vocab_size):
        super(DeepAPI, self).__init__()
        self.vocab_size = vocab_size
        self.maxlen = config['maxlen']
        self.clip = config['clip']
        self.temp = config['temp']

        self.desc_embedder = nn.Embedding(vocab_size,
                                          config['emb_size'],
                                          padding_idx=PAD_ID)
        self.api_embedder = nn.Embedding(vocab_size,
                                         config['emb_size'],
                                         padding_idx=PAD_ID)
        # utter encoder: encode response to vector
        self.encoder = Encoder(self.desc_embedder, config['emb_size'],
                               config['n_hidden'], True, config['n_layers'],
                               config['noise_radius'])
        self.decoder = Decoder(self.api_embedder, config['emb_size'],
                               config['n_hidden'] * 2, vocab_size,
                               config['use_attention'], 1,
                               config['dropout'])  # utter decoder: P(x|c,z)
        self.optimizer = optim.Adadelta(list(self.encoder.parameters()) +
                                        list(self.decoder.parameters()),
                                        lr=config['lr_ae'],
                                        rho=0.95)
        self.criterion_ce = nn.CrossEntropyLoss()

    def forward(self, descs, desc_lens, apiseqs, api_lens):
        c, hids = self.encoder(descs, desc_lens)
        output, _ = self.decoder(c, hids, None, apiseqs[:, :-1],
                                 (api_lens - 1))
        # decode from z, c  # output: [batch x seq_len x n_tokens]
        output = output.view(-1, self.vocab_size)  # [batch*seq_len x n_tokens]

        dec_target = apiseqs[:, 1:].contiguous().view(-1)
        mask = dec_target.gt(0)  # [(batch_sz*seq_len)]
        masked_target = dec_target.masked_select(mask)  #
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz*seq_len) x n_tokens]

        masked_output = output.masked_select(output_mask).view(
            -1, self.vocab_size)
        loss = self.criterion_ce(masked_output / self.temp, masked_target)
        return loss

    def train_AE(self, descs, desc_lens, apiseqs, api_lens):
        self.encoder.train()
        self.decoder.train()

        loss = self.forward(descs, desc_lens, apiseqs, api_lens)

        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` to prevent exploding gradient in RNNs / LSTMs
        torch.nn.utils.clip_grad_norm_(
            list(self.encoder.parameters()) + list(self.decoder.parameters()),
            self.clip)
        self.optimizer.step()
        return {'train_loss': loss.item()}

    def valid(self, descs, desc_lens, apiseqs, api_lens):
        self.encoder.eval()
        self.decoder.eval()
        loss = self.forward(descs, desc_lens, apiseqs, api_lens)
        return {'valid_loss': loss.item()}

    def sample(self, descs, desc_lens, n_samples, mode='beamsearch'):
        self.encoder.eval()
        self.decoder.eval()
        c, hids = self.encoder(descs, desc_lens)
        if mode == 'beamsearch':
            sample_words, sample_lens, _ = self.decoder.beam_decode(
                c, hids, None, 12, self.maxlen, n_samples)
            #[batch_size x n_samples x seq_len]
            sample_words, sample_lens = sample_words[0], sample_lens[0]
        else:
            sample_words, sample_lens = self.decoder.sampling(
                c, hids, None, n_samples, self.maxlen, mode)
        return sample_words, sample_lens

    def adjust_lr(self):
        #self.lr_scheduler_AE.step()
        return None
Ejemplo n.º 9
0
class CVAE(nn.Module):
    def __init__(self, config, api, PAD_token=0):
        super(CVAE, self).__init__()
        self.vocab = api.vocab
        self.vocab_size = len(self.vocab)
        self.rev_vocab = api.rev_vocab
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.maxlen = config.maxlen
        self.clip = config.clip
        self.temp = config.temp
        self.full_kl_step = config.full_kl_step
        self.z_size = config.z_size
        self.init_w = config.init_weight
        self.softmax = nn.Softmax(dim=1)

        self.embedder = nn.Embedding(self.vocab_size,
                                     config.emb_size,
                                     padding_idx=PAD_token)
        # 对title, 每一句诗做编码
        self.seq_encoder = Encoder(embedder=self.embedder,
                                   input_size=config.emb_size,
                                   hidden_size=config.n_hidden,
                                   bidirectional=True,
                                   n_layers=config.n_layers,
                                   noise_radius=config.noise_radius)

        # 先验网络的输入是 标题encode结果 + 上一句诗过encoder的结果 + 上一句情感过encoder的结果
        self.prior_net = Variation(config.n_hidden * 4,
                                   config.z_size,
                                   dropout_rate=config.dropout,
                                   init_weight=self.init_w)
        # 后验网络,再加上x的2*hidden
        # self.post_net = Variation(config.n_hidden * 6, config.z_size*2)
        self.post_net = Variation(config.n_hidden * 6,
                                  config.z_size,
                                  dropout_rate=config.dropout,
                                  init_weight=self.init_w)
        # 词包loss的MLP
        self.bow_project = nn.Sequential(
            nn.Linear(config.n_hidden * 4 + config.z_size, 400),
            nn.LeakyReLU(), nn.Dropout(config.dropout),
            nn.Linear(400, self.vocab_size))
        self.init_decoder_hidden = nn.Sequential(
            nn.Linear(config.n_hidden * 4 + config.z_size, config.n_hidden),
            nn.BatchNorm1d(config.n_hidden, eps=1e-05, momentum=0.1),
            nn.LeakyReLU())
        # self.post_generator = nn.Sequential(
        #     nn.Linear(config.z_size, config.z_size),
        #     nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1),
        #     nn.LeakyReLU(),
        #     nn.Linear(config.z_size, config.z_size),
        #     nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1),
        #     nn.LeakyReLU(),
        #     nn.Linear(config.z_size, config.z_size)
        # )
        # self.post_generator.apply(self.init_weights)

        # self.prior_generator = nn.Sequential(
        #     nn.Linear(config.z_size, config.z_size),
        #     nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1),
        #     nn.ReLU(),
        #     nn.Dropout(config.dropout),
        #     nn.Linear(config.z_size, config.z_size),
        #     nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1),
        #     nn.ReLU(),
        #     nn.Dropout(config.dropout),
        #     nn.Linear(config.z_size, config.z_size)
        # )
        # self.prior_generator.apply(self.init_weights)

        self.init_decoder_hidden.apply(self.init_weights)
        self.bow_project.apply(self.init_weights)
        self.post_net.apply(self.init_weights)

        self.decoder = Decoder(embedder=self.embedder,
                               input_size=config.emb_size,
                               hidden_size=config.n_hidden,
                               vocab_size=self.vocab_size,
                               n_layers=1)

        # self.optimizer_lead = optim.Adam(list(self.seq_encoder.parameters())\
        #                                + list(self.prior_net.parameters()), lr=config.lr_lead)
        self.optimizer_AE = optim.Adam(list(self.seq_encoder.parameters())\
                                       + list(self.prior_net.parameters())\
                                       # + list(self.prior_generator.parameters())

                                       + list(self.post_net.parameters())\
                                       # + list(self.post_generator.parameters())

                                       + list(self.bow_project.parameters())\
                                       + list(self.init_decoder_hidden.parameters())\
                                       + list(self.decoder.parameters()), lr=config.lr_ae)

        # self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE, step_size=10, gamma=0.6)

        self.criterion_ce = nn.CrossEntropyLoss()
        self.softmax = nn.Softmax(dim=1)
        self.criterion_sent_lead = nn.CrossEntropyLoss()

    def set_full_kl_step(self, kl_full_step):
        self.full_kl_step = kl_full_step

    def force_change_lr(self, new_init_lr_ae):
        self.optimizer_AE = optim.Adam(list(self.seq_encoder.parameters()) \
                                       + list(self.prior_net.parameters()) \
                                       # + list(self.prior_generator.parameters())

                                       + list(self.post_net.parameters()) \
                                       # + list(self.post_generator.parameters())

                                       + list(self.bow_project.parameters()) \
                                       + list(self.init_decoder_hidden.parameters()) \
                                       + list(self.decoder.parameters()), lr=new_init_lr_ae)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-self.init_w, self.init_w)
            m.bias.data.fill_(0)

    def sample_code_post(self, x, c):
        # import pdb
        # pdb.set_trace()
        # mulogsigma = self.post_net(torch.cat((x, c), dim=1))
        # mu, logsigma = torch.chunk(mulogsigma, chunks=2, dim=1)
        # batch_size = c.size(0)
        # std = torch.exp(0.5 * logsigma)
        # epsilon = to_tensor(torch.randn([batch_size, self.z_size]))
        # z = epsilon * std + mu
        z, mu, logsigma = self.post_net(torch.cat(
            (x, c), 1))  # 输入:(batch, 3*2*n_hidden)
        # z = self.post_generator(z)
        return z, mu, logsigma

    def sample_code_prior(self, c, sentiment_mask=None, mask_type=None):
        return self.prior_net(c,
                              sentiment_mask=sentiment_mask,
                              mask_type=mask_type)

    # # input: (batch, 3)
    # # target: (batch, 3)
    # def criterion_sent_lead(self, input, target):
    #     softmax_res = self.softmax(input)
    #     negative_log_softmax_res = -torch.log(softmax_res + 1e-10)  # (batch, 3)
    #     cross_entropy_loss = torch.sum(negative_log_softmax_res * target, dim=1)
    #     avg_cross_entropy = torch.mean(cross_entropy_loss)
    #     return avg_cross_entropy

    # sentiment_lead: (batch, 3)
    def train_AE(self,
                 global_t,
                 title,
                 context,
                 target,
                 target_lens,
                 sentiment_mask=None,
                 sentiment_lead=None):
        self.seq_encoder.train()
        self.decoder.train()
        # batch_size = title.size(0)
        # 每一句的情感用第二个分类器来预测,输入当前的m_hidden,输出分类结果

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)

        # import pdb
        # pdb.set_trace()
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        condition_prior = torch.cat((title_last_hidden, context_last_hidden),
                                    dim=1)

        z_prior, prior_mu, prior_logvar, pi, pi_final = self.sample_code_prior(
            condition_prior, sentiment_mask=sentiment_mask)
        z_post, post_mu, post_logvar = self.sample_code_post(
            x, condition_prior)
        # import pdb
        # pdb.set_trace()
        if sentiment_lead is not None:
            self.sent_lead_loss = self.criterion_sent_lead(
                input=pi, target=sentiment_lead)
        else:
            self.sent_lead_loss = 0

        # if sentiment_lead is not None:
        #     self.optimizer_lead.zero_grad()
        #     self.sent_lead_loss.backward()
        #     self.optimizer_lead.step()
        #     return [('lead_loss', self.sent_lead_loss.item())], global_t

        final_info = torch.cat((z_post, condition_prior), dim=1)
        # reconstruct_loss
        # import pdb
        # pdb.set_trace()
        output = self.decoder(init_hidden=self.init_decoder_hidden(final_info),
                              context=None,
                              inputs=target[:, :-1])
        flattened_output = output.view(-1, self.vocab_size)
        # flattened_output = self.softmax(flattened_output) + 1e-10
        # flattened_output = torch.log(flattened_output)
        dec_target = target[:, 1:].contiguous().view(-1)

        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        self.rc_loss = self.criterion_ce(masked_output / self.temp,
                                         masked_target)

        # kl散度
        kld = gaussian_kld(post_mu, post_logvar, prior_mu, prior_logvar)
        self.avg_kld = torch.mean(kld)
        self.kl_weights = min(global_t / self.full_kl_step, 1.0)  # 退火
        self.kl_loss = self.kl_weights * self.avg_kld

        # avg_bow_loss
        self.bow_logits = self.bow_project(final_info)
        # 说白了就是把target所有词的预测loss求个和
        labels = target[:, 1:]
        label_mask = torch.sign(labels).detach().float()
        # 取符号变成正数,从而通过最小化来optimize
        # soft_result = self.softmax(self.bow_logits) + 1e-10
        # bow_loss = -torch.log(soft_result).gather(1, labels) * label_mask
        bow_loss = -F.log_softmax(self.bow_logits, dim=1).gather(
            1, labels) * label_mask
        sum_bow_loss = torch.sum(bow_loss, 1)
        self.avg_bow_loss = torch.mean(sum_bow_loss)
        self.aug_elbo_loss = self.avg_bow_loss + self.kl_loss + self.rc_loss
        self.total_loss = self.aug_elbo_loss + self.sent_lead_loss

        # 变相增加标注集的学习率
        if sentiment_mask is not None:
            self.total_loss = self.total_loss * 13.33

        self.optimizer_AE.zero_grad()
        self.total_loss.backward()
        self.optimizer_AE.step()

        avg_total_loss = self.total_loss.item()
        avg_lead_loss = 0 if sentiment_lead is None else self.sent_lead_loss.item(
        )
        avg_aug_elbo_loss = self.aug_elbo_loss.item()
        avg_kl_loss = self.kl_loss.item()
        avg_rc_loss = self.rc_loss.data.item()
        avg_bow_loss = self.avg_bow_loss.item()
        global_t += 1

        return [('avg_total_loss', avg_total_loss),
                ('avg_lead_loss', avg_lead_loss),
                ('avg_aug_elbo_loss', avg_aug_elbo_loss),
                ('avg_kl_loss', avg_kl_loss), ('avg_rc_loss', avg_rc_loss),
                ('avg_bow_loss', avg_bow_loss),
                ('kl_weight', self.kl_weights)], global_t

    def valid_AE(self,
                 global_t,
                 title,
                 context,
                 target,
                 target_lens,
                 sentiment_mask=None,
                 sentiment_lead=None):
        self.seq_encoder.eval()
        self.decoder.eval()

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)

        # import pdb
        # pdb.set_trace()
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        condition_prior = torch.cat((title_last_hidden, context_last_hidden),
                                    dim=1)

        z_prior, prior_mu, prior_logvar, pi, pi_final = self.sample_code_prior(
            condition_prior, sentiment_mask=sentiment_mask)
        z_post, post_mu, post_logvar = self.sample_code_post(
            x, condition_prior)
        if sentiment_lead is not None:
            self.sent_lead_loss = self.criterion_sent_lead(
                input=pi, target=sentiment_lead)
        else:
            self.sent_lead_loss = 0

        # if sentiment_lead is not None:
        #     return [('valid_lead_loss', self.sent_lead_loss.item())], global_t

        final_info = torch.cat((z_post, condition_prior), dim=1)

        output = self.decoder(init_hidden=self.init_decoder_hidden(final_info),
                              context=None,
                              inputs=target[:, :-1])
        flattened_output = output.view(-1, self.vocab_size)
        # flattened_output = self.softmax(flattened_output) + 1e-10
        # flattened_output = torch.log(flattened_output)
        dec_target = target[:, 1:].contiguous().view(-1)

        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        self.rc_loss = self.criterion_ce(masked_output / self.temp,
                                         masked_target)

        # kl散度
        kld = gaussian_kld(post_mu, post_logvar, prior_mu, prior_logvar)
        self.avg_kld = torch.mean(kld)
        self.kl_weights = min(global_t / self.full_kl_step, 1.0)  # 退火
        self.kl_loss = self.kl_weights * self.avg_kld
        # avg_bow_loss
        self.bow_logits = self.bow_project(final_info)
        # 说白了就是把target所有词的预测loss求个和
        labels = target[:, 1:]
        label_mask = torch.sign(labels).detach().float()

        bow_loss = -F.log_softmax(self.bow_logits, dim=1).gather(
            1, labels) * label_mask
        sum_bow_loss = torch.sum(bow_loss, 1)
        self.avg_bow_loss = torch.mean(sum_bow_loss)
        self.aug_elbo_loss = self.avg_bow_loss + self.kl_loss + self.rc_loss

        avg_aug_elbo_loss = self.aug_elbo_loss.item()
        avg_kl_loss = self.kl_loss.item()
        avg_rc_loss = self.rc_loss.data.item()
        avg_bow_loss = self.avg_bow_loss.item()
        avg_lead_loss = 0 if sentiment_lead is None else self.sent_lead_loss.item(
        )

        return [('valid_lead_loss', avg_lead_loss),
                ('valid_aug_elbo_loss', avg_aug_elbo_loss),
                ('valid_kl_loss', avg_kl_loss), ('valid_rc_loss', avg_rc_loss),
                ('valid_bow_loss', avg_bow_loss)], global_t

    # batch_size = 1 只输入了一个标题
    # test的时候,只有先验,没有后验,更没有所谓的kl散度
    def test(self, title_tensor, title_words, mask_type=None):
        self.seq_encoder.eval()
        self.decoder.eval()
        assert title_tensor.size(0) == 1
        tem = [[2, 3] + [0] * (self.maxlen - 2)]
        pred_poems = []
        # 过滤掉标题中的<s> </s> 0,只为了打印
        title_tokens = [
            self.vocab[e] for e in title_words[0].tolist()
            if e not in [0, self.eos_id, self.go_id]
        ]
        pred_poems.append(title_tokens)

        for i in range(4):
            tem = to_tensor(np.array(tem))
            context = tem
            if i == 0:
                context_last_hidden, _ = self.seq_encoder(title_tensor)
            else:
                context_last_hidden, _ = self.seq_encoder(context)
            title_last_hidden, _ = self.seq_encoder(title_tensor)

            condition_prior = torch.cat(
                (title_last_hidden, context_last_hidden), dim=1)
            z_prior, prior_mu, prior_logvar, _, _ = self.sample_code_prior(
                condition_prior, mask_type=mask_type)
            final_info = torch.cat((z_prior, condition_prior), 1)

            decode_words = self.decoder.testing(
                init_hidden=self.init_decoder_hidden(final_info),
                maxlen=self.maxlen,
                go_id=self.go_id,
                mode="greedy")
            decode_words = decode_words[0].tolist()

            if len(decode_words) >= self.maxlen:
                tem = [decode_words[0:self.maxlen]]
            else:
                tem = [[0] * (self.maxlen - len(decode_words)) + decode_words]
            pred_tokens = [
                self.vocab[e] for e in decode_words[:-1]
                if e != self.eos_id and e != 0 and e != self.go_id
            ]
            pred_poems.append(pred_tokens)

        gen = ""
        for line in pred_poems:
            cur_line = " ".join(line)
            gen = gen + cur_line + '\n'

        return gen

    def sample(self, title, context, repeat, go_id, end_id):
        self.seq_encoder.eval()
        self.decoder.eval()

        assert title.size(0) == 1
        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        condition_prior = torch.cat((title_last_hidden, context_last_hidden),
                                    1)
        condition_prior_repeat = condition_prior.expand(repeat, -1)

        z_prior_repeat, _, _, _, _ = self.sample_code_prior(
            condition_prior_repeat)

        final_info = torch.cat((z_prior_repeat, condition_prior_repeat), dim=1)
        sample_words, sample_lens = self.decoder.sampling(
            init_hidden=self.init_decoder_hidden(final_info),
            maxlen=self.maxlen,
            go_id=self.go_id,
            eos_id=self.eos_id,
            mode="greedy")

        return sample_words, sample_lens
Ejemplo n.º 10
0
class RNN(object):
    def __init__(self, input_size, output_size, resume=False):
        super(RNN, self).__init__()

        self.encoder = Encoder(input_size)
        self.decoder = Decoder(output_size)

        self.loss = nn.CrossEntropyLoss()
        self.encoder_optimizer = optim.Adam(self.encoder.parameters())
        self.decoder_optimizer = optim.Adam(self.decoder.parameters())

        if resume:
            self.encoder.load_state_dict(torch.load("models/encoder.ckpt"))
            self.decoder.load_state_dict(torch.load("models/decoder.ckpt"))

    def train(self, input, target):
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()

        # Encoder
        hidden_state = self.encoder.first_hidden()
        for ivec in input:
            _, hidden_state = self.encoder.forward(ivec, hidden_state)

        # Decoder
        total_loss, outputs = 0, []
        for i in range(len(target) - 1):
            _, softmax, hidden_state = self.decoder.forward(
                target[i], hidden_state)

            outputs.append(np.argmax(softmax.data.numpy(), 1)[:, np.newaxis])
            total_loss += self.loss(softmax, target[i + 1].squeeze(1))

        total_loss /= len(outputs)
        total_loss.backward()

        self.decoder_optimizer.step()
        self.encoder_optimizer.step()

        return total_loss.data[0], outputs

    def eval(self, input):
        hidden_state = self.encoder.first_hidden()

        # Encoder
        for ivec in input:
            _, hidden_state = self.encoder.forward(Variable(ivec),
                                                   hidden_state)

        sentence = []
        input = self.sos
        # Decoder
        while input.data[0, 0] != 1:
            output, _, hidden_state = self.decoder.forward(input, hidden_state)
            word = np.argmax(output.data.numpy()).reshape((1, 1))
            input = Variable(torch.LongTensor(word))
            sentence.append(word)

        return sentence

    def save(self):
        torch.save(self.encoder.state_dict(), "models/encoder.ckpt")
        torch.save(self.decoder.state_dict(), "models/decoder.ckpt")
Ejemplo n.º 11
0
class Seq2Seq(nn.Module):
    def __init__(self, config, api, pad_token=0):
        super(Seq2Seq, self).__init__()
        self.vocab = api.vocab
        self.vocab_size = len(self.vocab)
        self.rev_vocab = api.rev_vocab
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.maxlen = config.maxlen

        self.embedder = nn.Embedding(self.vocab_size, config.emb_size, padding_idx=pad_token)
        self.encoder = Encoder(self.embedder, config.emb_size, config.n_hidden,
                               True, config.n_layers, config.noise_radius)
        self.decoder = AttnDecoder(config=config, embedder=self.embedder, vocab_size=self.vocab_size)

        self.criterion = nn.NLLLoss(reduction='none')
        self.optimizer = optim.Adam(list(self.encoder.parameters())
                                   + list(self.decoder.parameters()),
                                   lr=config.lr_s2s)
        self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.6)

    # 将title和诗句都限制成10个字,不够了pad,超了截取
    # 每次输入进来一个batch的context和target
    def train_model(self, context, target, target_lens):
        self.encoder.train()
        self.decoder.train()
        self.optimizer.zero_grad()
        # (batch, 2 * n_hidden), (batch, len, 2*n_hidden)
        encoder_last_hidden, encoder_output = self.encoder(context)
        batch_size = encoder_last_hidden.size(0)
        hidden_size = encoder_last_hidden.size(1) // 2
        # (1, batch, n_hidden)
        last_hidden = encoder_last_hidden.view(batch_size, 2, -1)[:, -1, :].squeeze().unsqueeze(0)
        # (batch, len, n_hidden)
        encoder_output = encoder_output.view(batch_size, -1, 2, hidden_size)[:, :, -1]

        decoder_input = target[:, :-1]  # (batch, 9)
        decoder_target = target[:, 1:]  # (batch, 9
        step_losses = []
        for i in range(self.maxlen - 1):
            decoded_result, last_hidden = \
                self.decoder(decoder_input=decoder_input[:, i], init_hidden=last_hidden, encoder_output=encoder_output)
            step_loss = self.criterion(decoded_result, decoder_target[:, i])
            step_losses.append(step_loss)

        stack_loss = torch.stack(step_losses, 1)  # (batch, maxlen)
        sum_loss = torch.sum(stack_loss, 1)  # 对每一行求和

        avg_loss_batch = sum_loss / (target_lens.float() - 1)  # 对每一行先求平均, decode时候每一行是9个字符
        loss = torch.mean(avg_loss_batch)  # 再对所有行一起求和
        loss.backward()
        self.optimizer.step()
        return [('train_loss', loss.item())]

    def valid(self, context, target, target_lens):
        self.encoder.eval()
        self.decoder.eval()

        encoder_last_hidden, encoder_output = self.encoder(context)
        batch_size = encoder_last_hidden.size(0)
        hidden_size = encoder_last_hidden.size(1) // 2
        # (1, batch, n_hidden)
        last_hidden = encoder_last_hidden.view(batch_size, 2, -1)[:, -1, :].squeeze().unsqueeze(0)
        # (batch, len, n_hidden)
        encoder_output = encoder_output.view(batch_size, -1, 2, hidden_size)[:, :, -1]

        decoder_input = target[:, :-1]  # (batch, 9)
        decoder_target = target[:, 1:]  # (batch, 9
        step_losses = []
        for i in range(self.maxlen - 1):
            decoded_result, last_hidden = \
                self.decoder(decoder_input=decoder_input[:, i], init_hidden=last_hidden, encoder_output=encoder_output)
            step_loss = self.criterion(decoded_result, decoder_target[:, i])
            step_losses.append(step_loss)

        stack_loss = torch.stack(step_losses, 1)  # (batch, maxlen)
        sum_loss = torch.sum(stack_loss, 1)  # 对每一行求和

        avg_loss_batch = sum_loss / (target_lens.float() - 1)  # 对每一行先求平均, decode时候每一行是9个字符
        loss = torch.mean(avg_loss_batch)  # 再对所有行一起求和

        return [('valid_loss', loss.item())]

    def test(self, title, title_list, batch_size):
        self.encoder.eval()
        self.decoder.eval()

        assert title.size(0) == 1
        tem = title[0][0: self.maxlen].unsqueeze(0)

        pred_poems = []
        title_tokens = [self.vocab[e] for e in title_list[0].tolist() if e not in [0, self.eos_id, self.go_id]]
        pred_poems.append(title_tokens)

        for sent_id in range(4):
            context = tem
            if type(context) is list:
                vec_context = np.zeros((batch_size, self.maxlen), dtype=np.int64)
                for b_id in range(batch_size):
                    vec_context[b_id, :] = np.array(context[b_id])
                context = to_tensor(vec_context)

            encoder_last_hidden, encoder_output = self.encoder(context)
            batch_size = encoder_last_hidden.size(0)
            hidden_size = encoder_last_hidden.size(1) // 2
            # (1, 1, n_hidden)
            last_hidden = encoder_last_hidden.view(batch_size, 2, -1)[:, -1, :].unsqueeze(0)
            # (batch, len, n_hidden)
            encoder_output = encoder_output.view(batch_size, -1, 2, hidden_size)[:, :, -1]

            # decode_words 是完整的一句诗
            decode_words = self.decoder.testing(init_hidden=last_hidden, encoder_output=encoder_output,
                                             maxlen=self.maxlen, go_id=self.go_id, mode="greedy")

            decode_words = decode_words[0].tolist()
            # import pdb
            # pdb.set_trace()
            if len(decode_words) > self.maxlen:
                tem = [decode_words[0: self.maxlen]]
            else:
                tem = [[0] * (self.maxlen - len(decode_words)) + decode_words]

            pred_tokens = [self.vocab[e] for e in decode_words[:-1] if e != self.eos_id and e != 0]
            pred_poems.append(pred_tokens)

        gen = ''
        for line in pred_poems:
            true_str = " ".join(line)
            gen = gen + true_str + '\n'

        return gen

    def sample(self, title, context, repeat, go_id, end_id):
        self.encoder.eval()
        self.decoder.eval()
        encoder_last_hidden, encoder_output = self.encoder(context)
        batch_size = encoder_last_hidden.size(0)
        hidden_size = encoder_last_hidden.size(1) // 2
        # (1, batch, n_hidden)
        last_hidden = encoder_last_hidden.view(batch_size, 2, -1)[:, -1].unsqueeze(0)
        # (batch, len, n_hidden)
        encoder_output = encoder_output.view(batch_size, -1, 2, hidden_size)[:, :, -1]

        last_hidden = last_hidden.expand(1, repeat, hidden_size)
        encoder_output = encoder_output.expand(repeat, -1, hidden_size)

        sample_words, sample_lens = self.decoder.sampling(last_hidden, encoder_output, self.maxlen,
                                                          go_id, end_id, "greedy")
        return sample_words, sample_lens

    def adjust_lr(self):
        self.lr_scheduler_AE.step()
def da_rnn(train_data: TrainData,
           n_targs: int,
           learning_rate=0.01,
           encoder_hidden_size=64,
           decoder_hidden_size=64,
           T=10,
           batch_size=128):

    # passed arguments are data, n_targs=len(targ_cols), learning_rate=.001, **da_rnn_kwargs

    #here n_args : int means that this argument takes only an integer as its value
    #train_data = TrainData means that this train_data argument takes only the datatype TrainData that we have defined as its value

    training_data_size_out_of_total = train_data.feats.shape[0] * 0.7

    training_configuration = TrainConfig(T,
                                         int(training_data_size_out_of_total),
                                         batch_size, nn.MSELoss())
    '''
            class TrainConfig(typing.NamedTuple):
                T: int
                train_size: int
                batch_size: int
                loss_func: typing.Callable


            '''

    logger.info(f"Training size: {training_configuration.train_size:d}.")

    encoder_kwargs = {
        "input_size": train_data.feats.shape[1],
        "hidden_size": encoder_hidden_size,
        "T": T
    }

    encoder = Encoder(**encoder_kwargs).to(device)

    with open(os.path.join("data", "enc_kwargs.json"), "w") as fi:
        json.dump(encoder_kwargs, fi, indent=4)

    decoder_kwargs = {
        "encoder_hidden_size": encoder_hidden_size,
        "decoder_hidden_size": decoder_hidden_size,
        "T": T,
        "out_feats": n_targs
    }

    decoder = Decoder(**decoder_kwargs).to(device)

    with open(os.path.join("data", "dec_kwargs.json"), "w") as fi:
        json.dump(decoder_kwargs, fi, indent=4)

    encoder_optimizer = optim.Adam(
        params=[p for p in encoder.parameters() if p.requires_grad],
        lr=learning_rate)

    decoder_optimizer = optim.Adam(
        params=[p for p in decoder.parameters() if p.requires_grad],
        lr=learning_rate)

    da_rnn_net = DaRnnNet(
        encoder, decoder, encoder_optimizer, decoder_optimizer
    )  #-------------------------------return the DA-RNN network

    return training_configuration, da_rnn_net
Ejemplo n.º 13
0
class DA_RNN:
    def __init__(self, X_dim, Y_dim, encoder_hidden_size=64, decoder_hidden_size=64,
                 linear_dropout=0, T=10, learning_rate=1e-5, batch_size=128, decay_rate=0.95):
        self.T = T
        self.decay_rate = decay_rate
        self.batch_size = batch_size
        self.X_dim = X_dim
        self.Y_dim = Y_dim

        self.encoder = Encoder(X_dim, encoder_hidden_size, T, linear_dropout).to(device)
        self.decoder = Decoder(encoder_hidden_size, decoder_hidden_size, T, linear_dropout, Y_dim).to(device)

        self.encoder_optim = torch.optim.Adam(params=self.encoder.parameters(), lr=learning_rate)
        self.decoder_optim = torch.optim.Adam(params=self.decoder.parameters(), lr=learning_rate)
        self.loss_func = torch.nn.MSELoss()

    def adjust_learning_rate(self):
        for enc_params, dec_params in zip(self.encoder_optim.param_groups, self.decoder_optim.param_groups):
            enc_params['lr'] = enc_params['lr'] * self.decay_rate
            dec_params['lr'] = dec_params['lr'] * self.decay_rate

    def ToTrainingBatches(self, X, Y, shuffle_slice=True):
        X_batches = []
        Y_batches = []

        N = X.shape[0]
        batch_num = math.ceil((N-self.T)/self.batch_size)
        i = self.T-1

        for b in range(batch_num):
            # number of output = N - T + 1
            # N is length, i is an index
            _batch_size = self.batch_size if N-i >= self.batch_size else N-i
            X_batch = np.empty((_batch_size, self.T, self.X_dim))
            Y_batch = np.empty((_batch_size, self.Y_dim))

            for b_idx in range(_batch_size):
                # print(N, i, i-self.T+1, i+1)
                # print(X[i-self.T+1:i+1].shape)
                X_batch[b_idx, :, :] = X[i-self.T+1:i+1]
                Y_batch[b_idx, :] = Y[i]
                i += 1

            X_batches.append(X_batch)
            Y_batches.append(Y_batch)

        # TODO: zero padding
        # print(X.shape[0], np.sum([_.shape[0] for _ in X_batches]))
        if shuffle_slice:
            return shuffle(X_batches, Y_batches)
        else:
            return X_batches, Y_batches

    def ToTestingBatch(self, X):
        N = X.shape[0]
        X_batch = np.empty((N-self.T+1, self.T, self.X_dim))
        i = self.T-1
        b_idx = 0

        while i < N:
            X_batch[b_idx, :, :] = X[i-self.T+1:i+1]
            i += 1
            b_idx += 1

        # TODO: zero padding
        return X_batch

    def train(self, X_train, Y_train, X_val, Y_val, epochs):
        if len(Y_train.shape) == 1:
            Y_train = Y_train[:, np.newaxis]
        if len(Y_val.shape) == 1:
            Y_val = Y_val[:, np.newaxis]

        assert len(X_train) == len(Y_train)
        assert len(X_val) == len(Y_val)

        epoch_loss_hist = []
        iter_loss_hist = []

        N = X_train.shape[0]

        for _e in range(epochs):
            X_train_batches, Y_train_batches = self.ToTrainingBatches(X_train, Y_train)
            for X_train_batch, Y_train_batch in zip(X_train_batches, Y_train_batches):
                X_train_loss = self.train_iter(X_train_batch, Y_train_batch)
                iter_loss_hist.append(np.mean(X_train_loss))

            # decay learning rate
            # if _e % 20 == 0:
            #     self.adjust_learning_rate()

            epoch_loss_hist.append(iter_loss_hist[-len(X_train_batches):])

            if _e % 2 == 0:
                print("Epoch: {}\t".format(_e), end="")
                Y_val_pred = self.predict(X_val, on_train=True)
                Y_val_loss = self.loss_func(Y_val_pred, toTorch(Y_val[-(N-self.T+1):]))
                print("train_loss: {:.4f} val_loss: {:.4f}".format(X_train_loss, Y_val_loss))

        return epoch_loss_hist, iter_loss_hist

    def train_iter(self, X, Y):
        self.encoder.train(), self.decoder.train()
        self.encoder_optim.zero_grad(), self.decoder_optim.zero_grad()

        _, X_encoded = self.encoder(toTorch(X))
        Y_pred = self.decoder(X_encoded)

        loss = self.loss_func(Y_pred, toTorch(Y))
        loss.backward()

        self.encoder_optim.step()
        self.decoder_optim.step()

        return loss.item()

    def predict(self, X, on_train=False):
        self.encoder.eval(), self.decoder.eval()

        X_batch = self.ToTestingBatch(X)

        _, X_encoded = self.encoder(toTorch(X_batch))
        Y_pred = self.decoder(X_encoded)

        if on_train == False:
            Y_pred = Y_pred.cpu().detach().numpy()

        return Y_pred
Ejemplo n.º 14
0
class Trainer(object):
    def __init__(self, celeba_loader, config):
        # miscellaneous
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # data loader
        self.dataload = celeba_loader

        # model configurations
        self.c64 = config.c64
        self.c256 = config.c256
        self.c2048 = config.c2048
        self.rb6 = config.rb6
        self.attr_dim = config.attr_dim
        self.hair_dim = config.hair_dim

        # training configurations
        self.selected_attrs = config.selected_attrs
        self.train_iters = config.train_iters
        self.num_iters_decay = config.num_iters_decay
        self.n_critic = config.n_critic
        self.d_lr = config.d_lr
        self.r_lr = config.r_lr
        self.t_lr = config.t_lr
        self.e_lr = config.e_lr
        self.decay_rate = config.decay_rate
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.lambda_cls = config.lambda_cls
        self.lambda_cyc = config.lambda_cyc
        self.lambda_gp = config.lambda_gp

        # test configurations
        self.test_iters = config.test_iters

        # directories
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir
        self.log_dir = config.log_dir

        # step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # initial models
        self.build_models()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_models(self):
        self.E = Encoder(self.c64, self.rb6)
        self.T_Hair = Transformer(self.hair_dim, self.c256, self.rb6)
        self.T_Gender = Transformer(self.attr_dim, self.c256, self.rb6)
        self.T_Smailing = Transformer(self.attr_dim, self.c256, self.rb6)
        self.R = Reconstructor(self.c256)
        self.D_Hair = Discriminator(self.hair_dim, self.c64)
        self.D_Gender = Discriminator(self.attr_dim, self.c64)
        self.D_Smailing = Discriminator(self.attr_dim, self.c64)

        self.e_optim = torch.optim.Adam(self.E.parameters(), self.e_lr, [self.beta1, self.beta2])
        self.th_optim = torch.optim.Adam(self.T_Hair.parameters(), self.t_lr, [self.beta1, self.beta2])
        self.tg_optim = torch.optim.Adam(self.T_Gender.parameters(), self.t_lr, [self.beta1, self.beta2])
        self.ts_optim = torch.optim.Adam(self.T_Smailing.parameters(), self.t_lr, [self.beta1, self.beta2])
        self.r_optim = torch.optim.Adam(self.R.parameters(), self.r_lr, [self.beta1, self.beta2])
        self.dh_optim = torch.optim.Adam(self.D_Hair.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.dg_optim = torch.optim.Adam(self.D_Gender.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.ds_optim = torch.optim.Adam(self.D_Smailing.parameters(), self.d_lr, [self.beta1, self.beta2])

        self.print_network(self.E, 'Encoder')
        self.print_network(self.T_Hair, 'Transformer for Hair Color')
        self.print_network(self.T_Gender, 'Transformer for Gender')
        self.print_network(self.T_Smailing, 'Transformer for Smailing')
        self.print_network(self.R, 'Reconstructor')
        self.print_network(self.D_Hair, 'D for Hair Color')
        self.print_network(self.D_Gender, 'D for Gender')
        self.print_network(self.D_Smailing, 'D for Smailing')

        self.E.to(self.device)
        self.T_Hair.to(self.device)
        self.T_Gender.to(self.device)
        self.T_Smailing.to(self.device)
        self.R.to(self.device)
        self.D_Gender.to(self.device)
        self.D_Smailing.to(self.device)
        self.D_Hair.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()

        print(name)
        print("The number of parameters: {}".format(num_params))
        print(model)
        
    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

    def reset_grad(self):
        self.e_optim.zero_grad()
        self.th_optim.zero_grad()
        self.tg_optim.zero_grad()
        self.ts_optim.zero_grad()
        self.r_optim.zero_grad()
        self.dh_optim.zero_grad()
        self.dg_optim.zero_grad()
        self.ds_optim.zero_grad()

    def update_lr(self, e_lr, d_lr, r_lr, t_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.e_optim.param_groups:
            param_group['lr'] = e_lr
        for param_group in self.dh_optim.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.dg_optim.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.ds_optim.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.r_optim.param_groups:
            param_group['lr'] = r_lr
        for param_group in self.th_optim.param_groups:
            param_group['lr'] = t_lr
        for param_group in self.tg_optim.param_groups:
            param_group['lr'] = t_lr
        for param_group in self.ts_optim.param_groups:
            param_group['lr'] = t_lr

    def create_labels(self, c_org, c_dim=5, selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # Get hair color indices.
        hair_color_indices = []
        for i, attr_name in enumerate(selected_attrs):
            if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair']:
                hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            c_trg = c_org.clone()
            if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                c_trg[:, i] = 1
                for j in hair_color_indices:
                    if j != i:
                        c_trg[:, j] = 0
            else:
                c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.

            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def train(self):
        data_loader = self.dataload

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)
        x_fixed, c_org = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        c_fixed_list = self.create_labels(c_org, 5, self.selected_attrs)

        d_lr = self.d_lr
        r_lr = self.r_lr
        t_lr = self.t_lr
        e_lr = self.e_lr

        # Start training
        print('Starting point==============================')
        start_time = time.time()

        for i in range(0, self.train_iters):
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels
            try:
                x_real, label_real = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, label_real = next(data_iter)

            rand_idx = torch.randperm(label_real.size(0))
            label_feak = label_real[rand_idx]

            x_real = x_real.to(self.device)
            # labels for hair color
            label_h_real = label_real[:, 0:3]
            label_h_feak = label_feak[:, 0:3]
            # labels for gender
            label_g_real = label_real[:, 3:4]
            label_g_feak = label_feak[:, 3:4]
            # labels for smailing
            label_s_real = label_real[:, 4:]
            label_s_feak = label_feak[:, 4:]

            label_h_real = label_h_real.to(self.device)
            label_h_feak = label_h_feak.to(self.device)
            label_g_real = label_g_real.to(self.device)
            label_g_feak = label_g_feak.to(self.device)
            label_s_real = label_s_real.to(self.device)
            label_s_feak = label_s_feak.to(self.device)

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Computer loss with real images
            h_src, h_cls = self.D_Hair(x_real)
            d_h_loss_real = -torch.mean(h_src)
            d_h_loss_cls = F.binary_cross_entropy_with_logits(h_cls, label_h_real, reduction='sum') / h_cls.size(0)

            g_src, g_cls = self.D_Gender(x_real)
            d_g_loss_real = -torch.mean(g_src)
            d_g_loss_cls = F.binary_cross_entropy_with_logits(g_cls, label_g_real, reduction='sum') / g_cls.size(0)

            s_src, s_cls = self.D_Smailing(x_real)
            d_s_loss_real = -torch.mean(s_src)
            d_s_loss_cls = F.binary_cross_entropy_with_logits(s_cls, label_s_real, reduction='sum') / s_cls.size(0)

            # Generate fake images and computer loss
            # Retrieve features of real image
            features = self.E(x_real)
            # Transform attributes from one value to an other
            t_h_features = self.T_Hair(features.detach(), label_h_feak)
            t_g_features = self.T_Gender(features.detach(), label_g_feak)
            t_s_features = self.T_Smailing(features.detach(), label_s_feak)
            # Reconstruct images from transformed attributes
            x_h_feak = self.R(t_h_features.detach())
            x_g_feak = self.R(t_g_features.detach())
            x_s_feak = self.R(t_s_features.detach())

            # Computer loss with fake images
            h_src, h_cls = self.D_Hair(x_h_feak.detach())
            d_h_loss_fake = torch.mean(h_src)

            g_src, g_cls = self.D_Gender(x_g_feak.detach())
            d_g_loss_fake = torch.mean(g_src)

            s_src, s_cls = self.D_Smailing(x_s_feak.detach())
            d_s_loss_fake = torch.mean(s_src)

            # Compute loss for gradient penalty
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_h_hat = (alpha * x_real.data + (1 - alpha) * x_h_feak.data).requires_grad_(True)
            #x_h_hat = (alpha * x_real.data + (1-alpha) * x_h_feak.data).requires_grad_(True).to(torch.float16)
            x_g_hat = (alpha * x_real.data + (1 - alpha) * x_g_feak.data).requires_grad_(True)
            #x_g_hat = (alpha * x_real.data + (1-alpha) * x_g_feak.data).requires_grad_(True).to(torch.float16)
            x_s_hat = (alpha * x_real.data + (1 - alpha) * x_s_feak.data).requires_grad_(True)
            #x_s_hat = (alpha * x_real.data + (1-alpha) * x_s_feak.data).requires_grad_(True).to(torch.float16)

            out_src, _ = self.D_Hair(x_h_hat)
            d_h_loss_gp = self.gradient_penalty(out_src, x_h_hat)
            out_src, _ = self.D_Gender(x_g_hat)
            d_g_loss_gp = self.gradient_penalty(out_src, x_g_hat)
            out_src, _ = self.D_Smailing(x_s_hat)
            d_s_loss_gp = self.gradient_penalty(out_src, x_s_hat)

            # Backward and optimize
            d_loss = d_h_loss_real + d_g_loss_real + d_s_loss_real + \
                     d_h_loss_fake + d_g_loss_fake + d_s_loss_fake + \
                     self.lambda_gp * (d_h_loss_gp + d_g_loss_gp + d_s_loss_gp) + \
                     self.lambda_cls * (d_h_loss_cls + d_g_loss_cls + d_s_loss_cls)
            #d_loss = d_h_loss_real + d_h_loss_fake + self.lambda_gp * d_h_loss_gp + self.lambda_cls * d_h_loss_cls


            self.reset_grad()
            d_loss.backward()
            self.dh_optim.step()
            self.dg_optim.step()
            self.ds_optim.step()

            # Logging
            loss = {}
            loss['D/h_loss_real'] = d_h_loss_real.item()
            loss['D/g_loss_real'] = d_g_loss_real.item()
            loss['D/s_loss_real'] = d_s_loss_real.item()
            loss['D/h_loss_fake'] = d_h_loss_fake.item()
            loss['D/g_loss_fake'] = d_g_loss_fake.item()
            loss['D/s_loss_fake'] = d_s_loss_fake.item()
            loss['D/h_loss_cls'] = d_h_loss_cls.item()
            loss['D/g_loss_cls'] = d_g_loss_cls.item()
            loss['D/s_loss_cls'] = d_s_loss_cls.item()
            loss['D/h_loss_gp'] = d_h_loss_gp.item()
            loss['D/g_loss_gp'] = d_g_loss_gp.item()
            loss['D/s_loss_gp'] = d_s_loss_gp.item()

            # =================================================================================== #
            #                  3. Train the encoder, transformer and reconstructor                #
            # =================================================================================== #

            if(i+1) % self.n_critic == 0:
                # Generate fake images and compute loss
                # Retrieve features of real image
                features = self.E(x_real)
                # Transform attributes from one value to an other
                t_h_features = self.T_Hair(features, label_h_feak)
                t_g_features = self.T_Gender(features, label_g_feak)
                t_s_features = self.T_Smailing(features, label_s_feak)
                # Reconstruct images from transformed attributes
                x_h_feak = self.R(t_h_features)
                x_g_feak = self.R(t_g_features)
                x_s_feak = self.R(t_s_features)

                # Computer loss with fake images
                h_src, h_cls = self.D_Hair(x_h_feak)
                etr_h_loss_fake = -torch.mean(h_src)
                etr_h_loss_cls = F.binary_cross_entropy_with_logits(h_cls, label_h_feak, reduction='sum') / h_cls.size(0)

                g_src, g_cls = self.D_Gender(x_g_feak)
                etr_g_loss_fake = -torch.mean(g_src)
                etr_g_loss_cls = F.binary_cross_entropy_with_logits(g_cls, label_g_feak, reduction='sum') / g_cls.size(0)

                s_src, s_cls = self.D_Smailing(x_s_feak)
                etr_s_loss_fake = -torch.mean(s_src)
                etr_s_loss_cls = F.binary_cross_entropy_with_logits(s_cls, label_s_feak, reduction='sum') / s_cls.size(0)

                # Real - Encoder - Reconstructor - Real loss
                x_re = self.R(features)
                er_loss_cyc = torch.mean(torch.abs(x_re - x_real))

                # Real - Encoder - Transform, Real - Encoder - Transform - Reconstructor - Encoder loss
                h_fake_features = self.E(x_h_feak)
                g_fake_features = self.E(x_g_feak)
                s_fake_features = self.E(x_s_feak)

                etr_h_loss_cyc = torch.mean(torch.abs(t_h_features - h_fake_features))
                etr_g_loss_cyc = torch.mean(torch.abs(t_g_features - g_fake_features))
                etr_s_loss_cyc = torch.mean(torch.abs(t_s_features - s_fake_features))

                # Backward and optimize
                etr_loss = etr_h_loss_fake + etr_g_loss_fake + etr_s_loss_fake + \
                           self.lambda_cls * (etr_h_loss_cls + etr_g_loss_cls + etr_s_loss_cls) + \
                           self.lambda_cyc * (er_loss_cyc + etr_h_loss_cyc + etr_g_loss_cyc + etr_s_loss_cyc)
                #etr_loss = etr_h_loss_fake + self.lambda_cls * etr_h_loss_cls + self.lambda_cyc * (er_loss_cyc + etr_h_loss_cyc)



                self.reset_grad()
                etr_loss.backward()
                self.e_optim.step()
                self.th_optim.step()
                self.tg_optim.step()
                self.ts_optim.step()
                self.r_optim.step()

                # Logging.
                loss['ETR/h_loss_fake'] = etr_h_loss_fake.item()
                loss['ETR/g_loss_fake'] = etr_g_loss_fake.item()
                loss['ETR/s_loss_fake'] = etr_s_loss_fake.item()
                loss['ETR/h_loss_cls'] = etr_h_loss_cls.item()
                loss['ETR/g_loss_cls'] = etr_g_loss_cls.item()
                loss['ETR/s_loss_cls'] = etr_s_loss_cls.item()
                loss['ER/er_loss_cyc'] = er_loss_cyc.item()
                loss['ETR/h_loss_cyc'] = etr_h_loss_cyc.item()
                loss['ETR/g_loss_cyc'] = etr_g_loss_cyc.item()
                loss['ETR/s_loss_cyc'] = etr_s_loss_cyc.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_fixed_list:
                        xf = self.E(x_fixed)
                        xth = self.T_Hair(xf, c_fixed[:, 0:3])
                        xtg = self.T_Gender(xth, c_fixed[:, 3:4])
                        xts = self.T_Smailing(xtg, c_fixed[:, 4:5])
                        x_fake_list.append(self.R(xts))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i + 1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.train_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)
                
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # save model checkpoints
            if (i+1) % self.model_save_step == 0:
                E_path = os.path.join(self.model_save_dir, '{}-E.ckpt'.format(i+1))
                D_h_path = os.path.join(self.model_save_dir, '{}-D_h.ckpt'.format(i+1))
                D_g_path = os.path.join(self.model_save_dir, '{}-D_g.ckpt'.format(i+1))
                D_s_path = os.path.join(self.model_save_dir, '{}-D_s.ckpt'.format(i+1))
                R_path = os.path.join(self.model_save_dir, '{}-R.ckpt'.format(i+1))
                T_h_path = os.path.join(self.model_save_dir, '{}-T_h.ckpt'.format(i+1))
                T_g_path = os.path.join(self.model_save_dir, '{}-T_g.ckpt'.format(i+1))
                T_s_path = os.path.join(self.model_save_dir, '{}-T_s.ckpt'.format(i+1))
                torch.save(self.E.state_dict(), E_path)
                torch.save(self.D_Hair.state_dict(), D_h_path)
                torch.save(self.D_Gender.state_dict(), D_g_path)
                torch.save(self.D_Smailing.state_dict(), D_s_path)
                torch.save(self.R.state_dict(), R_path)
                torch.save(self.T_Hair.state_dict(), T_h_path)
                torch.save(self.T_Gender.state_dict(), T_g_path)
                torch.save(self.T_Smailing.state_dict(), T_s_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # decay learning rates
            if (i+1) % self.lr_update_step == 0 and (i+1) > self.num_iters_decay:
                e_lr -= (self.e_lr / float(self.decay_rate))
                d_lr -= (self.d_lr / float(self.decay_rate))
                r_lr -= (self.r_lr / float(self.decay_rate))
                t_lr -= (self.t_lr / float(self.decay_rate))
                self.update_lr(e_lr, d_lr, r_lr, t_lr)
                print ('Decayed learning rates, e_lr: {}, d_lr: {}, r_lr: {}, t_lr: {}.'.format(e_lr, d_lr, r_lr, t_lr))
Ejemplo n.º 15
0
class RNN(object):
    def __init__(self, input_size, output_size):
        super(RNN, self).__init__()

        self.encoder = Encoder(input_size)
        self.decoder = Decoder(output_size)

        self.loss = nn.CrossEntropyLoss()
        self.encoder_optimizer = optim.Adam(self.encoder.parameters())
        self.decoder_optimizer = optim.Adam(self.decoder.parameters())

        sos, eos = torch.LongTensor(1, 1).zero_(), torch.LongTensor(1, 1).zero_()
        sos[0, 0], eos[0, 0] = 0, 1

        self.sos, self.eos = sos, eos

    def train(self, input, target):
        target.insert(0, self.sos)
        target.append(self.eos)

        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()

        # Encoder
        hidden_state = self.encoder.first_hidden()
        for ivec in input:
            _, hidden_state = self.encoder.forward(Variable(ivec), hidden_state)

        # Decoder
        total_loss, outputs = 0, []
        for i in range(len(target) - 1):
            _, softmax, hidden_state = self.decoder.forward(Variable(target[i]), hidden_state)

            outputs.append(np.argmax(softmax.data.numpy(), 1)[:, np.newaxis])
            total_loss += self.loss(softmax, Variable(target[i+1][0]))

        total_loss /= len(outputs)
        total_loss.backward()

        self.decoder_optimizer.step()
        self.encoder_optimizer.step()

        return total_loss.data[0], outputs   # use total_loss.data[0] for version 0.3.0_4 and below, .item() for 0.4.0

    def eval(self, input):
        hidden_state = self.encoder.first_hidden()

        # Encoder
        for ivec in input:
            _, hidden_state = self.encoder.forward(Variable(ivec), hidden_state)

        sentence = []
        input = self.sos
        # Decoder
        while input.data[0, 0] != 1:
            output, _, hidden_state = self.decoder.forward(input, hidden_state)
            word = np.argmax(output.data.numpy()).reshape((1, 1))
            input = Variable(torch.LongTensor(word))
            sentence.append(word)

        return sentence

    def save(self):
        torch.save(self.encoder.state_dict(), "models/encoder.ckpt")
        torch.save(self.decoder.state_dict(), "models/decoder.ckpt")
Ejemplo n.º 16
0
class PoemWAE(nn.Module):
    def __init__(self, config, api, PAD_token=0, pretrain_weight=None):
        super(PoemWAE, self).__init__()
        self.vocab = api.vocab
        self.vocab_size = len(self.vocab)
        self.rev_vocab = api.rev_vocab
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.maxlen = config.maxlen
        self.clip = config.clip
        self.lambda_gp = config.lambda_gp
        self.lr_gan_g = config.lr_gan_g
        self.lr_gan_d = config.lr_gan_d
        self.n_d_loss = config.n_d_loss
        self.temp = config.temp
        self.init_w = config.init_weight

        self.embedder = nn.Embedding(self.vocab_size,
                                     config.emb_size,
                                     padding_idx=PAD_token)
        if pretrain_weight is not None:
            self.embedder.weight.data.copy_(torch.from_numpy(pretrain_weight))
        # 用同一个seq_encoder来编码标题和前后两句话
        self.seq_encoder = Encoder(self.embedder, config.emb_size,
                                   config.n_hidden, True, config.n_layers,
                                   config.noise_radius)
        # 由于Poem这里context是title和last sentence双向GRU编码后的直接cat,4*hidden
        # 注意如果使用Poemwar_gmp则使用子类中的prior_net,即混合高斯分布的一个先验分布
        self.prior_net = Variation(config.n_hidden * 4,
                                   config.z_size,
                                   dropout_rate=config.dropout,
                                   init_weight=self.init_w)  # p(e|c)

        # 注意这儿原来是给Dialog那个任务用的,3*hidden
        # Poem数据集上,将title和上一句,另外加上x都分别用双向GRU编码并cat,因此是6*hidden
        self.post_net = Variation(config.n_hidden * 6,
                                  config.z_size,
                                  dropout_rate=config.dropout,
                                  init_weight=self.init_w)

        self.post_generator = nn.Sequential(
            nn.Linear(config.z_size, config.z_size),
            nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1), nn.ReLU(),
            nn.Linear(config.z_size, config.z_size),
            nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1), nn.ReLU(),
            nn.Linear(config.z_size, config.z_size))
        self.post_generator.apply(self.init_weights)

        self.prior_generator = nn.Sequential(
            nn.Linear(config.z_size, config.z_size),
            nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1), nn.ReLU(),
            nn.Linear(config.z_size, config.z_size),
            nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1), nn.ReLU(),
            nn.Linear(config.z_size, config.z_size))
        self.prior_generator.apply(self.init_weights)

        self.init_decoder_hidden = nn.Sequential(
            nn.Linear(config.n_hidden * 4 + config.z_size,
                      config.n_hidden * 4),
            nn.BatchNorm1d(config.n_hidden * 4, eps=1e-05, momentum=0.1),
            nn.ReLU())

        # 由于Poem这里context是title和last sentence双向GRU编码后的直接cat,因此hidden_size变为z_size + 4*hidden
        # 修改:decoder的hidden_size还设为n_hidden, init_hidden使用一个MLP将cat变换为n_hidden
        self.decoder = Decoder(self.embedder,
                               config.emb_size,
                               config.n_hidden * 4,
                               self.vocab_size,
                               n_layers=1)

        self.discriminator = nn.Sequential(
            # 因为Poem的cat两个双向编码,这里改为4*n_hidden + z_size
            nn.Linear(config.n_hidden * 4 + config.z_size,
                      config.n_hidden * 2),
            nn.BatchNorm1d(config.n_hidden * 2, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.2),
            nn.Linear(config.n_hidden * 2, config.n_hidden * 2),
            nn.BatchNorm1d(config.n_hidden * 2, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.2),
            nn.Linear(config.n_hidden * 2, 1),
        )
        self.discriminator.apply(self.init_weights)

        # optimizer 定义,分别对应三个模块的训练,注意!三个模块的optimizer不相同
        # self.optimizer_AE = optim.SGD(list(self.seq_encoder.parameters())
        self.optimizer_AE = optim.SGD(
            list(self.seq_encoder.parameters()) +
            list(self.post_net.parameters()) +
            list(self.post_generator.parameters()) +
            list(self.init_decoder_hidden.parameters()) +
            list(self.decoder.parameters()),
            lr=config.lr_ae)
        self.optimizer_G = optim.RMSprop(
            list(self.post_net.parameters()) +
            list(self.post_generator.parameters()) +
            list(self.prior_net.parameters()) +
            list(self.prior_generator.parameters()),
            lr=self.lr_gan_g)
        self.optimizer_D = optim.RMSprop(self.discriminator.parameters(),
                                         lr=self.lr_gan_d)

        self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE,
                                                         step_size=10,
                                                         gamma=0.8)

        self.criterion_ce = nn.CrossEntropyLoss()

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-self.init_w, self.init_w)
            # nn.init.kaiming_normal_(m.weight.data)
            # nn.init.kaiming_uniform_(m.weight.data)
            m.bias.data.fill_(0)

    # x: (batch, 2*n_hidden)
    # c: (batch, 2*2*n_hidden)
    def sample_code_post(self, x, c):
        z, _, _ = self.post_net(torch.cat((x, c),
                                          1))  # 输入:(batch, 3*2*n_hidden)
        z = self.post_generator(z)
        return z

    def sample_code_prior_sentiment(self, c, align):
        choice_statistic = self.prior_net(c, align)  # e: (batch, z_size)
        return choice_statistic

    def sample_code_prior(self, c):
        z, _, _ = self.prior_net(c)  # e: (batch, z_size)
        z = self.prior_generator(z)  # z: (batch, z_size)
        return z

    # 输入 title, context, target, target_lens.
    # c由title和context encode之后的hidden相concat而成
    def train_AE(self, title, context, target, target_lens):
        self.seq_encoder.train()
        self.decoder.train()
        # import pdb
        # pdb.set_trace()
        # (batch, 2 * hidden_size)
        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)

        # (batch, 2 * hidden_size)
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        # context_embedding
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2 * hidden_size * 2)
        z = self.sample_code_post(x, c)  # (batch, z_size)

        # 标准的autoencoder的decode,decoder初态为x, c的cat,将target错位输入
        # output: (batch, len, vocab_size) len是9,即7+标点+</s>

        output = self.decoder(self.init_decoder_hidden(torch.cat((z, c), 1)),
                              None, target[:, :-1], target_lens - 1)
        flattened_output = output.view(-1, self.vocab_size)

        dec_target = target[:, 1:].contiguous().view(-1)
        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)

        self.optimizer_AE.zero_grad()
        loss = self.criterion_ce(masked_output / self.temp, masked_target)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            list(self.seq_encoder.parameters()) +
            list(self.decoder.parameters()), self.clip)
        self.optimizer_AE.step()

        return [('train_loss_AE', loss.item())]

    # G是来缩短W距离的,可以类比VAE里面的缩小KL散度项
    def train_G(self,
                title,
                context,
                target,
                target_lens,
                sentiment_mask=None,
                mask_type=None):
        self.seq_encoder.eval()
        self.optimizer_G.zero_grad()

        for p in self.discriminator.parameters():
            p.requires_grad = False
        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2 * hidden_size * 2)

        # -----------------posterior samples ---------------------------
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        z_post = self.sample_code_post(
            x.detach(), c.detach())  # 去掉梯度,防止梯度向encoder的传播 (batch, z_size)

        errG_post = torch.mean(
            self.discriminator(torch.cat(
                (z_post, c.detach()),
                1))) * self.n_d_loss  # (batch, z_size + 4 * hidden)
        errG_post.backward(minus_one)

        # ----------------- prior samples ---------------------------
        prior_z = self.sample_code_prior(c.detach())
        errG_prior = torch.mean(
            self.discriminator(torch.cat(
                (prior_z, c.detach()), 1))) * self.n_d_loss
        # import pdb
        # pdb.set_trace()
        errG_prior.backward(one)
        self.optimizer_G.step()

        for p in self.discriminator.parameters():
            p.requires_grad = True

        costG = errG_prior - errG_post

        return [('train_loss_G', costG.item())]

    # D是用来拟合W距离,loss下降说明拟合度变好,增大gradient_penalty一定程度上可以提高拟合度
    # n_iters_n越大,D训练的次数越多,对应的拟合度也越好
    def train_D(self, title, context, target, target_lens):
        self.seq_encoder.eval()
        self.discriminator.train()
        self.optimizer_D.zero_grad()

        batch_size = context.size(0)

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2, hidden_size * 2)
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        post_z = self.sample_code_post(x, c)
        errD_post = torch.mean(
            self.discriminator(torch.cat(
                (post_z.detach(), c.detach()), 1))) * self.n_d_loss
        errD_post.backward(one)

        prior_z = self.sample_code_prior(c)
        errD_prior = torch.mean(
            self.discriminator(torch.cat(
                (prior_z.detach(), c.detach()), 1))) * self.n_d_loss
        errD_prior.backward(minus_one)
        # import pdb
        # pdb.set_trace()

        alpha = to_tensor(torch.rand(batch_size, 1))
        alpha = alpha.expand(prior_z.size())
        interpolates = alpha * prior_z.data + ((1 - alpha) * post_z.data)
        interpolates = Variable(interpolates, requires_grad=True)

        d_input = torch.cat((interpolates, c.detach()), 1)
        disc_interpolates = torch.mean(self.discriminator(d_input))
        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=to_tensor(torch.ones(disc_interpolates.size())),
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]
        gradient_penalty = (
            (gradients.contiguous().view(gradients.size(0), -1).norm(2, dim=1)
             - 1)**2).mean() * self.lambda_gp
        gradient_penalty.backward()

        self.optimizer_D.step()
        costD = -(errD_prior - errD_post) + gradient_penalty
        return [('train_loss_D', costD.item())]

    def valid(self, title, context, target, target_lens, sentiment_mask=None):
        self.seq_encoder.eval()
        self.discriminator.eval()
        self.decoder.eval()

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2 * hidden_size * 2)
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)

        post_z = self.sample_code_post(x, c)
        prior_z = self.sample_code_prior(c)
        errD_post = torch.mean(self.discriminator(torch.cat((post_z, c), 1)))
        errD_prior = torch.mean(self.discriminator(torch.cat((prior_z, c), 1)))
        costD = -(errD_prior - errD_post)
        costG = -costD

        dec_target = target[:, 1:].contiguous().view(-1)  # (batch_size * len)
        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size)

        output = self.decoder(
            self.init_decoder_hidden(torch.cat((post_z, c), 1)), None,
            target[:, :-1], (target_lens - 1))
        flattened_output = output.view(-1, self.vocab_size)
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        lossAE = self.criterion_ce(masked_output / self.temp, masked_target)
        return [('valid_loss_AE', lossAE.item()),
                ('valid_loss_G', costG.item()), ('valid_loss_D', costD.item())]

    # 正如论文中说的,测试生成的时候,从先验网络中拿到噪声,用G生成prior_z(即代码中的sample_code_prior(c))
    # 然后decoder将prior_z和c的cat当做输入,decode出这句诗(这和论文里面不太一样,论文里面只把prior_z当做输入)
    # batch_size是1,一次测一句

    # title 即标题
    # context 上一句
    def test(self, title_tensor, title_words, headers):
        self.seq_encoder.eval()
        self.discriminator.eval()
        self.decoder.eval()
        # tem初始化为[2,3,0,0,0,0,0,0,0]

        tem = [[2, 3] + [0] * (self.maxlen - 2)]
        pred_poems = []

        title_tokens = [
            self.vocab[e] for e in title_words[0].tolist()
            if e not in [0, self.eos_id, self.go_id]
        ]
        pred_poems.append(title_tokens)
        for sent_id in range(4):
            tem = to_tensor(np.array(tem))
            context = tem

            # vec_context = np.zeros((batch_size, self.maxlen), dtype=np.int64)
            # for b_id in range(batch_size):
            #     vec_context[b_id, :] = np.array(context[b_id])
            # context = to_tensor(vec_context)

            title_last_hidden, _ = self.seq_encoder(
                title_tensor)  # (batch=1, 2*hidden)
            if sent_id == 0:
                context_last_hidden, _ = self.seq_encoder(
                    title_tensor)  # (batch=1, 2*hidden)
            else:
                context_last_hidden, _ = self.seq_encoder(
                    context)  # (batch=1, 2*hidden)
            c = torch.cat((title_last_hidden, context_last_hidden),
                          1)  # (batch, 4*hidden_size)
            # 由于一次只有一首诗,batch_size = 1,因此不必repeat
            prior_z = self.sample_code_prior(c)

            # decode_words 是完整的一句诗
            decode_words = self.decoder.testing(
                init_hidden=self.init_decoder_hidden(torch.cat((prior_z, c),
                                                               1)),
                maxlen=self.maxlen,
                go_id=self.go_id,
                mode="greedy",
                header=headers[sent_id])

            decode_words = decode_words[0].tolist()
            # import pdb
            # pdb.set_trace()
            if len(decode_words) > self.maxlen:
                tem = [decode_words[0:self.maxlen]]
            else:
                tem = [[0] * (self.maxlen - len(decode_words)) + decode_words]

            pred_tokens = [
                self.vocab[e] for e in decode_words[:-1]
                if e != self.eos_id and e != 0
            ]
            pred_poems.append(pred_tokens)

        gen = ''
        for line in pred_poems:
            true_str = " ".join(line)
            gen = gen + true_str + '\n'

        return gen

    def sample(self, title, context, repeat, go_id, end_id):
        self.seq_encoder.eval()
        self.decoder.eval()

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2 * hidden_size * 2)

        c_repeated = c.expand(
            repeat, -1)  # 注意,我们输入的batch_size是1,这里复制repeat遍,为了后面的BLEU计算

        prior_z = self.sample_code_prior(
            c_repeated)  # c_repeated: (batch_size=repeat, 4*hidden_size)

        # (batch, max_len, 1)  (batch_size, 1)
        sample_words, sample_lens = self.decoder.sampling(
            self.init_decoder_hidden(torch.cat((prior_z, c_repeated), 1)),
            self.maxlen, go_id, end_id, "greedy")
        return sample_words, sample_lens
Ejemplo n.º 17
0
class BiGAN(nn.Module):
    def __init__(self,config):
        super(BiGAN,self).__init__()

        self._work_type = config.work_type
        self._epochs = config.epochs
        self._batch_size = config.batch_size

        self._encoder_lr = config.encoder_lr
        self._generator_lr = config.generator_lr
        self._discriminator_lr = config.discriminator_lr
        self._latent_dim = config.latent_dim
        self._weight_decay = config.weight_decay

        self._img_shape = (config.input_size,config.input_size)
        self._img_save_path = config.image_save_path
        self._model_save_path = config.model_save_path
        self._device = config.device

        if self._work_type == 'train':
            # Loss function
            self._adversarial_criterion = torch.nn.MSELoss()

            # Initialize generator, encoder and discriminator
            self._G = Generator(self._latent_dim,self._img_shape).to(self._device)
            self._E = Encoder(self._latent_dim,self._img_shape).to(self._device)
            self._D = Discriminator(self._latent_dim,self._img_shape).to(self._device)

            self._G.apply(self.weights_init)
            self._E.apply(self.weights_init)
            self._D.apply(self.discriminator_weights_init)

            self._G_optimizer = torch.optim.Adam([{'params' : self._G.parameters()},{'params' : self._E.parameters()}],
                                                lr=self._generator_lr,betas=(0.5,0.999),weight_decay=self._weight_decay)
            self._D_optimizer = torch.optim.Adam(self._D.parameters(),lr=self._discriminator_lr,betas=(0.5,0.999))
            
            self._G_scheduler = lr_scheduler.ExponentialLR(self._G_optimizer, gamma= 0.99) 
            self._D_scheduler = lr_scheduler.ExponentialLR(self._D_optimizer, gamma= 0.99) 

    def train(self,train_loader):
        Tensor = torch.cuda.FloatTensor if self._device == 'cuda' else torch.FloatTensor
        n_total_steps = len(train_loader)
        for epoch in range(self._epochs):
            self._G_scheduler.step()
            self._D_scheduler.step()

            for i, (images, _) in enumerate(train_loader):
                # Adversarial ground truths
                valid = Variable(Tensor(images.size(0), 1).fill_(1), requires_grad=False)
                fake = Variable(Tensor(images.size(0), 1).fill_(0), requires_grad=False)

                
                # ---------------------
                # Train Encoder
                # ---------------------
                
                # Configure input
                images = images.reshape(-1,np.prod(self._img_shape)).to(self._device)

                # z_ is encoded latent vector
                (original_img,z_)= self._E(images)
                predict_encoder = self._D(original_img,z_)
  

                # ---------------------
                # Train Generator
                # ---------------------
                
                # Sample noise as generator input
                z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0],self._latent_dim))))
                (gen_img,z)=self._G(z)
                predict_generator = self._D(gen_img,z)
                                                                                                               
                G_loss = (self._adversarial_criterion(predict_generator,valid)+self._adversarial_criterion(predict_encoder,fake)) *0.5   

                self._G_optimizer.zero_grad()
                G_loss.backward()
                self._G_optimizer.step()         

                # ---------------------
                # Train Discriminator
                # ---------------------

                z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0],self._latent_dim))))
                (gen_img,z)=self._G(z)
                (original_img,z_)= self._E(images)
                predict_encoder = self._D(original_img,z_)
                predict_generator = self._D(gen_img,z)

                D_loss = (self._adversarial_criterion(predict_encoder,valid)+self._adversarial_criterion(predict_generator,fake)) *0.5                
                
                self._D_optimizer.zero_grad()
                D_loss.backward()
                self._D_optimizer.step()

                

                
                if i % 100 == 0:
                    print (f'Epoch [{epoch+1}/{self._epochs}], Step [{i+1}/{n_total_steps}]')
                    print (f'Generator Loss: {G_loss.item():.4f} Discriminator Loss: {D_loss.item():.4f}')
 
                if i % 400 ==0:
                    vutils.save_image(gen_img.unsqueeze(1).cpu().data[:64, ], f'{self._img_save_path}/E{epoch}_Iteration{i}_fake.png')
                    vutils.save_image(original_img.unsqueeze(1).cpu().data[:64, ], f'{self._img_save_path}/E{epoch}_Iteration{i}_real.png')
                    print('image saved')
                    print('')
            if epoch % 100==0:
                torch.save(self._G.state_dict(), f'{self._model_save_path}/netG_{epoch}epoch.pth')
                torch.save(self._E.state_dict(), f'{self._model_save_path}/netE_{epoch}epoch.pth')
                torch.save(self._D.state_dict(), f'{self._model_save_path}/netD_{epoch}epoch.pth')





    def weights_init(self,m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
        elif classname.find('Linear') != -1:
            m.bias.data.fill_(0)

    def discriminator_weights_init(self,m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.5)
            m.bias.data.fill_(0)
        elif classname.find('Linear') != -1:
            m.bias.data.fill_(0)