예제 #1
0
    def train_Siamese_layer(self,layer_id):
        DISP_FREQ = self.DISP_FREQs[layer_id]

        for epoch in range(self.S_max_epoch[layer_id]):
            self.embednet.train()
            print('VCN Epoch [{}/{}]'.format(epoch, self.S_max_epoch[layer_id]))
            for idx, batch in enumerate(tqdm(self.S_dataloader)):
                image_f = batch['image_F'].to(self.device)
                image_l = batch['image_L'].to(self.device)
                label = batch['label'].to(self.device)

                r_image_f = F.interpolate(image_f, size=(2 ** layer_id) * self.base_size)
                r_image_l = F.interpolate(image_l, size=(2 ** layer_id) * self.base_size)

                self.S_optimizer.zero_grad()
                pred = self.embednet(r_image_f, r_image_l)
                loss = self.S_criterion(pred, label)
                loss.backward()
                self.S_optimizer.step()

                # if ((idx + 1) % DISP_FREQ == 0) and idx != 0:

                self.writer.add_scalar('Train_Siamese {}_loss'.format(layer_id),
                                       loss.item(),
                                       epoch * len(self.S_dataloader) + idx)
                self.writer.add_images("Train_front_{}_Original".format(layer_id),
                                       deNorm(r_image_f),
                                       epoch * len(self.S_dataloader) + idx)
                self.writer.add_images("Train_lateral_{}_Original".format(layer_id),
                                       deNorm(r_image_l),
                                       epoch * len(self.S_dataloader) + idx)

            self.S_lr_scheduler.step(epoch)

            self.embednet.eval()
            total = 0
            correct = 0
            for idx, batch in enumerate(tqdm(self.S_dataloader)):
                image_f = batch['image_F'].to(self.device)
                image_l = batch['image_L'].to(self.device)
                label = batch['label'].to(self.device)
                r_image_f = F.interpolate(image_f, size=(2 ** layer_id) * self.base_size)
                r_image_l = F.interpolate(image_l, size=(2 ** layer_id) * self.base_size)

                pred = self.embednet(r_image_f, r_image_l)
                pred[pred>0.5]=1
                pred[pred<=0.5]=0

                total += pred.shape[0]
                correct += torch.sum(pred==label).item()

            acc = correct / total
            # acc = self.evaluate_Siamese(layer_id)

            print(print("Accuracy {}".format(acc)))
            self.writer.add_scalar('Acc_Siamese_Layer {}'.format(layer_id),
                                   acc,
                                   epoch)
예제 #2
0
    def save_origin(self):

        for idx, batch in enumerate(tqdm(self.test_dataloader)):
            image_f = batch['image_F'].to(self.device)
            image_l = batch['image_L'].to(self.device)
            image_f = deNorm(image_f).data.cpu()
            image_l = deNorm(image_l).data.cpu()
            subject_id = batch['subject_id'].data.cpu().numpy()
            for i in range(image_f.shape[0]):
                save_image(
                    image_f[i], '{}/{}_f.png'.format(self.save_img_dir,
                                                     subject_id[i]))
                save_image(
                    image_l[i], '{}/{}_l.png'.format(self.save_img_dir,
                                                     subject_id[i]))
예제 #3
0
    def test(self):
        self.load_model()
        self.encoder.eval()
        self.decoder_F.eval()
        self.decoder_L.eval()
        print("Start generating")
        for idx, batch in enumerate(tqdm(self.test_dataloader)):
            finding = batch['finding'].to(self.device)
            impression = batch['impression'].to(self.device)
            txt_emded, hidden = self.encoder(finding, impression)
            pre_image_f = self.decoder_F(txt_emded)
            pre_image_l = self.decoder_L(txt_emded)
            pre_image_f = deNorm(pre_image_f).data.cpu()
            pre_image_l = deNorm(pre_image_l).data.cpu()
            subject_id = batch['subject_id'].data.cpu().numpy()
            for i in range(pre_image_f.shape[0]):

                save_image(pre_image_f[i],'{}/{}_f.png'.format(self.save_img_dir,subject_id[i]))
                save_image(pre_image_l[i],'{}/{}_l.png'.format(self.save_img_dir,subject_id[i]))
예제 #4
0
    def train_layer(self):
        DISP_FREQ = 10
        for epoch in range(20):
            print('Generator Epoch {}'.format(epoch))
            self.encoder.train()
            self.decoder_F.train()
            self.decoder_L.train()
            for idx, batch in enumerate(tqdm(self.train_dataloader)):
                finding = batch['finding'].to(self.device)
                impression = batch['impression'].to(self.device)
                image_f = batch['image_F'].to(self.device)
                image_l = batch['image_L'].to(self.device)

                loss_f, pre_image_f, r_image_f = self.Loss_on_layer(
                    image_f, finding, impression, self.decoder_F)
                loss_l, pre_image_l, r_image_l = self.Loss_on_layer(
                    image_l, finding, impression, self.decoder_L)

                # print('Loss: {:.4f}'.format(loss.item()))
                if ((idx + 1) % DISP_FREQ == 0) and idx != 0:
                    self.writer.add_scalar(
                        'Train_front loss', loss_f.item(),
                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_scalar(
                        'Train_lateral loss', loss_l.item(),
                        epoch * len(self.train_dataloader) + idx)

                    # write to tensorboard
                    self.writer.add_images(
                        "Train_front_Original", deNorm(r_image_f),
                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_images(
                        "Train_front_Predicted", deNorm(pre_image_f),
                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_images(
                        "Train_lateral_Original", deNorm(r_image_l),
                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_images(
                        "Train_lateral_Predicted", deNorm(pre_image_l),
                        epoch * len(self.train_dataloader) + idx)

            self.G_lr_scheduler.step(epoch)
예제 #5
0
    def train_GAN_layer(self):
        DISP_FREQ = 10
        self.encoder.train()
        self.decoder_F.train()
        self.decoder_L.train()
        self.D_F.train()
        self.D_L.train()
        for epoch in range(self.max_epoch):
            print('GAN Epoch {}'.format(epoch))

            for idx, batch in enumerate(tqdm(self.train_dataloader)):
                finding = batch['finding'].to(self.device)
                impression = batch['impression'].to(self.device)
                image_f = batch['image_F'].to(self.device)
                image_l = batch['image_L'].to(self.device)

                D_loss_f, G_loss_f, pre_image_f, image_f = self.Loss_on_layer_GAN(
                    image_f, finding, impression, self.decoder_F, self.D_F)
                D_loss_l, G_loss_l, pre_image_l, image_l = self.Loss_on_layer_GAN(
                    image_l, finding, impression, self.decoder_L, self.D_L)

                if ((idx + 1) % DISP_FREQ == 0) and idx != 0:
                    # ...log the running loss
                    # self.writer.add_scalar("Train_{}_SSIM".format(layer_id), ssim.ssim(r_image, pre_image).item(),
                    #                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_scalar(
                        'GAN_G_train_Layer_front_loss', G_loss_f.item(),
                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_scalar(
                        'GAN_D_train_Layer_front_loss', D_loss_f.item(),
                        epoch * len(self.train_dataloader) + idx)

                    self.writer.add_scalar(
                        'GAN_G_train_Layer_lateral_loss', G_loss_l.item(),
                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_scalar(
                        'GAN_D_train_Layer_lateral_loss', D_loss_l.item(),
                        epoch * len(self.train_dataloader) + idx)

                    # write to tensorboard
                    self.writer.add_images(
                        "GAN_Train_Original_front", deNorm(image_f),
                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_images(
                        "GAN_Train_Predicted_front", deNorm(pre_image_f),
                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_images(
                        "GAN_Train_Original_lateral", deNorm(image_l),
                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_images(
                        "GAN_Train_Predicted_lateral", deNorm(pre_image_l),
                        epoch * len(self.train_dataloader) + idx)
            self.G_lr_scheduler.step(epoch)
            self.D_lr_scheduler.step(epoch)
            if epoch % 10 == 0 and epoch != 0:
                torch.save(
                    self.encoder.state_dict(),
                    os.path.join(
                        self.encoder_checkpoint,
                        "Encoder_{}_epoch_{}_checkpoint.pth".format(
                            self.cfg["ENCODER"], epoch)))
                torch.save(
                    self.D_F.state_dict(),
                    os.path.join(
                        self.D_checkpoint,
                        "D_{}_F_epoch_{}_checkpoint.pth".format(
                            self.cfg["DISCRIMINATOR"], epoch)))
                torch.save(
                    self.D_L.state_dict(),
                    os.path.join(
                        self.D_checkpoint,
                        "D_{}_L_epoch_{}_checkpoint.pth".format(
                            self.cfg["DISCRIMINATOR"], epoch)))

                torch.save(
                    self.decoder_F.state_dict(),
                    os.path.join(
                        self.decoder_checkpoint,
                        "Decoder_{}_F_epoch_{}_checkpoint.pth".format(
                            self.cfg["DECODER"], epoch)))

                torch.save(
                    self.decoder_L.state_dict(),
                    os.path.join(
                        self.decoder_checkpoint,
                        "Decoder_{}_L_epoch_{}_checkpoint.pth".format(
                            self.cfg["DECODER"], epoch)))
예제 #6
0
    def train_GAN_layer(self,layer_id):
        DISP_FREQ = self.DISP_FREQs[layer_id]
        self.encoder.train()
        self.decoder_F.train()
        self.decoder_L.train()
        self.D_F.train()
        self.D_L.train()
        for epoch in range(self.max_epoch[layer_id]):
            print('GAN Epoch [{}/{}]'.format(epoch,self.max_epoch[layer_id]))

            for idx, batch in enumerate(tqdm(self.train_dataloader)):
                finding = batch['finding'].to(self.device)
                impression = batch['impression'].to(self.device)
                image_f = batch['image_F'].to(self.device)
                image_l = batch['image_L'].to(self.device)

                D_loss_f, G_loss_f, pre_image_f, image_f = self.Loss_on_layer_GAN(image_f, finding, impression, layer_id, self.decoder_F,self.D_F)
                D_loss_l, G_loss_l, pre_image_l, image_l = self.Loss_on_layer_GAN(image_l, finding, impression, layer_id, self.decoder_L,self.D_L)

                # train with view consistency loss
                self.G_optimizer.zero_grad()
                pred = self.embednet(pre_image_f,pre_image_l)
                id_loss = self.vc_loss_ratio * self.S_criterion(pred,torch.zeros_like(pred).to(self.device))
                id_loss.backward()
                self.G_optimizer.step()

                if ((idx + 1) % DISP_FREQ == 0) and idx != 0:
                    # ...log the running loss
                    # self.writer.add_scalar("Train_{}_SSIM".format(layer_id), ssim.ssim(r_image, pre_image).item(),
                    #                        epoch * len(self.train_dataloader) + idx)
                    self.writer.add_scalar('GAN_G_train_Layer_front_{}_loss'.format(layer_id),
                                           G_loss_f.item(),
                                      epoch * len(self.train_dataloader) + idx)
                    self.writer.add_scalar('GAN_D_train_Layer_front_{}_loss'.format(layer_id),
                                           D_loss_f.item(),
                                           epoch * len(self.train_dataloader) + idx)

                    self.writer.add_scalar('GAN_G_train_Layer_lateral_{}_loss'.format(layer_id),
                                           G_loss_l.item(),
                                           epoch * len(self.train_dataloader) + idx)
                    self.writer.add_scalar('GAN_D_train_Layer_lateral_{}_loss'.format(layer_id),
                                           D_loss_l.item(),
                                           epoch * len(self.train_dataloader) + idx)
                    # write to tensorboard
                    self.writer.add_images("GAN_Train_Original_front_{}".format(layer_id),
                                           deNorm(image_f),
                                           epoch * len(self.train_dataloader) + idx)
                    self.writer.add_images("GAN_Train_Predicted_front_{}".format(layer_id),
                                           deNorm(pre_image_f),
                                           epoch * len(self.train_dataloader) + idx)
                    self.writer.add_images("GAN_Train_Original_lateral_{}".format(layer_id),
                                           deNorm(image_l),
                                           epoch * len(self.train_dataloader) + idx)
                    self.writer.add_images("GAN_Train_Predicted_lateral_{}".format(layer_id),
                                           deNorm(pre_image_l),
                                           epoch * len(self.train_dataloader) + idx)
            self.G_lr_scheduler.step(epoch)
            self.D_lr_scheduler.step(epoch)
            if (epoch+1) % 20 == 0 and epoch != 0:
                torch.save(self.encoder.state_dict(), os.path.join(self.encoder_checkpoint,
                                                                   "Encoder_{}_Layer_{}_Time_{}_checkpoint.pth".format(
                                                                       self.cfg["ENCODER"], layer_id,
                                                                       get_time())))
                torch.save(self.D_F.state_dict(), os.path.join(self.D_checkpoint,
                                                               "D_{}_F_Layer_{}_Time_{}_checkpoint.pth".format(
                                                                   self.cfg["DISCRIMINATOR"], layer_id, get_time())))
                torch.save(self.D_L.state_dict(), os.path.join(self.D_checkpoint,
                                                               "D_{}_L_Layer_{}_Time_{}_checkpoint.pth".format(
                                                                   self.cfg["DISCRIMINATOR"], layer_id, get_time())))

                torch.save(self.decoder_F.state_dict(), os.path.join(self.decoder_checkpoint,
                                                                     "Decoder_{}_F_Layer_{}_Time_{}_checkpoint.pth".format(
                                                                         self.cfg["DECODER"], layer_id, get_time())))

                torch.save(self.decoder_L.state_dict(), os.path.join(self.decoder_checkpoint,
                                                                     "Decoder_{}_L_Layer_{}_Time_{}_checkpoint.pth".format(
                                                                         self.cfg["DECODER"], layer_id, get_time())))