コード例 #1
0
    def train_epoch(self, epoch):
        self.model.train()

        for batch_idx, input_tuple in enumerate(self.train_data_loader):

            self.optimizer.zero_grad()
            # import ipdb;ipdb.set_trace()
            input_tensor, target = prepare_input(input_tuple=input_tuple, args=self.args)
            input_tensor.requires_grad = True
            output = self.model(input_tensor)
            # import ipdb;ipdb.set_trace()
            # loss_dice, per_ch_score = self.criterion(output[0], target)
            loss_dice, per_ch_score = self.criterion(output, target)
            # loss_dice = self.criterion(output, target)
            # print('epoch ',epoch, 'loss: ',loss_dice.item())
            loss_dice.backward()
            self.optimizer.step()
            if self.args.local_rank == 0:
                self.writer.update_scores(batch_idx, loss_dice.item(), per_ch_score, 'train',
                                          epoch * self.len_epoch + batch_idx)

            if (batch_idx + 1) % self.terminal_show_freq == 0:
                partial_epoch = epoch + batch_idx / self.len_epoch - 1
                if self.args.local_rank == 0:
                    self.writer.display_terminal(partial_epoch, epoch, 'train')
        if self.args.local_rank == 0:
            self.writer.display_terminal(self.len_epoch, epoch, mode='train', summary=True)
コード例 #2
0
ファイル: trainer.py プロジェクト: Zhongyihhh/CS446
    def train_epoch(self, epoch):
        self.model.train()

        for batch_idx, input_tuple in enumerate(self.train_data_loader):
            # print("batch_idx:",batch_idx)
            # print("input_tuple shape:", input_tuple[0].size())
            self.optimizer.zero_grad()

            input_tensor, target = prepare_input(input_tuple=input_tuple,
                                                 args=self.args)
            input_tensor.requires_grad = True
            output = self.model(input_tensor)
            # print("model output shape:", output.size())
            loss_dice, per_ch_score = self.criterion(output, target)
            loss_dice.backward()
            self.optimizer.step()

            self.writer.update_scores(batch_idx, loss_dice.item(),
                                      per_ch_score, 'train',
                                      epoch * self.len_epoch + batch_idx)

            if (batch_idx + 1) % self.terminal_show_freq == 0:
                partial_epoch = epoch + batch_idx / self.len_epoch - 1
                self.writer.display_terminal(partial_epoch, epoch, 'train')

        self.writer.display_terminal(self.len_epoch,
                                     epoch,
                                     mode='train',
                                     summary=True)
コード例 #3
0
ファイル: viz_old.py プロジェクト: zmbhou/MedicalZooPytorch
def visualize_offline(args, epoch, model, full_volume, affine, writer):
    model.eval()
    test_loss = 0

    classes, slices, height, width = 4, 144, 192, 256

    predictions = torch.tensor([]).cpu()
    segment_map = torch.tensor([]).cpu()
    for batch_idx, input_tuple in enumerate(full_volume):
        with torch.no_grad():
            t1_path, t2_path, seg_path = input_tuple

            img_t1, img_t2, sub_segment_map = torch.tensor(np.load(t1_path), dtype=torch.float32)[None, None], \
                                              torch.tensor(np.load(t2_path), dtype=torch.float32)[None, None], \
                                              torch.tensor(
                                                  np.load(seg_path), dtype=torch.float32)[None]

            input_tensor, sub_segment_map = prepare_input(args, (img_t1, img_t2, sub_segment_map))
            input_tensor.requires_grad = False

            predicted = model(input_tensor).cpu()
            predictions = torch.cat((predictions, predicted))
            segment_map = torch.cat((segment_map, sub_segment_map.cpu()))

    predictions = predictions.view(-1, classes, slices, height, width).detach()
    segment_map = segment_map.view(-1, slices, height, width).detach()
    save_path_2d_fig = args.save + '/' + 'epoch__' + str(epoch).zfill(4) + '.png'

    create_2d_views(predictions, segment_map, epoch, writer, save_path_2d_fig)

    # TODO test save
    save_path = args.save + '/Pred_volume_epoch_' + str(epoch)
    save_3d_vol(predictions, affine, save_path)

    return test_loss
コード例 #4
0
    def train_epoch(self, epoch):
        self.model.train()
        n_processed = 0
        for batch_idx, input_tuple in enumerate(self.train_data_loader):
            self.optimizer.zero_grad()

            input_tensor, target = prepare_input(self.args, input_tuple)
            input_tensor.requires_grad = True
            output = self.model(input_tensor)
            loss_dice, per_ch_score = self.criterion(output, target)
            loss_dice.backward()
            self.optimizer.step()

            partial_epoch = epoch + batch_idx / self.len_epoch - 1

            self.writer.update_scores(batch_idx, loss_dice.item(),
                                      per_ch_score, 'train',
                                      epoch * self.len_epoch + batch_idx)
            ## TODO display terminal statistics per batch or iteration steps
            if (batch_idx % 100 == 0):
                self.writer.display_terminal(partial_epoch, epoch, 'train')

        # END OF EPOCH DISPLAY
        self.writer.display_terminal(self.len_epoch,
                                     epoch,
                                     mode='train',
                                     summary=True)
コード例 #5
0
    def train_epoch_aug(self, epoch):
        self.model.train()
        
        # args = get_arguments()
        # train_dataset = MICCAI2020_RIBFRAC(args, 'train', dataset_path='../datasets', classes=args.classes, dim=args.dim,
        #                                       split_id=0, samples=args.samples_train, load=args.loadData)
        # patch_size = (128, 128, 48)
        # batch_size = 2 
        # num_threads_for_brats_example = 2
        
        # dataloader_train = MICCAI2020_RIBFRAC_DataLoader3D(self.dataset, batch_size, patch_size, num_threads_for_brats_example)

        # tr_transforms = get_train_transform(patch_size)
        # tr_gen = SingleThreadedAugmenter(dataloader_train, tr_transforms,) #num_processes=num_threads_for_brats_example,
                                   # num_cached_per_queue=3,
                                  #  seeds=None, 
                                  #  pin_memory=False)
        # tr_gen.restart()
        # _ = next(tr_gen)
        # import ipdb;ipdb.set_trace()
        # for batch_idx, input_tuple in enumerate(self.train_data_loader):
        for batch_idx, data_seg_dict in enumerate(self.train_data_loader_aug):
        # for batch_idx, data_seg_dict in enumerate(tr_gen):
            if batch_idx / self.len_epoch == 1:
                break
            input_tuple = [torch.from_numpy(data_seg_dict['data']),torch.from_numpy(data_seg_dict['seg'])]
            # input_tuple = [data_seg_dict['data'],data_seg_dict['seg']]
            self.optimizer.zero_grad()
            # import ipdb;ipdb.set_trace()
            input_tensor, target = prepare_input(input_tuple=input_tuple, args=self.args)
            input_tensor.requires_grad = True
            output = self.model(input_tensor)
            # import ipdb;ipdb.set_trace()
            # loss_dice, per_ch_score = self.criterion(output[0], target)
            loss_dice, per_ch_score = self.criterion(output, target)
            # loss_dice = self.criterion(output, target)
            # print('epoch ',epoch, 'loss: ',loss_dice.item())
            loss_dice.backward()
            self.optimizer.step()
            # import ipdb;ipdb.set_trace()
            if self.args.local_rank == 0:
                self.writer.update_scores(batch_idx, loss_dice.item(), per_ch_score, 'train',
                                      epoch * self.len_epoch + batch_idx)

            if (batch_idx + 1) % self.terminal_show_freq == 0:
                partial_epoch = epoch + batch_idx / self.len_epoch - 1
                if self.args.local_rank == 0:
                    self.writer.display_terminal(partial_epoch, epoch, 'train')
        if self.args.local_rank == 0:
            self.writer.display_terminal(self.len_epoch, epoch, mode='train', summary=True)
コード例 #6
0
ファイル: trainer.py プロジェクト: yf817/MedicalZooPytorch
    def validate_epoch(self, epoch):
        self.model.eval()

        for batch_idx, input_tuple in enumerate(self.valid_data_loader):
            with torch.no_grad():
                input_tensor, target = prepare_input(self.args, input_tuple)
                input_tensor.requires_grad = False

                output = self.model(input_tensor)
                loss, per_ch_score = self.criterion(output, target)

                self.writer.update_scores(batch_idx, loss.item(), per_ch_score, 'train',
                                          epoch * self.len_epoch + batch_idx)

        self.writer.display_terminal(len(self.valid_data_loader), epoch, mode='train', summary=True)
コード例 #7
0
    def validate_epoch(self, epoch):
        self.model.eval()

        for batch_idx, input_tuple in enumerate(self.valid_data_loader):
            with torch.no_grad():
                input_tensor, target = prepare_input(input_tuple=input_tuple, args=self.args)
                input_tensor.requires_grad = False

                output = self.model(input_tensor)
                # loss, per_ch_score = self.criterion(output[0], target)
                loss, per_ch_score = self.criterion(output, target)
                # loss_dice = self.criterion(output, target)
                # print('epoch ',epoch , 'loss: ',loss_dice.item())
                if self.args.local_rank == 0:
                    self.writer.update_scores(batch_idx, loss.item(), per_ch_score, 'val',
                                              epoch * self.len_epoch + batch_idx)
        if self.args.local_rank == 0:
            self.writer.display_terminal(len(self.valid_data_loader), epoch, mode='val', summary=True)
コード例 #8
0
ファイル: trainer.py プロジェクト: Zhongyihhh/CS446
    def validate_epoch(self, epoch):
        self.model.eval()

        for batch_idx, input_tuple in enumerate(self.valid_data_loader):
            # print("batch_idx:",batch_idx)
            # print("input_tuple shape:", input_tuple[0].size())
            with torch.no_grad():
                input_tensor, target = prepare_input(input_tuple=input_tuple,
                                                     args=self.args)
                input_tensor.requires_grad = False

                output = self.model(input_tensor)
                loss, per_ch_score = self.criterion(output, target)

                self.writer.update_scores(batch_idx, loss.item(), per_ch_score,
                                          'val',
                                          epoch * self.len_epoch + batch_idx)

        self.writer.display_terminal(len(self.valid_data_loader),
                                     epoch,
                                     mode='val',
                                     summary=True)