def __init__(self, config):
        super().__init__(config)

        self.net = Discriminator(self.config)  # Segmenation Network
        if config.phase == 'testing':
            self.testloader = Supervised_Dataset(self.config, "testing")
        else:
            self.trainloader = Supervised_Dataset(self.config, "training")
            self.valloader = Supervised_Dataset(self.config, "validating")

        # optimizer
        self.optimizer = torch.optim.Adam(self.net.parameters(),
                                          lr=self.config.learning_rate,
                                          betas=(self.config.beta1,
                                                 self.config.beta2))

        # counter initialization
        self.current_epoch = 0
        self.best_validation_dice = 0
        self.current_iteration = 0

        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        if self.cuda:
            self.net = self.net.cuda()

        class_weights = torch.tensor([[0.33, 1.5, 0.83, 1.33]])
        if self.cuda:
            class_weights = torch.FloatTensor(class_weights).cuda()
        self.criterion = nn.CrossEntropyLoss(class_weights)

        # set the manual seed for torch
        if not self.config.seed:
            self.manual_seed = random.randint(1, 10000)
        else:
            self.manual_seed = self.config.seed
        self.logger.info("seed: %d", self.manual_seed)
        random.seed(self.manual_seed)
        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            torch.cuda.manual_seed_all(self.manual_seed)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU***** ")

        if (self.config.load_chkpt == True):
            self.load_checkpoint()
class Supervised_Model(BaseAgent):
    def __init__(self, config):
        super().__init__(config)

        self.net = Discriminator(self.config)  # Segmenation Network
        if config.phase == 'testing':
            self.testloader = Supervised_Dataset(self.config, "testing")
        else:
            self.trainloader = Supervised_Dataset(self.config, "training")
            self.valloader = Supervised_Dataset(self.config, "validating")

        # optimizer
        self.optimizer = torch.optim.Adam(self.net.parameters(),
                                          lr=self.config.learning_rate,
                                          betas=(self.config.beta1,
                                                 self.config.beta2))

        # counter initialization
        self.current_epoch = 0
        self.best_validation_dice = 0
        self.current_iteration = 0

        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        if self.cuda:
            self.net = self.net.cuda()

        class_weights = torch.tensor([[0.33, 1.5, 0.83, 1.33]])
        if self.cuda:
            class_weights = torch.FloatTensor(class_weights).cuda()
        self.criterion = nn.CrossEntropyLoss(class_weights)

        # set the manual seed for torch
        if not self.config.seed:
            self.manual_seed = random.randint(1, 10000)
        else:
            self.manual_seed = self.config.seed
        self.logger.info("seed: %d", self.manual_seed)
        random.seed(self.manual_seed)
        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            torch.cuda.manual_seed_all(self.manual_seed)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU***** ")

        if (self.config.load_chkpt == True):
            self.load_checkpoint()

    def load_checkpoint(self, phase):
        try:
            if phase == 'training':
                filename = os.path.join(self.config.checkpoint_dir,
                                        'checkpoint.pth.tar')
            elif phase == 'testing':
                filename = os.path.join(self.config.checkpoint_dir,
                                        'model_best.pth.tar')
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.net.load_state_dict(checkpoint['net'])
            self.manual_seed = checkpoint['manual_seed']

            self.logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {})\n".
                format(self.config.checkpoint_dir, checkpoint['epoch']))

        except OSError as e:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    self.config.checkpoint_dir))
            self.logger.info("**First time to train**")

    def save_checkpoint(self, is_best=False):
        file_name = "checkpoint.pth.tar"
        state = {
            'epoch': self.current_epoch,
            'net': self.net.state_dict(),
            'manual_seed': self.manual_seed
        }
        torch.save(state, os.path.join(self.config.checkpoint_dir, file_name))
        if is_best:
            print("SAVING BEST CHECKPOINT !!!")
            shutil.copyfile(self.config.checkpoint_dir + file_name,
                            self.config.checkpoint_dir + 'model_best.pth.tar')

    def run(self):
        try:
            if self.config.phase == 'training':
                self.train()
            if self.config.phase == 'testing':
                self.load_checkpoint(self.config.phase)
                self.test()
        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        for epoch in range(self.current_epoch, self.config.epochs):
            self.current_epoch = epoch
            self.current_iteration = 0
            self.train_one_epoch()
            self.save_checkpoint()
            if (self.current_epoch % self.config.validation_every_epoch == 0):
                self.validate()

    def train_one_epoch(self):
        # initialize tqdm batch
        tqdm_batch = tqdm(self.trainloader.loader,
                          total=self.trainloader.num_iterations,
                          desc="epoch-{}-".format(self.current_epoch))

        self.net.train()
        epoch_loss = AverageMeter()

        for curr_it, (patches, labels) in enumerate(tqdm_batch):
            #y = torch.full((self.batch_size,), self.real_label)
            if self.cuda:
                patches = patches.cuda()
                labels = labels.cuda()

            patches = Variable(patches)
            labels = Variable(labels).long()

            self.net.zero_grad()
            output_logits, output_prob = self.net(patches)
            loss = self.criterion(output_logits, labels)

            loss.backward()
            self.optimizer.step()

            epoch_loss.update(loss.item())
            self.current_iteration += 1
            print("Epoch: {0}, Iteration: {1}/{2}, Loss: {3}".format(self.current_epoch, self.current_iteration,\
                                                                    self.trainloader.num_iterations, loss.item()))

        tqdm_batch.close()

        self.logger.info("Training at epoch-" + str(self.current_epoch) +
                         " | " + "Model loss: " + str(epoch_loss.val))

    def validate(self):
        self.net.eval()
        prediction_image = torch.zeros([self.valloader.dataset.label.shape[0], self.config.patch_shape[0],\
                                        self.config.patch_shape[1], self.config.patch_shape[2]])
        whole_vol = self.valloader.dataset.whole_vol
        for batch_number, (patches, label,
                           _) in enumerate(self.valloader.loader):
            patches = patches.cuda()
            _, batch_prediction_softmax = self.net(patches)
            batch_prediction = torch.argmax(batch_prediction_softmax,
                                            dim=1).cpu()
            prediction_image[
                batch_number * self.config.batch_size:(batch_number + 1) *
                self.config.batch_size, :, :, :] = batch_prediction

            print("Validating.. [{0}/{1}]".format(
                batch_number, self.valloader.num_iterations))

        vol_shape_x, vol_shape_y, vol_shape_z = self.config.volume_shape
        prediction_image = prediction_image.numpy()
        val_image_pred = recompose3D_overlap(prediction_image, vol_shape_x,
                                             vol_shape_y, vol_shape_z,
                                             self.config.extraction_step[0],
                                             self.config.extraction_step[1],
                                             self.config.extraction_step[2])
        val_image_pred = val_image_pred.astype('uint8')
        pred2d = np.reshape(val_image_pred,
                            (val_image_pred.shape[0] * vol_shape_x *
                             vol_shape_y * vol_shape_z))
        lab2d = np.reshape(
            whole_vol,
            (whole_vol.shape[0] * vol_shape_x * vol_shape_y * vol_shape_z))

        classes = list(range(0, self.config.num_classes))
        F1_score = f1_score(lab2d, pred2d, classes, average=None)
        print("Validation Dice Coefficient.... ")
        print("Background:", F1_score[0])
        print("CSF:", F1_score[1])
        print("GM:", F1_score[2])
        print("WM:", F1_score[3])

        current_validation_dice = F1_score[2] + F1_score[3]
        if (self.best_validation_dice < current_validation_dice):
            self.best_validation_dice = current_validation_dice
            self.save_checkpoint(is_best=True)

    def test(self):
        self.net.eval()

        prediction_image = torch.zeros([self.testloader.dataset.patches.shape[0], self.config.patch_shape[0],\
                                        self.config.patch_shape[1], self.config.patch_shape[2]])
        whole_vol = self.testloader.dataset.whole_vol
        for batch_number, (patches, _) in enumerate(self.testloader.loader):
            patches = patches.cuda()
            _, batch_prediction_softmax = self.net(patches)
            batch_prediction = torch.argmax(batch_prediction_softmax,
                                            dim=1).cpu()
            prediction_image[
                batch_number * self.config.batch_size:(batch_number + 1) *
                self.config.batch_size, :, :, :] = batch_prediction

            print("Testing.. [{0}/{1}]".format(batch_number,
                                               self.testloader.num_iterations))

        vol_shape_x, vol_shape_y, vol_shape_z = self.config.volume_shape
        prediction_image = prediction_image.numpy()
        test_image_pred = recompose3D_overlap(prediction_image, vol_shape_x,
                                              vol_shape_y, vol_shape_z,
                                              self.config.extraction_step[0],
                                              self.config.extraction_step[1],
                                              self.config.extraction_step[2])
        test_image_pred = test_image_pred.astype('uint8')
        pred2d = np.reshape(test_image_pred,
                            (test_image_pred.shape[0] * vol_shape_x *
                             vol_shape_y * vol_shape_z))
        lab2d = np.reshape(
            whole_vol,
            (whole_vol.shape[0] * vol_shape_x * vol_shape_y * vol_shape_z))

        classes = list(range(0, self.config.num_classes))
        F1_score = f1_score(lab2d, pred2d, classes, average=None)
        print("Test Dice Coefficient.... ")
        print("Background:", F1_score[0])
        print("CSF:", F1_score[1])
        print("GM:", F1_score[2])
        print("WM:", F1_score[3])

    def finalize(self):
        """
        Finalize all the operations of the 2 Main classes of the process the operator and the data loader
        :return:
        """
        self.logger.info(
            "Please wait while finalizing the operation.. Thank you")
        self.save_checkpoint()
示例#3
0
    def __init__(self, config):
        self.config = config

        self.logger = logging.getLogger("DCGANAgent")

        # define models ( generator and discriminator)
        self.netG = Generator(self.config)
        self.netD = Discriminator(self.config)

        # define dataloader
        self.dataloader = CelebADataLoader(self.config)
        self.batch_size = self.config.batch_size

        # define loss
        self.loss = BinaryCrossEntropy()

        # define optimizers for both generator and discriminator
        self.optimG = torch.optim.Adam(self.netG.parameters(),
                                       lr=self.config.learning_rate,
                                       betas=(self.config.beta1,
                                              self.config.beta2))
        self.optimD = torch.optim.Adam(self.netD.parameters(),
                                       lr=self.config.learning_rate,
                                       betas=(self.config.beta1,
                                              self.config.beta2))

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_mean_iou = 0

        self.fixed_noise = Variable(
            torch.randn(self.batch_size, self.config.g_input_size, 1, 1))
        self.real_label = 1
        self.fake_label = 0

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda
        # set the manual seed for torch
        #if not self.config.seed:
        self.manual_seed = random.randint(1, 10000)
        #self.manual_seed = self.config.seed
        print("seed: ", self.manual_seed)
        random.seed(self.manual_seed)
        torch.manual_seed(self.manual_seed)

        if self.cuda:
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            torch.cuda.manual_seed_all(self.manual_seed)
            print_cuda_statistics()
            torch.cuda.set_device(self.config.gpu_device)
            self.fixed_noise = self.fixed_noise.cuda(
                async=self.config.async_loading)
            self.device = torch.device("cuda")

        else:
            self.logger.info("Program will run on *****CPU***** ")
            self.device = torch.device("cpu")

        self.netG = self.netG.to(self.device)
        self.netD = self.netD.to(self.device)
        self.loss = self.loss.to(self.device)
        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file)

        # Summary Writer
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir,
                                            comment='DCGAN')
示例#4
0
class DCGANAgent:
    def __init__(self, config):
        self.config = config

        self.logger = logging.getLogger("DCGANAgent")

        # define models ( generator and discriminator)
        self.netG = Generator(self.config)
        self.netD = Discriminator(self.config)

        # define dataloader
        self.dataloader = CelebADataLoader(self.config)
        self.batch_size = self.config.batch_size

        # define loss
        self.loss = BinaryCrossEntropy()

        # define optimizers for both generator and discriminator
        self.optimG = torch.optim.Adam(self.netG.parameters(),
                                       lr=self.config.learning_rate,
                                       betas=(self.config.beta1,
                                              self.config.beta2))
        self.optimD = torch.optim.Adam(self.netD.parameters(),
                                       lr=self.config.learning_rate,
                                       betas=(self.config.beta1,
                                              self.config.beta2))

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_mean_iou = 0

        self.fixed_noise = Variable(
            torch.randn(self.batch_size, self.config.g_input_size, 1, 1))
        self.real_label = 1
        self.fake_label = 0

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda
        # set the manual seed for torch
        #if not self.config.seed:
        self.manual_seed = random.randint(1, 10000)
        #self.manual_seed = self.config.seed
        print("seed: ", self.manual_seed)
        random.seed(self.manual_seed)
        torch.manual_seed(self.manual_seed)

        if self.cuda:
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            torch.cuda.manual_seed_all(self.manual_seed)
            print_cuda_statistics()
            torch.cuda.set_device(self.config.gpu_device)
            self.fixed_noise = self.fixed_noise.cuda(
                async=self.config.async_loading)
            self.device = torch.device("cuda")

        else:
            self.logger.info("Program will run on *****CPU***** ")
            self.device = torch.device("cpu")

        self.netG = self.netG.to(self.device)
        self.netD = self.netD.to(self.device)
        self.loss = self.loss.to(self.device)
        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file)

        # Summary Writer
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir,
                                            comment='DCGAN')

    def load_checkpoint(self, file_name):
        filename = self.config.checkpoint_dir + file_name
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.netG.load_state_dict(checkpoint['G_state_dict'])
            self.optimG.load_state_dict(checkpoint['G_optimizer'])
            self.netD.load_state_dict(checkpoint['D_state_dict'])
            self.optimD.load_state_dict(checkpoint['D_optimizer'])
            self.fixed_noise = checkpoint['fixed_noise']
            self.manual_seed = checkpoint['manual_seed']

            self.logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                .format(self.config.checkpoint_dir, checkpoint['epoch'],
                        checkpoint['iteration']))
        except OSError as e:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    self.config.checkpoint_dir))
            self.logger.info("**First time to train**")

    def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best=0):
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'G_state_dict': self.netG.state_dict(),
            'G_optimizer': self.optimG.state_dict(),
            'D_state_dict': self.netD.state_dict(),
            'D_optimizer': self.optimD.state_dict(),
            'fixed_noise': self.fixed_noise,
            'manual_seed': self.manual_seed
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + file_name)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + file_name,
                            self.config.checkpoint_dir + 'model_best.pth.tar')

    def run(self):
        """
        This function will the operator
        :return:
        """
        try:
            self.train()

        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        for epoch in range(self.current_epoch, self.config.max_epoch):
            self.current_epoch = epoch
            self.train_one_epoch()
            self.save_checkpoint()

    def train_one_epoch(self):
        # initialize tqdm batch
        tqdm_batch = tqdm(self.dataloader.loader,
                          total=self.dataloader.num_iterations,
                          desc="epoch-{}-".format(self.current_epoch))

        self.netG.train()
        self.netD.train()

        epoch_lossG = AverageMeter()
        epoch_lossD = AverageMeter()

        for curr_it, x in enumerate(tqdm_batch):
            #y = torch.full((self.batch_size,), self.real_label)
            x = x[0]
            y = torch.randn(x.size(0), )
            fake_noise = torch.randn(x.size(0), self.config.g_input_size, 1, 1)

            if self.cuda:
                x = x.cuda(async=self.config.async_loading)
                y = y.cuda(async=self.config.async_loading)
                fake_noise = fake_noise.cuda(async=self.config.async_loading)

            x = Variable(x)
            y = Variable(y)
            fake_noise = Variable(fake_noise)
            ####################
            # Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            # train with real
            self.netD.zero_grad()
            D_real_out = self.netD(x)
            y.fill_(self.real_label)
            loss_D_real = self.loss(D_real_out, y)
            loss_D_real.backward()
            #D_mean_real_out = D_real_out.mean().item()

            # train with fake
            G_fake_out = self.netG(fake_noise)
            y.fill_(self.fake_label)

            D_fake_out = self.netD(G_fake_out.detach())

            loss_D_fake = self.loss(D_fake_out, y)
            loss_D_fake.backward()
            #D_mean_fake_out = D_fake_out.mean().item()

            loss_D = loss_D_fake + loss_D_real
            self.optimD.step()

            ####################
            # Update G network: maximize log(D(G(z)))
            self.netG.zero_grad()
            y.fill_(self.real_label)
            D_out = self.netD(G_fake_out)
            loss_G = self.loss(D_out, y)
            loss_G.backward()

            #D_G_mean_out = D_out.mean().item()

            self.optimG.step()

            epoch_lossD.update(loss_D.item())
            epoch_lossG.update(loss_G.item())

            self.current_iteration += 1

            self.summary_writer.add_scalar("epoch/Generator_loss",
                                           epoch_lossG.val,
                                           self.current_iteration)
            self.summary_writer.add_scalar("epoch/Discriminator_loss",
                                           epoch_lossD.val,
                                           self.current_iteration)

            #if curr_it % 1000 ==  0:
            #self.summary_writer.add_image("train/Real_Image", x, self.current_iteration)
            #gen_out = self.netG(self.fixed_noise)

            #out_img = self.dataloader.plot_samples_per_epoch(gen_out.data, self.current_iteration)
            #self.summary_writer.add_image('train/generated_image', out_img, self.current_iteration)
            #self.summary_writer.add_image("Generated Images",out_img, self.current_iteration)

            #self.summary_writer.add_scalar("epoch/Generator_loss", epoch_lossG.val, self.current_iteration)
            #self.summary_writer.add_scalar("epoch/Discriminator_loss", epoch_lossD.val, self.current_iteration)

            #if curr_it % 1000 ==  0:
        #self.summary_writer.add_image("train/Real_Image", x, self.current_iteration)
        gen_out = self.netG(self.fixed_noise)
        out_img = self.dataloader.plot_samples_per_epoch(
            gen_out.data, self.current_iteration)
        self.summary_writer.add_image('train/generated_image', out_img,
                                      self.current_iteration)

        tqdm_batch.close()
        #self.summary_writer.add_scalar("epoch/Generator_loss", epoch_lossG.val, self.current_iteration)
        #self.summary_writer.add_scalar("epoch/Discriminator_loss", epoch_lossD.val, self.current_iteration)

        self.logger.info("Training at epoch-" + str(self.current_epoch) +
                         " | " + "Discriminator loss: " +
                         str(epoch_lossD.val) + " - Generator Loss-: " +
                         str(epoch_lossG.val))

    def validate(self):
        pass

    def finalize(self):
        """
        Finalize all the operations of the 2 Main classes of the process the operator and the data loader
        :return:
        """
        self.logger.info(
            "Please wait while finalizing the operation.. Thank you")
        self.save_checkpoint()
        self.summary_writer.export_scalars_to_json("{}all_scalars.json".format(
            self.config.summary_dir))
        self.summary_writer.close()
        self.dataloader.finalize()
示例#5
0
    def __init__(self, config):
        self.config = config

        self.logger = logging.getLogger("MC_WAE")

        self.batch_size = self.config.batch_size

        # define models ( generator and discriminator)
        self.model = Model()
        self.discriminator = Discriminator()
        self.discriminator_z = DiscriminatorZ()

        # define dataloader
        self.dataset = NoteDataset(self.config.root_path, self.config)
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=self.batch_size,
                                     shuffle=False,
                                     num_workers=3,
                                     pin_memory=self.config.pin_memory,
                                     collate_fn=self.make_batch)

        # define loss
        self.loss = WAELoss()
        self.lossD = DLoss()

        # define optimizers for both generator and discriminator
        self.lr = self.config.learning_rate
        self.lrD = self.config.learning_rate
        self.lrDZ = self.config.learning_rate
        self.optimWAE = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lrD)
        self.optimDZ = torch.optim.Adam(self.discriminator_z.parameters(),
                                        lr=self.lrDZ)

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_error = 9999999999.

        self.fixed_noise = Variable(torch.randn(1, 510, dtype=torch.float32))
        self.zero_note = Variable(
            torch.zeros(1, 1, 384, 96, dtype=torch.float32))

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        self.manual_seed = random.randint(1, 10000)

        print("seed: ", self.manual_seed)
        random.seed(self.manual_seed)
        torch.manual_seed(self.manual_seed)

        if self.cuda:
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            torch.cuda.manual_seed_all(self.manual_seed)
            torch.cuda.set_device(self.config.gpu_device[0])
            self.fixed_noise = self.fixed_noise.cuda(
                async=self.config.async_loading)
            self.zero_note = self.zero_note.cuda(
                async=self.config.async_loading)
            self.device = torch.device("cuda")

        else:
            self.logger.info("Program will run on *****CPU***** ")
            self.device = torch.device("cpu")

        self.model = self.model.to(self.device)
        self.discriminator = self.discriminator.to(self.device)
        self.discriminator_z = self.discriminator_z.to(self.device)
        self.loss = self.loss.to(self.device)
        self.lossD = self.lossD.to(self.device)

        if len(self.config.gpu_device) > 1:
            self.model = nn.DataParallel(self.model,
                                         device_ids=self.config.gpu_device)
            self.discriminator = nn.DataParallel(
                self.discriminator, device_ids=self.config.gpu_device)
            self.discriminator_z = nn.DataParallel(
                self.discriminator_z, device_ids=self.config.gpu_device)

        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file)

        # Summary Writer
        self.summary_writer = SummaryWriter(log_dir=os.path.join(
            self.config.root_path, self.config.summary_dir),
                                            comment='MC_WAE')
示例#6
0
class MCWAE(object):
    def __init__(self, config):
        self.config = config

        self.logger = logging.getLogger("MC_WAE")

        self.batch_size = self.config.batch_size

        # define models ( generator and discriminator)
        self.model = Model()
        self.discriminator = Discriminator()
        self.discriminator_z = DiscriminatorZ()

        # define dataloader
        self.dataset = NoteDataset(self.config.root_path, self.config)
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=self.batch_size,
                                     shuffle=False,
                                     num_workers=3,
                                     pin_memory=self.config.pin_memory,
                                     collate_fn=self.make_batch)

        # define loss
        self.loss = WAELoss()
        self.lossD = DLoss()

        # define optimizers for both generator and discriminator
        self.lr = self.config.learning_rate
        self.lrD = self.config.learning_rate
        self.lrDZ = self.config.learning_rate
        self.optimWAE = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lrD)
        self.optimDZ = torch.optim.Adam(self.discriminator_z.parameters(),
                                        lr=self.lrDZ)

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_error = 9999999999.

        self.fixed_noise = Variable(torch.randn(1, 510, dtype=torch.float32))
        self.zero_note = Variable(
            torch.zeros(1, 1, 384, 96, dtype=torch.float32))

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        self.manual_seed = random.randint(1, 10000)

        print("seed: ", self.manual_seed)
        random.seed(self.manual_seed)
        torch.manual_seed(self.manual_seed)

        if self.cuda:
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            torch.cuda.manual_seed_all(self.manual_seed)
            torch.cuda.set_device(self.config.gpu_device[0])
            self.fixed_noise = self.fixed_noise.cuda(
                async=self.config.async_loading)
            self.zero_note = self.zero_note.cuda(
                async=self.config.async_loading)
            self.device = torch.device("cuda")

        else:
            self.logger.info("Program will run on *****CPU***** ")
            self.device = torch.device("cpu")

        self.model = self.model.to(self.device)
        self.discriminator = self.discriminator.to(self.device)
        self.discriminator_z = self.discriminator_z.to(self.device)
        self.loss = self.loss.to(self.device)
        self.lossD = self.lossD.to(self.device)

        if len(self.config.gpu_device) > 1:
            self.model = nn.DataParallel(self.model,
                                         device_ids=self.config.gpu_device)
            self.discriminator = nn.DataParallel(
                self.discriminator, device_ids=self.config.gpu_device)
            self.discriminator_z = nn.DataParallel(
                self.discriminator_z, device_ids=self.config.gpu_device)

        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file)

        # Summary Writer
        self.summary_writer = SummaryWriter(log_dir=os.path.join(
            self.config.root_path, self.config.summary_dir),
                                            comment='MC_WAE')

    def free(self, module: nn.Module):
        for p in module.parameters():
            p.requires_grad = True

    def frozen(self, module: nn.Module):
        for p in module.parameters():
            p.requires_grad = False

    def make_batch(self, samples):
        note = np.concatenate([sample['note'] for sample in samples], axis=0)
        pre_note = np.concatenate([sample['pre_note'] for sample in samples],
                                  axis=0)
        position = np.concatenate([sample['position'] for sample in samples],
                                  axis=0)

        return tuple([
            torch.tensor(note, dtype=torch.float),
            torch.tensor(pre_note, dtype=torch.float),
            torch.tensor(position, dtype=torch.long)
        ])

    def load_checkpoint(self, file_name):
        filename = os.path.join(self.config.root_path,
                                self.config.checkpoint_dir, file_name)
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimWAE.load_state_dict(checkpoint['model_optimizer'])
            self.discriminator.load_state_dict(
                checkpoint['discriminator_state_dict'])
            self.optimD.load_state_dict(checkpoint['discriminator_optimizer'])
            self.discriminator_z.load_state_dict(
                checkpoint['discriminatorZ_state_dict'])
            self.optimDZ.load_state_dict(
                checkpoint['discriminatorZ_optimizer'])
            self.fixed_noise = checkpoint['fixed_noise']
            self.manual_seed = checkpoint['manual_seed']

            self.logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                .format(self.config.checkpoint_dir, checkpoint['epoch'],
                        checkpoint['iteration']))
        except OSError as e:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    self.config.checkpoint_dir))
            self.logger.info("**First time to train**")

    def save_checkpoint(self, file_name, is_best=False):
        gpu_cnt = len(self.config.gpu_device)
        file_name = os.path.join(self.config.root_path,
                                 self.config.checkpoint_dir, file_name)

        state = {
            'epoch':
            self.current_epoch,
            'iteration':
            self.current_iteration,
            'model_state_dict':
            self.model.module.state_dict()
            if gpu_cnt > 1 else self.model.state_dict(),
            'model_optimizer':
            self.optimWAE.state_dict(),
            'discriminator_state_dict':
            self.discriminator.module.state_dict()
            if gpu_cnt > 1 else self.discriminator.state_dict(),
            'discriminator_optimizer':
            self.optimD.state_dict(),
            'discriminatorZ_state_dict':
            self.discriminator_z.module.state_dict()
            if gpu_cnt > 1 else self.discriminator_z.state_dict(),
            'discriminatorZ_optimizer':
            self.optimDZ.state_dict(),
            'fixed_noise':
            self.fixed_noise,
            'manual_seed':
            self.manual_seed
        }

        # Save the state
        torch.save(state, file_name)
        if is_best:
            shutil.copyfile(
                file_name,
                os.path.join(self.config.root_path, self.config.checkpoint_dir,
                             'model_best.pth.tar'))

    def run(self):
        try:
            self.train()

        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        for epoch in range(self.current_epoch, self.config.epoch):
            self.current_epoch = epoch
            is_best = self.train_one_epoch()
            self.save_checkpoint(self.config.checkpoint_file, is_best)
            torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimWAE,
                                                       mode='min',
                                                       factor=0.8,
                                                       cooldown=4)
            torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimD,
                                                       mode='min',
                                                       factor=0.8,
                                                       cooldown=4)
            torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimDZ,
                                                       mode='min',
                                                       factor=0.8,
                                                       cooldown=4)

    def train_one_epoch(self):
        tqdm_batch = tqdm(self.dataloader,
                          total=self.dataset.num_iterations,
                          desc="epoch-{}-".format(self.current_epoch))

        self.model.train()
        self.discriminator.train()
        self.discriminator_z.train()

        epoch_loss = AverageMeter()
        epoch_lossD = AverageMeter()
        epoch_lossDZ = AverageMeter()

        for curr_it, (note, pre_note, position) in enumerate(tqdm_batch):
            if self.cuda:
                note = note.cuda(async=self.config.async_loading)
                pre_note = pre_note.cuda(async=self.config.async_loading)
                position = position.cuda(async=self.config.async_loading)

            ####################
            note = Variable(note)
            pre_note = Variable(pre_note)
            position = Variable(position)

            self.model.zero_grad()
            self.discriminator.zero_grad()
            self.discriminator_z.zero_grad()
            zeros = torch.randn(note.size(0), ).fill_(0.).cuda()
            ones = torch.randn(note.size(0), ).fill_(1.).cuda()

            #################### Generator ####################
            self.free(self.model)
            self.frozen(self.discriminator)
            self.frozen(self.discriminator_z)

            gen_note, z = self.model(note, pre_note, position)
            f_logits = self.discriminator(gen_note)
            fz_logits = self.discriminator_z(z)

            loss_model = self.loss(gen_note, note, f_logits, fz_logits)
            loss_model.backward(retain_graph=True)
            self.optimWAE.step()

            #################### Discriminator ####################
            self.free(self.discriminator)
            self.frozen(self.discriminator_z)
            self.frozen(self.model)

            r_logits = self.discriminator(note)
            f_logits = self.discriminator(
                self.model(note, pre_note, position)[0])

            loss_D = -((torch.log(1.001 - r_logits).mean()) +
                       torch.log(f_logits).mean())
            loss_D.backward(retain_graph=True)
            self.optimD.step()

            #################### DiscriminatorZ ####################
            self.free(self.discriminator_z)
            self.frozen(self.discriminator)
            self.frozen(self.model)

            z_fake = torch.randn(note.size()[0], 510)
            z_fake = z_fake.cuda()

            r_logits = self.discriminator_z(
                self.model(note, pre_note, position)[1])
            f_logits = self.discriminator_z(z_fake)

            loss_DZ = -((torch.log(1.001 - r_logits).mean()) +
                        torch.log(f_logits).mean())
            loss_DZ.backward(retain_graph=True)
            self.optimDZ.step()

            ####################
            epoch_lossD.update(loss_D.item())
            epoch_lossDZ.update(loss_DZ.item())
            epoch_loss.update(loss_model.item())

            self.current_iteration += 1

            self.summary_writer.add_scalar("epoch/Generator_loss",
                                           epoch_loss.val,
                                           self.current_iteration)
            self.summary_writer.add_scalar("epoch/Discriminator_loss",
                                           epoch_lossD.val,
                                           self.current_iteration)
            self.summary_writer.add_scalar("epoch/DiscriminatorZ_loss",
                                           epoch_lossDZ.val,
                                           self.current_iteration)

        out_img = self.model(self.fixed_noise, self.zero_note,
                             torch.tensor([330], dtype=torch.long).cuda(),
                             False)
        self.summary_writer.add_image(
            'train/generated_image',
            torch.gt(out_img, 0.3).type('torch.FloatTensor').view(1, 384, 96) *
            255, self.current_iteration)

        tqdm_batch.close()

        self.logger.info("Training at epoch-" + str(self.current_epoch) +
                         " | " + "Discriminator loss: " +
                         str(epoch_lossD.val) + " - Generator Loss-: " +
                         str(epoch_loss.val))

        if epoch_loss.val < self.best_error:
            self.best_error = epoch_loss.val
            return True
        else:
            return False

    def finalize(self):
        self.logger.info(
            "Please wait while finalizing the operation.. Thank you")
        self.save_checkpoint(self.config.checkpoint_file)
        self.summary_writer.export_scalars_to_json(
            os.path.join(self.config.root_path, self.config.summary_dir,
                         'all_scalars.json'))
        self.summary_writer.close()
示例#7
0
class BADGAN_Model(BaseAgent):
    def __init__(self, config):
        super().__init__(config)

        self.generator = Generator(self.config)
        self.encoder = Encoder(self.config)
        self.discriminator = Discriminator(self.config)  # Segmenation Network
        if self.config.phase == 'testing':
            self.testloader = FewShot_Dataset(self.config, "testing")
        else:
            self.trainloader = FewShot_Dataset(self.config, "training")
            self.valloader = FewShot_Dataset(self.config, "validating")

        # optimizer
        self.g_optim = torch.optim.Adam(self.generator.parameters(),
                                        lr=self.config.learning_rate_G,
                                        betas=(self.config.beta1G,
                                               self.config.beta2G))
        self.d_optim = torch.optim.Adam(self.discriminator.parameters(),
                                        lr=self.config.learning_rate_D,
                                        betas=(self.config.beta1D,
                                               self.config.beta2D))
        self.e_optim = torch.optim.Adam(self.encoder.parameters(),
                                        lr=self.config.learning_rate_E,
                                        betas=(self.config.beta1E,
                                               self.config.beta2E))
        # counter initialization
        self.current_epoch = 0
        self.best_validation_dice = 0
        self.current_iteration = 0

        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        if self.cuda:
            self.generator = self.generator.cuda()
            self.discriminator = self.discriminator.cuda()
            self.encoder = self.encoder.cuda()

        class_weights = torch.tensor([[0.33, 1.5, 0.83, 1.33]])
        if self.cuda:
            class_weights = torch.FloatTensor(class_weights).cuda()
        self.criterion = nn.CrossEntropyLoss(class_weights)

        # set the manual seed for torch
        if not self.config.seed:
            self.manual_seed = random.randint(1, 10000)
        else:
            self.manual_seed = self.config.seed
        self.logger.info("seed: %d", self.manual_seed)
        random.seed(self.manual_seed)
        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            torch.cuda.manual_seed_all(self.manual_seed)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU***** ")

        if (self.config.load_chkpt == True):
            self.load_checkpoint(self.config.phase)

    def load_checkpoint(self, phase):
        try:
            if phase == 'training':
                filename = os.path.join(self.config.checkpoint_dir,
                                        'checkpoint.pth.tar')
            elif phase == 'testing':
                filename = os.path.join(self.config.checkpoint_dir,
                                        'model_best.pth.tar')
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.generator.load_state_dict(checkpoint['generator'])
            self.discriminator.load_state_dict(checkpoint['discriminator'])
            self.manual_seed = checkpoint['manual_seed']

            self.logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {})\n".
                format(self.config.checkpoint_dir, checkpoint['epoch']))

        except OSError as e:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    self.config.checkpoint_dir))
            self.logger.info("**First time to train**")

    def save_checkpoint(self, is_best=False):
        file_name = "checkpoint.pth.tar"
        state = {
            'epoch': self.current_epoch,
            'generator': self.generator.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'manual_seed': self.manual_seed
        }
        torch.save(state, os.path.join(self.config.checkpoint_dir, file_name))
        if is_best:
            print("SAVING BEST CHECKPOINT !!!\n")
            shutil.copyfile(self.config.checkpoint_dir + file_name,
                            self.config.checkpoint_dir + 'model_best.pth.tar')

    def run(self):
        try:
            if self.config.phase == 'training':
                self.train()
            if self.config.phase == 'testing':
                self.load_checkpoint(self.config.phase)
                self.test()
        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        for epoch in range(self.current_epoch, self.config.epochs):
            self.current_epoch = epoch
            self.current_iteration = 0
            self.train_one_epoch()
            self.save_checkpoint()
            if (self.current_epoch % self.config.validation_every_epoch == 0):
                self.validate()

    def train_one_epoch(self):
        # initialize tqdm batch
        tqdm_batch = tqdm(self.trainloader.loader,
                          total=self.trainloader.num_iterations,
                          desc="epoch-{}-".format(self.current_epoch))

        self.generator.train()
        self.discriminator.train()
        epoch_loss_gen = AverageMeter()
        epoch_loss_dis = AverageMeter()
        epoch_loss_ce = AverageMeter()
        epoch_loss_unlab = AverageMeter()
        epoch_loss_fake = AverageMeter()

        for curr_it, (patches_lab, patches_unlab,
                      labels) in enumerate(tqdm_batch):
            #y = torch.full((self.batch_size,), self.real_label)
            if self.cuda:
                patches_lab = patches_lab.cuda()
                patches_unlab = patches_unlab.cuda()
                labels = labels.cuda()

            patches_lab = Variable(patches_lab)
            patches_unlab = Variable(patches_unlab.float())
            labels = Variable(labels).long()

            noise_vector = torch.tensor(
                np.random.uniform(
                    -1, 1,
                    [self.config.batch_size, self.config.noise_dim])).float()
            if self.cuda:
                noise_vector = noise_vector.cuda()
            patches_fake = self.generator(noise_vector)

            ## Discriminator
            # Supervised loss
            lab_output, lab_output_sofmax = self.discriminator(patches_lab)
            lab_loss = self.criterion(lab_output, labels)

            unlab_output, unlab_output_softmax = self.discriminator(
                patches_unlab)
            fake_output, fake_output_softmax = self.discriminator(
                patches_fake.detach())

            # Unlabeled Loss and Fake loss
            unlab_lsp = torch.logsumexp(unlab_output, dim=1)
            fake_lsp = torch.logsumexp(fake_output, dim=1)
            unlab_loss = -0.5 * torch.mean(unlab_lsp) + 0.5 * torch.mean(
                F.softplus(unlab_lsp, 1))
            fake_loss = 0.5 * torch.mean(F.softplus(fake_lsp, 1))
            discriminator_loss = lab_loss + unlab_loss + fake_loss

            self.d_optim.zero_grad()
            discriminator_loss.backward()
            self.d_optim.step()

            ## Generator
            _, _, unlab_feature = self.discriminator(patches_unlab,
                                                     get_feature=True)
            _, _, fake_feature = self.discriminator(patches_fake,
                                                    get_feature=True)

            # Feature matching loss
            unlab_feature, fake_feature = torch.mean(unlab_feature,
                                                     0), torch.mean(
                                                         fake_feature, 0)
            fm_loss = torch.mean(torch.abs(unlab_feature - fake_feature))

            # Variational Inferece loss
            mu, log_sigma = self.encoder(patches_fake)
            vi_loss = gaussian_nll(mu, log_sigma, noise_vector)

            generator_loss = fm_loss + self.config.vi_loss_weight * vi_loss

            self.g_optim.zero_grad()
            self.e_optim.zero_grad()
            generator_loss.backward()
            self.g_optim.step()
            self.e_optim.step()

            epoch_loss_gen.update(generator_loss.item())
            epoch_loss_dis.update(discriminator_loss.item())
            epoch_loss_ce.update(lab_loss.item())
            epoch_loss_unlab.update(unlab_loss.item())
            epoch_loss_fake.update(fake_loss.item())
            self.current_iteration += 1

            print("Epoch: {0}, Iteration: {1}/{2}, Gen loss: {3:.3f}, Dis loss: {4:.3f} :: CE loss {5:.3f}, Unlab loss: {6:.3f}, Fake loss: {7:.3f}, VI loss: {8:.3f}".format(
                                self.current_epoch, self.current_iteration,\
                                self.trainloader.num_iterations, generator_loss.item(), discriminator_loss.item(),\
                                lab_loss.item(), unlab_loss.item(), fake_loss.item(), vi_loss.item()))

        tqdm_batch.close()

        self.logger.info("Training at epoch-" + str(self.current_epoch) + " | " +\
         " Generator loss: " + str(epoch_loss_gen.val) +\
          " Discriminator loss: " + str(epoch_loss_dis.val) +\
           " CE loss: " + str(epoch_loss_ce.val) + " Unlab loss: " + str(epoch_loss_unlab.val) + " Fake loss: " + str(epoch_loss_fake.val))

    def validate(self):
        self.discriminator.eval()

        prediction_image = torch.zeros([self.valloader.dataset.label.shape[0], self.config.patch_shape[0],\
                                        self.config.patch_shape[1], self.config.patch_shape[2]])
        whole_vol = self.valloader.dataset.whole_vol
        for batch_number, (patches, label,
                           _) in enumerate(self.valloader.loader):
            patches = patches.cuda()
            _, batch_prediction_softmax = self.discriminator(patches)
            batch_prediction = torch.argmax(batch_prediction_softmax,
                                            dim=1).cpu()
            prediction_image[
                batch_number * self.config.batch_size:(batch_number + 1) *
                self.config.batch_size, :, :, :] = batch_prediction

            print("Validating.. [{0}/{1}]".format(
                batch_number, self.valloader.num_iterations))

        vol_shape_x, vol_shape_y, vol_shape_z = self.config.volume_shape
        prediction_image = prediction_image.numpy()
        val_image_pred = recompose3D_overlap(prediction_image, vol_shape_x,
                                             vol_shape_y, vol_shape_z,
                                             self.config.extraction_step[0],
                                             self.config.extraction_step[1],
                                             self.config.extraction_step[2])
        val_image_pred = val_image_pred.astype('uint8')
        pred2d = np.reshape(val_image_pred,
                            (val_image_pred.shape[0] * vol_shape_x *
                             vol_shape_y * vol_shape_z))
        lab2d = np.reshape(
            whole_vol,
            (whole_vol.shape[0] * vol_shape_x * vol_shape_y * vol_shape_z))

        classes = list(range(0, self.config.num_classes))
        F1_score = f1_score(lab2d, pred2d, classes, average=None)
        print("Validation Dice Coefficient.... ")
        print("Background:", F1_score[0])
        print("CSF:", F1_score[1])
        print("GM:", F1_score[2])
        print("WM:", F1_score[3])

        current_validation_dice = F1_score[2] + F1_score[3]
        if (self.best_validation_dice < current_validation_dice):
            self.best_validation_dice = current_validation_dice
            self.save_checkpoint(is_best=True)

    def test(self):
        self.discriminator.eval()

        prediction_image = torch.zeros([self.testloader.dataset.patches.shape[0], self.config.patch_shape[0],\
                                        self.config.patch_shape[1], self.config.patch_shape[2]])
        whole_vol = self.testloader.dataset.whole_vol
        for batch_number, (patches, _) in enumerate(self.testloader.loader):
            patches = patches.cuda()
            _, batch_prediction_softmax = self.discriminator(patches)
            batch_prediction = torch.argmax(batch_prediction_softmax,
                                            dim=1).cpu()
            prediction_image[
                batch_number * self.config.batch_size:(batch_number + 1) *
                self.config.batch_size, :, :, :] = batch_prediction

            print("Testing.. [{0}/{1}]".format(batch_number,
                                               self.testloader.num_iterations))

        vol_shape_x, vol_shape_y, vol_shape_z = self.config.volume_shape
        prediction_image = prediction_image.numpy()
        test_image_pred = recompose3D_overlap(prediction_image, vol_shape_x,
                                              vol_shape_y, vol_shape_z,
                                              self.config.extraction_step[0],
                                              self.config.extraction_step[1],
                                              self.config.extraction_step[2])
        test_image_pred = test_image_pred.astype('uint8')
        pred2d = np.reshape(test_image_pred,
                            (test_image_pred.shape[0] * vol_shape_x *
                             vol_shape_y * vol_shape_z))
        lab2d = np.reshape(
            whole_vol,
            (whole_vol.shape[0] * vol_shape_x * vol_shape_y * vol_shape_z))

        classes = list(range(0, self.config.num_classes))
        F1_score = f1_score(lab2d, pred2d, classes, average=None)
        print("Test Dice Coefficient.... ")
        print("Background:", F1_score[0])
        print("CSF:", F1_score[1])
        print("GM:", F1_score[2])
        print("WM:", F1_score[3])

    def finalize(self):
        """
        Finalize all the operations of the 2 Main classes of the process the operator and the data loader
        :return:
        """
        self.logger.info(
            "Please wait while finalizing the operation.. Thank you")
        self.save_checkpoint()