Esempio n. 1
0
    def train(self, params):
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'

        num_elements = params['num_elements']
        mom_range = params['mom_range']
        n_res = params['n_res']
        niter = params['niter']
        scheduler = params['scheduler']
        optimizer_type = params['optimizer']
        momentum = params['momentum']
        learning_rate = params['learning_rate'].__format__('e')
        n_flows = params['n_flows']
        weight_decay = params['weight_decay'].__format__('e')
        warmup = params['warmup']
        l1 = params['l1'].__format__('e')
        l2 = params['l2'].__format__('e')

        weight_decay = float(str(weight_decay)[:1] + str(weight_decay)[-4:])
        learning_rate = float(str(learning_rate)[:1] + str(learning_rate)[-4:])
        l1 = float(str(l1)[:1] + str(l1)[-4:])
        l2 = float(str(l2)[:1] + str(l2)[-4:])
        if self.verbose > 1:
            print("Parameters: \n\t",
                  'zdim: ' + str(self.n_classes) + "\n\t",
                  'mom_range: ' + str(mom_range) + "\n\t",
                  'num_elements: ' + str(num_elements) + "\n\t",
                  'niter: ' + str(niter) + "\n\t",
                  'nres: ' + str(n_res) + "\n\t",
                  'learning_rate: ' + learning_rate.__format__('e') + "\n\t",
                  'momentum: ' + str(momentum) + "\n\t",
                  'n_flows: ' + str(n_flows) + "\n\t",
                  'weight_decay: ' + weight_decay.__format__('e') + "\n\t",
                  'warmup: ' + str(warmup) + "\n\t",
                  'l1: ' + l1.__format__('e') + "\n\t",
                  'l2: ' + l2.__format__('e') + "\n\t",
                  'optimizer_type: ' + optimizer_type + "\n\t",
                  )

        self.modelname = "classif_3dcnn_" \
                         + '_bn' + str(self.batchnorm) \
                         + '_niter' + str(niter) \
                         + '_nres' + str(n_res) \
                         + '_momrange' + str(mom_range) \
                         + '_momentum' + str(momentum) \
                         + '_' + str(optimizer_type) \
                         + "_nclasses" + str(self.n_classes) \
                         + '_gated' + str(self.gated) \
                         + '_resblocks' + str(self.resblocks) \
                         + '_initlr' + learning_rate.__format__('e') \
                         + '_warmup' + str(warmup) \
                         + '_wd' + weight_decay.__format__('e') \
                         + '_l1' + l1.__format__('e') \
                         + '_l2' + l2.__format__('e') \
                         + '_size' + str(self.size)
        model = ConvResnet3D(self.maxpool,
                             self.in_channels,
                             self.out_channels,
                             self.kernel_sizes,
                             self.strides,
                             self.dilatations,
                             self.padding,
                             self.batchnorm,
                             self.n_classes,
                             activation=torch.nn.ReLU,
                             n_res=n_res,
                             gated=self.gated,
                             has_dense=self.has_dense,
                             resblocks=self.resblocks,
                             ).to(device)
        model.random_init()
        criterion = nn.CrossEntropyLoss()
        if optimizer_type == 'adamw':
            optimizer = torch.optim.AdamW(params=model.parameters(),
                                          lr=learning_rate,
                                          weight_decay=weight_decay,
                                          amsgrad=True)
        elif optimizer_type == 'sgd':
            optimizer = torch.optim.SGD(params=model.parameters(),
                                        lr=learning_rate,
                                        weight_decay=weight_decay,
                                        momentum=momentum)
        elif optimizer_type == 'rmsprop':
            optimizer = torch.optim.RMSprop(params=model.parameters(),
                                            lr=learning_rate,
                                            weight_decay=weight_decay,
                                            momentum=momentum)
        else:
            exit('error: no such optimizer type available')
        # if self.fp16_run:
        #     from apex import amp
        #    model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

        # Load checkpoint if one exists
        epoch = 0
        best_loss = -1
        if self.checkpoint_path is not None and self.save:
            model, optimizer, \
            epoch, losses, \
            kl_divs, losses_recon, \
            best_loss = load_checkpoint(checkpoint_path,
                                        model,
                                        self.maxpool,
                                        save=self.save,
                                        padding=self.padding,
                                        has_dense=self.has_dense,
                                        batchnorm=self.batchnorm,
                                        flow_type=None,
                                        padding_deconv=None,
                                        optimizer=optimizer,
                                        z_dim=self.n_classes,
                                        gated=self.gated,
                                        in_channels=self.in_channels,
                                        out_channels=self.out_channels,
                                        kernel_sizes=self.kernel_sizes,
                                        kernel_sizes_deconv=None,
                                        strides=self.strides,
                                        strides_deconv=None,
                                        dilatations=self.dilatations,
                                        dilatations_deconv=None,
                                        name=self.modelname,
                                        n_flows=n_flows,
                                        n_res=n_res,
                                        resblocks=resblocks,
                                        h_last=None,
                                        n_elements=None
                                        )
        model = model.to(device)
        # t1 = torch.Tensor(np.load('/run/media/simon/DATA&STUFF/data/biology/arrays/t1.npy'))
        # targets = torch.Tensor([0 for _ in t1])

        train_transform = transforms.Compose([
            XFlip(),
            YFlip(),
            ZFlip(),
            Flip90(),
            Flip180(),
            Flip270(),
            torchvision.transforms.Normalize(mean=(self.mean), std=(self.std)),
            Normalize()
        ])
        all_set = MRIDatasetClassifier(self.path, transform=train_transform, size=self.size)
        train_set, valid_set = validation_split(all_set, val_share=self.val_share)

        train_loader = DataLoader(train_set,
                                  num_workers=0,
                                  shuffle=True,
                                  batch_size=self.batch_size,
                                  pin_memory=False,
                                  drop_last=True)
        valid_loader = DataLoader(valid_set,
                                  num_workers=0,
                                  shuffle=True,
                                  batch_size=2,
                                  pin_memory=False,
                                  drop_last=True)

        # Get shared output_directory ready
        logger = SummaryWriter('logs')
        epoch_offset = max(1, epoch)

        if scheduler == 'ReduceLROnPlateau':
            lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                     factor=0.1,
                                                                     cooldown=50,
                                                                     patience=200,
                                                                     verbose=True,
                                                                     min_lr=1e-15)
        elif scheduler == 'CycleScheduler':
            lr_schedule = CycleScheduler(optimizer,
                                         learning_rate,
                                         n_iter=niter * len(train_loader),
                                         momentum=[
                                             max(0.0, momentum - mom_range),
                                             min(1.0, momentum + mom_range),
                                         ])

        losses = {
            "train": [],
            "valid": [],
        }
        accuracies = {
            "train": [],
            "valid": [],
        }
        shapes = {
            "train": len(train_set),
            "valid": len(valid_set),
        }
        early_stop_counter = 0
        print("Training Started on device:", device)
        for epoch in range(epoch_offset, self.epochs):
            if early_stop_counter == 500:
                if self.verbose > 0:
                    print('EARLY STOPPING.')
                break
            best_epoch = False
            model.train()
            train_losses = []
            train_accuracy = []

            # pbar = tqdm(total=len(train_loader))
            for i, batch in enumerate(train_loader):
                #    pbar.update(1)
                model.zero_grad()
                images, targets = batch
                images = torch.autograd.Variable(images).to(device)
                targets = torch.autograd.Variable(targets).to(device)
                # images = images.unsqueeze(1)
                preds = model(images)
                images = images.squeeze(1)

                loss = criterion(preds, targets)
                l2_reg = torch.tensor(0.)
                l1_reg = torch.tensor(0.)
                for name, param in model.named_parameters():
                    if 'weight' in name:
                        l1_reg = l1 + torch.norm(param, 1)
                for name, param in model.named_parameters():
                    if 'weight' in name:
                        l2_reg = l2 + torch.norm(param, 1)
                loss += l1 * l1_reg
                loss += l2 * l2_reg

                loss.backward()
                accuracy = sum([1 if torch.argmax(pred) == target else 0 for (pred, target) in zip(preds, targets)]) / len(targets)
                train_accuracy += [accuracy]

                train_losses += [loss.item()]

                optimizer.step()
                logger.add_scalar('training_loss', loss.item(), i + len(train_loader) * epoch)
                del loss

            losses["train"] += [np.mean(train_losses)]
            accuracies["train"] += [np.mean(train_accuracy)]

            if epoch % self.epochs_per_print == 0:
                if self.verbose > 1:
                    print("Epoch: {}:\t"
                          "Train Loss: {:.5f} , "
                          "Accuracy: {:.3f} , "
                          .format(epoch,
                                  losses["train"][-1],
                                  accuracies["train"][-1]
                                  ))

            model.eval()
            valid_losses = []
            valid_accuracy = []
            # pbar = tqdm(total=len(valid_loader))
            for i, batch in enumerate(valid_loader):
                #    pbar.update(1)
                images, targets = batch
                images = torch.autograd.Variable(images).to(device)
                targets = torch.autograd.Variable(targets).to(device)
                preds = model(images)

                loss = criterion(preds, targets)
                valid_losses += [loss.item()]
                accuracy = sum([1 if torch.argmax(pred) == target else 0 for (pred, target) in zip(preds, targets)]) / len(targets)
                valid_accuracy += [accuracy]
                logger.add_scalar('training loss', np.log2(loss.item()), i + len(train_loader) * epoch)
            losses["valid"] += [np.mean(valid_losses)]
            accuracies["valid"] += [np.mean(valid_accuracy)]
            if epoch - epoch_offset > 5:
                lr_schedule.step(losses["valid"][-1])
            # should be valid, but train is ok to test if it can be done without caring about
            # generalisation
            mode = 'valid'
            if (losses[mode][-1] < best_loss or best_loss == -1) and not np.isnan(losses[mode][-1]):
                if self.verbose > 1:
                    print('BEST EPOCH!', losses[mode][-1], accuracies[mode][-1])
                early_stop_counter = 0
                best_loss = losses[mode][-1]
                best_epoch = True
            else:
                early_stop_counter += 1

            if epoch % self.epochs_per_checkpoint == 0:
                if best_epoch and self.save:
                    if self.verbose > 1:
                        print('Saving model...')
                    save_checkpoint(model=model,
                                    optimizer=optimizer,
                                    maxpool=self.maxpool,
                                    padding=self.padding,
                                    padding_deconv=None,
                                    learning_rate=learning_rate,
                                    epoch=epoch,
                                    checkpoint_path=None,
                                    z_dim=self.n_classes,
                                    gated=self.gated,
                                    batchnorm=self.batchnorm,
                                    losses=losses,
                                    kl_divs=None,
                                    losses_recon=None,
                                    in_channels=self.in_channels,
                                    out_channels=self.out_channels,
                                    kernel_sizes=self.kernel_sizes,
                                    kernel_sizes_deconv=None,
                                    strides=self.strides,
                                    strides_deconv=None,
                                    dilatations=self.dilatations,
                                    dilatations_deconv=None,
                                    best_loss=best_loss,
                                    save=self.save,
                                    name=self.modelname,
                                    n_flows=None,
                                    flow_type=None,
                                    n_res=n_res,
                                    resblocks=resblocks,
                                    h_last=None,
                                    n_elements=None
                                    )
            if epoch % self.epochs_per_print == 0:
                if self.verbose > 0:
                    print("Epoch: {}:\t"
                          "Valid Loss: {:.5f} , "
                          "Accuracy: {:.3f} "
                          .format(epoch,
                                  losses["valid"][-1],
                                  accuracies["valid"][-1],
                                  )
                          )
                if self.verbose > 1:
                    print("Current LR:", optimizer.param_groups[0]['lr'])
                if 'momentum' in optimizer.param_groups[0].keys():
                    print("Current Momentum:", optimizer.param_groups[0]['momentum'])
            #if self.plot_perform:
            #    plot_performance(loss_total=losses, losses_recon=losses_recon, kl_divs=kl_divs, shapes=shapes,
            #                     results_path="../figures",
            #                     filename="training_loss_trace_"
            #                              + self.modelname + '.jpg')
        if self.verbose > 0:
            print('BEST LOSS :', best_loss)
        return best_loss
Esempio n. 2
0
    def predict(self, params):
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'

        mom_range = params['mom_range']
        n_res = params['n_res']
        niter = params['niter']
        scheduler = params['scheduler']
        optimizer_type = params['optimizer']
        momentum = params['momentum']
        learning_rate = params['learning_rate'].__format__('e')
        weight_decay = params['weight_decay'].__format__('e')

        weight_decay = float(str(weight_decay)[:1] + str(weight_decay)[-4:])
        learning_rate = float(str(learning_rate)[:1] + str(learning_rate)[-4:])
        if self.verbose > 1:
            print(
                "Parameters: \n\t",
                'zdim: ' + str(self.n_classes) + "\n\t",
                'mom_range: ' + str(mom_range) + "\n\t",
                'niter: ' + str(niter) + "\n\t",
                'nres: ' + str(n_res) + "\n\t",
                'learning_rate: ' + learning_rate.__format__('e') + "\n\t",
                'momentum: ' + str(momentum) + "\n\t",
                'weight_decay: ' + weight_decay.__format__('e') + "\n\t",
                'optimizer_type: ' + optimizer_type + "\n\t",
            )

        self.modelname = "classif_3dcnn_" \
                         + '_bn' + str(self.batchnorm) \
                         + '_niter' + str(niter) \
                         + '_nres' + str(n_res) \
                         + '_momrange' + str(mom_range) \
                         + '_momentum' + str(momentum) \
                         + '_' + str(optimizer_type) \
                         + "_nclasses" + str(self.n_classes) \
                         + '_gated' + str(self.gated) \
                         + '_resblocks' + str(self.resblocks) \
                         + '_initlr' + learning_rate.__format__('e') \
                         + '_wd' + weight_decay.__format__('e') \
                         + '_size' + str(self.size)
        model = ConvResnet3D(self.maxpool,
                             self.in_channels,
                             self.out_channels,
                             self.kernel_sizes,
                             self.strides,
                             self.dilatations,
                             self.padding,
                             self.batchnorm,
                             self.n_classes,
                             is_bayesian=self.is_bayesian,
                             activation=torch.nn.ReLU,
                             n_res=n_res,
                             gated=self.gated,
                             has_dense=self.has_dense,
                             resblocks=self.resblocks,
                             max_fvc=self.max_fvc,
                             n_kernels=self.n_kernels).to(device)
        l1 = nn.L1Loss()
        if optimizer_type == 'adamw':
            optimizer = torch.optim.AdamW(params=model.parameters(),
                                          lr=learning_rate,
                                          weight_decay=weight_decay,
                                          amsgrad=True)
        elif optimizer_type == 'sgd':
            optimizer = torch.optim.SGD(params=model.parameters(),
                                        lr=learning_rate,
                                        weight_decay=weight_decay,
                                        momentum=momentum)
        elif optimizer_type == 'rmsprop':
            optimizer = torch.optim.RMSprop(params=model.parameters(),
                                            lr=learning_rate,
                                            weight_decay=weight_decay,
                                            momentum=momentum)
        else:
            exit('error: no such optimizer type available')
        # if self.fp16_run:
        #     from apex import amp
        #    model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

        # Load checkpoint if one exists
        epoch = 0
        best_loss = -1
        model, optimizer, \
        epoch, losses, \
        kl_divs, losses_recon, \
        best_loss = load_checkpoint(self.checkpoint_path,
                                    model,
                                    self.maxpool,
                                    save=False,
                                    padding=self.padding,
                                    has_dense=self.has_dense,
                                    batchnorm=self.batchnorm,
                                    flow_type=None,
                                    padding_deconv=None,
                                    optimizer=optimizer,
                                    z_dim=self.n_classes,
                                    gated=self.gated,
                                    in_channels=self.in_channels,
                                    out_channels=self.out_channels,
                                    kernel_sizes=self.kernel_sizes,
                                    kernel_sizes_deconv=None,
                                    strides=self.strides,
                                    strides_deconv=None,
                                    dilatations=self.dilatations,
                                    dilatations_deconv=None,
                                    name=self.modelname,
                                    n_res=n_res,
                                    resblocks=self.resblocks,
                                    h_last=None,
                                    n_elements=None,
                                    n_flows=None,
                                    predict=True
                                    )
        model = model.to(device)

        test_set = CTDatasetInfere(train_path=self.train_path,
                                   test_path=self.test_path,
                                   train_labels_path=self.train_labels_path,
                                   test_labels_path=self.test_labels_path,
                                   submission_file=self.submission_file,
                                   size=self.size)
        test_loader = DataLoader(test_set,
                                 num_workers=0,
                                 shuffle=False,
                                 batch_size=1,
                                 pin_memory=False,
                                 drop_last=False)

        # pbar = tqdm(total=len(train_loader))
        f = open(self.basedir + "/submission.csv", "w")
        f.write("Patient_Week,FVC,Confidence\n")
        for i, batch in enumerate(test_loader):
            #    pbar.update(1)
            patient, images, targets, patient_info = batch
            patient_info = patient_info.to(device)

            images = images.to(device)
            targets = targets.to(device)

            _, mu, log_var = model(images, patient_info)

            l1_loss = l1(mu * test_set.max_fvc, targets.cuda())

            fvc = l1_loss.item()
            confidence = 2 * np.exp(np.sqrt(log_var.item())) * test_set.max_fvc
            f.write(",".join([patient[0],
                              str(int(fvc)),
                              str(int(confidence))]))
            f.write('\n')
        f.close()
Esempio n. 3
0
    def train(self, params):
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'

        num_elements = params['num_elements']
        mom_range = params['mom_range']
        n_res = params['n_res']
        niter = params['niter']
        scheduler = params['scheduler']
        optimizer_type = params['optimizer']
        momentum = params['momentum']
        z_dim = params['z_dim']
        learning_rate = params['learning_rate'].__format__('e')
        n_flows = params['n_flows']
        weight_decay = params['weight_decay'].__format__('e')
        warmup = params['warmup']
        l1 = params['l1'].__format__('e')
        l2 = params['l2'].__format__('e')

        weight_decay = float(str(weight_decay)[:1] + str(weight_decay)[-4:])
        learning_rate = float(str(learning_rate)[:1] + str(learning_rate)[-4:])
        l1 = float(str(l1)[:1] + str(l1)[-4:])
        l2 = float(str(l2)[:1] + str(l2)[-4:])
        if self.verbose > 1:
            print("Parameters: \n\t",
                  'zdim: ' + str(z_dim) + "\n\t",
                  'mom_range: ' + str(mom_range) + "\n\t",
                  'num_elements: ' + str(num_elements) + "\n\t",
                  'niter: ' + str(niter) + "\n\t",
                  'nres: ' + str(n_res) + "\n\t",
                  'learning_rate: ' + learning_rate.__format__('e') + "\n\t",
                  'momentum: ' + str(momentum) + "\n\t",
                  'n_flows: ' + str(n_flows) + "\n\t",
                  'weight_decay: ' + weight_decay.__format__('e') + "\n\t",
                  'warmup: ' + str(warmup) + "\n\t",
                  'l1: ' + l1.__format__('e') + "\n\t",
                  'l2: ' + l2.__format__('e') + "\n\t",
                  'optimizer_type: ' + optimizer_type + "\n\t",
                  )

        self.modelname = "vae_1dcnn_" \
                         + '_flows' + self.flow_type + str(n_flows) \
                         + '_bn' + str(self.batchnorm) \
                         + '_niter' + str(niter) \
                         + '_nres' + str(n_res) \
                         + '_momrange' + str(mom_range) \
                         + '_momentum' + str(momentum) \
                         + '_' + str(optimizer_type) \
                         + "_zdim" + str(z_dim) \
                         + '_gated' + str(self.gated) \
                         + '_resblocks' + str(self.resblocks) \
                         + '_initlr' + learning_rate.__format__('e') \
                         + '_warmup' + str(warmup) \
                         + '_wd' + weight_decay.__format__('e') \
                         + '_l1' + l1.__format__('e') \
                         + '_l2' + l2.__format__('e') \
                         + '_size' + str(self.size)
        if self.flow_type != 'o-sylvester':
            model = Autoencoder1DCNN(z_dim,
                                     self.maxpool,
                                     self.in_channels,
                                     self.out_channels,
                                     self.kernel_sizes,
                                     self.kernel_sizes_deconv,
                                     self.strides,
                                     self.strides_deconv,
                                     self.dilatations,
                                     self.dilatations_deconv,
                                     self.padding,
                                     self.padding_deconv,
                                     has_dense=self.has_dense,
                                     batchnorm=self.batchnorm,
                                     flow_type=self.flow_type,
                                     n_flows=n_flows,
                                     n_res=n_res,
                                     gated=self.gated,
                                     resblocks=self.resblocks
                                     ).to(device)
        else:
            model = SylvesterVAE(z_dim=z_dim,
                                 maxpool=self.maxpool,
                                 in_channels=self.in_channels,
                                 out_channels=self.out_channels,
                                 kernel_sizes=self.kernel_sizes,
                                 kernel_sizes_deconv=self.kernel_sizes_deconv,
                                 strides=self.strides,
                                 strides_deconv=self.strides_deconv,
                                 dilatations=self.dilatations,
                                 dilatations_deconv=self.dilatations_deconv,
                                 padding=self.padding,
                                 padding_deconv=self.padding_deconv,
                                 batchnorm=self.batchnorm,
                                 flow_type=self.flow_type,
                                 n_res=n_res,
                                 gated=self.gated,
                                 has_dense=self.has_dense,
                                 resblocks=self.resblocks,
                                 h_last=z_dim,
                                 n_flows=n_flows,
                                 num_elements=num_elements,
                                 auxiliary=False,
                                 a_dim=0,

                                 )
        model.random_init()
        criterion = nn.MSELoss(reduction="none")
        if optimizer_type == 'adamw':
            optimizer = torch.optim.AdamW(params=model.parameters(),
                                          lr=learning_rate,
                                          weight_decay=weight_decay,
                                          amsgrad=True)
        elif optimizer_type == 'sgd':
            optimizer = torch.optim.SGD(params=model.parameters(),
                                        lr=learning_rate,
                                        weight_decay=weight_decay,
                                        momentum=momentum)
        elif optimizer_type == 'rmsprop':
            optimizer = torch.optim.RMSprop(params=model.parameters(),
                                            lr=learning_rate,
                                            weight_decay=weight_decay,
                                            momentum=momentum)
        else:
            exit('error: no such optimizer type available')
        # if self.fp16_run:
        #     from apex import amp
        #    model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

        # Load checkpoint if one exists
        epoch = 0
        best_loss = -1
        if self.checkpoint_path is not None and self.save:
            model, optimizer, \
            epoch, losses, \
            kl_divs, losses_recon, \
            best_loss = load_checkpoint(checkpoint_path,
                                        model,
                                        self.maxpool,
                                        save=self.save,
                                        padding=self.padding,
                                        has_dense=self.has_dense,
                                        batchnorm=self.batchnorm,
                                        flow_type=self.flow_type,
                                        padding_deconv=self.padding_deconv,
                                        optimizer=optimizer,
                                        z_dim=z_dim,
                                        gated=self.gated,
                                        in_channels=self.in_channels,
                                        out_channels=self.out_channels,
                                        kernel_sizes=self.kernel_sizes,
                                        kernel_sizes_deconv=self.kernel_sizes_deconv,
                                        strides=self.strides,
                                        strides_deconv=self.strides_deconv,
                                        dilatations=self.dilatations,
                                        dilatations_deconv=self.dilatations_deconv,
                                        name=self.modelname,
                                        n_flows=n_flows,
                                        n_res=n_res,
                                        resblocks=resblocks,
                                        h_last=self.out_channels[-1],
                                        n_elements=num_elements,
                                        model_name=Autoencoder1DCNN
                                        )
        model = model.to(device)
        # t1 = torch.Tensor(np.load('/run/media/simon/DATA&STUFF/data/biology/arrays/t1.npy'))
        # targets = torch.Tensor([0 for _ in t1])

        all_set = EEGDataset(self.path, transform=None)
        train_set, valid_set = validation_split(all_set, val_share=self.val_share)

        train_loader = DataLoader(train_set,
                                  num_workers=0,
                                  shuffle=True,
                                  batch_size=self.batch_size,
                                  pin_memory=False,
                                  drop_last=True)
        valid_loader = DataLoader(valid_set,
                                  num_workers=0,
                                  shuffle=True,
                                  batch_size=2,
                                  pin_memory=False,
                                  drop_last=True)

        # Get shared output_directory ready
        logger = SummaryWriter('logs')
        epoch_offset = max(1, epoch)

        if scheduler == 'ReduceLROnPlateau':
            lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                     factor=0.1,
                                                                     cooldown=50,
                                                                     patience=200,
                                                                     verbose=True,
                                                                     min_lr=1e-15)
        elif scheduler == 'CycleScheduler':
            lr_schedule = CycleScheduler(optimizer,
                                         learning_rate,
                                         n_iter=niter * len(train_loader),
                                         momentum=[
                                             max(0.0, momentum - mom_range),
                                             min(1.0, momentum + mom_range),
                                         ])

        losses = {
            "train": [],
            "valid": [],
        }
        kl_divs = {
            "train": [],
            "valid": [],
        }
        losses_recon = {
            "train": [],
            "valid": [],
        }
        running_abs_error = {
            "train": [],
            "valid": [],
        }
        shapes = {
            "train": len(train_set),
            "valid": len(valid_set),
        }
        early_stop_counter = 0

        for epoch in range(epoch_offset, self.epochs):
            if early_stop_counter == 500:
                if self.verbose > 0:
                    print('EARLY STOPPING.')
                break
            best_epoch = False
            model.train()
            train_losses = []
            train_abs_error = []
            train_kld = []
            train_recons = []

            # pbar = tqdm(total=len(train_loader))
            for i, batch in enumerate(train_loader):
                #    pbar.update(1)
                model.zero_grad()
                images = batch
                images = torch.autograd.Variable(images).to(device)
                # images = images.unsqueeze(1)
                reconstruct, kl = model(images)
                reconstruct = reconstruct[:, :,
                              :images.shape[2],
                              :images.shape[3],
                              :images.shape[4]].squeeze(1)
                images = images.squeeze(1)
                loss_recon = criterion(
                    reconstruct,
                    images
                ).sum() / self.batch_size
                kl_div = torch.mean(kl)
                loss = loss_recon + kl_div
                l2_reg = torch.tensor(0.)
                l1_reg = torch.tensor(0.)
                for name, param in model.named_parameters():
                    if 'weight' in name:
                        l1_reg = l1 + torch.norm(param, 1)
                for name, param in model.named_parameters():
                    if 'weight' in name:
                        l2_reg = l2 + torch.norm(param, 1)
                loss += l1 * l1_reg
                loss += l2 * l2_reg
                # torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-8)
                loss.backward()
                # not sure if before or after
                # torch.nn.utils.clip_grad_norm_(model.parameters(), 100)
                # lr_schedule.step()

                try:
                    train_losses += [loss.item()]
                except:
                    return best_loss
                train_kld += [kl_div.item()]
                train_recons += [loss_recon.item()]
                train_abs_error += [
                    float(torch.mean(torch.abs_(
                        reconstruct - images.to(device)
                    )).item())
                ]

                # if self.fp16_run:
                #    with amp.scale_loss(loss, optimizer) as scaled_loss:
                #        scaled_loss.backward()
                #    del scaled_loss
                # else:
                optimizer.step()
                logger.add_scalar('training_loss', loss.item(), i + len(train_loader) * epoch)
                del kl, loss_recon, kl_div, loss

            img = nib.Nifti1Image(images.detach().cpu().numpy()[0], np.eye(4))
            recon = nib.Nifti1Image(reconstruct.detach().cpu().numpy()[0], np.eye(4))
            if 'views' not in os.listdir():
                os.mkdir('views')
            img.to_filename(filename='views/image_train_' + str(epoch) + '.nii.gz')
            recon.to_filename(filename='views/reconstruct_train_' + str(epoch) + '.nii.gz')

            losses["train"] += [np.mean(train_losses)]
            kl_divs["train"] += [np.mean(train_kld)]
            losses_recon["train"] += [np.mean(train_recons)]
            running_abs_error["train"] += [np.mean(train_abs_error)]

            if epoch % self.epochs_per_print == 0:
                if self.verbose > 1:
                    print("Epoch: {}:\t"
                          "Train Loss: {:.5f} , "
                          "kld: {:.3f} , "
                          "recon: {:.3f}"
                          .format(epoch,
                                  losses["train"][-1],
                                  kl_divs["train"][-1],
                                  losses_recon["train"][-1])
                          )

            if np.isnan(losses["train"][-1]):
                if self.verbose > 0:
                    print('PREMATURE RETURN...')
                return best_loss
            model.eval()
            valid_losses = []
            valid_kld = []
            valid_recons = []
            valid_abs_error = []
            # pbar = tqdm(total=len(valid_loader))
            for i, batch in enumerate(valid_loader):
                #    pbar.update(1)
                images = batch
                images = images.to(device)
                # images = images.unsqueeze(1)
                reconstruct, kl = model(images)
                reconstruct = reconstruct[:, :,
                              :images.shape[2],
                              :images.shape[3],
                              :images.shape[4]].squeeze(1)
                images = images.squeeze(1)
                loss_recon = criterion(
                    reconstruct,
                    images.to(device)
                ).sum()
                kl_div = torch.mean(kl)
                if epoch < warmup:
                    kl_div = kl_div * (epoch / warmup)
                loss = loss_recon + kl_div
                try:
                    valid_losses += [loss.item()]
                except:
                    return best_loss
                valid_kld += [kl_div.item()]
                valid_recons += [loss_recon.item()]
                valid_abs_error += [float(torch.mean(torch.abs_(reconstruct - images.to(device))).item())]
                logger.add_scalar('training loss', np.log2(loss.item()), i + len(train_loader) * epoch)
            losses["valid"] += [np.mean(valid_losses)]
            kl_divs["valid"] += [np.mean(valid_kld)]
            losses_recon["valid"] += [np.mean(valid_recons)]
            running_abs_error["valid"] += [np.mean(valid_abs_error)]
            if epoch - epoch_offset > 5:
                lr_schedule.step(losses["valid"][-1])
            # should be valid, but train is ok to test if it can be done without caring about
            # generalisation
            mode = 'valid'
            if (losses[mode][-1] < best_loss or best_loss == -1) and not np.isnan(losses[mode][-1]):
                if self.verbose > 1:
                    print('BEST EPOCH!', losses[mode][-1])
                early_stop_counter = 0
                best_loss = losses[mode][-1]
                best_epoch = True
            else:
                early_stop_counter += 1

            if epoch % self.epochs_per_checkpoint == 0:
                img = nib.Nifti1Image(images.detach().cpu().numpy()[0], np.eye(4))
                recon = nib.Nifti1Image(reconstruct.detach().cpu().numpy()[0], np.eye(4))
                if 'views' not in os.listdir():
                    os.mkdir('views')
                img.to_filename(filename='views/image_' + str(epoch) + '.nii.gz')
                recon.to_filename(filename='views/reconstruct_' + str(epoch) + '.nii.gz')
                if best_epoch and self.save:
                    if self.verbose > 1:
                        print('Saving model...')
                    save_checkpoint(model=model,
                                    optimizer=optimizer,
                                    maxpool=maxpool,
                                    padding=self.padding,
                                    padding_deconv=self.padding_deconv,
                                    learning_rate=learning_rate,
                                    epoch=epoch,
                                    checkpoint_path=output_directory,
                                    z_dim=z_dim,
                                    gated=self.gated,
                                    batchnorm=self.batchnorm,
                                    losses=losses,
                                    kl_divs=kl_divs,
                                    losses_recon=losses_recon,
                                    in_channels=self.in_channels,
                                    out_channels=self.out_channels,
                                    kernel_sizes=self.kernel_sizes,
                                    kernel_sizes_deconv=self.kernel_sizes_deconv,
                                    strides=self.strides,
                                    strides_deconv=self.strides_deconv,
                                    dilatations=self.dilatations,
                                    dilatations_deconv=self.dilatations_deconv,
                                    best_loss=best_loss,
                                    save=self.save,
                                    name=self.modelname,
                                    n_flows=n_flows,
                                    flow_type=self.flow_type,
                                    n_res=n_res,
                                    resblocks=resblocks,
                                    h_last=z_dim,
                                    model_name=Autoencoder1DCNN
                                    )
            if epoch % self.epochs_per_print == 0:
                if self.verbose > 0:
                    print("Epoch: {}:\t"
                          "Valid Loss: {:.5f} , "
                          "kld: {:.3f} , "
                          "recon: {:.3f}"
                          .format(epoch,
                                  losses["valid"][-1],
                                  kl_divs["valid"][-1],
                                  losses_recon["valid"][-1]
                                  )
                          )
                if self.verbose > 1:
                    print("Current LR:", optimizer.param_groups[0]['lr'])
                if 'momentum' in optimizer.param_groups[0].keys():
                    print("Current Momentum:", optimizer.param_groups[0]['momentum'])
            if self.plot_perform:
                plot_performance(loss_total=losses, losses_recon=losses_recon, kl_divs=kl_divs, shapes=shapes,
                             results_path="../figures",
                             filename="training_loss_trace_"
                                      + self.modelname + '.jpg')
        if self.verbose > 0:
            print('BEST LOSS :', best_loss)
        return best_loss
Esempio n. 4
0
    def train(self, params):
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'

        best_losses = []

        num_elements = params['num_elements']
        mom_range = params['mom_range']
        n_res = params['n_res']
        niter = params['niter']
        scheduler = params['scheduler']
        optimizer_type = params['optimizer']
        momentum = params['momentum']
        z_dim = params['z_dim']
        learning_rate = params['learning_rate'].__format__('e')
        n_flows = params['n_flows']
        weight_decay = params['weight_decay'].__format__('e')
        warmup = params['warmup']
        l1 = params['l1'].__format__('e')
        l2 = params['l2'].__format__('e')

        weight_decay = float(str(weight_decay)[:1] + str(weight_decay)[-4:])
        learning_rate = float(str(learning_rate)[:1] + str(learning_rate)[-4:])
        l1 = float(str(l1)[:1] + str(l1)[-4:])
        l2 = float(str(l2)[:1] + str(l2)[-4:])
        if self.verbose > 1:
            print("Parameters: \n\t",
                  'zdim: ' + str(z_dim) + "\n\t",
                  'mom_range: ' + str(mom_range) + "\n\t",
                  'num_elements: ' + str(num_elements) + "\n\t",
                  'niter: ' + str(niter) + "\n\t",
                  'nres: ' + str(n_res) + "\n\t",
                  'learning_rate: ' + learning_rate.__format__('e') + "\n\t",
                  'momentum: ' + str(momentum) + "\n\t",
                  'n_flows: ' + str(n_flows) + "\n\t",
                  'weight_decay: ' + weight_decay.__format__('e') + "\n\t",
                  'warmup: ' + str(warmup) + "\n\t",
                  'l1: ' + l1.__format__('e') + "\n\t",
                  'l2: ' + l2.__format__('e') + "\n\t",
                  'optimizer_type: ' + optimizer_type + "\n\t",
                  'in_channels:' + "-".join([str(item) for item in self.in_channels]) + "\n\t",
                  'out_channels:' + "-".join([str(item) for item in self.out_channels]) + "\n\t",
                  'kernel_sizes:' + "-".join([str(item) for item in self.kernel_sizes]) + "\n\t",
                  'kernel_sizes_deconv:' + "-".join([str(item) for item in self.kernel_sizes_deconv]) + "\n\t",
                  'paddings:' + "-".join([str(item) for item in self.padding]) + "\n\t",
                  'padding_deconv:' + "-".join([str(item) for item in self.padding_deconv]) + "\n\t",
                  'dilatations:' + "-".join([str(item) for item in self.dilatations]) + "\n\t",
                  'dilatations_deconv:' + "-".join([str(item) for item in self.dilatations_deconv]) + "\n\t",
                  )

        self.model_name = "vae_3dcnn_" \
                          + '_flows' + self.flow_type + str(n_flows) \
                          + '_bn' + str(self.batchnorm) \
                          + '_niter' + str(niter) \
                          + '_nres' + str(n_res) \
                          + '_momrange' + str(mom_range) \
                          + '_momentum' + str(momentum) \
                          + '_' + str(optimizer_type) \
                          + "_zdim" + str(z_dim) \
                          + '_gated' + str(self.gated) \
                          + '_resblocks' + str(self.resblocks) \
                          + '_initlr' + learning_rate.__format__('e') \
                          + '_warmup' + str(warmup) \
                          + '_wd' + weight_decay.__format__('e') \
                          + '_l1' + l1.__format__('e') \
                          + '_l2' + l2.__format__('e') \
                          + '_size' + str(self.size) \
                          + "-".join([str(item) for item in self.in_channels]) \

        if self.flow_type != 'o-sylvester':
            model = Autoencoder3DCNN(z_dim,
                                     self.maxpool,
                                     # self.maxpool2,
                                     self.in_channels,
                                     # self.in_channels2,
                                     self.out_channels,
                                     # self.out_channels2,
                                     self.kernel_sizes,
                                     self.kernel_sizes_deconv,
                                     self.strides,
                                     self.strides_deconv,
                                     self.dilatations,
                                     self.dilatations_deconv,
                                     self.padding,
                                     # self.padding2,
                                     self.padding_deconv,
                                     # self.padding_deconv2,
                                     has_dense=self.has_dense,
                                     batchnorm=self.batchnorm,
                                     flow_type=self.flow_type,
                                     n_flows=n_flows,
                                     n_res=n_res,
                                     gated=self.gated,
                                     resblocks=self.resblocks
                                     ).to(device)
        else:
            model = SylvesterVAE(z_dim=z_dim,
                                 maxpool=self.maxpool,
                                 in_channels=self.in_channels,
                                 out_channels=self.out_channels,
                                 kernel_sizes=self.kernel_sizes,
                                 kernel_sizes_deconv=self.kernel_sizes_deconv,
                                 strides=self.strides,
                                 strides_deconv=self.strides_deconv,
                                 dilatations=self.dilatations,
                                 dilatations_deconv=self.dilatations_deconv,
                                 padding=self.padding,
                                 padding_deconv=self.padding_deconv,
                                 batchnorm=self.batchnorm,
                                 flow_type=self.flow_type,
                                 n_res=n_res,
                                 gated=self.gated,
                                 has_dense=self.has_dense,
                                 resblocks=self.resblocks,
                                 h_last=z_dim,
                                 n_flows=n_flows,
                                 num_elements=num_elements,
                                 auxiliary=False,
                                 a_dim=0,

                                 )
        model.random_init()
        criterion = nn.MSELoss(reduction="none")
        if optimizer_type == 'adamw':
            optimizer = torch.optim.AdamW(params=model.parameters(),
                                          lr=learning_rate,
                                          weight_decay=weight_decay,
                                          amsgrad=True)
        elif optimizer_type == 'sgd':
            optimizer = torch.optim.SGD(params=model.parameters(),
                                        lr=learning_rate,
                                        weight_decay=weight_decay,
                                        momentum=momentum)
        elif optimizer_type == 'rmsprop':
            optimizer = torch.optim.RMSprop(params=model.parameters(),
                                            lr=learning_rate,
                                            weight_decay=weight_decay,
                                            momentum=momentum)
        else:
            exit('error: no such optimizer type available')

        # Load checkpoint if one exists
        epoch = 0
        best_loss = -1
        if self.checkpoint_path is not None and self.load:
            model, _, \
            epoch, losses, \
            kl_divs, losses_recon, \
            best_loss = load_checkpoint(checkpoint_path,
                                        model,
                                        self.maxpool,
                                        save=self.save,
                                        padding=self.padding,
                                        has_dense=self.has_dense,
                                        batchnorm=self.batchnorm,
                                        flow_type=self.flow_type,
                                        padding_deconv=self.padding_deconv,
                                        optimizer=optimizer,
                                        z_dim=z_dim,
                                        gated=self.gated,
                                        in_channels=self.in_channels,
                                        out_channels=self.out_channels,
                                        kernel_sizes=self.kernel_sizes,
                                        kernel_sizes_deconv=self.kernel_sizes_deconv,
                                        strides=self.strides,
                                        strides_deconv=self.strides_deconv,
                                        dilatations=self.dilatations,
                                        dilatations_deconv=self.dilatations_deconv,
                                        name=self.model_name,
                                        n_flows=n_flows,
                                        n_res=n_res,
                                        resblocks=resblocks,
                                        h_last=self.out_channels[-1],
                                        n_elements=num_elements
                                        )
        model = model.to(device)

        train_transform = transforms.Compose([
            transforms.RandomApply([
                XFlip(),
                YFlip(),
                ZFlip(),
                transforms.RandomChoice([
                    Flip90(),
                    Flip180(),
                    Flip270()
                ]),
                # ColorJitter3D(.01, .01, .01, .01),
                # transforms.RandomChoice([
                #     RandomAffine3D(0, [.05, .05], None, None),
                #     RandomAffine3D(1, [.05, .05], None, None),
                #     RandomAffine3D(2, [.05, .05], None, None),
                # ]),
                RandomRotation3D(90, 0),
                RandomRotation3D(90, 1),
                RandomRotation3D(90, 2),

            ]),
            torchvision.transforms.Normalize(mean=(self.mean), std=(self.std)),
            Normalize()
        ])
        all_set = CTDataset(self.path, transform=train_transform)
        spliter = validation_spliter(all_set, cv=self.cross_validation)

        epoch_offset = max(1, epoch)

        for cv in range(self.cross_validation):
            model.random_init()
            best_loss = -1
            valid_set, train_set = spliter.__next__()

            train_loader = DataLoader(train_set,
                                      num_workers=0,
                                      shuffle=True,
                                      batch_size=self.batch_size,
                                      pin_memory=False,
                                      drop_last=True)
            valid_loader = DataLoader(valid_set,
                                      num_workers=0,
                                      shuffle=True,
                                      batch_size=2,
                                      pin_memory=False,
                                      drop_last=True)

            if optimizer_type == 'adamw':
                optimizer = torch.optim.AdamW(params=model.parameters(),
                                              lr=learning_rate,
                                              weight_decay=weight_decay,
                                              amsgrad=True)
            elif optimizer_type == 'sgd':
                optimizer = torch.optim.SGD(params=model.parameters(),
                                            lr=learning_rate,
                                            weight_decay=weight_decay,
                                            momentum=momentum)
            elif optimizer_type == 'rmsprop':
                optimizer = torch.optim.RMSprop(params=model.parameters(),
                                                lr=learning_rate,
                                                weight_decay=weight_decay,
                                                momentum=momentum)
            else:
                exit('error: no such optimizer type available')

            # Get shared output_directory ready
            logger = SummaryWriter('logs')

            if scheduler == 'ReduceLROnPlateau':
                lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                         factor=0.1,
                                                                         cooldown=50,
                                                                         patience=200,
                                                                         verbose=True,
                                                                         min_lr=1e-15)
            elif scheduler == 'CycleScheduler':
                lr_schedule = CycleScheduler(optimizer,
                                             learning_rate,
                                             n_iter=niter * len(train_loader),
                                             momentum=[
                                                 max(0.0, momentum - mom_range),
                                                 min(1.0, momentum + mom_range),
                                             ])

            losses = {
                "train": [],
                "valid": [],
            }
            kl_divs = {
                "train": [],
                "valid": [],
            }
            losses_recon = {
                "train": [],
                "valid": [],
            }
            running_abs_error = {
                "train": [],
                "valid": [],
            }
            shapes = {
                "train": len(train_set),
                "valid": len(valid_set),
            }
            early_stop_counter = 0
            print("\n\n\nCV:", cv, "/", self.cross_validation, "\nTrain samples:", len(train_set),
                  "\nValid samples:", len(valid_set), "\n\n\n")
            valid_losses = []
            valid_kld = []
            valid_recons = []
            valid_abs_error = []
            train_losses = []
            train_abs_error = []
            train_kld = []
            train_recons = []

            for epoch in range(epoch_offset, self.epochs):
                if early_stop_counter == self.early_stop:
                    if self.verbose > 0:
                        print('EARLY STOPPING.')
                    break
                best_epoch = False
                model.train()
                # pbar = tqdm(total=len(train_loader))

                for i, batch in enumerate(train_loader):
                    # pbar.update(1)
                    model.zero_grad()
                    images, _ = batch
                    images = torch.autograd.Variable(images).to(device)
                    reconstruct, kl = model(images)
                    reconstruct = reconstruct[:, :,
                                  :images.shape[2],
                                  :images.shape[3],
                                  :images.shape[4]].squeeze(1)
                    images = images.squeeze(1)
                    loss_recon = criterion(
                        reconstruct,
                        images
                    ).sum() / self.batch_size
                    kl_div = torch.mean(kl)
                    loss = loss_recon + kl_div
                    l2_reg = torch.tensor(0.)
                    l1_reg = torch.tensor(0.)
                    for name, param in model.named_parameters():
                        if 'weight' in name:
                            l1_reg = l1 + torch.norm(param, 1)
                    for name, param in model.named_parameters():
                        if 'weight' in name:
                            l2_reg = l2 + torch.norm(param, 1)
                    loss += l1 * l1_reg
                    loss += l2 * l2_reg
                    loss.backward()

                    train_losses += [loss.item()]
                    train_kld += [kl_div.item()]
                    train_recons += [loss_recon.item()]
                    train_abs_error += [
                        float(torch.mean(torch.abs_(
                            reconstruct - images.to(device)
                        )).item())
                    ]

                    optimizer.step()
                    if scheduler == "CycleScheduler":
                        lr_schedule.step()
                        # optimizer = lr_schedule.optimizer
                    logger.add_scalar('training_loss', loss.item(), i + len(train_loader) * epoch)
                    del kl, loss_recon, kl_div, loss

                img = nib.Nifti1Image(images.detach().cpu().numpy()[0], np.eye(4))
                recon = nib.Nifti1Image(reconstruct.detach().cpu().numpy()[0], np.eye(4))
                if self.save:
                    if 'views' not in os.listdir():
                        os.mkdir('views')
                    img.to_filename(filename='views/image_train_' + str(epoch) + '.nii.gz')
                    recon.to_filename(filename='views/reconstruct_train_' + str(epoch) + '.nii.gz')

                losses["train"] += [np.mean(train_losses)]
                kl_divs["train"] += [np.mean(train_kld)]
                losses_recon["train"] += [np.mean(train_recons)]
                running_abs_error["train"] += [np.mean(train_abs_error)]

                if epoch % self.epochs_per_print == 0:
                    if self.verbose > 1:
                        print("Epoch: {}:\t"
                              "Train Loss: {:.5f} , "
                              "kld: {:.3f} , "
                              "recon: {:.3f}"
                              .format(epoch,
                                      losses["train"][-1],
                                      kl_divs["train"][-1],
                                      losses_recon["train"][-1])
                              )
                    train_losses = []
                    train_abs_error = []
                    train_kld = []
                    train_recons = []

                model.eval()
                for i, batch in enumerate(valid_loader):
                    images, _ = batch
                    images = images.to(device)
                    reconstruct, kl = model(images)
                    reconstruct = reconstruct[:, :,
                                  :images.shape[2],
                                  :images.shape[3],
                                  :images.shape[4]].squeeze(1)
                    images = images.squeeze(1)
                    loss_recon = criterion(
                        reconstruct,
                        images.to(device)
                    ).sum()
                    kl_div = torch.mean(kl)
                    if epoch < warmup:
                        kl_div = kl_div * (epoch / warmup)
                    loss = loss_recon + kl_div
                    valid_losses += [loss.item()]
                    valid_kld += [kl_div.item()]
                    valid_recons += [loss_recon.item()]
                    valid_abs_error += [float(torch.mean(torch.abs_(reconstruct - images.to(device))).item())]
                    logger.add_scalar('training loss', np.log2(loss.item()), i + len(train_loader) * epoch)
                losses["valid"] += [np.mean(valid_losses)]
                kl_divs["valid"] += [np.mean(valid_kld)]
                losses_recon["valid"] += [np.mean(valid_recons)]
                running_abs_error["valid"] += [np.mean(valid_abs_error)]
                if scheduler == "ReduceLROnPlateau":
                    if epoch - epoch_offset > 5:
                        lr_schedule.step(losses["valid"][-1])
                if (losses[self.mode][-1] < best_loss or best_loss == -1) and not np.isnan(losses[self.mode][-1]):
                    if self.verbose > 1:
                        print('BEST EPOCH!', losses[self.mode][-1])
                    early_stop_counter = 0
                    best_loss = losses[self.mode][-1]
                    best_epoch = True
                else:
                    early_stop_counter += 1

                if epoch % self.epochs_per_checkpoint == 0:
                    if self.save:
                        img = nib.Nifti1Image(images.detach().cpu().numpy()[0], np.eye(4))
                        recon = nib.Nifti1Image(reconstruct.detach().cpu().numpy()[0], np.eye(4))
                        img.to_filename(filename='views/image_' + str(epoch) + '.nii.gz')
                        recon.to_filename(filename='views/reconstruct_' + str(epoch) + '.nii.gz')
                    if best_epoch and self.save:
                        if self.verbose > 1:
                            print('Saving model...')
                        save_checkpoint(model=model,
                                        optimizer=optimizer,
                                        maxpool=maxpool,
                                        padding=self.padding,
                                        padding_deconv=self.padding_deconv,
                                        learning_rate=learning_rate,
                                        epoch=epoch,
                                        checkpoint_path=output_directory,
                                        z_dim=z_dim,
                                        gated=self.gated,
                                        batchnorm=self.batchnorm,
                                        losses=losses,
                                        kl_divs=kl_divs,
                                        losses_recon=losses_recon,
                                        in_channels=self.in_channels,
                                        out_channels=self.out_channels,
                                        kernel_sizes=self.kernel_sizes,
                                        kernel_sizes_deconv=self.kernel_sizes_deconv,
                                        strides=self.strides,
                                        strides_deconv=self.strides_deconv,
                                        dilatations=self.dilatations,
                                        dilatations_deconv=self.dilatations_deconv,
                                        best_loss=best_loss,
                                        save=self.save,
                                        name=self.model_name,
                                        n_flows=n_flows,
                                        flow_type=self.flow_type,
                                        n_res=n_res,
                                        resblocks=resblocks,
                                        h_last=z_dim,
                                        n_elements=num_elements,
                                        )
                if epoch % self.epochs_per_print == 0:
                    if self.verbose > 0:
                        print("Epoch: {}:\t"
                              "Valid Loss: {:.5f} , "
                              "kld: {:.3f} , "
                              "recon: {:.3f}"
                              .format(epoch,
                                      losses["valid"][-1],
                                      kl_divs["valid"][-1],
                                      losses_recon["valid"][-1]
                                      )
                              )
                        valid_losses = []
                        valid_kld = []
                        valid_recons = []
                        valid_abs_error = []

                    if self.verbose > 1:
                        print("Current LR:", optimizer.param_groups[0]['lr'])
                    if 'momentum' in optimizer.param_groups[0].keys():
                        print("Current Momentum:", optimizer.param_groups[0]['momentum'])
                if self.plot_perform:
                    plot_performance(loss_total=losses, losses_recon=losses_recon, kl_divs=kl_divs, shapes=shapes,
                                     results_path="../figures",
                                     filename="training_loss_trace_"
                                              + self.model_name + '.jpg')
            if self.verbose > 0:
                print('BEST LOSS :', best_loss)
            best_losses += [best_loss]
        return min(best_losses)
Esempio n. 5
0
    def train(self, params):
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'

        mom_range = params['mom_range']
        n_res = params['n_res']
        niter = params['niter']
        scheduler = params['scheduler']
        optimizer_type = params['optimizer']
        momentum = params['momentum']
        learning_rate = params['learning_rate'].__format__('e')
        weight_decay = params['weight_decay'].__format__('e')

        weight_decay = float(str(weight_decay)[:1] + str(weight_decay)[-4:])
        learning_rate = float(str(learning_rate)[:1] + str(learning_rate)[-4:])
        if self.verbose > 1:
            print(
                "Parameters: \n\t",
                'zdim: ' + str(self.n_classes) + "\n\t",
                'mom_range: ' + str(mom_range) + "\n\t",
                'niter: ' + str(niter) + "\n\t",
                'nres: ' + str(n_res) + "\n\t",
                'learning_rate: ' + learning_rate.__format__('e') + "\n\t",
                'momentum: ' + str(momentum) + "\n\t",
                'weight_decay: ' + weight_decay.__format__('e') + "\n\t",
                'optimizer_type: ' + optimizer_type + "\n\t",
            )

        self.modelname = "classif_3dcnn_" \
                         + '_bn' + str(self.batchnorm) \
                         + '_niter' + str(niter) \
                         + '_nres' + str(n_res) \
                         + '_momrange' + str(mom_range) \
                         + '_momentum' + str(momentum) \
                         + '_' + str(optimizer_type) \
                         + "_nclasses" + str(self.n_classes) \
                         + '_gated' + str(self.gated) \
                         + '_resblocks' + str(self.resblocks) \
                         + '_initlr' + learning_rate.__format__('e') \
                         + '_wd' + weight_decay.__format__('e') \
                         + '_size' + str(self.size)
        model = ConvResnet3D(
            self.maxpool,
            self.in_channels,
            self.out_channels,
            self.kernel_sizes,
            self.strides,
            self.dilatations,
            self.padding,
            self.batchnorm,
            self.n_classes,
            max_fvc=None,
            n_kernels=self.n_kernels,
            is_bayesian=self.is_bayesian,
            activation=torch.nn.ReLU,
            n_res=n_res,
            gated=self.gated,
            has_dense=self.has_dense,
            resblocks=self.resblocks,
        ).to(device)
        criterion = nn.MSELoss()
        l1 = nn.L1Loss()
        if optimizer_type == 'adamw':
            optimizer = torch.optim.AdamW(params=model.parameters(),
                                          lr=learning_rate,
                                          weight_decay=weight_decay,
                                          amsgrad=True)
        elif optimizer_type == 'sgd':
            optimizer = torch.optim.SGD(params=model.parameters(),
                                        lr=learning_rate,
                                        weight_decay=weight_decay,
                                        momentum=momentum)
        elif optimizer_type == 'rmsprop':
            optimizer = torch.optim.RMSprop(params=model.parameters(),
                                            lr=learning_rate,
                                            weight_decay=weight_decay,
                                            momentum=momentum)
        else:
            exit('error: no such optimizer type available')
        # if self.fp16_run:
        #     from apex import amp
        #    model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

        # Load checkpoint if one exists
        epoch = 0
        best_loss = -1
        if self.checkpoint_path is not None and self.save:
            model, optimizer, \
            epoch, losses, \
            kl_divs, losses_recon, \
            best_loss = load_checkpoint(self.checkpoint_path,
                                        model,
                                        self.maxpool,
                                        save=self.save,
                                        padding=self.padding,
                                        has_dense=self.has_dense,
                                        batchnorm=self.batchnorm,
                                        flow_type=None,
                                        padding_deconv=None,
                                        optimizer=optimizer,
                                        z_dim=self.n_classes,
                                        gated=self.gated,
                                        in_channels=self.in_channels,
                                        out_channels=self.out_channels,
                                        kernel_sizes=self.kernel_sizes,
                                        kernel_sizes_deconv=None,
                                        strides=self.strides,
                                        strides_deconv=None,
                                        dilatations=self.dilatations,
                                        dilatations_deconv=None,
                                        name=self.modelname,
                                        n_res=n_res,
                                        resblocks=self.resblocks,
                                        h_last=None,
                                        n_elements=None,
                                        n_flows=None,
                                        predict=False,
                                        n_kernels=self.n_kernels
                                        )
        model = model.to(device)
        # t1 = torch.Tensor(np.load('/run/media/simon/DATA&STUFF/data/biology/arrays/t1.npy'))
        # targets = torch.Tensor([0 for _ in t1])

        train_transform = transforms.Compose([
            transforms.RandomChoice([XFlip(), YFlip(),
                                     ZFlip()]),
            transforms.RandomChoice([Flip90(), Flip180(),
                                     Flip270()]),
            # ColorJitter3D(.1, .1, .1, .1),
            # transforms.RandomChoice(
            #    [
            #        RandomAffine3D(0, [.1, .1], [.1, .1], [.1, .1]),
            #        RandomAffine3D(1, [.1, .1], [.1, .1], [.1, .1]),
            #        RandomAffine3D(2, [.1, .1], [.1, .1], [.1, .1])
            #    ]
            # ),
            transforms.RandomChoice([
                RandomRotation3D(25, 0),
                RandomRotation3D(25, 1),
                RandomRotation3D(25, 2)
            ]),
            torchvision.transforms.Normalize(mean=(self.mean), std=(self.std)),
            # Normalize()
        ])
        """
        """
        all_set = CTDataset(self.path,
                            self.train_csv,
                            transform=train_transform,
                            size=self.size)
        spliter = validation_spliter(all_set, cv=self.cross_validation)
        model.max_fvc = all_set.max_fvc
        print("Training Started on device:", device)
        best_losses = []
        for cv in range(self.cross_validation):
            model.random_init(self.init_func)
            best_loss = -1
            valid_set, train_set = spliter.__next__()
            valid_set.transform = False
            train_loader = DataLoader(train_set,
                                      num_workers=0,
                                      shuffle=True,
                                      batch_size=self.batch_size,
                                      pin_memory=False,
                                      drop_last=True)
            valid_loader = DataLoader(valid_set,
                                      num_workers=0,
                                      shuffle=True,
                                      batch_size=2,
                                      pin_memory=False,
                                      drop_last=True)

            # Get shared output_directory ready
            logger = SummaryWriter('logs')

            if scheduler == 'ReduceLROnPlateau':
                lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer,
                    factor=0.1,
                    cooldown=10,
                    patience=20,
                    verbose=True,
                    min_lr=1e-15)
            elif scheduler == 'CycleScheduler':
                lr_schedule = CycleScheduler(optimizer,
                                             learning_rate,
                                             n_iter=niter * len(train_loader),
                                             momentum=[
                                                 max(0.0,
                                                     momentum - mom_range),
                                                 min(1.0,
                                                     momentum + mom_range),
                                             ])

            losses = {
                "train": [],
                "valid": [],
            }
            log_gaussians = {
                "train": [],
                "valid": [],
            }
            vars = {
                "train": [],
                "valid": [],
            }
            accuracies = {
                "train": [],
                "valid": [],
            }
            shapes = {
                "train": len(train_set),
                "valid": len(valid_set),
            }
            early_stop_counter = 0
            print("\n\n\nCV:", cv, "/", self.cross_validation,
                  "\nTrain samples:", len(train_set), "\nValid samples:",
                  len(valid_set), "\n\n\n")
            train_losses = []
            train_accuracy = []
            valid_losses = []
            train_log_gauss = []
            valid_log_gauss = []
            train_var = []
            valid_var = []
            valid_accuracy = []
            for epoch in range(self.epochs):
                if early_stop_counter == 100:
                    if self.verbose > 0:
                        print('EARLY STOPPING.')
                    break
                best_epoch = False
                model.train()

                # pbar = tqdm(total=len(train_loader))
                for i, batch in enumerate(train_loader):
                    #    pbar.update(1)
                    model.zero_grad()
                    _, images, targets, patient_info = batch
                    images = images.to(device)
                    targets = targets.type(torch.FloatTensor).to(device)
                    patient_info = patient_info.to(device)

                    _, mu, log_var = model(images, patient_info)
                    mu = mu.type(torch.FloatTensor).to(device)
                    log_var = log_var.type(torch.FloatTensor).to(device)

                    rv = norm(mu.detach().cpu().numpy(),
                              np.exp(log_var.detach().cpu().numpy()))
                    train_log_gauss += [rv.pdf(mu.detach().cpu().numpy())]
                    # loss = criterion(preds, targets.to(device)) # - 0.01 * log_gaussian(preds.view(-1), mu.view(-1), log_var.view(-1))
                    loss = -log_gaussian(targets, mu,
                                         log_var) / self.batch_size
                    loss = torch.sum(loss, 0)
                    argmin = torch.argmin(loss)
                    # print('argmin: ', argmin)
                    loss = torch.mean(loss)
                    # loss += criterion(mu[argmin], log_var[argmin])
                    train_var += [
                        np.exp(log_var[argmin].detach().cpu().numpy()) *
                        model.max_fvc
                    ]
                    loss.backward()
                    l1_loss = l1(mu[argmin].to(device), targets.to(device))

                    accuracy = l1_loss.item()
                    train_accuracy += [accuracy * model.max_fvc]

                    train_losses += [loss.item()]

                    optimizer.step()
                    if scheduler == "CycleScheduler":
                        lr_schedule.step()
                    logger.add_scalar('training_loss', loss.item(),
                                      i + len(train_loader) * epoch)
                    del loss

                if epoch % self.epochs_per_print == 0:
                    losses["train"] += [
                        np.mean(train_losses) / self.batch_size
                    ]
                    accuracies["train"] += [np.mean(train_accuracy)]
                    log_gaussians["train"] += [np.mean(train_log_gauss)]
                    vars['train'] += [np.mean(train_var)]
                    if self.verbose > 1:
                        print("Epoch: {}:\t"
                              "Train Loss: {:.5f} , "
                              "Accuracy: {:.3f} , "
                              "confidence: {:.9f} , "
                              "Vars: {:.9f} ".format(
                                  epoch, losses["train"][-1],
                                  accuracies["train"][-1],
                                  log_gaussians["train"][-1],
                                  np.sqrt(vars["train"][-1])))
                    train_losses = []
                    train_accuracy = []
                    train_log_gauss = []
                    train_var = []

                model.eval()
                # pbar = tqdm(total=len(valid_loader))
                for i, batch in enumerate(valid_loader):
                    #    pbar.update(1)
                    _, images, targets, patient_info = batch
                    images = images.to(device)
                    targets = targets.to(device)
                    patient_info = patient_info.to(device)
                    _, mu, log_var = model(images, patient_info)
                    rv = norm(mu.detach().cpu().numpy(),
                              np.exp(log_var.detach().cpu().numpy()))
                    loss = -log_gaussian(
                        targets.type(torch.FloatTensor).to(device),
                        mu.type(torch.FloatTensor).to(device),
                        torch.exp(log_var.type(
                            torch.FloatTensor).to(device))) / self.batch_size
                    loss = torch.sum(loss, 0)
                    argmin = torch.argmin(loss)
                    loss = torch.mean(loss)
                    valid_losses += [np.exp(-loss.item())]
                    valid_log_gauss += [rv.pdf(mu.detach().cpu().numpy())]
                    valid_var += [
                        np.exp(log_var[argmin].detach().cpu().numpy()) *
                        model.max_fvc
                    ]
                    l1_loss = l1(mu[argmin], targets.to(device))

                    accuracy = l1_loss.item()
                    valid_accuracy += [accuracy * model.max_fvc]
                    logger.add_scalar('training loss', loss.item(),
                                      i + len(train_loader) * epoch)
                if scheduler == "ReduceLROnPlateau":
                    if epoch > 25:
                        lr_schedule.step(losses["valid"][-1])
                if epoch % self.epochs_per_print == 0:
                    losses["valid"] += [np.mean(valid_losses) / 2]
                    accuracies["valid"] += [np.mean(valid_accuracy)]
                    log_gaussians["valid"] += [np.mean(valid_log_gauss)]
                    vars['valid'] += [np.mean(valid_var)]
                    if self.verbose > 0:
                        print("Epoch: {}:\t"
                              "Valid Loss: {:.5f} , "
                              "Accuracy: {:.3f} "
                              "confidence: {:.9f} "
                              "Vars: {:.9f} ".format(
                                  epoch,
                                  losses["valid"][-1],
                                  accuracies["valid"][-1],
                                  log_gaussians["valid"][-1],
                                  np.sqrt(vars['valid'][-1]),
                              ))
                    if self.verbose > 1:
                        print("Current LR:", optimizer.param_groups[0]['lr'])
                    if 'momentum' in optimizer.param_groups[0].keys():
                        print("Current Momentum:",
                              optimizer.param_groups[0]['momentum'])
                    valid_losses = []
                    valid_accuracy = []
                    valid_log_gauss = []
                    valid_var = []

                mode = 'valid'
                if epoch > 1 and epoch % self.epochs_per_print == 0:
                    if (losses[mode][-1] > best_loss or best_loss == -1) \
                            and not np.isnan(losses[mode][-1]):
                        if self.verbose > 1:
                            print('BEST EPOCH!', losses[mode][-1],
                                  accuracies[mode][-1])
                        early_stop_counter = 0
                        best_loss = losses[mode][-1]
                        best_epoch = True
                    else:
                        early_stop_counter += 1

                if epoch % self.epochs_per_checkpoint == 0:
                    if best_epoch and self.save:
                        if self.verbose > 1:
                            print('Saving model...')
                        save_checkpoint(model=model,
                                        optimizer=optimizer,
                                        maxpool=self.maxpool,
                                        padding=self.padding,
                                        padding_deconv=None,
                                        learning_rate=learning_rate,
                                        epoch=epoch,
                                        checkpoint_path=checkpoint_path,
                                        z_dim=self.n_classes,
                                        gated=self.gated,
                                        batchnorm=self.batchnorm,
                                        losses=losses,
                                        kl_divs=None,
                                        losses_recon=None,
                                        in_channels=self.in_channels,
                                        out_channels=self.out_channels,
                                        kernel_sizes=self.kernel_sizes,
                                        kernel_sizes_deconv=None,
                                        strides=self.strides,
                                        strides_deconv=None,
                                        dilatations=self.dilatations,
                                        dilatations_deconv=None,
                                        best_loss=best_loss,
                                        save=self.save,
                                        name=self.modelname,
                                        n_flows=None,
                                        flow_type=None,
                                        n_res=n_res,
                                        resblocks=resblocks,
                                        h_last=None,
                                        n_elements=None,
                                        n_kernels=self.n_kernels)

                if self.plot_perform:
                    plot_performance(loss_total=losses,
                                     losses_recon=None,
                                     accuracies=accuracies,
                                     kl_divs=None,
                                     shapes=shapes,
                                     results_path="../figures",
                                     filename="training_loss_trace_" +
                                     self.modelname + '.jpg')
            if self.verbose > 0:
                print('BEST LOSS :', best_loss)
            best_losses += [best_loss]
        return np.mean(best_losses)