def run_once(self):
        
        log_dir = self.log_dir

        misc.check_manual_seed(self.seed)
        train_pairs, valid_pairs = dataset.prepare_data_VIABLE_2048()
        print(len(train_pairs))
        # --------------------------- Dataloader

        train_augmentors = self.train_augmentors()
        train_dataset = dataset.DatasetSerial(train_pairs[:],
                        shape_augs=iaa.Sequential(train_augmentors[0]),
                        input_augs=iaa.Sequential(train_augmentors[1]))

        infer_augmentors = self.infer_augmentors()
        infer_dataset = dataset.DatasetSerial(valid_pairs[:],
                        shape_augs=iaa.Sequential(infer_augmentors))

        train_loader = data.DataLoader(train_dataset, 
                                num_workers=self.nr_procs_train, 
                                batch_size=self.train_batch_size, 
                                shuffle=True, drop_last=True)

        valid_loader = data.DataLoader(infer_dataset, 
                                num_workers=self.nr_procs_valid, 
                                batch_size=self.infer_batch_size, 
                                shuffle=True, drop_last=False)

        # --------------------------- Training Sequence

        if self.logging:
            misc.check_log_dir(log_dir)

        device = 'cuda'

        # networks
        input_chs = 3    
        net = DenseNet(input_chs, self.nr_classes)
        net = torch.nn.DataParallel(net).to(device)
        # print(net)

        # optimizers
        optimizer = optim.Adam(net.parameters(), lr=self.init_lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, self.lr_steps)

        # load pre-trained models
        if self.load_network:
            saved_state = torch.load(self.save_net_path)
            net.load_state_dict(saved_state)
        #
        trainer = Engine(lambda engine, batch: self.train_step(net, batch, optimizer, 'cuda'))
        inferer = Engine(lambda engine, batch: self.infer_step(net, batch, 'cuda'))

        train_output = ['loss', 'acc']
        infer_output = ['prob', 'true']
        ##

        if self.logging:
            checkpoint_handler = ModelCheckpoint(log_dir, self.chkpts_prefix, 
                                            save_interval=1, n_saved=120, require_empty=False)
            # adding handlers using `trainer.add_event_handler` method API
            trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
                                    to_save={'net': net}) 

        timer = Timer(average=True)
        timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                            pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
        timer.attach(inferer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                            pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

        # attach running average metrics computation
        # decay of EMA to 0.95 to match tensorpack default
        RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach(trainer, 'loss')
        RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(trainer, 'acc')

        # attach progress bar
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=['loss'])
        pbar.attach(inferer)

        # adding handlers using `trainer.on` decorator API
        @trainer.on(Events.EXCEPTION_RAISED)
        def handle_exception(engine, e):
            if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
                engine.terminate()
                warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
                checkpoint_handler(engine, {'net_exception': net})
            else:
                raise e

        # writer for tensorboard logging
        if self.logging:
            writer = SummaryWriter(log_dir=log_dir)
            json_log_file = log_dir + '/stats.json'
            with open(json_log_file, 'w') as json_file:
                json.dump({}, json_file) # create empty file

        @trainer.on(Events.EPOCH_STARTED)
        def log_lrs(engine):
            if self.logging:
                lr = float(optimizer.param_groups[0]['lr'])
                writer.add_scalar("lr", lr, engine.state.epoch)
            # advance scheduler clock
            scheduler.step()

        ####
        def update_logs(output, epoch, prefix, color):
            # print values and convert
            max_length = len(max(output.keys(), key=len))
            for metric in output:
                key = colored(prefix + '-' + metric.ljust(max_length), color)
                print('------%s : ' % key, end='')
                print('%0.7f' % output[metric])
            if 'train' in prefix:
                lr = float(optimizer.param_groups[0]['lr'])
                key = colored(prefix + '-' + 'lr'.ljust(max_length), color)
                print('------%s : %0.7f' % (key, lr))

            if not self.logging:
                return

            # create stat dicts
            stat_dict = {}
            for metric in output:
                metric_value = output[metric] 
                stat_dict['%s-%s' % (prefix, metric)] = metric_value

            # json stat log file, update and overwrite
            with open(json_log_file) as json_file:
                json_data = json.load(json_file)

            current_epoch = str(epoch)
            if current_epoch in json_data:
                old_stat_dict = json_data[current_epoch]
                stat_dict.update(old_stat_dict)
            current_epoch_dict = {current_epoch : stat_dict}
            json_data.update(current_epoch_dict)

            with open(json_log_file, 'w') as json_file:
                json.dump(json_data, json_file)

            # log values to tensorboard
            for metric in output:
                writer.add_scalar(prefix + '-' + metric, output[metric], current_epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_train_running_results(engine):
            """
            running training measurement
            """
            training_ema_output = engine.state.metrics #
            update_logs(training_ema_output, engine.state.epoch, prefix='train-ema', color='green')

        ####
        def get_init_accumulator(output_names):
            return {metric : [] for metric in output_names}

        import cv2
        def process_accumulated_output(output):
            def uneven_seq_to_np(seq, batch_size=self.infer_batch_size):
                if self.infer_batch_size == 1:
                    return np.squeeze(seq)
                    
                item_count = batch_size * (len(seq) - 1) + len(seq[-1])
                cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype)
                for idx in range(0, len(seq)-1):
                    cat_array[idx   * batch_size : 
                            (idx+1) * batch_size] = seq[idx] 
                cat_array[(idx+1) * batch_size:] = seq[-1]
                return cat_array
            #
            prob = uneven_seq_to_np(output['prob'])
            true = uneven_seq_to_np(output['true'])

            # cmap = plt.get_cmap('jet')
            # epi = prob[...,1]
            # epi = (cmap(epi) * 255.0).astype('uint8')
            # cv2.imwrite('sample.png', cv2.cvtColor(epi, cv2.COLOR_RGB2BGR))

            pred = np.argmax(prob, axis=-1)
            true = np.squeeze(true)

            # deal with ignore index
            pred = pred.flatten()
            true = true.flatten()
            pred = pred[true != 0] - 1
            true = true[true != 0] - 1

            acc = np.mean(pred == true)
            inter = (pred * true).sum()
            total = (pred + true).sum()
            dice = 2 * inter / total
            #
            proc_output = dict(acc=acc, dice=dice)
            return proc_output

        @trainer.on(Events.EPOCH_COMPLETED)
        def infer_valid(engine):
            """
            inference measurement
            """
            inferer.accumulator = get_init_accumulator(infer_output)
            inferer.run(valid_loader)
            output_stat = process_accumulated_output(inferer.accumulator)
            update_logs(output_stat, engine.state.epoch, prefix='valid', color='red')

        @inferer.on(Events.ITERATION_COMPLETED)
        def accumulate_outputs(engine):
            batch_output = engine.state.output
            for key, item in batch_output.items():
                engine.accumulator[key].extend([item])
        ###
        #Setup is done. Now let's run the training
        trainer.run(train_loader, self.nr_epochs)
        return
    def run(self):
        def center_pad_to(img, h, w):
            shape = img.shape

            diff_h = h - shape[0]
            padt = diff_h // 2
            padb = diff_h - padt

            diff_w = w - shape[1]
            padl = diff_w // 2
            padr = diff_w - padl

            img = np.lib.pad(img, ((padt, padb), (padl, padr), (0, 0)),
                             'constant',
                             constant_values=255)
            return img

        input_chs = 3
        net = DenseNet(input_chs, self.nr_classes)

        saved_state = torch.load(self.inf_model_path)
        pretrained_dict = saved_state.module.state_dict(
        )  # due to torch.nn.DataParallel
        net.load_state_dict(pretrained_dict, strict=False)
        net = net.to('cuda')

        file_list = glob.glob('%s/*%s' %
                              (self.inf_imgs_dir, self.inf_imgs_ext))
        file_list.sort()  # ensure same order

        if not os.path.isdir(self.inf_output_dir):
            os.makedirs(self.inf_output_dir)

        cmap = plt.get_cmap('jet')
        for filename in file_list:
            filename = os.path.basename(filename)
            basename = filename.split('.')[0]

            print(filename, ' ---- ', end='', flush=True)

            img = cv2.imread(self.inf_imgs_dir + filename)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            img = cv2.resize(img, (0, 0),
                             fx=0.25,
                             fy=0.25,
                             interpolation=cv2.INTER_CUBIC)

            orig_shape = img.shape
            img = center_pad_to(img, 2880, 2880)
            pred = self.infer_step(net, [img])[0, ..., 1:]
            pred = misc.cropping_center(pred, orig_shape[:-1])

            # plt.subplot(1,3,1)
            # plt.imshow(img)
            # plt.subplot(1,3,2)
            # plt.imshow(pred[...,0])
            # plt.subplot(1,3,3)
            # plt.imshow(pred[...,1])
            # plt.show()
            # exit()
            np.save('%s/%s.npy' % (self.inf_output_dir, basename), pred)

            # epi = cmap(pred[0,...,2])[...,:3] # gray to RGB heatmap
            # epi = (epi * 255).astype('uint8')
            # epi = cv2.cvtColor(epi, cv2.COLOR_RGB2BGR)

            # cv2.imwrite('%s/%s.png' % (self.inf_output_dir, basename), epi)
            print('FINISH')
Exemplo n.º 3
0
for checkpoint in filtered_checkpoints:
	print(checkpoint)
	state = torch.load(checkpoint)
	if 'best_loss' in state:
		if state['best_loss'] is None:
			continue
		print(state['best_loss'])

		if state['best_loss'] > 0.50 or state['best_loss'] < 0.47:
			continue
	if 'best_acc' in state:
		print(state['best_acc'])
		if state['best_acc'] is None or state['best_acc'] < 85.5:
			continue
	try:
		net.load_state_dict(state['state_dict'])
	except:
		continue
	net.cuda()
	net.eval()
	#eval()

import pickle
for ID in final_dict:
	pred = final_dict[ID]
	#pickle.dump(pred, file=open("den_pred", 'w'))
	#pred = pred_dict[ID]
	subm['predicted'][ID] = np.argmax(pred) + 1

'''
for (data, ids) in test_dataloader:
Exemplo n.º 4
0
from config import Config
import cv2

device = 'cuda'

net = DenseNet(3, 2)
net.eval()  # infer mode

viable_saved_state = torch.load('log/v1.0.0.1/model_net_46.pth')

new_saved_state = {}

for key, value in viable_saved_state.items():
    new_saved_state[key[7:]] = value

net.load_state_dict(new_saved_state)
net = torch.nn.DataParallel(net).to(device)

wsi_img = openslide.OpenSlide('01_01_0138.svs')
wsi_w, wsi_h = wsi_img.level_dimensions[0]

prediction = np.zeros((wsi_h, wsi_w))
batch = []
location = []
batch_size = 80
one = np.ones((512, 512))
for i in range(0, wsi_h, 512):
    for j in range(0, wsi_w, 512):
        print('{0} {1}'.format(i, j))
        # if i+512>wsi_h and j+512>wsi_w:
        #     patch = wsi_img.read_region((wsi_w - 512, wsi_h-512), 0, (512, 512)).convert('RGB')
Exemplo n.º 5
0
        end = time.time()
        print(
            '[{}/{}] - Loss: {:.4f} Acc: {:.4f} Val Loss: {:.4f} Val Acc: {:.4f} Time: {:.4f} min'
            .format(epoch, epochs, losses[epoch], accs[epoch],
                    val_losses[epoch], val_accs[epoch], (end - start) / 60))
        if epoch > 10 and losses[epoch] < best_loss:
            best_loss = val_losses[epoch]
            print('Saving model at epoch {}'.format(epoch))
            torch.save(net.state_dict(), 'model.ckpt')

    print('Finished Training')
    torch.save(net.state_dict(), 'model.ckpt')

if test:
    print('Starting inference...')
    correct = 0.0
    net.load_state_dict(torch.load('model.ckpt'))
    start = time.time()
    net.eval()
    for i, (image, target) in enumerate(val_loader):
        # forward + backward + optimize
        outputs = net(image) > 0.5
        # print statistics
        correct += ((outputs[:, 1] == target.byte()).sum().item() /
                    len(outputs))

    end = time.time()
    print('Acc: {} Time:{} min'.format(correct / len(val_loader),
                                       (end - start) / 60))
    print('Finished testing')