Esempio n. 1
0
    def __init__(  # TODO move parameters to config file
            self,
            pset,
            batch_size=64,
            max_size=100,
            vocab_inp_size=32,
            vocab_tar_size=32,
            embedding_dim=64,
            units=128,
            hidden_size=128,
            alpha=0.1,
            epochs=200,
            epoch_decay=1,
            min_epochs=10,
            verbose=True):
        self.alpha = alpha
        self.batch_size = batch_size
        self.max_size = max_size
        self.epochs = epochs
        self.epoch_decay = epoch_decay
        self.min_epochs = min_epochs
        self.train_steps = 0

        self.verbose = verbose

        self.enc = Encoder(vocab_inp_size, embedding_dim, units, batch_size)
        self.dec = Decoder(vocab_inp_size, vocab_tar_size, embedding_dim,
                           units, batch_size)
        self.surrogate = Surrogate(hidden_size)
        self.population = Population(pset, max_size, batch_size)
        self.prob = 0.5

        self.optimizer = tf.keras.optimizers.Adam()
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=False, reduction='none')
Esempio n. 2
0
    def __init__(self, config):
        self.train_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(config.train_data_dir, transforms.Compose([
                transforms.RandomSizedCrop(config.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=config.batch_size, shuffle=True,
            num_workers=config.workers, pin_memory=True)

        os.makedirs(f'{config.save_dir}/{config.version}',exist_ok=True)

        self.loss_dir = f'{config.save_dir}/{config.version}/loss'
        self.model_state_dir = f'{config.save_dir}/{config.version}/model_state'
        self.image_dir = f'{config.save_dir}/{config.version}/image'
        self.psnr_dir = f'{config.save_dir}/{config.version}/psnr'

        os.makedirs(self.loss_dir,exist_ok=True)
        os.makedirs(self.model_state_dir,exist_ok=True)
        os.makedirs(self.image_dir,exist_ok=True)
        os.makedirs(self.psnr_dir,exist_ok=True)

        self.encoder = Encoder(True).cuda()
        self.decoder = Decoder(False, True).cuda()
        self.D = VGG16_mid().cuda()

        self.config = config
Esempio n. 3
0
 def __init__(self, model_name: str, args):
     self.model_name = model_name
     self.args = args
     self.encoder = Encoder(self.args.add_noise).to(self.args.device)
     self.decoder = Decoder(self.args.upsample_mode).to(self.args.device)
     self.pretrainDataset = None
     self.pretrainDataloader = None
     self.pretrainOptimizer = None
     self.pretrainScheduler = None
     self.RHO_tensor = None
     self.pretrain_batch_cnt = 0
     self.writer = None
     self.svmDataset = None
     self.svmDataloader = None
     self.testDataset = None
     self.testDataloader = None
     self.svm = SVC(C=self.args.svm_c,
                    kernel=self.args.svm_ker,
                    verbose=True,
                    max_iter=self.args.svm_max_iter)
     self.resnet = Resnet(use_pretrained=True,
                          num_classes=self.args.classes,
                          resnet_depth=self.args.resnet_depth,
                          dropout=self.args.resnet_dropout).to(
                              self.args.device)
     self.resnetOptimizer = None
     self.resnetScheduler = None
     self.resnetLossFn = None
Esempio n. 4
0
    def __init__(self, opt, vocabs):
        super(S2SModel, self).__init__()

        self.opt = opt
        self.vocabs = vocabs
        self.encoder = Encoder(vocabs, opt)
        self.decoder = Decoder(vocabs, opt)
        self.generator = ProdGenerator(self.opt.decoder_rnn_size, vocabs,
                                       self.opt)
Esempio n. 5
0
    def __init__(self, config):

        content_trans = transforms.Compose([
            transforms.RandomSizedCrop(256),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        self.train_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(config.train_data_dir, content_trans),
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=config.workers,
            pin_memory=True,
            drop_last=True)

        self.trans = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

        style_trans = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

        self.loss_dir = f'{config.save_dir}/loss'
        self.model_state_dir = f'{config.save_dir}/model_state'
        self.image_dir = f'{config.save_dir}/image'
        self.psnr_dir = f'{config.save_dir}/psnr'

        if not os.path.exists(self.loss_dir):
            os.mkdir(self.loss_dir)
            os.mkdir(self.model_state_dir)
            os.mkdir(self.image_dir)
            os.mkdir(self.psnr_dir)

        self.encoder = Encoder().cuda()
        self.transformer = Attention().cuda()
        self.decoder = Decoder().cuda()

        self.wavepool = WavePool(256).cuda()

        self.decoder.load_state_dict(torch.load("./decoder.pth"))

        S_path = os.path.join(config.style_dir, str(config.S))
        style_images = glob.glob((S_path + '/*.jpg'))
        s = Image.open(style_images[0])
        s = s.resize((512, 320), 0)
        s = style_trans(s).cuda()
        self.style_image = s.unsqueeze(0)
        self.style_target = torch.stack([s for i in range(config.batch_size)],
                                        0)

        self.config = config
Esempio n. 6
0
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(config.num_vocab,  # 词汇表大小
                                   config.embedding_size,  # 嵌入层维度
                                   config.pad_id,  # pad_id
                                   config.dropout)

        # post编码器
        self.post_encoder = Encoder(config.post_encoder_cell_type,  # rnn类型
                                    config.embedding_size,  # 输入维度
                                    config.post_encoder_output_size,  # 输出维度
                                    config.post_encoder_num_layers,  # rnn层数
                                    config.post_encoder_bidirectional,  # 是否双向
                                    config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(config.response_encoder_cell_type,
                                        config.embedding_size,  # 输入维度
                                        config.response_encoder_output_size,  # 输出维度
                                        config.response_encoder_num_layers,  # rnn层数
                                        config.response_encoder_bidirectional,  # 是否双向
                                        config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(config.post_encoder_output_size,  # post输入维度
                                  config.latent_size,  # 潜变量维度
                                  config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(config.post_encoder_output_size,  # post输入维度
                                          config.response_encoder_output_size,  # response输入维度
                                          config.latent_size,  # 潜变量维度
                                          config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(config.post_encoder_output_size+config.latent_size,
                                          config.decoder_cell_type,
                                          config.decoder_output_size,
                                          config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(config.decoder_cell_type,  # rnn类型
                               config.embedding_size,  # 输入维度
                               config.decoder_output_size,  # 输出维度
                               config.decoder_num_layers,  # rnn层数
                               config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1)
        )
Esempio n. 7
0
    def __init__(
            self,
            n_cap_vocab, cap_max_seq, dim_language = 768,
            d_word_vec=512, d_model=512, d_inner=2048,
            n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1,
            c3d_path=False, tgt_emb_prj_weight_sharing=True):

        super().__init__()

        # Load Kinetics/Self pre-trained C3D model, return only features
        self.encoder = nn.Linear(1024, 768)

        self.decoder = Decoder(
            n_tgt_vocab=n_cap_vocab, len_max_seq=cap_max_seq,
            d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
            n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
            dropout=dropout)

        self.cap_word_prj = nn.Linear(d_model, n_cap_vocab, bias=False)

        nn.init.xavier_normal_(self.cap_word_prj.weight)

        assert d_model == d_word_vec, \
            'To facilitate the residual connections, ' \
            'the dimensions of all module outputs shall be the same.'

        if tgt_emb_prj_weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.cap_word_prj.weight = self.decoder.tgt_word_emb.weight
            self.x_logit_scale = (d_model ** -0.5)
        else:
            self.x_logit_scale = 1.
Esempio n. 8
0
    def __init__(self, config):
        content_trans = transforms.Compose([
            transforms.Resize(config.image_size),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        self.train_loader = torch.utils.data.DataLoader(
            HDR_LDR(config.ldr_dir, config.hdr_dir, content_trans),
            batch_size=1,
            shuffle=True,
            num_workers=config.workers,
            pin_memory=True,
            drop_last=True)

        os.makedirs(f'{config.save_dir}/{config.version}', exist_ok=True)

        self.loss_dir = f'{config.save_dir}/{config.version}/loss'
        self.model_state_dir = f'{config.save_dir}/{config.version}/model_state'
        self.image_dir = f'{config.save_dir}/{config.version}/image'
        self.psnr_dir = f'{config.save_dir}/{config.version}/psnr'
        self.code_dir = f'{config.save_dir}/{config.version}/code'

        os.makedirs(self.loss_dir, exist_ok=True)
        os.makedirs(self.model_state_dir, exist_ok=True)
        os.makedirs(self.image_dir, exist_ok=True)
        os.makedirs(self.psnr_dir, exist_ok=True)
        os.makedirs(self.code_dir, exist_ok=True)

        script_name = 'trainer_' + config.script_name + '.py'
        shutil.copyfile(os.path.join('scripts', script_name),
                        os.path.join(self.code_dir, script_name))
        shutil.copyfile('components/transformer.py',
                        os.path.join(self.code_dir, 'transformer.py'))
        shutil.copyfile('model/Fusion.py',
                        os.path.join(self.code_dir, 'Fusion.py'))

        self.encoder = Encoder().cuda()
        self.attention = Transformer(config.topk, True, False).cuda()
        self.decoder = Decoder().cuda()

        self.decoder.load_state_dict(torch.load("./hdr_decoder.pth"))

        self.config = config
Esempio n. 9
0
class S2SModel(nn.Module):
    def __init__(self, opt, vocabs):
        super(S2SModel, self).__init__()

        self.opt = opt
        self.vocabs = vocabs
        self.encoder = Encoder(vocabs, opt)
        self.decoder = Decoder(vocabs, opt)
        self.generator = ProdGenerator(self.opt.decoder_rnn_size, vocabs,
                                       self.opt)
        # self.cuda()

    def forward(self, batch):
        # initial parent states for Prod Decoder
        batch_size = batch['seq2seq'].size(0)
        batch['parent_states'] = {}
        for j in range(0, batch_size):
            batch['parent_states'][j] = {}
            batch['parent_states'][j][0] = Variable(torch.zeros(
                1, 1, self.opt.decoder_rnn_size),
                                                    requires_grad=False)

        context, context_lengths, enc_hidden = self.encoder(batch)

        dec_initial_state = DecoderState(
            enc_hidden,
            Variable(torch.zeros(batch_size, 1, self.opt.decoder_rnn_size),
                     requires_grad=False))

        output, attn, copy_attn = self.decoder(batch, context, context_lengths,
                                               dec_initial_state)

        del batch['parent_states']

        src_map = torch.zeros(0, 0)
        # print(src_map)
        # print(batch['concode_src_map_vars'].shape)
        src_map = torch.cat((src_map, batch['concode_src_map_vars']), 1)
        src_map = torch.cat((src_map, batch['concode_src_map_methods']), 1)

        scores = self.generator(bottle(output), bottle(copy_attn), src_map,
                                batch)
        loss, total, correct = self.generator.computeLoss(scores, batch)

        return loss, Statistics(loss.data[0], total, correct,
                                self.encoder.n_src_words)

    # This only works for a batch size of 1
    def predict(self, batch, opt, vis_params):
        curr_batch_size = batch['seq2seq'].size(0)
        assert (curr_batch_size == 1)
        context, context_lengths, enc_hidden = self.encoder(batch)
        return self.decoder.predict(enc_hidden, context, context_lengths,
                                    batch, opt.beam_size, opt.max_sent_length,
                                    self.generator, opt.replace_unk,
                                    vis_params)
Esempio n. 10
0
 def __init__(self,
              args,
              vocab,
              n_dim,
              image_dim,
              layers,
              dropout,
              num_choice=5):
     super().__init__()
     print("Model name: DA, 1 layer, fixed subspaces")
     self.vocab = vocab
     self.encoder = Encoder(args, vocab, n_dim, image_dim, layers, dropout,
                            num_choice).cuda()
     self.decoder = Decoder(args, vocab, n_dim, image_dim, layers, dropout,
                            num_choice).cuda()
Esempio n. 11
0
    def __init__(self,
                 num_layers,
                 d_model,
                 num_heads,
                 dff,
                 input_vocab_size,
                 target_vocab_size,
                 dropout=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, dff,
                               input_vocab_size, dropout)

        self.decoder = Decoder(num_layers, d_model, num_heads, dff,
                               target_vocab_size, dropout)

        self.final_layer = tf.keras.layers.Dense(target_vocab_size)
Esempio n. 12
0
    def __init__(self,
                 num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6,
                 dim_embedding: int = 512,
                 num_heads: int = 6,
                 dim_feedfordward: int = 512,
                 dropout: float = 0.1,
                 activation: nn.Module = nn.ReLU()):
        super().__init__()
        self.encoder = Encoder(num_layers=num_encoder_layers,
                               dim_embedding=dim_embedding,
                               num_heads=num_heads,
                               dim_feedfordward=dim_feedfordward,
                               dropout=dropout)

        self.decoder = Decoder(num_layers=num_decoder_layers,
                               dim_embedding=dim_embedding,
                               num_heads=num_heads,
                               dim_feedfordward=dim_feedfordward,
                               dropout=dropout)
        self.criterion = nn.CrossEntropyLoss()
Esempio n. 13
0
    def __init__(self):
        super(Model, self).__init__()

        self.encoder = Encoder()
        self.decoder = Decoder()
        self.embeds = nn.Embedding(config.vocab_size, config.emb_dim)
        init_wt.init_wt_normal(self.embeds.weight)

        self.encoder = get_cuda(self.encoder)
        self.decoder = get_cuda(self.decoder)
        self.embeds = get_cuda(self.embeds)


# if __name__ == '__main__':
#
#     my_model = Model()
#     my_model_paramters = my_model.parameters()
#
#     print(my_model_paramters)
#     my_model_paramters_group = list(my_model_paramters)
#     print(my_model_paramters_group)
Esempio n. 14
0
 def __init__(self,
              args,
              vocab,
              n_dim,
              image_dim,
              layers,
              dropout,
              num_choice=5):
     super().__init__()
     print("Model name: DA")
     self.vocab = vocab
     self.encoder = Encoder(args, vocab, n_dim, image_dim, layers, dropout,
                            num_choice).cuda()
     #self.encoder = TransformerEncoder(args, vocab, n_dim, image_dim, layers, dropout, num_choice).cuda()
     #self.encoder = DAEncoder(args, vocab, n_dim, image_dim, layers, dropout, num_choice).cuda()
     #self.encoder = MHEncoder(args, vocab, n_dim, image_dim, layers, dropout, num_choice).cuda()
     ##self.encoder = HierarchicalDA(args, vocab, n_dim, image_dim, layers, dropout, num_choice).cuda()
     #self.decoder = Disc(args, vocab, n_dim, image_dim, layers, dropout, num_choice)
     #self.decoder = SumDisc(args, vocab, n_dim, image_dim, layers, dropout, num_choice)
     self.decoder = Decoder(args, vocab, n_dim, image_dim, layers, dropout,
                            num_choice).cuda()
Esempio n. 15
0
def make_model(src_vocab,
               tar_vocab,
               N=6,
               d_model=512,
               d_ff=2014,
               h=8,
               dropout=0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = GeneralEncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embedding(d_model, src_vocab), c(position)),
        nn.Sequential(Embedding(d_model, tar_vocab), c(position)),
        Generator(d_model, tar_vocab))

    # 随机初始化参数,这非常重要
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model
Esempio n. 16
0
    def __init__(self,
                 n_cap_vocab,
                 n_cms_vocab,
                 cap_max_seq,
                 cms_max_seq,
                 vis_emb=2048,
                 d_word_vec=512,
                 d_model=512,
                 d_inner=2048,
                 n_layers=6,
                 rnn_layers=1,
                 n_head=8,
                 d_k=64,
                 d_v=64,
                 dropout=0.1,
                 tgt_emb_prj_weight_sharing=True):

        super().__init__()

        # set RNN layers at 1 or 2 yield better performance.
        self.vis_emb = nn.Linear(vis_emb, d_model)
        self.encoder = Encoder(40,
                               d_model,
                               rnn_layers,
                               n_head,
                               d_k,
                               d_v,
                               d_model,
                               d_inner,
                               dropout=0.1)

        self.decoder = Decoder(n_tgt_vocab=n_cap_vocab,
                               len_max_seq=cap_max_seq,
                               d_word_vec=d_word_vec,
                               d_model=d_model,
                               d_inner=d_inner,
                               n_layers=n_layers,
                               n_head=n_head,
                               d_k=d_k,
                               d_v=d_v,
                               dropout=dropout)

        self.cms_decoder = Decoder(n_tgt_vocab=n_cms_vocab,
                                   len_max_seq=cms_max_seq,
                                   d_word_vec=d_word_vec,
                                   d_model=d_model,
                                   d_inner=d_inner,
                                   n_layers=n_layers,
                                   n_head=n_head,
                                   d_k=d_k,
                                   d_v=d_v,
                                   dropout=dropout)

        self.cap_word_prj = nn.Linear(d_model, n_cap_vocab, bias=False)
        self.cms_word_prj = nn.Linear(d_model, n_cms_vocab, bias=False)

        nn.init.xavier_normal_(self.cap_word_prj.weight)
        nn.init.xavier_normal_(self.cms_word_prj.weight)

        assert d_model == d_word_vec, \
            'To facilitate the residual connections, ' \
            'the dimensions of all module outputs shall be the same.'

        if tgt_emb_prj_weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.cap_word_prj.weight = self.decoder.tgt_word_emb.weight
            self.cms_word_prj.weight = self.cms_decoder.tgt_word_emb.weight
            self.x_logit_scale = (d_model**-0.5)
        else:
            self.x_logit_scale = 1.
Esempio n. 17
0
class Trainer(Trainer_Base):
    def create_model(self):
        self.encoder = Encoder(True).cuda()
        self.decoder = Decoder(True, True).cuda()
        self.D = VGG16_mid().cuda()
        self.attention1 = Transformer(4, 512, self.config.topk, True,
                                      False).cuda()

    def train(self):
        self.create_model()

        optimizer = torch.optim.Adam(self.attention1.parameters(),
                                     lr=self.config.learning_rate)
        optimizer2 = torch.optim.Adam(self.decoder.parameters(),
                                      lr=self.config.learning_rate)

        criterion = torch.nn.L1Loss()
        criterion_p = torch.nn.MSELoss(reduction='mean')
        styles = iter(self.style_loader)

        self.encoder.eval()
        self.decoder.train()
        self.reporter.writeInfo("Start to train the model")
        for e in range(1, self.config.epoch_size + 1):
            for i, (content, target) in enumerate(self.train_loader):
                try:
                    style, target = next(styles)
                except:
                    styles = iter(self.style_loader)
                    style, target = next(styles)

                content = content.cuda()
                style = style.cuda()

                fea_c = self.encoder(content)
                fea_s = self.encoder(style)

                out_feature, attention_map = self.attention1(fea_c, fea_s)
                # out_feature, attention_map = self.attention2(out_feature, fea_s)
                rec, _ = self.attention1(fea_s, fea_s)
                out_content = self.decoder(out_feature)

                c1, c2, c3, _ = self.D(content)
                h1, h2, h3, _ = self.D(out_content)
                s1, s2, s3, _ = self.D(style)

                loss_content = torch.norm(c3 - h3, p=2)
                loss_perceptual = 0
                for t in range(3):
                    loss_perceptual += criterion(
                        gram_matrix(eval('s' + str(t + 1))),
                        gram_matrix(eval('h' + str(t + 1))))
                loss = loss_content * self.config.content_weight + loss_perceptual * self.config.style_weight

                optimizer.zero_grad()
                optimizer2.zero_grad()
                loss.backward()
                optimizer.step()
                optimizer2.step()

                if i % self.config.log_interval == 0:
                    now = datetime.datetime.now()
                    otherStyleTime = now.strftime("%Y-%m-%d %H:%M:%S")
                    print(otherStyleTime)
                    print('epoch: ', e, ' iter: ', i)
                    print(
                        'attention scartters: ',
                        torch.std(attention_map.argmax(-1).float(),
                                  1).mean().cpu())
                    print(attention_map.shape)

                    # self.attention1.hard = True
                    self.attention1.eval()
                    self.decoder.eval()
                    tosave, perc, cont = self.eval()
                    save_image(
                        denorm(tosave),
                        self.image_dir + '/epoch_{}-iter_{}.png'.format(e, i))
                    print("image saved to " + self.image_dir +
                          '/epoch_{}-iter_{}.png'.format(e, i))
                    print('content loss:', cont)
                    print('perceptual loss:', perc)

                    self.reporter.writeTrainLog(
                        e, i, f'''
                        attention scartters: {torch.std(attention_map.argmax(-1).float(), 1).mean().cpu()}\n
                        content loss: {cont}\n
                        perceptual loss: {perc}
                    ''')

                    # self.attention1.hard = False
                    self.attention1.train()
                    self.decoder.train()

                    torch.save(
                        {
                            'layer1': self.attention1.state_dict(),
                            # 'layer2':self.attention2.state_dict(),
                            'decoder': self.decoder.state_dict()
                        },
                        f'{self.model_state_dir}/epoch_{e}-iter_{i}.pth')
Esempio n. 18
0
 def create_model(self):
     self.encoder = Encoder(True).cuda()
     self.decoder = Decoder(True, True).cuda()
     self.D = VGG16_mid().cuda()
     self.attention1 = Transformer(4, 512, self.config.topk, True,
                                   False).cuda()
Esempio n. 19
0
class Trainer(object):
    def __init__(self, config):

        content_trans = transforms.Compose([
            transforms.RandomSizedCrop(256),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        self.train_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(config.train_data_dir, content_trans),
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=config.workers,
            pin_memory=True,
            drop_last=True)

        self.trans = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

        style_trans = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

        self.loss_dir = f'{config.save_dir}/loss'
        self.model_state_dir = f'{config.save_dir}/model_state'
        self.image_dir = f'{config.save_dir}/image'
        self.psnr_dir = f'{config.save_dir}/psnr'

        if not os.path.exists(self.loss_dir):
            os.mkdir(self.loss_dir)
            os.mkdir(self.model_state_dir)
            os.mkdir(self.image_dir)
            os.mkdir(self.psnr_dir)

        self.encoder = Encoder().cuda()
        self.transformer = Attention().cuda()
        self.decoder = Decoder().cuda()

        self.wavepool = WavePool(256).cuda()

        self.decoder.load_state_dict(torch.load("./decoder.pth"))

        S_path = os.path.join(config.style_dir, str(config.S))
        style_images = glob.glob((S_path + '/*.jpg'))
        s = Image.open(style_images[0])
        s = s.resize((512, 320), 0)
        s = style_trans(s).cuda()
        self.style_image = s.unsqueeze(0)
        self.style_target = torch.stack([s for i in range(config.batch_size)],
                                        0)

        self.config = config

    def train(self):

        optimizer = torch.optim.Adam(self.transformer.parameters(),
                                     lr=self.config.learning_rate)
        criterion = torch.nn.L1Loss()  #torch.nn.MSELoss()
        criterion_p = torch.nn.MSELoss(reduction='mean')

        self.transformer.train()

        for e in range(1, self.config.epoch_size + 1):
            print(f'Start {e} epoch')
            psnr_list = []
            # for i, (content, target)  in tqdm(enumerate(self.train_loader, 1)):
            for i, (content, target) in enumerate(self.train_loader):
                content = content.cuda()

                content_feature = self.encoder(content)
                style_feature = self.encoder(self.style_target)
                # rec_content = self.decoder(content_feature)

                out_feature = self.transformer(content_feature, style_feature)

                loss = criterion_p(
                    self.wavepool(out_feature)[0],
                    self.wavepool(content_feature)[0])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if i % self.config.log_interval == 0:
                    now = datetime.datetime.now()
                    otherStyleTime = now.strftime("%Y-%m-%d %H:%M:%S")
                    print(otherStyleTime)
                    print('epoch: ', e, ' iter: ', i)
                    print('loss:', loss.cpu().item())

                    self.encoder.eval()
                    self.transformer.eval()
                    self.decoder.eval()

                    with torch.no_grad():
                        test_image = Image.open('test.jpg')
                        test_image = self.trans(test_image).unsqueeze(0).cuda()
                        content_feature = self.encoder(content)
                        style_feature = self.encoder(self.style_image)
                        out_feature = self.transformer(content_feature,
                                                       style_feature)
                        out_content = self.decoder(out_feature)

                    self.transformer.train()

                    save_image(
                        denorm(out_content),
                        self.image_dir + '/epoch_{}-iter_{}.png'.format(e, i))
                    print("image saved to " + self.image_dir +
                          '/epoch_{}-iter_{}.png'.format(e, i))

                    model_dicts = self.transformer.state_dict()
                    torch.save(
                        model_dicts,
                        f'{self.model_state_dir}/epoch_{e}-iter_{i}.pth')
Esempio n. 20
0
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=256,
                                          shuffle=True,
                                          num_workers=2)

testset = torchvision.datasets.ImageFolder(root='./data/Test',
                                           transform=transform_test)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=200,
                                         shuffle=False,
                                         num_workers=2)

# Mode
print('==> Building model..')
encoder = Encoder(mask=mask)
decoder = Decoder(mask=mask)
classifier = Classifier()
encoder = encoder.to(device)
decoder = decoder.to(device)
classifier = classifier.to(device)
if device == 'cuda':
    cudnn.benchmark = True
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt_' + codir + '.t7')
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])
    classifier.load_state_dict(checkpoint['classifier'])
    best_loss = checkpoint['loss']
Esempio n. 21
0
class Instructor:
    def __init__(self, model_name: str, args):
        self.model_name = model_name
        self.args = args
        self.encoder = Encoder(self.args.add_noise).to(self.args.device)
        self.decoder = Decoder(self.args.upsample_mode).to(self.args.device)
        self.pretrainDataset = None
        self.pretrainDataloader = None
        self.pretrainOptimizer = None
        self.pretrainScheduler = None
        self.RHO_tensor = None
        self.pretrain_batch_cnt = 0
        self.writer = None
        self.svmDataset = None
        self.svmDataloader = None
        self.testDataset = None
        self.testDataloader = None
        self.svm = SVC(C=self.args.svm_c,
                       kernel=self.args.svm_ker,
                       verbose=True,
                       max_iter=self.args.svm_max_iter)
        self.resnet = Resnet(use_pretrained=True,
                             num_classes=self.args.classes,
                             resnet_depth=self.args.resnet_depth,
                             dropout=self.args.resnet_dropout).to(
                                 self.args.device)
        self.resnetOptimizer = None
        self.resnetScheduler = None
        self.resnetLossFn = None

    def _load_data_by_label(self, label: str) -> list:
        ret = []
        LABEL_PATH = os.path.join(self.args.TRAIN_PATH, label)
        for dir_path, _, file_list in os.walk(LABEL_PATH, topdown=False):
            for file_name in file_list:
                file_path = os.path.join(dir_path, file_name)
                img_np = imread(file_path)
                img = img_np.copy()
                img = img.tolist()
                ret.append(img)
        return ret

    def _load_all_data(self):
        all_data = []
        all_labels = []
        for label_id in range(0, self.args.classes):
            expression = LabelEnum(label_id)
            sub_data = self._load_data_by_label(expression.name)
            sub_labels = [label_id] * len(sub_data)
            all_data.extend(sub_data)
            all_labels.extend(sub_labels)
        return all_data, all_labels

    def _load_test_data(self):
        file_map = pd.read_csv(
            os.path.join(self.args.RAW_PATH, 'submission.csv'))
        test_data = []
        img_names = []
        for file_name in file_map['file_name']:
            file_path = os.path.join(self.args.TEST_PATH, file_name)
            img_np = imread(file_path)
            img = img_np.copy()
            img = img.tolist()
            test_data.append(img)
            img_names.append(file_name)
        return test_data, img_names

    def trainAutoEncoder(self):
        self.writer = SummaryWriter(
            os.path.join(self.args.LOG_PATH, self.model_name))
        all_data, all_labels = self._load_all_data()
        self.pretrainDataset = FERDataset(all_data,
                                          labels=all_labels,
                                          args=self.args)
        self.pretrainDataloader = DataLoader(dataset=self.pretrainDataset,
                                             batch_size=self.args.batch_size,
                                             shuffle=True,
                                             num_workers=self.args.num_workers)
        self.pretrainOptimizer = torch.optim.Adam([{
            'params':
            self.encoder.parameters(),
            'lr':
            self.args.pretrain_lr
        }, {
            'params':
            self.decoder.parameters(),
            'lr':
            self.args.pretrain_lr
        }])
        tot_steps = math.ceil(
            len(self.pretrainDataloader) /
            self.args.cumul_batch) * self.args.epochs
        self.pretrainScheduler = get_linear_schedule_with_warmup(
            self.pretrainOptimizer,
            num_warmup_steps=0,
            num_training_steps=tot_steps)
        self.RHO_tensor = torch.tensor(
            [self.args.rho for _ in range(self.args.embed_dim)],
            dtype=torch.float).unsqueeze(0).to(self.args.device)
        epochs = self.args.epochs
        for epoch in range(1, epochs + 1):
            print()
            print(
                "================ AutoEncoder Training Epoch {:}/{:} ================"
                .format(epoch, epochs))
            print(" ---- Start training ------>")
            self.epochTrainAutoEncoder(epoch)
            print()
        self.writer.close()

    def epochTrainAutoEncoder(self, epoch):
        self.encoder.train()
        self.decoder.train()

        cumul_loss = 0
        cumul_steps = 0
        cumul_samples = 0

        self.pretrainOptimizer.zero_grad()
        cumulative_batch = 0

        for idx, (images, labels) in enumerate(tqdm(self.pretrainDataloader)):
            batch_size = images.shape[0]
            images, labels = images.to(self.args.device), labels.to(
                self.args.device)

            embeds = self.encoder(images)
            outputs = self.decoder(embeds)

            loss = torch.nn.functional.mse_loss(outputs, images)
            if self.args.use_sparse:
                rho_hat = torch.mean(embeds, dim=0, keepdim=True)
                sparse_penalty = self.args.regulizer_weight * torch.nn.functional.kl_div(
                    input=torch.nn.functional.log_softmax(rho_hat, dim=-1),
                    target=torch.nn.functional.softmax(self.RHO_tensor,
                                                       dim=-1))
                loss = loss + sparse_penalty

            loss_each = loss / self.args.cumul_batch
            loss_each.backward()

            cumulative_batch += 1
            cumul_steps += 1
            cumul_loss += loss.detach().cpu().item() * batch_size
            cumul_samples += batch_size

            if cumulative_batch >= self.args.cumul_batch:
                torch.nn.utils.clip_grad_norm_(self.encoder.parameters(),
                                               max_norm=self.args.max_norm)
                torch.nn.utils.clip_grad_norm_(self.decoder.parameters(),
                                               max_norm=self.args.max_norm)
                self.pretrainOptimizer.step()
                self.pretrainScheduler.step()
                self.pretrainOptimizer.zero_grad()
                cumulative_batch = 0

            if cumul_steps >= self.args.disp_period or idx + 1 == len(
                    self.pretrainDataloader):
                print(" -> cumul_steps={:} loss={:}".format(
                    cumul_steps, cumul_loss / cumul_samples))
                self.pretrain_batch_cnt += 1
                self.writer.add_scalar('batch-loss',
                                       cumul_loss / cumul_samples,
                                       global_step=self.pretrain_batch_cnt)
                self.writer.add_scalar('encoder_lr',
                                       self.pretrainOptimizer.state_dict()
                                       ['param_groups'][0]['lr'],
                                       global_step=self.pretrain_batch_cnt)
                self.writer.add_scalar('decoder_lr',
                                       self.pretrainOptimizer.state_dict()
                                       ['param_groups'][1]['lr'],
                                       global_step=self.pretrain_batch_cnt)
                cumul_steps = 0
                cumul_loss = 0
                cumul_samples = 0

        self.saveAutoEncoder(epoch)

    def saveAutoEncoder(self, epoch):
        encoderPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Encoder" + "--EPOCH-{:}".format(epoch))
        decoderPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Decoder" + "--EPOCH-{:}".format(epoch))
        print("-----------------------------------------------")
        print("  -> Saving AutoEncoder {:} ......".format(self.model_name))
        torch.save(self.encoder.state_dict(), encoderPath)
        torch.save(self.decoder.state_dict(), decoderPath)
        print("  -> Successfully saved AutoEncoder.")
        print("-----------------------------------------------")

    def generateAutoEncoderTestResultSamples(self, sample_cnt):
        self.encoder.eval()
        self.decoder.eval()
        print('  -> Generating samples with AutoEncoder ...')
        save_path = os.path.join(self.args.SAMPLE_PATH, self.model_name)
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        with torch.no_grad():
            for dir_path, _, file_list in os.walk(self.args.TEST_PATH,
                                                  topdown=False):
                sample_file_list = random.choices(file_list, k=sample_cnt)
                for file_name in sample_file_list:
                    file_path = os.path.join(dir_path, file_name)
                    img_np = imread(file_path)
                    img = img_np.copy()
                    img = ToTensor()(img)
                    img = img.reshape(1, 1, 48, 48)
                    img = img.to(self.args.device)
                    embed = self.encoder(img)
                    out = self.decoder(embed).cpu()
                    out = out.reshape(1, 48, 48)
                    out_img = ToPILImage()(out)
                    out_img.save(os.path.join(save_path, file_name))
        print('  -> Done sampling from AutoEncoder with test pictures.')

    def loadAutoEncoder(self, epoch):
        encoderPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Encoder" + "--EPOCH-{:}".format(epoch))
        decoderPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Decoder" + "--EPOCH-{:}".format(epoch))
        print("-----------------------------------------------")
        print("  -> Loading AutoEncoder {:} ......".format(self.model_name))
        self.encoder.load_state_dict(torch.load(encoderPath))
        self.decoder.load_state_dict(torch.load(decoderPath))
        print("  -> Successfully loaded AutoEncoder.")
        print("-----------------------------------------------")

    def generateExtractedFeatures(
            self, data: torch.FloatTensor) -> torch.FloatTensor:
        """
        :param data: (batch, channel, l, w)
        :return: embed: (batch, embed_dim)
        """
        with torch.no_grad():
            data = data.to(self.args.device)
            embed = self.encoder(data)
            embed = embed.detach().cpu()
            return embed

    def trainSVM(self, load: bool):
        svm_path = os.path.join(self.args.CKPT_PATH, self.model_name + '--svm')
        self.loadAutoEncoder(self.args.epochs)
        self.encoder.eval()
        self.decoder.eval()
        if load:
            print('  -> Loaded from SVM trained model.')
            self.svm = joblib.load(svm_path)
            return
        print()
        print("================ SVM Training Starting ================")
        all_data, all_labels = self._load_all_data()
        all_length = len(all_data)
        self.svmDataset = FERDataset(all_data,
                                     labels=all_labels,
                                     use_da=False,
                                     args=self.args)
        self.svmDataloader = DataLoader(dataset=self.svmDataset,
                                        batch_size=self.args.batch_size,
                                        shuffle=False,
                                        num_workers=self.args.num_workers)
        print("  -> Converting to extracted features ...")
        cnt = 0
        all_embeds = []
        all_labels = []
        for images, labels in self.svmDataloader:
            cnt += 1
            embeds = self.generateExtractedFeatures(images)
            all_embeds.extend(embeds.tolist())
            all_labels.extend(labels.reshape(-1).tolist())
        print('  -> Start SVM fit ...')
        self.svm.fit(X=all_embeds, y=all_labels)
        # self.svm.fit(X=all_embeds[0:3], y=[0, 1, 2])
        joblib.dump(self.svm, svm_path)
        print("  -> Done training for SVM.")

    def genTestResult(self, from_svm=True):
        print()
        print('-------------------------------------------------------')
        print('  -> Generating test result for {:} ...'.format(
            'SVM' if from_svm else 'Resnet'))
        test_data, img_names = self._load_test_data()
        test_length = len(test_data)
        self.testDataset = FERDataset(test_data,
                                      filenames=img_names,
                                      use_da=False,
                                      args=self.args)
        self.testDataloader = DataLoader(dataset=self.testDataset,
                                         batch_size=self.args.batch_size,
                                         shuffle=False,
                                         num_workers=self.args.num_workers)
        str_preds = []
        for images, filenames in self.testDataloader:
            if from_svm:
                embeds = self.generateExtractedFeatures(images)
                preds = self.svm.predict(X=embeds)
            else:
                self.resnet.eval()
                outs = self.resnet(
                    images.repeat(1, 3, 1, 1).to(self.args.device))
                preds = outs.max(-1)[1].cpu().tolist()
            str_preds.extend([LabelEnum(pred).name for pred in preds])
        # generate submission
        assert len(str_preds) == len(img_names)
        submission = pd.DataFrame({'file_name': img_names, 'class': str_preds})
        submission.to_csv(os.path.join(self.args.DATA_PATH, 'submission.csv'),
                          index=False,
                          index_label=False)
        print('  -> Done generation of submission.csv with model {:}'.format(
            self.model_name))

    def epochTrainResnet(self, epoch):
        self.resnet.train()

        cumul_loss = 0
        cumul_acc = 0
        cumul_steps = 0
        cumul_samples = 0
        cumulative_batch = 0

        self.resnetOptimizer.zero_grad()

        for idx, (images, labels) in enumerate(tqdm(self.pretrainDataloader)):
            batch_size = images.shape[0]
            images, labels = images.to(self.args.device), labels.to(
                self.args.device)
            images += torch.randn(images.shape).to(
                images.device) * self.args.add_noise
            images = images.repeat(1, 3, 1, 1)

            outs = self.resnet(images)
            preds = outs.max(-1)[1].unsqueeze(dim=1)
            cur_acc = (preds == labels).type(torch.int).sum().item()

            loss = self.resnetLossFn(outs, labels.squeeze(dim=1))

            loss_each = loss / self.args.cumul_batch
            loss_each.backward()

            cumulative_batch += 1
            cumul_steps += 1
            cumul_loss += loss.detach().cpu().item() * batch_size
            cumul_acc += cur_acc
            cumul_samples += batch_size

            if cumulative_batch >= self.args.cumul_batch:
                torch.nn.utils.clip_grad_norm_(self.resnet.parameters(),
                                               max_norm=self.args.max_norm)
                self.resnetOptimizer.step()
                self.resnetScheduler.step()
                self.resnetOptimizer.zero_grad()
                cumulative_batch = 0

            if cumul_steps >= self.args.disp_period or idx + 1 == len(
                    self.pretrainDataloader):
                print(" -> cumul_steps={:} loss={:} acc={:}".format(
                    cumul_steps, cumul_loss / cumul_samples,
                    cumul_acc / cumul_samples))
                self.pretrain_batch_cnt += 1
                self.writer.add_scalar('batch-loss',
                                       cumul_loss / cumul_samples,
                                       global_step=self.pretrain_batch_cnt)
                self.writer.add_scalar('batch-acc',
                                       cumul_acc / cumul_samples,
                                       global_step=self.pretrain_batch_cnt)
                self.writer.add_scalar(
                    'resnet_lr',
                    self.resnetOptimizer.state_dict()['param_groups'][0]['lr'],
                    global_step=self.pretrain_batch_cnt)
                cumul_steps = 0
                cumul_loss = 0
                cumul_acc = 0
                cumul_samples = 0

        if epoch % 10 == 0:
            self.saveResnet(epoch)

    def saveResnet(self, epoch):
        resnetPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Resnet" + "--EPOCH-{:}".format(epoch))
        print("-----------------------------------------------")
        print("  -> Saving Resnet {:} ......".format(self.model_name))
        torch.save(self.resnet.state_dict(), resnetPath)
        print("  -> Successfully saved Resnet.")
        print("-----------------------------------------------")

    def loadResnet(self, epoch):
        resnetPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Resnet" + "--EPOCH-{:}".format(epoch))
        print("-----------------------------------------------")
        print("  -> Loading Resnet {:} ......".format(self.model_name))
        self.resnet.load_state_dict(torch.load(resnetPath))
        print("  -> Successfully loaded Resnet.")
        print("-----------------------------------------------")

    def trainResnet(self):
        self.writer = SummaryWriter(
            os.path.join(self.args.LOG_PATH, self.model_name))
        all_data, all_labels = self._load_all_data()
        self.pretrainDataset = FERDataset(all_data,
                                          labels=all_labels,
                                          args=self.args)
        self.pretrainDataloader = DataLoader(dataset=self.pretrainDataset,
                                             batch_size=self.args.batch_size,
                                             shuffle=True,
                                             num_workers=self.args.num_workers)
        self.resnetOptimizer = self.getResnetOptimizer()
        tot_steps = math.ceil(
            len(self.pretrainDataloader) /
            self.args.cumul_batch) * self.args.epochs
        self.resnetScheduler = get_linear_schedule_with_warmup(
            self.resnetOptimizer,
            num_warmup_steps=tot_steps * self.args.warmup_rate,
            num_training_steps=tot_steps)
        self.resnetLossFn = torch.nn.CrossEntropyLoss(
            weight=torch.tensor([
                9.40661861,
                1.00104606,
                0.56843877,
                0.84912748,
                1.02660468,
                1.29337298,
                0.82603942,
            ],
                                dtype=torch.float,
                                device=self.args.device))
        epochs = self.args.epochs
        for epoch in range(1, epochs + 1):
            print()
            print(
                "================ Resnet Training Epoch {:}/{:} ================"
                .format(epoch, epochs))
            print(" ---- Start training ------>")
            self.epochTrainResnet(epoch)
            print()
        self.writer.close()

    def getResnetOptimizer(self):
        if self.args.resnet_optim == 'SGD':
            return torch.optim.SGD([{
                'params': self.resnet.baseParameters(),
                'lr': self.args.resnet_base_lr,
                'weight_decay': self.args.weight_decay,
                'momentum': self.args.resnet_momentum
            }, {
                'params': self.resnet.finetuneParameters(),
                'lr': self.args.resnet_ft_lr,
                'weight_decay': self.args.weight_decay,
                'momentum': self.args.resnet_momentum
            }],
                                   lr=self.args.resnet_base_lr)
        elif self.args.resnet_optim == 'Adam':
            return torch.optim.Adam([{
                'params': self.resnet.baseParameters(),
                'lr': self.args.resnet_base_lr
            }, {
                'params': self.resnet.finetuneParameters(),
                'lr': self.args.resnet_ft_lr,
                'weight_decay': self.args.weight_decay
            }])
Esempio n. 22
0
class NeoOriginal:
    def __init__(  # TODO move parameters to config file
            self,
            pset,
            batch_size=64,
            max_size=100,
            vocab_inp_size=32,
            vocab_tar_size=32,
            embedding_dim=64,
            units=128,
            hidden_size=128,
            alpha=0.1,
            epochs=200,
            epoch_decay=1,
            min_epochs=10,
            verbose=True):
        self.alpha = alpha
        self.batch_size = batch_size
        self.max_size = max_size
        self.epochs = epochs
        self.epoch_decay = epoch_decay
        self.min_epochs = min_epochs
        self.train_steps = 0

        self.verbose = verbose

        self.enc = Encoder(vocab_inp_size, embedding_dim, units, batch_size)
        self.dec = Decoder(vocab_inp_size, vocab_tar_size, embedding_dim,
                           units, batch_size)
        self.surrogate = Surrogate(hidden_size)
        self.population = Population(pset, max_size, batch_size)
        self.prob = 0.5

        self.optimizer = tf.keras.optimizers.Adam()
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=False, reduction='none')

    def save_models(self):
        self.enc.save_weights("model/weights/encoder/enc_{}".format(
            self.train_steps),
                              save_format="tf")
        self.dec.save_weights("model/weights/decoder/dec_{}".format(
            self.train_steps),
                              save_format="tf")
        self.surrogate.save_weights(
            "model/weights/surrogate/surrogate_{}".format(self.train_steps),
            save_format="tf")

    def load_models(self, train_steps):
        self.enc.load_weights(
            "model/weights/encoder/enc_{}".format(train_steps))
        self.dec.load_weights(
            "model/weights/decoder/dec_{}".format(train_steps))
        self.surrogate.load_weights(
            "model/weights/surrogate/surrogate_{}".format(train_steps))

    @tf.function
    def train_step(self, inp, targ, targ_surrogate, enc_hidden, enc_cell):
        autoencoder_loss = 0
        with tf.GradientTape(persistent=True) as tape:
            enc_output, enc_hidden, enc_cell = self.enc(
                inp, [enc_hidden, enc_cell])

            surrogate_output = self.surrogate(enc_hidden)
            surrogate_loss = self.surrogate_loss_function(
                targ_surrogate, surrogate_output)

            dec_hidden = enc_hidden
            dec_cell = enc_cell
            context = tf.zeros(shape=[len(dec_hidden), 1, dec_hidden.shape[1]])

            dec_input = tf.expand_dims([1] * len(inp), 1)

            for t in range(1, self.max_size):
                initial_state = [dec_hidden, dec_cell]
                predictions, context, [dec_hidden, dec_cell
                                       ], _ = self.dec(dec_input, context,
                                                       enc_output,
                                                       initial_state)
                autoencoder_loss += self.autoencoder_loss_function(
                    targ[:, t], predictions)

                # Probabilistic teacher forcing
                # (feeding the target as the next input)
                if tf.random.uniform(shape=[], maxval=1,
                                     dtype=tf.float32) > self.prob:
                    dec_input = tf.expand_dims(targ[:, t], 1)
                else:
                    pred_token = tf.argmax(predictions,
                                           axis=1,
                                           output_type=tf.dtypes.int32)
                    dec_input = tf.expand_dims(pred_token, 1)

            loss = autoencoder_loss + self.alpha * surrogate_loss

        ae_loss_per_token = autoencoder_loss / int(targ.shape[1])
        batch_loss = ae_loss_per_token + self.alpha * surrogate_loss
        batch_ae_loss = (autoencoder_loss / int(targ.shape[1]))
        batch_surrogate_loss = surrogate_loss

        gradients, variables = self.backward(loss, tape)
        self.optimize(gradients, variables)

        return batch_loss, batch_ae_loss, batch_surrogate_loss

    def backward(self, loss, tape):
        variables = \
            self.enc.trainable_variables + self.dec.trainable_variables \
            + self.surrogate.trainable_variables
        gradients = tape.gradient(loss, variables)
        return gradients, variables

    def optimize(self, gradients, variables):
        self.optimizer.apply_gradients(zip(gradients, variables))

    def surrogate_breed(self, output, latent, tape):
        gradients = tape.gradient(output, latent)
        return gradients

    def update_latent(self, latent, gradients, eta):
        latent += eta * gradients
        return latent

    def autoencoder_loss_function(self, real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, 0))
        loss_ = self.loss_object(real, pred)
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
        return tf.reduce_mean(loss_)

    def surrogate_loss_function(self, real, pred):
        loss_ = tf.keras.losses.mean_squared_error(real, pred)
        return tf.reduce_mean(loss_)

    def __train(self):

        for epoch in range(self.epochs):
            self.epoch = epoch
            start = time.time()

            total_loss = 0
            total_ae_loss = 0
            total_surrogate_loss = 0

            data_generator = self.population()
            for (batch, (inp, targ,
                         targ_surrogate)) in enumerate(data_generator):
                enc_hidden = self.enc.initialize_hidden_state(
                    batch_sz=len(inp))
                enc_cell = self.enc.initialize_cell_state(batch_sz=len(inp))
                batch_loss, batch_ae_loss, batch_surr_loss = self.train_step(
                    inp, targ, targ_surrogate, enc_hidden, enc_cell)
                total_loss += batch_loss
                total_ae_loss += batch_ae_loss
                total_surrogate_loss += batch_surr_loss

                if False and self.verbose:
                    print(f'Epoch {epoch + 1} Batch {batch} '
                          f'Loss {batch_loss.numpy():.4f}')

            if self.verbose and ((epoch + 1) % 10 == 0 or epoch == 0):
                epoch_loss = total_loss / self.population.steps_per_epoch
                ae_loss = total_ae_loss / self.population.steps_per_epoch
                surrogate_loss = \
                    total_surrogate_loss / self.population.steps_per_epoch
                epoch_time = time.time() - start
                print(f'Epoch {epoch + 1} Loss {epoch_loss:.6f} AE_loss '
                      f'{ae_loss:.6f} Surrogate_loss '
                      f'{surrogate_loss:.6f} Time: {epoch_time:.3f}')

        # decrease number of epochs, but don't go below self.min_epochs
        self.epochs = max(self.epochs - self.epoch_decay, self.min_epochs)

    def _gen_children(self,
                      candidates,
                      enc_output,
                      enc_hidden,
                      enc_cell,
                      max_eta=1000):
        children = []
        eta = 0
        enc_mask = enc_output._keras_mask
        last_copy_ind = len(candidates)
        while eta < max_eta:
            eta += 1
            start = time.time()
            new_children = self._gen_decoded(eta, enc_output, enc_hidden,
                                             enc_cell, enc_mask).numpy()
            new_children = self.cut_seq(new_children, end_token=2)
            new_ind, copy_ind = self.find_new(new_children, candidates)
            if len(copy_ind) < last_copy_ind:
                last_copy_ind = len(copy_ind)
                print("Eta {} Not-changed {} Time: {:.3f}".format(
                    eta, len(copy_ind),
                    time.time() - start))
            for i in new_ind:
                children.append(new_children[i])
            if len(copy_ind) < 1:
                break
            enc_output = tf.gather(enc_output, copy_ind)
            enc_mask = tf.gather(enc_mask, copy_ind)
            enc_hidden = tf.gather(enc_hidden, copy_ind)
            enc_cell = tf.gather(enc_cell, copy_ind)
            candidates = tf.gather(candidates, copy_ind)
        if eta == max_eta:
            print("Maximal value of eta reached - breed stopped")
        for i in copy_ind:
            children.append(new_children[i])
        return children

    def _gen_decoded(self, eta, enc_output, enc_hidden, enc_cell, enc_mask):
        with tf.GradientTape(persistent=True,
                             watch_accessed_variables=False) as tape:
            tape.watch(enc_hidden)
            surrogate_output = self.surrogate(enc_hidden)
        gradients = self.surrogate_breed(surrogate_output, enc_hidden, tape)
        dec_hidden = self.update_latent(enc_hidden, gradients, eta=eta)
        dec_cell = enc_cell
        context = tf.zeros(shape=[len(dec_hidden), 1, dec_hidden.shape[1]])

        dec_input = tf.expand_dims([1] * len(enc_hidden), 1)

        child = dec_input
        for _ in range(1, self.max_size - 1):
            initial_state = [dec_hidden, dec_cell]
            predictions, context, [dec_hidden, dec_cell
                                   ], _ = self.dec(dec_input, context,
                                                   enc_output, initial_state,
                                                   enc_mask)
            dec_input = tf.expand_dims(
                tf.argmax(predictions, axis=1, output_type=tf.dtypes.int32), 1)
            child = tf.concat([child, dec_input], axis=1)
        stop_tokens = tf.expand_dims([2] * len(enc_hidden), 1)
        child = tf.concat([child, stop_tokens], axis=1)
        return child

    def cut_seq(self, seq, end_token=2):
        ind = (seq == end_token).argmax(1)
        res = []
        tree_max = []
        for d, i in zip(seq, ind):
            repaired_tree = create_expression_tree(d[:i + 1][1:-1])
            repaired_seq = [i.data for i in repaired_tree.preorder()
                            ][-(self.max_size - 2):]
            tree_max.append(len(repaired_seq) == self.max_size - 2)
            repaired_seq = [1] + repaired_seq + [2]
            res.append(np.pad(repaired_seq, (0, self.max_size - i - 1)))
        return res

    def find_new(self, seq, candidates):
        new_ind = []
        copy_ind = []
        n = False
        cp = False
        for i, (s, c) in enumerate(zip(seq, candidates)):
            if not np.array_equal(s, c):
                if not n:
                    n = True
                new_ind.append(i)
            else:
                if not cp:
                    cp = True
                copy_ind.append(i)
        return new_ind, copy_ind

    def _gen_latent(self, candidates):
        enc_hidden = self.enc.initialize_hidden_state(batch_sz=len(candidates))
        enc_cell = self.enc.initialize_cell_state(batch_sz=len(candidates))
        enc_output, enc_hidden, enc_cell = self.enc(candidates,
                                                    [enc_hidden, enc_cell])
        return enc_output, enc_hidden, enc_cell

    def update(self):
        print("Training")
        self.enc.train()
        self.dec.train()
        self.__train()
        self.save_models()
        self.train_steps += 1

    def breed(self):
        print("Breed")
        self.dec.eval()
        data_generator = self.population(
            batch_size=len(self.population.samples))

        tokenized_pop = []
        for (batch, (inp, _, _)) in enumerate(data_generator):
            enc_output, enc_hidden, enc_cell = self._gen_latent(inp)

            tokenized_pop += (self._gen_children(inp, enc_output, enc_hidden,
                                                 enc_cell))

        pop_expressions = [
            self.population.tokenizer.reproduce_expression(tp)
            for tp in tokenized_pop
        ]
        offspring = [deap.creator.Individual(pe) for pe in pop_expressions]
        return offspring
Esempio n. 23
0
class Trainer(object):
    def __init__(self, config):
        content_trans = transforms.Compose([
            transforms.Resize(config.image_size),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        self.train_loader = torch.utils.data.DataLoader(
            HDR_LDR(config.ldr_dir, config.hdr_dir, content_trans),
            batch_size=1,
            shuffle=True,
            num_workers=config.workers,
            pin_memory=True,
            drop_last=True)

        os.makedirs(f'{config.save_dir}/{config.version}', exist_ok=True)

        self.loss_dir = f'{config.save_dir}/{config.version}/loss'
        self.model_state_dir = f'{config.save_dir}/{config.version}/model_state'
        self.image_dir = f'{config.save_dir}/{config.version}/image'
        self.psnr_dir = f'{config.save_dir}/{config.version}/psnr'
        self.code_dir = f'{config.save_dir}/{config.version}/code'

        os.makedirs(self.loss_dir, exist_ok=True)
        os.makedirs(self.model_state_dir, exist_ok=True)
        os.makedirs(self.image_dir, exist_ok=True)
        os.makedirs(self.psnr_dir, exist_ok=True)
        os.makedirs(self.code_dir, exist_ok=True)

        script_name = 'trainer_' + config.script_name + '.py'
        shutil.copyfile(os.path.join('scripts', script_name),
                        os.path.join(self.code_dir, script_name))
        shutil.copyfile('components/transformer.py',
                        os.path.join(self.code_dir, 'transformer.py'))
        shutil.copyfile('model/Fusion.py',
                        os.path.join(self.code_dir, 'Fusion.py'))

        self.encoder = Encoder().cuda()
        self.attention = Transformer(config.topk, True, False).cuda()
        self.decoder = Decoder().cuda()

        self.decoder.load_state_dict(torch.load("./hdr_decoder.pth"))

        self.config = config

    def train(self):

        optimizer = torch.optim.Adam(self.attention.parameters(),
                                     lr=self.config.learning_rate)
        criterion = torch.nn.L1Loss()  #torch.nn.MSELoss()

        self.decoder.eval()
        self.attention.eval()
        self.encoder.eval()

        for e in range(1, self.config.epoch_size + 1):
            print(f'Start {e} epoch')
            psnr_list = []
            # for i, (content, target)  in tqdm(enumerate(self.train_loader, 1)):
            for i, (ldr, hdr, ref) in enumerate(self.train_loader):
                ### data prepare
                ldr = ldr.cuda()
                hdr = hdr.cuda()

                ldrs, hdrs = patch_crop(ldr, True), patch_crop(hdr, True)
                flag = random.randint(0, 3)

                ldr_in = ldrs[flag]
                hdr_in = [hdrs[i] for i in range(4) if i != flag]
                target = hdrs[flag]

                ### encode
                fea_ldr = self.encoder(ldr_in)
                fea_hdr = [self.encoder(hdr_i) for hdr_i in hdr_in]

                ### attention swap
                out_feature, attention_map = self.attention(fea_ldr, fea_hdr)
                target_feature = self.encoder(target)

                ### decode
                out_image = self.decoder(out_feature)
                out_image = align_shape(out_image, target)

                loss = criterion(out_image, target) * 0.1
                loss_fea = criterion(out_feature, target_feature) * 10
                loss = loss + loss_fea

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if i % self.config.log_interval == 0:
                    now = datetime.datetime.now()
                    otherStyleTime = now.strftime("%Y-%m-%d %H:%M:%S")
                    print(otherStyleTime)
                    print('epoch: ', e, ' iter: ', i)
                    print('loss:', loss.cpu().item())
                    print('loss_fea:', loss_fea.cpu().item())

                    self.attention.hard = True
                    self.attention.eval()
                    with torch.no_grad():
                        out_feature, _ = self.attention(fea_ldr, fea_hdr)
                        out_image = self.decoder(out_feature)
                        out_image = align_shape(out_image, target)

                    self.attention.hard = False
                    self.attention.train()

                    ldrs[flag] = out_image
                    print('attention scartters: ',
                          torch.std(attention_map.argmax(-1).float(), 1).cpu())
                    print(attention_map.shape)

                    tosave = torch.cat([patch_crop(ldrs, False), hdr], -1)
                    save_image(
                        denorm(tosave[0]).cpu(),
                        self.image_dir + '/epoch_{}-iter_{}.png'.format(e, i))
                    print("image saved to " + self.image_dir +
                          '/epoch_{}-iter_{}.png'.format(e, i))

                    torch.save(
                        self.attention.state_dict(),
                        f'{self.model_state_dir}/epoch_{e}-iter_{i}.pth')
Esempio n. 24
0
class Trainer(object):
    def __init__(self, config):
        self.train_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(config.train_data_dir, transforms.Compose([
                transforms.RandomSizedCrop(config.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=config.batch_size, shuffle=True,
            num_workers=config.workers, pin_memory=True)

        os.makedirs(f'{config.save_dir}/{config.version}',exist_ok=True)

        self.loss_dir = f'{config.save_dir}/{config.version}/loss'
        self.model_state_dir = f'{config.save_dir}/{config.version}/model_state'
        self.image_dir = f'{config.save_dir}/{config.version}/image'
        self.psnr_dir = f'{config.save_dir}/{config.version}/psnr'

        os.makedirs(self.loss_dir,exist_ok=True)
        os.makedirs(self.model_state_dir,exist_ok=True)
        os.makedirs(self.image_dir,exist_ok=True)
        os.makedirs(self.psnr_dir,exist_ok=True)

        self.encoder = Encoder(True).cuda()
        self.decoder = Decoder(False, True).cuda()
        self.D = VGG16_mid().cuda()

        self.config = config

    def train(self):

        optimizer = torch.optim.Adam(itertools.chain(self.encoder.parameters(),
                                                     self.decoder.parameters()),
                                                     lr=self.config.learning_rate)
        criterion = torch.nn.L1Loss()#torch.nn.MSELoss()

        loss_list = []
        psnr_list = []
        self.encoder.train()
        self.decoder.train()
        for e in range(1, self.config.epoch_size+1):
            print(f'Start {e} epoch')
            psnr_list = []
            # for i, (content, target)  in tqdm(enumerate(self.train_loader, 1)):
            for i, (content, target)  in enumerate(self.train_loader):
                content = content.cuda()
                content_feature = self.encoder(content)
                out_content = self.decoder(content_feature)

                loss = criterion(content, out_content)

                c1,c2,c3,_ = self.D(content)
                h1,h2,h3,_ = self.D(out_content)

                b,c,w,h = c3.shape
                loss_content = torch.norm(c3-h3,p=2)/c/w/h
                loss_perceptual = 0
                for t in range(3):
                    loss_perceptual += criterion( gram_matrix(eval('c'+str(t+1))), gram_matrix(eval('h'+str(t+1))) )
                loss = loss + loss_content + loss_perceptual*10000

                loss_list.append(loss.item())

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    if i%self.config.log_interval == 0:
                        print(loss.item())
                        print(loss_content.item())
                        print(loss_perceptual.item()*10000)
                        psnr = PSNR2(denorm(content).cpu().numpy(),denorm(out_content).cpu().numpy())
                        psnr_list.append(psnr)
                        print('psnr:',psnr)

                        ori = torch.cat(list(denorm(content)), 2)
                        out = torch.cat(list(denorm(out_content)), 2)
                        save_image(torch.cat([ori,out], 1), self.image_dir+'/epoch_{}.png'.format(e))
                        print("image saved to " + self.image_dir + '/epoch_{}.png'.format(e))

                        torch.save(self.decoder.state_dict(), f'{self.model_state_dir}/{e}_epoch.pth')
                        filename = self.psnr_dir+'/e'+ str(e) + '.pkl'
                        joblib.dump(psnr_list,filename)

        self.plot_loss_curve(loss_list)

    def plot_loss_curve(self, loss_list):
        plt.plot(range(len(loss_list)), loss_list)
        plt.xlabel('iteration')
        plt.ylabel('loss')
        plt.title('train loss')
        plt.savefig(f'{self.loss_dir}/train_loss.png')
        with open(f'{self.loss_dir}/loss_log.txt', 'w') as f:
            for l in loss_list:
                f.write(f'{l}\n')
        print(f'Loss saved in {self.loss_dir}')   
Esempio n. 25
0
    sample_cell = encoder.initialize_cell_state()
    sample_output, sample_hidden, cell_hidden = encoder(
        example_input_batch, [sample_hidden, sample_cell])
    print(
        'Encoder output shape: (batch size, sequence length, units) {}'.format(
            sample_output.shape))
    print('Encoder Hidden state shape: (batch size, units) {}'.format(
        sample_hidden.shape))
    print('Encoder Cell state shape: (batch size, units) {}'.format(
        sample_hidden.shape))

    # Attention
    attention_layer = Attention()
    attention_result, attention_weights = attention_layer(
        sample_hidden, sample_output)

    print("Attention result shape: (batch size, units) {}".format(
        attention_result.shape))
    print(
        "Attention weights shape: (batch_size, sequence_length, 1) {}".format(
            attention_weights.shape))

    # Decoder
    decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)

    sample_decoder_output, _, _, _ = decoder(
        tf.random.uniform((BATCH_SIZE, 1)), sample_hidden, sample_output)

    print('Decoder output shape: (batch_size, vocab size) {}'.format(
        sample_decoder_output.shape))
Esempio n. 26
0
def main():
    
    #create tensorboard summary writer
    writer = SummaryWriter(args.experiment_id)
    #[TODO] may need to resize input image
    cudnn.enabled = True
    #create model: Encoder
    model_encoder = Encoder()
    model_encoder.train()
    model_encoder.cuda(args.gpu)
    optimizer_encoder = optim.Adam(model_encoder.parameters(), lr=args.learning_rate, betas=(0.95, 0.99))
    optimizer_encoder.zero_grad()

    #create model: Decoder
    model_decoder = Decoder()
    model_decoder.train()
    model_decoder.cuda(args.gpu)
    optimizer_decoder = optim.Adam(model_decoder.parameters(), lr=args.learning_rate, betas=(0.95, 0.99))
    optimizer_decoder.zero_grad()
    
    l2loss = nn.MSELoss()
    
    #load data
    for i in range(1, 360002, 30000):
        train_data, valid_data = get_data(i)
        for e in range(1, args.epoch + 1):
            train_loss_value = 0
            validation_loss_value = 0
            for j in range(0, int(args.train_size/4), args.batch_size):
                optimizer_decoder.zero_grad()
                optimizer_decoder.zero_grad()
                image = Variable(torch.tensor(train_data[j: j + args.batch_size, :, :])).cuda(args.gpu)
                latent = model_encoder(image)
                img_recon = model_decoder(latent)
                img_recon = F.interpolate(img_recon, size=image.shape[2:], mode='bilinear', align_corners=True) 
                loss = l2loss(img_recon, image)
                train_loss_value += loss.data.cpu().numpy() / args.batch_size
                loss.backward()
                optimizer_decoder.step()
                optimizer_encoder.step()
            print("data load: {:8d}".format(i))
            print("epoch: {:8d}".format(e))
            print("train_loss: {:08.6f}".format(train_loss_value / (args.train_size / args.batch_size)))
            for j in range(0,int(args.validation_size/4), args.batch_size):
                model_encoder.eval()
                model_decoder.eval() 
                image = Variable(torch.tensor(valid_data[j: j + args.batch_size, :, :])).cuda(args.gpu)
                latent = model_encoder(image)
                img_recon = model_decoder(latent)
                img_1 = img_recon[0][0]
                img = image[0][0]
                img_recon = F.interpolate(img_recon, size=image.shape[2:], mode='bilinear', align_corners=True) 
                save_image(img_1, args.image_dir + '/fake' + str(i) + "_" + str(j) + ".png")
                save_image(img, args.image_dir + '/real' + str(i) + "_" + str(j) + ".png")
                image = Variable(torch.tensor(train_data[j: j + args.batch_size, :, :, :])).cuda(args.gpu)
                loss = l2loss(img_recon, image)
                validation_loss_value += loss.data.cpu().numpy() / args.batch_size
            model_encoder.train()
            model_decoder.train()
            print("train_loss: {:08.6f}".format(validation_loss_value / (args.validation_size / args.batch_size)))
        torch.save({'encoder_state_dict': model_encoder.state_dict()}, osp.join(args.checkpoint_dir, 'AE_encoder.pth'))
        torch.save({'decoder_state_dict': model_decoder.state_dict()}, osp.join(args.checkpoint_dir, 'AE_decoder.pth'))
Esempio n. 27
0
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(config.num_vocab,  # 词汇表大小
                                   config.embedding_size,  # 嵌入层维度
                                   config.pad_id,  # pad_id
                                   config.dropout)

        # post编码器
        self.post_encoder = Encoder(config.post_encoder_cell_type,  # rnn类型
                                    config.embedding_size,  # 输入维度
                                    config.post_encoder_output_size,  # 输出维度
                                    config.post_encoder_num_layers,  # rnn层数
                                    config.post_encoder_bidirectional,  # 是否双向
                                    config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(config.response_encoder_cell_type,
                                        config.embedding_size,  # 输入维度
                                        config.response_encoder_output_size,  # 输出维度
                                        config.response_encoder_num_layers,  # rnn层数
                                        config.response_encoder_bidirectional,  # 是否双向
                                        config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(config.post_encoder_output_size,  # post输入维度
                                  config.latent_size,  # 潜变量维度
                                  config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(config.post_encoder_output_size,  # post输入维度
                                          config.response_encoder_output_size,  # response输入维度
                                          config.latent_size,  # 潜变量维度
                                          config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(config.post_encoder_output_size+config.latent_size,
                                          config.decoder_cell_type,
                                          config.decoder_output_size,
                                          config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(config.decoder_cell_type,  # rnn类型
                               config.embedding_size,  # 输入维度
                               config.decoder_output_size,  # 输出维度
                               config.decoder_num_layers,  # rnn层数
                               config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1)
        )

    def forward(self, inputs, inference=False, max_len=60, gpu=True):
        if not inference:  # 训练
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            id_responses = inputs['responses']  # [batch, seq]
            len_responses = inputs['len_responses']  # [batch, seq]
            sampled_latents = inputs['sampled_latents']  # [batch, latent_size]
            len_decoder = id_responses.size(1) - 1

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            embed_responses = self.embedding(id_responses)  # [batch, seq, embed_size]
            # state: [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts)
            _, state_responses = self.response_encoder(embed_responses.transpose(0, 1), len_responses)
            if isinstance(state_posts, tuple):
                state_posts = state_posts[0]
            if isinstance(state_responses, tuple):
                state_responses = state_responses[0]
            x = state_posts[-1, :, :]  # [batch, dim]
            y = state_responses[-1, :, :]  # [batch, dim]

            # p(z|x)
            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            # p(z|x,y)
            mu, logvar = self.recognize_net(x, y)  # [batch, latent]
            # 重参数化
            z = mu + (0.5 * logvar).exp() * sampled_latents  # [batch, latent]

            # 解码器的输入为回复去掉end_id
            decoder_inputs = embed_responses[:, :-1, :].transpose(0, 1)  # [seq-1, batch, embed_size]
            decoder_inputs = decoder_inputs.split([1] * len_decoder, 0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]
            first_state = self.prepare_state(torch.cat([z, x], 1))  # [num_layer, batch, dim_out]

            outputs = []
            for idx in range(len_decoder):
                if idx == 0:
                    state = first_state  # 解码器初始状态
                decoder_input = decoder_inputs[idx]  # 当前时间步输入 [1, batch, embed_size]
                # output: [1, batch, dim_out]
                # state: [num_layer, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                assert output.squeeze().equal(state[0][-1])
                outputs.append(output)

            outputs = torch.cat(outputs, 0).transpose(0, 1)  # [batch, seq-1, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq-1, num_vocab]

            return output_vocab, _mu, _logvar, mu, logvar
        else:  # 测试
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            sampled_latents = inputs['sampled_latents']  # [batch, latent_size]
            batch_size = id_posts.size(0)

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            # state = [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts)
            if isinstance(state_posts, tuple):  # 如果是lstm则取h
                state_posts = state_posts[0]  # [layers, batch, dim]
            x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]

            # p(z|x)
            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            # 重参数化
            z = _mu + (0.5 * _logvar).exp() * sampled_latents  # [batch, latent]

            first_state = self.prepare_state(torch.cat([z, x], 1))  # [num_layer, batch, dim_out]
            done = torch.tensor([0] * batch_size).bool()
            first_input_id = (torch.ones((1, batch_size)) * self.config.start_id).long()
            if gpu:
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            outputs = []
            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = first_state  # 解码器初始状态
                    decoder_input = self.embedding(first_input_id)  # 解码器初始输入 [1, batch, embed_size]
                else:
                    decoder_input = self.embedding(next_input_id)  # [1, batch, embed_size]
                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                _done = next_input_id.squeeze(0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break

            outputs = torch.cat(outputs, 0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, _mu, _logvar, None, None

    def print_parameters(self):
        r""" 统计参数 """
        total_num = 0  # 参数总数
        for param in self.parameters():
            num = 1
            if param.requires_grad:
                size = param.size()
                for dim in size:
                    num *= dim
            total_num += num
        print(f"参数总数: {total_num}")

    def save_model(self, epoch, global_step, path):
        r""" 保存模型 """
        torch.save({'embedding': self.embedding.state_dict(),
                    'post_encoder': self.post_encoder.state_dict(),
                    'response_encoder': self.response_encoder.state_dict(),
                    'prior_net': self.prior_net.state_dict(),
                    'recognize_net': self.recognize_net.state_dict(),
                    'prepare_state': self.prepare_state.state_dict(),
                    'decoder': self.decoder.state_dict(),
                    'projector': self.projector.state_dict(),
                    'epoch': epoch,
                    'global_step': global_step}, path)

    def load_model(self, path):
        r""" 载入模型 """
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        self.embedding.load_state_dict(checkpoint['embedding'])
        self.post_encoder.load_state_dict(checkpoint['post_encoder'])
        self.response_encoder.load_state_dict(checkpoint['response_encoder'])
        self.prior_net.load_state_dict(checkpoint['prior_net'])
        self.recognize_net.load_state_dict(checkpoint['recognize_net'])
        self.prepare_state.load_state_dict(checkpoint['prepare_state'])
        self.decoder.load_state_dict(checkpoint['decoder'])
        self.projector.load_state_dict(checkpoint['projector'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        return epoch, global_step