Example #1
0
    def __init__(self, resume, device, weight_path, batch_size, dataset):
        torch.manual_seed(7)
        self.resume = resume
        self.device = device
        if not self.resume:
            self.init_routes()
        # Training params
        self.start_epoch = 0
        self.num_timesteps_input = 12
        self.num_timesteps_output = 3

        self.epochs = 1000
        self.batch_size = 50

        # Tools
        self.summary = TensorboardSummary(os.path.join('result', 'events'))

        # Model
        self.model = None
        self.optimizer = None
        self.loss_criterion = None

        # Eval measures
        self.best_mae = 10000
        self.best_loss = 10000
Example #2
0
    def __init__(self, args, model, train_set, val_set, test_set, class_weights, saver):
        self.args = args
        self.saver = saver
        self.saver.save_experiment_config()
        self.train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        self.val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
        self.test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
        self.train_summary = TensorboardSummary(os.path.join(self.saver.experiment_dir, "train"))
        self.train_writer = self.train_summary.create_summary()
        self.val_summary = TensorboardSummary(os.path.join(self.saver.experiment_dir, "validation"))
        self.val_writer = self.val_summary.create_summary()
        self.model = model
        self.dataset_size = {'train': len(train_set), 'val': len(val_set), 'test': len(test_set)}

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        if args.use_balanced_weights:
            weight = torch.from_numpy(class_weights.astype(np.float32))
        else:
            weight = None

        if args.optimizer == 'SGD':
            print('Using SGD')
            self.optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            print('Using Adam')
            self.optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

        self.lr_scheduler = None
        if args.use_lr_scheduler:
            if args.lr_scheduler == 'step':
                print('Using step lr scheduler')                
                self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[int(x) for x in args.step_size.split(",")], gamma=0.1)

        self.criterion = SegmentationLosses(weight=weight, ignore_index=255, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.evaluator = Evaluator(train_set.num_classes)
        self.best_pred = 0.0
Example #3
0
    def __init__(self):
        super().__init__()
        now_time = time.strftime('%Y-%m-%d-%H-%M',time.localtime(time.time()))
        logger_path = os.path.join(
            self.args.training.save_dir,
            self.args.dataset.dataset_train,
            self.args.models.model_warpper,
            self.args.training.experiment_id,
            '%s.log' % now_time
        )
        set_logger_path(logger_path)
        logger.info(self.args)

        # Define Saver
        self.saver = Saver(self.args)

        # Define Tensorboard Summary
        self.summary = TensorboardSummary()
        self.writer = self.summary.create_summary(self.saver.experiment_dir, self.args.models)


        self.init_training_container()
        self.batchsize = self.args.training.batchsize
        self.reset_batchsize()
        self.evaluator = Evaluator()
        self.best = 0.0

        # show parameters to be trained
        logger.debug('\nTraining params:')
        for p in self.model.named_parameters():
            if p[1].requires_grad:
                logger.debug(p[0])
        logger.debug('\n')

        # Clear start epoch if fine-tuning
        logger.info('Starting iteration: %d' % self.start_it)
        logger.info('Total iterationes: %d' % self.args.training.max_iter)
Example #4
0
class Trainer:

    def __init__(self, args, model, train_set, val_set, test_set, class_weights, saver):
        self.args = args
        self.saver = saver
        self.saver.save_experiment_config()
        self.train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        self.val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
        self.test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
        self.train_summary = TensorboardSummary(os.path.join(self.saver.experiment_dir, "train"))
        self.train_writer = self.train_summary.create_summary()
        self.val_summary = TensorboardSummary(os.path.join(self.saver.experiment_dir, "validation"))
        self.val_writer = self.val_summary.create_summary()
        self.model = model
        self.dataset_size = {'train': len(train_set), 'val': len(val_set), 'test': len(test_set)}

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        if args.use_balanced_weights:
            weight = torch.from_numpy(class_weights.astype(np.float32))
        else:
            weight = None

        if args.optimizer == 'SGD':
            print('Using SGD')
            self.optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            print('Using Adam')
            self.optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

        self.lr_scheduler = None
        if args.use_lr_scheduler:
            if args.lr_scheduler == 'step':
                print('Using step lr scheduler')                
                self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[int(x) for x in args.step_size.split(",")], gamma=0.1)

        self.criterion = SegmentationLosses(weight=weight, ignore_index=255, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.evaluator = Evaluator(train_set.num_classes)
        self.best_pred = 0.0

    def training(self, epoch):

        train_loss = 0.0
        self.model.train()
        num_img_tr = len(self.train_dataloader)
        tbar = tqdm(self.train_dataloader, desc='\r')

        visualization_index = int(random.random() * len(self.train_dataloader))
        vis_img, vis_tgt, vis_out = None, None, None

        self.train_writer.add_scalar('learning_rate', get_learning_rate(self.optimizer), epoch)

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.train_writer.add_scalar('total_loss_iter', loss.item(), i + num_img_tr * epoch)

            if i == visualization_index:
                vis_img, vis_tgt, vis_out = image, target, output

        self.train_writer.add_scalar('total_loss_epoch', train_loss / self.dataset_size['train'], epoch)
        if constants.VISUALIZATION:
            self.train_summary.visualize_state(self.train_writer, self.args.dataset, vis_img, vis_tgt, vis_out, epoch)

        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)
        print('BestPred: %.3f' % self.best_pred)

    def validation(self, epoch, test=False):
        self.model.eval()
        self.evaluator.reset()
        
        ret_list = []
        if test:
            tbar = tqdm(self.test_dataloader, desc='\r')
        else:
            tbar = tqdm(self.val_dataloader, desc='\r')
        test_loss = 0.0

        visualization_index = int(random.random() * len(self.val_dataloader))
        vis_img, vis_tgt, vis_out = None, None, None

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)

            if i == visualization_index:
                vis_img, vis_tgt, vis_out = image, target, output

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = torch.argmax(output, dim=1).data.cpu().numpy()
            target = target.cpu().numpy()
            self.evaluator.add_batch(target, pred)
            
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU_20 = self.evaluator.Mean_Intersection_over_Union_20()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        if not test:
            self.val_writer.add_scalar('total_loss_epoch', test_loss / self.dataset_size['val'], epoch)
            self.val_writer.add_scalar('mIoU', mIoU, epoch)
            self.val_writer.add_scalar('mIoU_20', mIoU_20, epoch)
            self.val_writer.add_scalar('Acc', Acc, epoch)
            self.val_writer.add_scalar('Acc_class', Acc_class, epoch)
            self.val_writer.add_scalar('fwIoU', FWIoU, epoch)
            if constants.VISUALIZATION:
                self.val_summary.visualize_state(self.val_writer, self.args.dataset, vis_img, vis_tgt, vis_out, epoch)

        print("Test: " if test else "Validation:")
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, mIoU_20:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, mIoU_20, FWIoU))
        print('Loss: %.3f' % test_loss)

        if not test:
            new_pred = mIoU
            if new_pred > self.best_pred:
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                })

        return test_loss, mIoU, mIoU_20, Acc, Acc_class, FWIoU#, ret_list

    def load_best_checkpoint(self):
        checkpoint = self.saver.load_checkpoint()
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        print(f'=> loaded checkpoint - epoch {checkpoint["epoch"]})')
        return checkpoint["epoch"]
Example #5
0
        'lr': params.learning_rate
    }, {
        'params': model.get_10x_lr_params(),
        'lr': params.learning_rate * 10
    }]

    optimizer = optim.SGD(train_params,
                          momentum=params.momentum,
                          weight_decay=params.weight_decay)

    if params.cuda:
        model = nn.DataParallel(model, device_ids=[0])
        patch_replication_callback(model)
        model = model.cuda()

    scheduler = LR_Scheduler("poly", params.learning_rate, params.num_epochs,
                             len(train_dl))

    loss_fns = loss_fns

    # Define Tensorboard Summary
    summary = TensorboardSummary(args.model_dir)
    writer = summary.create_summary()

    evaluator = Evaluator(20 + 1)

    logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
    train_and_evaluate(model, train_dl, val_dl, optimizer, loss_fns, scheduler,
                       evaluator, writer, params, args.model_dir,
                       args.model_type, args.restore_file)
Example #6
0
def main():

    args = argument_parser.parse_args()
    print(args)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = True
    # hardcoding scannet

    # get handle to lmdb dataset
    lmdb_handle = dataset_base.LMDBHandle(os.path.join(constants.HDD_DATASET_ROOT, args.dataset, "dataset.lmdb"), args.memory_hog)
    
    # create train val and test sets
    train_set = get_active_dataset(args.active_selection_mode)(args.dataset, lmdb_handle, args.superpixel_dir, args.base_size, 'seedset_0')
    val_set = IndoorScenes(args.dataset, lmdb_handle, args.base_size, 'val')
    test_set = IndoorScenes(args.dataset, lmdb_handle, args.base_size, 'test')

    class_weights = None
    if args.use_balanced_weights:
        class_weights = calculate_weights_labels(get_active_dataset(args.active_selection_mode)(args.dataset, lmdb_handle, args.superpixel_dir, args.base_size, 'train'))

    saver = Saver(args)
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()

    # get active selection method
    active_selector = get_active_selector(args, lmdb_handle, train_set)


    # for each active selection iteration
    for selection_iter in range(args.max_iterations):

        fraction_of_data_labeled = int(round(train_set.get_fraction_of_labeled_data() * 100))
        
        if os.path.exists(os.path.join(constants.RUNS, args.dataset, args.checkname, f'runs_{fraction_of_data_labeled:03d}', "selections")):
            # resume: load selections if this is a rerun, and selections are available from a previous run
            train_set.load_selections(os.path.join(constants.RUNS, args.dataset, args.checkname, f'runs_{fraction_of_data_labeled:03d}', "selections"))
        elif os.path.exists(os.path.join(constants.RUNS, args.dataset, args.checkname, f'runs_{fraction_of_data_labeled:03d}', "selections.txt")):
            # resume: load selections if this is a rerun, and selections are available from a previous run
            train_set.load_selections(os.path.join(constants.RUNS, args.dataset, args.checkname, f'runs_{fraction_of_data_labeled:03d}', "selections.txt"))
        else:
            # active selection iteration

            train_set.make_dataset_multiple_of_batchsize(args.batch_size)
            # create model from scratch
            model = DeepLab(num_classes=train_set.num_classes, backbone=args.backbone, output_stride=args.out_stride, sync_bn=args.sync_bn,
                            mc_dropout=((args.active_selection_mode.startswith('viewmc')) or(args.active_selection_mode.startswith('vote')) or args.view_entropy_mode == 'mc_dropout'))
            model = model.cuda()
            # create trainer
            trainer = Trainer(args, model, train_set, val_set, test_set, class_weights, Saver(args, suffix=f'runs_{fraction_of_data_labeled:03d}'))
            
            # train for args.epochs epochs
            lr_scheduler = trainer.lr_scheduler
            for epoch in range(args.epochs):
                trainer.training(epoch)
                if epoch % args.eval_interval == (args.eval_interval - 1):
                    trainer.validation(epoch)
                if lr_scheduler:
                    lr_scheduler.step()

            train_set.reset_dataset()
            epoch = trainer.load_best_checkpoint()

            # get best val miou / metrics
            _, best_mIoU, best_mIoU_20, best_Acc, best_Acc_class, best_FWIoU = trainer.validation(epoch, test=True)

            trainer.evaluator.dump_matrix(os.path.join(trainer.saver.experiment_dir, "confusion_matrix.npy"))

            writer.add_scalar('active_loop/mIoU', best_mIoU, train_set.get_fraction_of_labeled_data() * 100)
            writer.add_scalar('active_loop/mIoU_20', best_mIoU_20, train_set.get_fraction_of_labeled_data() * 100)
            writer.add_scalar('active_loop/Acc', best_Acc, train_set.get_fraction_of_labeled_data() * 100)
            writer.add_scalar('active_loop/Acc_class', best_Acc_class, train_set.get_fraction_of_labeled_data() * 100)
            writer.add_scalar('active_loop/fwIoU', best_FWIoU, train_set.get_fraction_of_labeled_data() * 100)

            # make active selection
            active_selector.select_next_batch(model, train_set, args.active_selection_size)
            # save selections
            trainer.saver.save_active_selections(train_set.get_selections(), args.active_selection_mode.endswith("_region"))
            trainer.train_writer.close()
            trainer.val_writer.close()

        print(selection_iter, " / Train-set length: ", len(train_set))
        
    writer.close()
Example #7
0
from utils.utils import AverageMeter
from utils.utils import to_scalar
from utils.utils import may_set_mode
from utils.utils import may_mkdir
from utils.utils import set_seed
from sklearn.metrics import r2_score

# parsed arguments; see args.py
cfg = Config()
# log
if cfg.log_to_file:
    # redirect the standard outputs/error to local txt files.
    ReDirectSTD(cfg.stdout_file, 'stdout', False)
    ReDirectSTD(cfg.stderr_file, 'stderr', False)

summary = TensorboardSummary(cfg.exp_dir)
writer = summary.create_summary()

# dump the configuration to log.
import pprint
print(('-' * 60))
print('cfg.__dict__')
pprint.pprint(cfg.__dict__)
print(('-' * 60))

# set the random seed
if cfg.set_seed:
    set_seed(cfg.rand_seed)

# init devices
set_devices(cfg.sys_device_ids)  # gpu runnable?
Example #8
0
class Trainer(BaseContainer):
    def __init__(self):
        super().__init__()
        now_time = time.strftime('%Y-%m-%d-%H-%M',time.localtime(time.time()))
        logger_path = os.path.join(
            self.args.training.save_dir,
            self.args.dataset.dataset_train,
            self.args.models.model_warpper,
            self.args.training.experiment_id,
            '%s.log' % now_time
        )
        set_logger_path(logger_path)
        logger.info(self.args)

        # Define Saver
        self.saver = Saver(self.args)

        # Define Tensorboard Summary
        self.summary = TensorboardSummary()
        self.writer = self.summary.create_summary(self.saver.experiment_dir, self.args.models)


        self.init_training_container()
        self.batchsize = self.args.training.batchsize
        self.reset_batchsize()
        self.evaluator = Evaluator()
        self.best = 0.0

        # show parameters to be trained
        logger.debug('\nTraining params:')
        for p in self.model.named_parameters():
            if p[1].requires_grad:
                logger.debug(p[0])
        logger.debug('\n')

        # Clear start epoch if fine-tuning
        logger.info('Starting iteration: %d' % self.start_it)
        logger.info('Total iterationes: %d' % self.args.training.max_iter)

    # main function for training
    def training(self):
        self.model.train()

        num_img_tr = len(self.train_loader)
        logger.info('\nTraining')

        max_iter = self.args.training.max_iter
        it = self.start_it

        # support multiple optimizers, but only one 
        # optimizer is used here, i.e., names = ['match']
        names = self.args.training.optimizer.keys()

        while it < max_iter:
            for samples in self.train_loader:
                samples = to_cuda(samples)

                # validation
                val_iter = self.args.training.get('val_iter', -1)
                if val_iter > 0 and it % val_iter == 0 and it >= self.args.training.get('start_eval_it', 15000):
                    self.validation(it, 'val')
                    self.model.train()

                if it % 100 == 0:
                    logger.info('\n===> Iteration  %d/%d' % (it, max_iter))
    
                # update class weights
                if it >= 500 and self.args.training.get('weight_update_iter', -1) > 0 and it % self.args.training.get('weight_update_iter', -1) == 0:
                    self.model.update_hard()
                    logger.info('\nUpdate hard ID: %.3f'%self.model.center.ratio)
                    self.writer.add_scalar('train/data_ratio', self.model.center.ratio, it)

                for name in names:
                    losses = dict()

                    self.optimizer[name].zero_grad()
                    outputs = self.model(samples, type=name)
                    losses = self.criterion(outputs, name)
                    loss = losses['loss']
                    loss.backward()
                    self.optimizer[name].step()

                    losses.update(losses)

                    # log training loss
                    if it % 100 == 0:
                        loss_log_str = '=>%s   loss: %.4f'%(name, loss.item())
                        for loss_name in losses.keys():
                            if loss_name != 'loss':
                                loss_log_str += '    %s: %.4f'%(loss_name, losses[loss_name])
                                self.writer.add_scalar('train/%s_iter'%loss_name, losses[loss_name], it)
                        logger.info(loss_log_str)
                        self.writer.add_scalar('train/total_loss_iter_%s'%name, loss.item(), it)

                    # adjust learning rate
                    lr_decay_iter = self.args.training.optimizer[name].get('lr_decay_iter', None)
                    if lr_decay_iter is not None:
                        for i in range(len(lr_decay_iter)):
                            if it == lr_decay_iter[i]:
                                lr = self.args.training.optimizer[name].lr * (self.args.training.optimizer[name].lr_decay ** (i+1))
                                logger.info('\nReduce lr to %.6f\n'%(lr))
                                for param_group in self.optimizer[name].param_groups:
                                    param_group["lr"] = lr 
                                break

                it += 1

                # save model and optimizer
                if it % self.args.training.save_iter == 0 or it == max_iter or it == 1:
                    logger.info('\nSaving checkpoint ......')
                    optimizer_to_save = dict()
                    for i in self.optimizer.keys():
                        optimizer_to_save[i] = self.optimizer[i].state_dict()
                    self.saver.save_checkpoint({
                        'start_it': it,
                        'stage': self.stage,
                        'state_dict': self.model.state_dict(),
                        'optimizer': optimizer_to_save,
                    }, filename='ckp_%06d.pth.tar'%it)
                    logger.info('Done.')

    # main function for validation
    def validation(self, it, split):
        logger.info('\nEvaluating %s...'%split)
        self.evaluator.reset()
        self.model.eval()

        data_loader = self.val_loader if split == 'val' else self.test_loader
        num_img_tr = len(data_loader)
        dist_pos = []
        dist_neg = []
        total_loss = []
        name = list(self.args.training.optimizer.keys())[0]
        for i, samples in enumerate(data_loader):
            samples = to_cuda(samples)

            with torch.no_grad():
                outputs = self.model(samples, type=name, is_triple=True)
                dist_pos.append(outputs[-1]['dist_pos'].mean().item())
                dist_neg.append(outputs[-1]['dist_neg'].mean().item())

            self.evaluator.add_batch(outputs[-1]['pred'], outputs[0]['target'])

        self.writer.add_scalar('%s/dist_pos'%split, np.array(dist_pos).mean(), it)
        self.writer.add_scalar('%s/dist_neg'%split, np.array(dist_neg).mean(), it)

        acc = self.evaluator.Accuracy()
        self.writer.add_scalar('%s/acc'%split, acc, it)
        if split == 'val':
            logger.info('=====>[Iteration: %d    %s/acc=%.4f    previous best=%.4f'%(it, split, acc, self.best))
        else:
            logger.info('=====>[Iteration: %d    %s/acc=%.4f'%(it, split, acc))

        # if split == 'val':
        #     self.validation(it, 'test')

        if split == 'val' and acc > self.best:
            self.best = acc
            logger.info('\nSaving checkpoint ......')
            optimizer_to_save = dict()
            for i in self.optimizer.keys():
                optimizer_to_save[i] = self.optimizer[i].state_dict()
            self.saver.save_checkpoint({
                'start_it': it,
                'stage': self.stage,
                'state_dict': self.model.state_dict(),
                'optimizer': optimizer_to_save,
            }, filename='best.pth.tar')
Example #9
0
def main():

    # script for training a model using 100% train set

    args = argument_parser.parse_args()
    print(args)
    torch.manual_seed(args.seed)

    lmdb_handle = dataset_base.LMDBHandle(
        os.path.join(constants.HDD_DATASET_ROOT, args.dataset, "dataset.lmdb"),
        args.memory_hog)
    train_set = IndoorScenes(args.dataset, lmdb_handle, args.base_size,
                             'train')
    val_set = IndoorScenes(args.dataset, lmdb_handle, args.base_size, 'val')
    test_set = IndoorScenes(args.dataset, lmdb_handle, args.base_size, 'test')
    train_set.make_dataset_multiple_of_batchsize(args.batch_size)

    model = DeepLab(num_classes=train_set.num_classes,
                    backbone=args.backbone,
                    output_stride=args.out_stride,
                    sync_bn=args.sync_bn)
    model = model.cuda()

    class_weights = None
    if args.use_balanced_weights:
        class_weights = calculate_weights_labels(train_set)

    saver = Saver(args)
    trainer = Trainer(args, model, train_set, val_set, test_set, class_weights,
                      Saver(args))
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()

    start_epoch = 0
    if args.resume:
        args.resume = os.path.join(constants.RUNS, args.dataset, args.resume,
                                   'checkpoint.pth.tar')
        if not os.path.isfile(args.resume):
            raise RuntimeError(f"=> no checkpoint found at {args.resume}")
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        trainer.model.load_state_dict(checkpoint['state_dict'])
        trainer.optimizer.load_state_dict(checkpoint['optimizer'])
        trainer.best_pred = checkpoint['best_pred']
        print(
            f'=> loaded checkpoint {args.resume} (epoch {checkpoint["epoch"]})'
        )

    lr_scheduler = trainer.lr_scheduler

    for epoch in range(start_epoch, args.epochs):
        trainer.training(epoch)
        if epoch % args.eval_interval == (args.eval_interval - 1):
            trainer.validation(epoch)
        if lr_scheduler:
            lr_scheduler.step()

    epoch = trainer.load_best_checkpoint()
    _, best_mIoU, best_mIoU_20, best_Acc, best_Acc_class, best_FWIoU = trainer.validation(
        epoch, test=True)

    writer.add_scalar('test/mIoU', best_mIoU, epoch)
    writer.add_scalar('test/mIoU_20', best_mIoU_20, epoch)
    writer.add_scalar('test/Acc', best_Acc, epoch)
    writer.add_scalar('test/Acc_class', best_Acc_class, epoch)
    writer.add_scalar('test/fwIoU', best_FWIoU, epoch)

    trainer.train_writer.close()
    trainer.val_writer.close()
Example #10
0
def main():
    global args, best_acc
    args = get_args()

    #cnn
    with procedure('init model'):
        model = models.resnet34(True)
        model.fc = nn.Linear(model.fc.in_features, 2)
        model = torch.nn.parallel.DataParallel(model.cuda())

    with procedure('loss and optimizer'):
        if cfg.weights == 0.5:
            criterion = nn.CrossEntropyLoss().cuda()
        else:
            w = torch.Tensor([1 - cfg.weights, cfg.weights])
            criterion = nn.CrossEntropyLoss(w).cuda()
        optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

    cudnn.benchmark = True

    #normalization
    normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    with procedure('prepare dataset'):
        #load data
        with open(cfg.data_split) as f:  #
            data = json.load(f)
        train_dset = MILdataset(data['train_neg'][:14] + data['train_pos'],
                                args.patch_size, trans)
        train_loader = torch.utils.data.DataLoader(train_dset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        if args.val:
            val_dset = MILdataset(data['val_pos'] + data['val_neg'],
                                  args.patch_size, trans)
            val_loader = torch.utils.data.DataLoader(
                val_dset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
    with procedure('init tensorboardX'):
        tensorboard_path = os.path.join(args.output, 'tensorboard')
        if not os.path.isdir(tensorboard_path):
            os.makedirs(tensorboard_path)
        summary = TensorboardSummary(tensorboard_path, args.dis_slide)
        writer = summary.create_writer()

    #open output file
    fconv = open(os.path.join(args.output, 'convergence.csv'), 'w')
    fconv.write('epoch,metric,value\n')
    fconv.close()

    #loop throuh epochs
    for epoch in range(args.nepochs):
        train_dset.setmode(1)
        probs = inference(epoch, train_loader, model)
        topk = group_argtopk(np.array(train_dset.slideIDX), probs, args.k)
        images, names, labels = train_dset.getpatchinfo(topk)
        summary.plot_calsses_pred(writer, images, names, labels,
                                  np.array([probs[k] for k in topk]), args.k,
                                  epoch)
        slidenames, length = train_dset.getslideinfo()
        summary.plot_histogram(writer, slidenames, probs, length, epoch)
        #print([probs[k] for k in topk ])
        train_dset.maketraindata(topk)
        train_dset.shuffletraindata()
        train_dset.setmode(2)
        loss = train(epoch, train_loader, model, criterion, optimizer, writer)
        cp('(#r)Training(#)\t(#b)Epoch: [{}/{}](#)\t(#g)Loss:{}(#)'.format(
            epoch + 1, args.nepochs, loss))
        fconv = open(os.path.join(args.output, 'convergence.csv'), 'a')
        fconv.write('{},loss,{}\n'.format(epoch + 1, loss))
        fconv.close()

        #Validation
        if args.val and (epoch + 1) % args.test_every == 0:
            val_dset.setmode(1)
            probs = inference(epoch, val_loader, model)
            maxs = group_max(np.array(val_dset.slideIDX), probs,
                             len(val_dset.targets))
            pred = [1 if x >= 0.5 else 0 for x in maxs]
            err, fpr, fnr = calc_err(pred, val_dset.targets)
            #print('Validation\tEpoch: [{}/{}]\tError: {}\tFPR: {}\tFNR: {}'.format(epoch+1, args.nepochs, err, fpr, fnr))
            cp('(#y)Vaildation\t(#)(#b)Epoch: [{}/{}]\t(#)(#g)Error: {}\tFPR: {}\tFNR: {}(#)'
               .format(epoch + 1, args.nepochs, err, fpr, fnr))
            fconv = open(os.path.join(args.output, 'convergence.csv'), 'a')
            fconv.write('{},error,{}\n'.format(epoch + 1, err))
            fconv.write('{},fpr,{}\n'.format(epoch + 1, fpr))
            fconv.write('{},fnr,{}\n'.format(epoch + 1, fnr))
            fconv.close()
            #Save best model
            err = (fpr + fnr) / 2.
            if 1 - err >= best_acc:
                best_acc = 1 - err
                obj = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict()
                }
                torch.save(obj, os.path.join(args.output,
                                             'checkpoint_best.pth'))