Exemplo n.º 1
0
def train_fold(fold, args):
    # loggers
    logging_logger = args.logging_logger
    if args.tb_log:
        tb_logger = args.tb_logger

    num_classes = utils.problem_class[args.problem_type]

    # init model
    model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False)
    model = nn.DataParallel(model, device_ids=args.device_ids).cuda()

    # transform for train/valid data
    train_transform, valid_transform = get_transform(args.model)

    # loss function
    loss_func = LossMulti(num_classes, args.jaccard_weight)
    if args.semi:
        loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight,
                                       args.semi_loss_alpha, args.semi_method)

    # train/valid filenames
    train_filenames, valid_filenames = utils.trainval_split(
        args.train_dir, fold)

    ckpt_dir = Path(args.ckpt_dir)
    ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0]
    res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename)
    # restore epoch
    engine.state.epoch = int(res.groups()[0])
    # load model state dict
    model.load_state_dict(torch.load(str(ckpt_filename)))
    logging_logger.info('restore model [{}] from epoch {}.'.format(
        args.model, engine.state.epoch))

    # DataLoader and Dataset args
    # train_shuffle = True
    # train_ds_kwargs = {
    #     'filenames': train_filenames,
    #     'problem_type': args.problem_type,
    #     'transform': train_transform,
    #     'model': args.model,
    #     'mode': 'train',
    #     'semi': args.semi,
    # }

    valid_num_workers = args.num_workers
    valid_batch_size = args.batch_size
    # if 'TAPNet' in args.model:
    #     # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead
    #     train_shuffle = False
    #     train_ds_kwargs['batch_size'] = args.batch_size
    #     train_ds_kwargs['mf'] = args.mf
    # if args.semi == True:
    #     train_ds_kwargs['semi_method'] = args.semi_method
    #     train_ds_kwargs['semi_percentage'] = args.semi_percentage

    # additional valid dataset kws
    valid_ds_kwargs = {
        'filenames': valid_filenames,
        'problem_type': args.problem_type,
        'transform': valid_transform,
        'model': args.model,
        'mode': 'valid',
    }

    if 'TAPNet' in args.model:
        # in validation, num_workers should be set to 0 for sequences
        valid_num_workers = 0
        # in validation, batch_size should be set to 1 for sequences
        valid_batch_size = 1
        valid_ds_kwargs['mf'] = args.mf

    # # train dataloader
    # train_loader = DataLoader(
    #     dataset=RobotSegDataset(**train_ds_kwargs),
    #     shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle
    #     num_workers=args.num_workers,
    #     batch_size=args.batch_size,
    #     pin_memory=True
    # )
    # valid dataloader
    valid_loader = DataLoader(
        dataset=RobotSegDataset(**valid_ds_kwargs),
        shuffle=False,  # in validation, no need to shuffle
        num_workers=valid_num_workers,
        batch_size=
        valid_batch_size,  # in valid time. have to use one image by one
        pin_memory=True)

    # optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9,
    #     weight_decay=args.weight_decay, nesterov=True)

    # # ignite trainer process function
    # def train_step(engine, batch):
    #     # set model to train
    #     model.train()
    #     # clear gradients
    #     optimizer.zero_grad()

    #     # additional params to feed into model
    #     add_params = {}
    #     inputs = batch['input'].cuda(non_blocking=True)
    #     with torch.no_grad():
    #         targets = batch['target'].cuda(non_blocking=True)
    #         # for TAPNet, add attention maps
    #         if 'TAPNet' in args.model:
    #             add_params['attmap'] = batch['attmap'].cuda(non_blocking=True)

    #     outputs = model(inputs, **add_params)

    #     loss_kwargs = {}

    #     if args.semi:
    #         loss_kwargs['labeled'] = batch['labeled']
    #         if args.semi_method == 'rev_flow':
    #             loss_kwargs['optflow'] = batch['optflow']
    #         loss = loss_func_semi(outputs, targets, **loss_kwargs)
    #     else:
    #         loss = loss_func(outputs, targets, **loss_kwargs)
    #     loss.backward()
    #     optimizer.step()

    #     return_dict = {
    #         'output': outputs,
    #         'target': targets,
    #         'loss_kwargs': loss_kwargs,
    #         'loss': loss.item(),
    #     }

    #     # for TAPNet, update attention maps after each iteration
    #     if 'TAPNet' in args.model:
    #         # output_classes and target_classes: <b, h, w>
    #         output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy()
    #         # update attention maps
    #         train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy())
    #         return_dict['attmap'] = add_params['attmap']

    #     return return_dict

    # # init trainer
    # trainer = engine.Engine(train_step)

    # # lr scheduler and handler
    # # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr)
    # # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler)
    # # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler)

    # step_scheduler = optim.lr_scheduler.StepLR(optimizer,
    #     step_size=args.lr_decay_epochs, gamma=args.lr_decay)
    # lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler)
    # trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler)

    # @trainer.on(engine.Events.STARTED)
    # def trainer_start_callback(engine):
    #     logging_logger.info('training fold {}, {} train / {} valid files'. \
    #         format(fold, len(train_filenames), len(valid_filenames)))

    #     # resume training
    #     if args.resume:
    #         # ckpt for current fold fold_<fold>_model_<epoch>.pth
    #         ckpt_dir = Path(args.ckpt_dir)
    #         ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0]
    #         res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename)
    #         # restore epoch
    #         engine.state.epoch = int(res.groups()[0])
    #         # load model state dict
    #         model.load_state_dict(torch.load(str(ckpt_filename)))
    #         logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch))
    #     else:
    #         logging_logger.info('train model [{}] from scratch'.format(args.model))

    #     # record metrics history every epoch
    #     engine.state.metrics_records = {}

    # @trainer.on(engine.Events.EPOCH_STARTED)
    # def trainer_epoch_start_callback(engine):
    #     # log learning rate on pbar
    #     train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \
    #         (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size))

    #     # for TAPNet, change dataset schedule to random after the first epoch
    #     if 'TAPNet' in args.model and engine.state.epoch > 1:
    #         train_loader.dataset.set_dataset_schedule("shuffle")

    # @trainer.on(engine.Events.ITERATION_COMPLETED)
    # def trainer_iter_comp_callback(engine):
    #     # logging_logger.info(engine.state.metrics)
    #     pass

    # # monitor loss
    # # running average loss
    # train_ra_loss = imetrics.RunningAverage(output_transform=
    #     lambda x: x['loss'], alpha=0.98)
    # train_ra_loss.attach(trainer, 'train_ra_loss')

    # # monitor train loss over epoch
    # if args.semi:
    #     train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs']))
    # else:
    #     train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target']))
    # train_loss.attach(trainer, 'train_loss')

    # # progress bar
    # train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    # train_metric_names = ['train_ra_loss']
    # train_pbar.attach(trainer, metric_names=train_metric_names)

    # # tensorboardX: log train info
    # if args.tb_log:
    #     tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'),
    #         event_name=engine.Events.EPOCH_STARTED)

    #     tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names),
    #         event_name=engine.Events.ITERATION_COMPLETED)

    #     tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']),
    #         event_name=engine.Events.EPOCH_COMPLETED)

    #     tb_logger.attach(trainer,
    #          log_handler=WeightsScalarHandler(model, reduction=torch.norm),
    #          event_name=engine.Events.ITERATION_COMPLETED)

    # tb_logger.attach(trainer, log_handler=tb_log_train_vars,
    #     event_name=engine.Events.ITERATION_COMPLETED)

    # ignite validator process function
    def valid_step(engine, batch):
        with torch.no_grad():
            model.eval()
            inputs = batch['input'].cuda(non_blocking=True)
            targets = batch['target'].cuda(non_blocking=True)

            # additional arguments
            add_params = {}
            # for TAPNet, add attention maps
            if 'TAPNet' in args.model:
                add_params['attmap'] = batch['attmap'].cuda(non_blocking=True)

            # output logits
            outputs = model(inputs, **add_params)
            # loss
            loss = loss_func(outputs, targets)

            output_softmaxs = torch.softmax(outputs, dim=1)
            output_argmaxs = output_softmaxs.argmax(dim=1)
            # output_classes and target_classes: <b, h, w>
            output_classes = output_argmaxs.cpu().numpy()
            target_classes = targets.cpu().numpy()

            # record current batch metrics
            iou_mRecords = MetricRecord()
            dice_mRecords = MetricRecord()

            cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32)

            for output_class, target_class in zip(output_classes,
                                                  target_classes):
                # calculate metrics for each frame
                # calculate using confusion matrix or dirctly using definition
                cm = calculate_confusion_matrix_from_arrays(
                    output_class, target_class, num_classes)
                iou_mRecords.update_record(calculate_iou(cm))
                dice_mRecords.update_record(calculate_dice(cm))
                cm_b += cm

                ######## calculate directly using definition ##########
                # iou_mRecords.update_record(iou_multi_np(target_class, output_class))
                # dice_mRecords.update_record(dice_multi_np(target_class, output_class))

            # accumulate batch metrics to engine state
            engine.state.epoch_metrics['confusion_matrix'] += cm_b
            engine.state.epoch_metrics['iou'].merge(iou_mRecords)
            engine.state.epoch_metrics['dice'].merge(dice_mRecords)

            return_dict = {
                'loss': loss.item(),
                'output': outputs,
                'output_argmax': output_argmaxs,
                'target': targets,
                # for monitoring
                'iou': iou_mRecords,
                'dice': dice_mRecords,
            }

            if 'TAPNet' in args.model:
                # for TAPNet, update attention maps after each iteration
                valid_loader.dataset.update_attmaps(
                    output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy())
                # for TAPNet, return extra internal values
                return_dict['attmap'] = add_params['attmap']
                # TODO: for TAPNet, return internal self-learned attention maps

            return return_dict
Exemplo n.º 2
0
def train_fold(fold, args):
    # loggers
    logging_logger = args.logging_logger
    if args.tb_log:
        tb_logger = args.tb_logger

    num_classes = utils.problem_class[args.problem_type]

    # init model
    model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False)
    model = nn.DataParallel(model, device_ids=args.device_ids).cuda()

    # transform for train/valid data
    train_transform, valid_transform = get_transform(args.model)

    # loss function
    loss_func = LossMulti(num_classes, args.jaccard_weight)
    if args.semi:
        loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight, args.semi_loss_alpha, args.semi_method)

    # train/valid filenames
    train_filenames, valid_filenames = utils.trainval_split(args.train_dir, fold)

    # DataLoader and Dataset args
    train_shuffle = True
    train_ds_kwargs = {
        'filenames': train_filenames,
        'problem_type': args.problem_type,
        'transform': train_transform,
        'model': args.model,
        'mode': 'train',
        'semi': args.semi,
    }

    valid_num_workers = args.num_workers
    valid_batch_size = args.batch_size
    if 'TAPNet' in args.model:
        # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead
        train_shuffle = False
        train_ds_kwargs['batch_size'] = args.batch_size
        train_ds_kwargs['mf'] = args.mf
    if args.semi == True:
        train_ds_kwargs['semi_method'] = args.semi_method
        train_ds_kwargs['semi_percentage'] = args.semi_percentage

    # additional valid dataset kws
    valid_ds_kwargs = {
        'filenames': valid_filenames,
        'problem_type': args.problem_type,
        'transform': valid_transform,
        'model': args.model,
        'mode': 'valid',
    }

    if 'TAPNet' in args.model:
        # in validation, num_workers should be set to 0 for sequences
        valid_num_workers = 0
        # in validation, batch_size should be set to 1 for sequences
        valid_batch_size = 1
        valid_ds_kwargs['mf'] = args.mf

    # train dataloader
    train_loader = DataLoader(
        dataset=RobotSegDataset(**train_ds_kwargs),
        shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle
        num_workers=args.num_workers,
        batch_size=args.batch_size,
        pin_memory=True
    )
    # valid dataloader
    valid_loader = DataLoader(
        dataset=RobotSegDataset(**valid_ds_kwargs),
        shuffle=False, # in validation, no need to shuffle
        num_workers=valid_num_workers,
        batch_size=valid_batch_size, # in valid time. have to use one image by one
        pin_memory=True
    )

    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, 
    #     weight_decay=args.weight_decay, nesterov=True)    

    # ignite trainer process function
    def train_step(engine, batch):
        # set model to train
        model.train()
        # clear gradients
        optimizer.zero_grad()
        
        # additional params to feed into model
        add_params = {}
        inputs = batch['input'].cuda(non_blocking=True)
        with torch.no_grad():
            targets = batch['target'].cuda(non_blocking=True)
            # for TAPNet, add attention maps
            if 'TAPNet' in args.model:
                add_params['attmap'] = batch['attmap'].cuda(non_blocking=True)

        outputs = model(inputs, **add_params)

        loss_kwargs = {}

        if args.semi:
            loss_kwargs['labeled'] = batch['labeled']
            if args.semi_method == 'rev_flow':
                loss_kwargs['optflow'] = batch['optflow']
            loss = loss_func_semi(outputs, targets, **loss_kwargs)
        else:
            loss = loss_func(outputs, targets, **loss_kwargs)
        loss.backward()
        optimizer.step()

        return_dict = {
            'output': outputs,
            'target': targets,
            'loss_kwargs': loss_kwargs,
            'loss': loss.item(),
        }

        # for TAPNet, update attention maps after each iteration
        if 'TAPNet' in args.model:
            # output_classes and target_classes: <b, h, w>
            output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy()
            # update attention maps
            train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy())
            return_dict['attmap'] = add_params['attmap']

        return return_dict
    
    # init trainer
    trainer = engine.Engine(train_step)

    # lr scheduler and handler
    # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr)
    # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler)
    # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler)

    step_scheduler = optim.lr_scheduler.StepLR(optimizer,
        step_size=args.lr_decay_epochs, gamma=args.lr_decay)
    lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler)
    trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler)


    @trainer.on(engine.Events.STARTED)
    def trainer_start_callback(engine):
        logging_logger.info('training fold {}, {} train / {} valid files'. \
            format(fold, len(train_filenames), len(valid_filenames)))

        # resume training
        if args.resume:
            # ckpt for current fold fold_<fold>_model_<epoch>.pth
            ckpt_dir = Path(args.ckpt_dir)
            ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0]
            res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename)
            # restore epoch
            engine.state.epoch = int(res.groups()[0])
            # load model state dict
            model.load_state_dict(torch.load(str(ckpt_filename)))
            logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch))
        else:
            logging_logger.info('train model [{}] from scratch'.format(args.model))

        # record metrics history every epoch
        engine.state.metrics_records = {}


    @trainer.on(engine.Events.EPOCH_STARTED)
    def trainer_epoch_start_callback(engine):
        # log learning rate on pbar
        train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \
            (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size))
        
        # for TAPNet, change dataset schedule to random after the first epoch
        if 'TAPNet' in args.model and engine.state.epoch > 1:
            train_loader.dataset.set_dataset_schedule("shuffle")


    @trainer.on(engine.Events.ITERATION_COMPLETED)
    def trainer_iter_comp_callback(engine):
        # logging_logger.info(engine.state.metrics)
        pass

    # monitor loss
    # running average loss
    train_ra_loss = imetrics.RunningAverage(output_transform=
        lambda x: x['loss'], alpha=0.98)
    train_ra_loss.attach(trainer, 'train_ra_loss')

    # monitor train loss over epoch
    if args.semi:
        train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs']))
    else:
        train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target']))
    train_loss.attach(trainer, 'train_loss')

    # progress bar
    train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    train_metric_names = ['train_ra_loss']
    train_pbar.attach(trainer, metric_names=train_metric_names)

    # tensorboardX: log train info
    if args.tb_log:
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'), 
            event_name=engine.Events.EPOCH_STARTED)

        tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names),
            event_name=engine.Events.ITERATION_COMPLETED)

        tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']),
            event_name=engine.Events.EPOCH_COMPLETED)

        tb_logger.attach(trainer,
             log_handler=WeightsScalarHandler(model, reduction=torch.norm),
             event_name=engine.Events.ITERATION_COMPLETED)

        # tb_logger.attach(trainer, log_handler=tb_log_train_vars, 
        #     event_name=engine.Events.ITERATION_COMPLETED)


    # ignite validator process function
    def valid_step(engine, batch):
        with torch.no_grad():
            model.eval()
            inputs = batch['input'].cuda(non_blocking=True)
            targets = batch['target'].cuda(non_blocking=True)

            # additional arguments
            add_params = {}
            # for TAPNet, add attention maps
            if 'TAPNet' in args.model:
                add_params['attmap'] = batch['attmap'].cuda(non_blocking=True)

            # output logits
            outputs = model(inputs, **add_params)
            # loss
            loss = loss_func(outputs, targets)

            output_softmaxs = torch.softmax(outputs, dim=1)
            output_argmaxs = output_softmaxs.argmax(dim=1)
            # output_classes and target_classes: <b, h, w>
            output_classes = output_argmaxs.cpu().numpy()
            target_classes = targets.cpu().numpy()

            # record current batch metrics
            iou_mRecords = MetricRecord()
            dice_mRecords = MetricRecord()

            cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32)

            for output_class, target_class in zip(output_classes, target_classes):
                # calculate metrics for each frame
                # calculate using confusion matrix or dirctly using definition
                cm = calculate_confusion_matrix_from_arrays(output_class, target_class, num_classes)
                iou_mRecords.update_record(calculate_iou(cm))
                dice_mRecords.update_record(calculate_dice(cm))
                cm_b += cm

                ######## calculate directly using definition ##########
                # iou_mRecords.update_record(iou_multi_np(target_class, output_class))
                # dice_mRecords.update_record(dice_multi_np(target_class, output_class))

            # accumulate batch metrics to engine state
            engine.state.epoch_metrics['confusion_matrix'] += cm_b
            engine.state.epoch_metrics['iou'].merge(iou_mRecords)
            engine.state.epoch_metrics['dice'].merge(dice_mRecords)


            return_dict = {
                'loss': loss.item(),
                'output': outputs,
                'output_argmax': output_argmaxs,
                'target': targets,
                # for monitoring
                'iou': iou_mRecords,
                'dice': dice_mRecords,
            }

            if 'TAPNet' in args.model:
                # for TAPNet, update attention maps after each iteration
                valid_loader.dataset.update_attmaps(output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy())
                # for TAPNet, return extra internal values
                return_dict['attmap'] = add_params['attmap']
                # TODO: for TAPNet, return internal self-learned attention maps

            return return_dict


    # validator engine
    validator = engine.Engine(valid_step)

    # monitor loss
    valid_ra_loss = imetrics.RunningAverage(output_transform=
        lambda x: x['loss'], alpha=0.98)
    valid_ra_loss.attach(validator, 'valid_ra_loss')

    # monitor validation loss over epoch
    valid_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target']))
    valid_loss.attach(validator, 'valid_loss')
    
    # monitor <data> mean metrics
    valid_data_miou = imetrics.RunningAverage(output_transform=
        lambda x: x['iou'].data_mean()['mean'], alpha=0.98)
    valid_data_miou.attach(validator, 'mIoU')
    valid_data_mdice = imetrics.RunningAverage(output_transform=
        lambda x: x['dice'].data_mean()['mean'], alpha=0.98)
    valid_data_mdice.attach(validator, 'mDice')

    # show metrics on progress bar (after every iteration)
    valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice']
    valid_pbar.attach(validator, metric_names=valid_metric_names)


    # ## monitor ignite IoU (the same as iou we are using) ###
    # cm = imetrics.ConfusionMatrix(num_classes, 
    #     output_transform=lambda x: (x['output'], x['target']))
    # imetrics.IoU(cm, 
    #     ignore_index=0
    #     ).attach(validator, 'iou')

    # # monitor ignite mean iou (over all classes even not exist in gt)
    # mean_iou = imetrics.mIoU(cm, 
    #     ignore_index=0
    #     ).attach(validator, 'mean_iou')


    @validator.on(engine.Events.STARTED)
    def validator_start_callback(engine):
        pass

    @validator.on(engine.Events.EPOCH_STARTED)
    def validator_epoch_start_callback(engine):
        engine.state.epoch_metrics = {
            # directly use definition to calculate
            'iou': MetricRecord(),
            'dice': MetricRecord(),
            'confusion_matrix': np.zeros((num_classes, num_classes), dtype=np.uint32),
        }


    # evaluate after iter finish
    @validator.on(engine.Events.ITERATION_COMPLETED)
    def validator_iter_comp_callback(engine):
        pass

    # evaluate after epoch finish
    @validator.on(engine.Events.EPOCH_COMPLETED)
    def validator_epoch_comp_callback(engine):

        # log ignite metrics
        # logging_logger.info(engine.state.metrics)
        # ious = engine.state.metrics['iou']
        # msg = 'IoU: '
        # for ins_id, iou in enumerate(ious):
        #     msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou)
        # logging_logger.info(msg)
        # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean()))

        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics

        ######### NOTICE: Two metrics are available but different ##########
        ### 1. mean metrics for all data calculated by confusion matrix ####

        '''
        compared with using confusion_matrix[1:, 1:] in original code,
        we use the full confusion matrix and only present non-background result
        '''
        confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        std_ious = np.std(list(ious.values()))
        std_dices = np.std(list(dices.values()))

        logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' % 
            (mean_ious, std_ious, ious))
        logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' % 
            (mean_dices, std_dices, dices))


        ### 2. mean metrics for all data calculated by definition ###
        iou_data_mean = epoch_metrics['iou'].data_mean()
        dice_data_mean = epoch_metrics['dice'].data_mean()

        logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' %
            (len(iou_data_mean['items']), iou_data_mean['mean'], iou_data_mean['std']))
        logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' %
            (len(dice_data_mean['items']), dice_data_mean['mean'], dice_data_mean['std']))

        # record metrics in trainer every epoch
        # trainer.state.metrics_records[trainer.state.epoch] = \
        #     {'miou': mean_ious, 'std_miou': std_ious,
        #     'mdice': mean_dices, 'std_mdice': std_dices}
        
        trainer.state.metrics_records[trainer.state.epoch] = \
            {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'],
            'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']}


    # log interal variables(attention maps, outputs, etc.) on validation
    def tb_log_valid_iter_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        output = engine.state.output
        batch_size = output['output'].shape[0]
        res_grid = tvutils.make_grid(torch.cat([
            output['output_argmax'].unsqueeze(1),
            output['target'].unsqueeze(1),
        ]), padding=2, 
        normalize=False, # show origin image
        nrow=batch_size).cpu()

        logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag), img_tensor=res_grid)

        if 'TAPNet' in args.model:
            # log attention maps and other internal values
            inter_vals_grid = tvutils.make_grid(torch.cat([
                output['attmap'],
            ]), padding=2, normalize=True, nrow=batch_size).cpu()
            logger.writer.add_image(tag='%s internal vals' % (log_tag), img_tensor=inter_vals_grid)

    def tb_log_valid_epoch_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics
        confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch)
        logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch)



    if args.tb_log:
        # log internal values
        tb_logger.attach(validator, log_handler=tb_log_valid_iter_vars, 
            event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator, log_handler=tb_log_valid_epoch_vars,
            event_name=engine.Events.EPOCH_COMPLETED)
        # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names),
        #     event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator, log_handler=OutputHandler('valid_epoch', ['valid_loss']),
            event_name=engine.Events.EPOCH_COMPLETED)


    # score function for model saving
    ckpt_score_function = lambda engine: \
        np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values()))
    # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean']
    
    ckpt_filename_prefix = 'fold_%d' % fold

    # model saving handler
    model_ckpt_handler = handlers.ModelCheckpoint(
        dirname=args.model_save_dir,
        filename_prefix=ckpt_filename_prefix, 
        score_function=ckpt_score_function,
        create_dir=True,
        require_empty=False,
        save_as_state_dict=True,
        atomic=True)


    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, 
        handler=model_ckpt_handler,
        to_save={
            'model': model,
        })

    # early stop
    # trainer=trainer, but should be handled by validator
    early_stopping = handlers.EarlyStopping(patience=args.es_patience, 
        score_function=ckpt_score_function,
        trainer=trainer
        )

    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED,
        handler=early_stopping)


    # evaluate after epoch finish
    @trainer.on(engine.Events.EPOCH_COMPLETED)
    def trainer_epoch_comp_callback(engine):
        validator.run(valid_loader)

    trainer.run(train_loader, max_epochs=args.max_epochs)

    if args.tb_log:
        # close tb_logger
        tb_logger.close()

    return trainer.state.metrics_records
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=1, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=1)
    arg('--n-epochs', type=int, default=10)
    arg('--lr', type=float, default=0.0002)
    arg('--workers', type=int, default=10)
    arg('--type',
        type=str,
        default='binary',
        choices=['binary', 'parts', 'instruments'])
    arg('--model',
        type=str,
        default='DLinkNet',
        choices=['UNet', 'UNet11', 'LinkNet34', 'DLinkNet'])

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    if args.type == 'parts':
        num_classes = 4
    elif args.type == 'instruments':
        num_classes = 8
    else:
        num_classes = 1

    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'DLinkNet':
        model = D_LinkNet34(num_classes=num_classes, pretrained=True)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()

    if args.type == 'binary':
        # loss = LossBinary(jaccard_weight=args.jaccard_weight)
        loss = LossBCE_DICE()
    else:
        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names,
                    shuffle=False,
                    transform=None,
                    problem_type='binary'):
        return DataLoader(dataset=RoboticsDataset(file_names,
                                                  transform=transform,
                                                  problem_type=problem_type),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=args.batch_size,
                          pin_memory=torch.cuda.is_available())

    # train_file_names, val_file_names = get_split(args.fold)
    train_file_names, val_file_names = get_train_val_files()

    print('num train = {}, num_val = {}'.format(len(train_file_names),
                                                len(val_file_names)))

    train_transform = DualCompose(
        [HorizontalFlip(),
         VerticalFlip(),
         ImageOnly(Normalize())])

    val_transform = DualCompose([ImageOnly(Normalize())])

    train_loader = make_loader(train_file_names,
                               shuffle=True,
                               transform=train_transform,
                               problem_type=args.type)
    valid_loader = make_loader(val_file_names,
                               transform=val_transform,
                               problem_type=args.type)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=valid,
                fold=args.fold,
                num_classes=num_classes)
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()

    arg = parser.add_argument
    arg('--jaccard-weight', default=1, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=8)
    arg('--n-epochs', type=int, default=14)
    arg('--lr', type=float, default=0.000001)
    arg('--workers', type=int, default=8)
    arg('--type',
        type=str,
        default='binary',
        choices=['binary', 'parts', 'instruments'])
    arg('--model',
        type=str,
        default='TernausNet',
        choices=['UNet', 'UNet11', 'LinkNet34', 'TernausNet'])

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    if args.type == 'parts':
        num_classes = 3
    elif args.type == 'instruments':
        num_classes = 8
    else:
        num_classes = 1

    if args.model == 'TernausNet':
        model = TernausNet34(num_classes=num_classes)
    else:
        model = TernausNet34(num_classes=num_classes)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()

    if args.type == 'binary':
        loss = LossBinary(jaccard_weight=args.jaccard_weight)
    else:
        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names,
                    shuffle=False,
                    transform=None,
                    mode='train',
                    problem_type='binary'):
        return DataLoader(dataset=MapDataset(file_names,
                                             transform=transform,
                                             problem_type=problem_type,
                                             mode=mode),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=args.batch_size,
                          pin_memory=torch.cuda.is_available())

    # labels = pd.read_csv('data/stage1_train_labels.csv')
    # labels = os.listdir('data/stage1_train_')
    # train_file_names, val_file_names = train_test_split(labels, test_size=0.2, random_state=42)

    # print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names)))

    # train_transform = DualCompose([
    #     HorizontalFlip(),
    #     VerticalFlip(),
    #     RandomCrop([256, 256]),
    #     RandomRotate90(),
    #     ShiftScaleRotate(),
    #     ImageOnly(RandomHueSaturationValue()),
    #     ImageOnly(RandomBrightness()),
    #     ImageOnly(RandomContrast()),
    #     ImageOnly(Normalize())
    # ])
    train_transform = DualCompose([
        OneOrOther(*(OneOf([
            Distort1(distort_limit=0.05, shift_limit=0.05),
            Distort2(num_steps=2, distort_limit=0.05)
        ]),
                     ShiftScaleRotate(shift_limit=0.0625,
                                      scale_limit=0.10,
                                      rotate_limit=45)),
                   prob=0.5),
        RandomRotate90(),
        RandomCrop([256, 256]),
        RandomFlip(prob=0.5),
        Transpose(prob=0.5),
        ImageOnly(RandomContrast(limit=0.2, prob=0.5)),
        ImageOnly(RandomFilter(limit=0.5, prob=0.2)),
        ImageOnly(RandomHueSaturationValue(prob=0.2)),
        ImageOnly(RandomBrightness()),
        ImageOnly(Normalize())
    ])

    val_transform = DualCompose([
        # RandomCrop([256, 256]),
        Rescale([256, 256]),
        ImageOnly(Normalize())
    ])

    train_loader = make_loader(TRAIN_ANNOTATIONS_PATH,
                               shuffle=True,
                               transform=train_transform,
                               problem_type=args.type)
    valid_loader = make_loader(VAL_ANNOTATIONS_PATH,
                               transform=val_transform,
                               mode='valid',
                               problem_type=args.type)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=valid,
                fold=args.fold,
                num_classes=num_classes)
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=0.5, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=1)
    arg('--n-epochs', type=int, default=100)
    arg('--lr', type=float, default=0.0001)
    arg('--workers', type=int, default=12)
    arg('--train_crop_height', type=int, default=1024)
    arg('--train_crop_width', type=int, default=1280)
    arg('--val_crop_height', type=int, default=1024)
    arg('--val_crop_width', type=int, default=1280)
    arg('--type',
        type=str,
        default='binary',
        choices=['binary', 'parts', 'instruments'])
    arg('--model', type=str, default='UNet', choices=moddel_list.keys())

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    if not utils.check_crop_size(args.train_crop_height,
                                 args.train_crop_width):
        print('Input image sizes should be divisible by 32, but train '
              'crop sizes ({train_crop_height} and {train_crop_width}) '
              'are not.'.format(train_crop_height=args.train_crop_height,
                                train_crop_width=args.train_crop_width))
        sys.exit(0)

    if not utils.check_crop_size(args.val_crop_height, args.val_crop_width):
        print('Input image sizes should be divisible by 32, but validation '
              'crop sizes ({val_crop_height} and {val_crop_width}) '
              'are not.'.format(val_crop_height=args.val_crop_height,
                                val_crop_width=args.val_crop_width))
        sys.exit(0)

    if args.type == 'parts':
        num_classes = 4
    elif args.type == 'instruments':
        num_classes = 8
    else:
        num_classes = 1

    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    else:
        model_name = moddel_list[args.model]
        model = model_name(num_classes=num_classes, pretrained=True)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        raise SystemError('GPU device not found')

    if args.type == 'binary':
        loss = LossBinary(jaccard_weight=args.jaccard_weight)

    else:
        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names,
                    shuffle=False,
                    transform=None,
                    problem_type='binary',
                    batch_size=1):
        return DataLoader(dataset=RoboticsDataset(file_names,
                                                  transform=transform,
                                                  problem_type=problem_type),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=batch_size,
                          pin_memory=torch.cuda.is_available())

    #print('sfsdgsdhsfffffffffff',args.fold)
    train_file_names, val_file_names = get_split(args.fold)

    print('num train = {}, num_val = {}'.format(len(train_file_names),
                                                len(val_file_names)))

    def train_transform(p=1):
        return Compose([
            PadIfNeeded(min_height=args.train_crop_height,
                        min_width=args.train_crop_width,
                        p=1),
            RandomCrop(height=args.train_crop_height,
                       width=args.train_crop_width,
                       p=1),
            VerticalFlip(p=0.5),
            HorizontalFlip(p=0.5),
            Normalize(p=1)
        ],
                       p=p)

    def val_transform(p=1):
        return Compose([
            PadIfNeeded(min_height=args.val_crop_height,
                        min_width=args.val_crop_width,
                        p=1),
            CenterCrop(
                height=args.val_crop_height, width=args.val_crop_width, p=1),
            Normalize(p=1)
        ],
                       p=p)

    train_loader = make_loader(train_file_names,
                               shuffle=True,
                               transform=train_transform(p=1),
                               problem_type=args.type,
                               batch_size=args.batch_size)
    valid_loader = make_loader(val_file_names,
                               transform=val_transform(p=1),
                               problem_type=args.type,
                               batch_size=len(device_ids))

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    print(model.parameters())
    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=valid,
                fold=args.fold,
                num_classes=num_classes)
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=0.5, type=float)
    arg('--device-ids', type=str, default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--filepath', type=str, help='folder with images and annotation masks')
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=32)
    arg('--n-epochs', type=int, default=100)
    arg('--lr', type=float, default=0.0001)
    arg('--workers', type=int, default=12)
    arg('--train_crop_height', type=int, default=416)
    arg('--train_crop_width', type=int, default=416)
    arg('--val_crop_height', type=int, default=416)
    arg('--val_crop_width', type=int, default=416)
    arg('--type', type=str, default='binary', choices=['binary', 'multi'])
    arg('--model', type=str, default='UNet', choices=model_list.keys())
    arg('--datatype', type=str, default='buildings',
        choices=['buildings', 'roads', 'combined'])
    arg('--pretrained', action='store_true',
        help='use pretrained network for initialisation')
    arg('--num_classes', type=int, default=1)

    args = parser.parse_args()

    timestr = time.strftime("%Y%m%d-%H%M%S")

    root = Path(args.root)
    root = Path(os.path.join(root, timestr))
    root.mkdir(exist_ok=True, parents=True)
#    dataset_type = args.filepath.split("/")[-3]
    dataset_type = args.datatype
    print('log', root, dataset_type)
    if not utils.check_crop_size(args.train_crop_height, args.train_crop_width):
        print('Input image sizes should be divisible by 32, but train '
              'crop sizes ({train_crop_height} and {train_crop_width}) '
              'are not.'.format(train_crop_height=args.train_crop_height, train_crop_width=args.train_crop_width))
        sys.exit(0)

    if not utils.check_crop_size(args.val_crop_height, args.val_crop_width):
        print('Input image sizes should be divisible by 32, but validation '
              'crop sizes ({val_crop_height} and {val_crop_width}) '
              'are not.'.format(val_crop_height=args.val_crop_height, val_crop_width=args.val_crop_width))
        sys.exit(0)

    num_classes = args.num_classes

    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    else:
        model_name = model_list[args.model]
        model = model_name(num_classes=num_classes, pretrained=args.pretrained)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        raise SystemError('GPU device not found')

    if args.type == 'binary':
        loss = LossBinary(jaccard_weight=args.jaccard_weight)
    elif args.num_classes == 2:
        labelweights = [89371542, 7083233]
        labelweights = np.sum(labelweights) / \
            (np.multiply(num_classes, labelweights))

        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight, class_weights=labelweights)

    else:
        #labelweights = [30740321,3046555,1554577]
        #labelweights = labelweights / np.sum(labelweights)
        #labelweights = 1 / np.log(1.2 + labelweights)
        labelweights = [89371542, 29703049, 7083233]
        labelweights = np.sum(labelweights) / \
            (np.multiply(num_classes, labelweights))

        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight, class_weights=labelweights)

    cudnn.benchmark = True

    train_filename = os.path.join(args.filepath, 'trainval.txt')
    val_filename = os.path.join(args.filepath, 'test.txt')

    def train_transform(p=1):
        return Compose([
            PadIfNeeded(min_height=args.train_crop_height,
                        min_width=args.train_crop_width, p=1),
            RandomCrop(height=args.train_crop_height,
                       width=args.train_crop_width, p=1),
            VerticalFlip(p=0.5),
            HorizontalFlip(p=0.5),
            Normalize(p=1)
        ], p=p)

    def val_transform(p=1):
        return Compose([
            PadIfNeeded(min_height=args.val_crop_height,
                        min_width=args.val_crop_width, p=1),
            CenterCrop(height=args.val_crop_height,
                       width=args.val_crop_width, p=1),
            Normalize(p=1)
        ], p=p)

    train_loader = make_loader(train_filename, shuffle=True, transform=train_transform(
        p=1), problem_type=args.type, batch_size=args.batch_size, datatype=args.datatype)
    valid_loader = make_loader(val_filename, transform=val_transform(p=1), problem_type=args.type,
                               batch_size=len(device_ids), datatype=args.datatype)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))
    args.root = root
    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    utils.train(
        init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
        args=args,
        model=model,
        criterion=loss,
        train_loader=train_loader,
        valid_loader=valid_loader,
        validation=valid,
        num_classes=num_classes,
        model_name=args.model,
        dataset_type=dataset_type
    )
Exemplo n.º 7
0
    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
    }, str(model_path))

    report_each = 10
    valid_each = 4
    log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8')
    valid_losses = []

    if(add_log == False):
        criterion = MultiDiceLoss(num_classes=11)
    else:
        criterion = LossMulti(num_classes=11, jaccard_weight=0.5)
    class_color_table = read_json(json_file_name)
    first_time = True
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        random.seed()
        tq = tqdm.tqdm(total=(len(train_loader) * batch_size))
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        losses = []
        try:
            mean_loss = 0
            for i, (inputs, targets) in enumerate(train_loader):
                # images = inputs.data.cpu().numpy()
                # targets = targets.data.cpu().numpy()
                # print(targets.shape)
                # images = np.moveaxis(images, [0, 1, 2, 3], [0, 3, 1, 2])
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=0.5, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=1)
    arg('--n-epochs', type=int, default=100)
    arg('--lr', type=float, default=0.0001)
    arg('--workers', type=int, default=12)
    arg('--type',
        type=str,
        default='binary',
        choices=['binary', 'parts', 'instruments'])
    arg('--model',
        type=str,
        default='UNet',
        choices=['UNet', 'UNet11', 'LinkNet34', 'AlbuNet'])

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    if args.type == 'parts':
        num_classes = 4
    elif args.type == 'instruments':
        num_classes = 8
    else:
        num_classes = 1

    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained=True)
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained=True)
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'AlbuNet':
        model = AlbuNet(num_classes=num_classes, pretrained=True)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()

    if args.type == 'binary':
        loss = LossBinary(jaccard_weight=args.jaccard_weight)
    else:
        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names,
                    shuffle=False,
                    transform=None,
                    problem_type='binary',
                    batch_size=1):
        return DataLoader(dataset=CustomDataset(file_names,
                                                transform=transform),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=batch_size,
                          pin_memory=torch.cuda.is_available())

    train_file_names, val_file_names = get_split()

    print('num train = {}, num_val = {}'.format(len(train_file_names),
                                                len(val_file_names)))

    def train_transform(p=1):
        return Compose(
            [
                #            Rescale(SIZE),
                RandomCrop(SIZE),
                RandomBrightness(0.2),
                OneOf([
                    IAAAdditiveGaussianNoise(),
                    GaussNoise(),
                ], p=0.15),
                #            OneOf([
                #                OpticalDistortion(p=0.3),
                #                GridDistortion(p=.1),
                #                IAAPiecewiseAffine(p=0.3),
                #            ], p=0.1),
                #            OneOf([
                #                IAASharpen(),
                #                IAAEmboss(),
                #                RandomContrast(),
                #                RandomBrightness(),
                #            ], p=0.15),
                HueSaturationValue(p=0.15),
                HorizontalFlip(p=0.5),
                Normalize(p=1),
            ],
            p=p)

    def val_transform(p=1):
        return Compose(
            [
                #            Rescale(256),
                RandomCrop(SIZE),
                Normalize(p=1)
            ],
            p=p)

    train_loader = make_loader(train_file_names,
                               shuffle=True,
                               transform=train_transform(p=1),
                               problem_type=args.type,
                               batch_size=args.batch_size)
    valid_loader = make_loader(val_file_names,
                               transform=val_transform(p=1),
                               problem_type=args.type,
                               batch_size=len(device_ids))

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=valid,
                fold=args.fold,
                num_classes=num_classes)