コード例 #1
0
ファイル: train_val_Task1.py プロジェクト: zlinzju/MICCAI2019
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss_record = AvgMeter()
    for batch_idx, data in enumerate(train_loader):
        if epoch % args['lr_step'] == 0 and epoch != 0:
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] / args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] / args['lr_decay']
        inputs_volume, labels_volume = data
        sub_batch_len = math.ceil(inputs_volume.shape[1] /
                                  args['train_batch_size'])
        for sub_batch_idx in range(sub_batch_len):
            #split volume to batchs
            start = sub_batch_idx * args['train_batch_size']
            end = (sub_batch_idx + 1) * args['train_batch_size'] if (
                sub_batch_idx +
                1) * args['train_batch_size'] < inputs_volume.shape[
                    1] else inputs_volume.shape[1]
            inputs = inputs_volume[:, start:end, :, :].permute(1, 0, 2, 3)
            inputs = inputs.expand(
                torch.Size((inputs.shape[0], 3, inputs.shape[2],
                            inputs.shape[3])))  #adjust to net input channel
            # labels = labels_volume[:, start:end, :, :].permute(1, 0, 2, 3)
            labels = labels_volume[:, start:end, :, :].squeeze(0)
            # training trick
            tmp = np.array(labels)
            tmp[0, :128, :] = 255
            tmp[0, 384:, :] = 255
            tmp[0, :, :128] = 255
            tmp[0, :, 384:] = 255
            labels = torch.from_numpy(tmp)

            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()
            optimizer.zero_grad()
            outputs = net(inputs)

            # pred=outputs.argmax(dim=1)
            # prediction = np.array(pred.detach().cpu())
            # l = np.array(labels.detach().cpu())
            # index = np.where(l==np.max(l))
            # a=np.count_nonzero(prediction)
            # b=np.count_nonzero(l)
            # print(str(a))
            # print(str(b))

            criterion = MultiClassCriterion('OhemCrossEntropy')
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()
            train_loss_record.update(loss.item(), batch_size)
            log = 'iter: %d | [Total loss: %.5f], [lr: %.8f]' % \
                  (epoch, train_loss_record.avg, optimizer.param_groups[1]['lr'])
            progress_bar(batch_idx, len(train_loader), log)
コード例 #2
0
del data_config['dataset']

modelname = config_path.stem
output_dir = Path('../model') / modelname
output_dir.mkdir(exist_ok=True)
log_dir = Path('../logs') / modelname
log_dir.mkdir(exist_ok=True)

logger = debug_logger(log_dir)
logger.debug(config)
logger.info(f'Device: {device}')
logger.info(f'Max Epoch: {max_epoch}')

# Loss
print('Initializing loss function, optimizer and scheduler...')
loss_fn = MultiClassCriterion(**loss_config).to(device)
params = model.parameters()
optimizer, scheduler = create_optimizer(params, **opt_config)

# history
if resume:
    with open(log_dir.joinpath('history.pkl'), 'rb') as f:
        history_dict = pickle.load(f)
        best_metrics = history_dict['best_metrics']
        loss_history = history_dict['loss']
        iou_history = history_dict['iou']
        start_epoch = len(iou_history)
        for _ in range(start_epoch):
            scheduler.step()
else:
    start_epoch = 0
コード例 #3
0
def main(argv, configPath=None):
    # arguments
    args = getArgs_(argv, configPath)
    saveDir = savePath(args)
    logger = infoLogger(logdir=saveDir, name=args.model)

    logger.info(argv)
    logger.debug(cfgInfo(args))
    logger.info("CheckPoints path: {}".format(saveDir))
    logger.debug("Model Name: {}".format(args.model))

    train_dataset = BDD100K_Area_Seg(base_dir=args.dataPath,
                                     split='train',
                                     target_size=args.size)
    valid_dataset = BDD100K_Area_Seg(base_dir=args.dataPath,
                                     split='val',
                                     target_size=args.size)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_worker,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_worker,
                              pin_memory=True)

    args.num_gpus, args.device = deviceSetting(logger=logger,
                                               device=args.device)
    # model
    model = Deeplabv3plus_Mobilenet(args.output_channels,
                                    output_stride=args.output_stride)

    optimizer, scheduler = create_optimizer_(model, args)
    loss_fn = MultiClassCriterion(loss_type=args.loss_type,
                                  ignore_index=args.ignore_index)
    model, trainData = modelDeploy(args, model, optimizer, scheduler, logger)

    tensorLogger = SummaryWriter(log_dir=os.path.join(saveDir, 'runs'),
                                 filename_suffix=args.model)
    logger.info("Tensorboard event log saved in {}".format(
        tensorLogger.log_dir))

    logger.info('Start training...')
    # global_step = 0
    start_epoch = trainData['epoch']

    num_classes = args.output_channels
    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.size[0], args.size[1])
    for i_epoch in range(start_epoch, args.max_epoch):
        if i_epoch >= 29:
            optimizer.param_groups[0]["lr"] = np.float64(0.00001)
        trainData['epoch'] = i_epoch
        lossList, miouList = train_seg(model,
                                       train_loader,
                                       i_epoch,
                                       optimizer,
                                       loss_fn,
                                       num_classes,
                                       logger,
                                       tensorLogger,
                                       args=args)
        scheduler.step()
        trainData['loss'].extend(lossList)
        trainData['miou'].extend(miouList)

        valLoss, valMiou = val_seg(model,
                                   valid_loader,
                                   i_epoch,
                                   loss_fn,
                                   num_classes,
                                   logger,
                                   tensorLogger,
                                   args=args)
        trainData['val'].append([valLoss, valMiou])

        best = valMiou > trainData['bestMiou']
        if valMiou > trainData['bestMiou']:
            trainData['bestMiou'] = valMiou

        weights_dict = model.module.state_dict(
        ) if args.device == 'cuda' else model.state_dict()

        save_checkpoint(
            {
                'trainData': trainData,
                'model': weights_dict,
                'optimizer': optimizer.state_dict(),
            },
            is_best=best,
            dir=saveDir,
            extra_info=extra_info_ckpt,
            miou_val=valMiou,
            logger=logger)

    tensorLogger.close()
コード例 #4
0
def process(config_path):
    gc.collect()
    torch.cuda.empty_cache()
    config = yaml.load(open(config_path))
    net_config = config['Net']
    data_config = config['Data']
    train_config = config['Train']
    loss_config = config['Loss']
    opt_config = config['Optimizer']
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    t_max = opt_config['t_max']

    # Collect training parameters
    max_epoch = train_config['max_epoch']
    batch_size = train_config['batch_size']
    fp16 = train_config['fp16']
    resume = train_config['resume']
    pretrained_path = train_config['pretrained_path']
    freeze_enabled = train_config['freeze']
    seed_enabled = train_config['seed']

    #########################################
    # Deterministic training
    if seed_enabled:
        seed = 100
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed=seed)
        import random
        random.seed(a=100)
    #########################################

    # Network
    if 'unet' in net_config['dec_type']:
        net_type = 'unet'
        model = EncoderDecoderNet(**net_config)
    else:
        net_type = 'deeplab'
        net_config['output_channels'] = 19
        model = SPPNet(**net_config)

    dataset = data_config['dataset']
    if dataset == 'deepglobe-dynamic':
        from dataset.deepglobe_dynamic import DeepGlobeDatasetDynamic as Dataset
        net_config['output_channels'] = 7
        classes = np.arange(0, 7)
    else:
        raise NotImplementedError
    del data_config['dataset']

    modelname = config_path.stem
    timestamp = datetime.timestamp(datetime.now())
    print("timestamp =", datetime.fromtimestamp(timestamp))
    output_dir = Path(os.path.join(ROOT_DIR, f'model/{modelname}_{datetime.fromtimestamp(timestamp)}') )
    output_dir.mkdir(exist_ok=True)
    log_dir = Path(os.path.join(ROOT_DIR, f'logs/{modelname}_{datetime.fromtimestamp(timestamp)}') )
    log_dir.mkdir(exist_ok=True)
    dataset_dir= '/home/sfoucher/DEV/pytorch-segmentation/data/deepglobe_as_pascalvoc/VOCdevkit/VOC2012'
    logger = debug_logger(log_dir)
    logger.debug(config)
    logger.info(f'Device: {device}')
    logger.info(f'Max Epoch: {max_epoch}')

    # Loss
    loss_fn = MultiClassCriterion(**loss_config).to(device)
    params = model.parameters()
    optimizer, scheduler = create_optimizer(params, **opt_config)

    # history
    if resume:
        with open(log_dir.joinpath('history.pkl'), 'rb') as f:
            history_dict = pickle.load(f)
            best_metrics = history_dict['best_metrics']
            loss_history = history_dict['loss']
            iou_history = history_dict['iou']
            start_epoch = len(iou_history)
            for _ in range(start_epoch):
                scheduler.step()
    else:
        start_epoch = 0
        best_metrics = 0
        loss_history = []
        iou_history = []


    affine_augmenter = albu.Compose([albu.HorizontalFlip(p=.5),albu.VerticalFlip(p=.5)
                                    # Rotate(5, p=.5)
                                    ])
    # image_augmenter = albu.Compose([albu.GaussNoise(p=.5),
    #                                 albu.RandomBrightnessContrast(p=.5)])
    image_augmenter = None

    # This has been put in the loop for the dynamic training

    """
    # Dataset
    train_dataset = Dataset(affine_augmenter=affine_augmenter, image_augmenter=image_augmenter,
                            net_type=net_type, **data_config)
    valid_dataset = Dataset(split='valid', net_type=net_type, **data_config)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4,
                            pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
    """

    

    # Pretrained model
    if pretrained_path:
        logger.info(f'Resume from {pretrained_path}')
        param = torch.load(pretrained_path)
        model.load_state_dict(param)
        model.logits = torch.nn.Conv2d(256, net_config['output_channels'], 1)
        del param

    # To device
    model = model.to(device)

    #########################################
    if freeze_enabled:
        # Code de Rémi
        # Freeze layers
        for param_index in range(int((len(optimizer.param_groups[0]['params']))*0.5)):
            optimizer.param_groups[0]['params'][param_index].requires_grad = False
    #########################################
        params_to_update = model.parameters()
        print("Params to learn:")
        if freeze_enabled:
            params_to_update = []
            for name,param in model.named_parameters():
                if param.requires_grad == True:
                    params_to_update.append(param)
                    print("\t",name)
        optimizer, scheduler = create_optimizer(params_to_update, **opt_config)

    # fp16
    if fp16:
        # I only took the necessary files because I don't need the C backend of apex,
        # which is broken and can't be installed
        # from apex import fp16_utils
        from utils.apex.apex.fp16_utils.fp16util import BN_convert_float
        from utils.apex.apex.fp16_utils.fp16_optimizer import FP16_Optimizer
        # model = fp16_utils.BN_convert_float(model.half())
        model = BN_convert_float(model.half())
        # optimizer = fp16_utils.FP16_Optimizer(optimizer, verbose=False, dynamic_loss_scale=True)
        optimizer = FP16_Optimizer(optimizer, verbose=False, dynamic_loss_scale=True)
        logger.info('Apply fp16')

    # Restore model
    if resume:
        model_path = output_dir.joinpath(f'model_tmp.pth')
        logger.info(f'Resume from {model_path}')
        param = torch.load(model_path)
        model.load_state_dict(param)
        del param
        opt_path = output_dir.joinpath(f'opt_tmp.pth')
        param = torch.load(opt_path)
        optimizer.load_state_dict(param)
        del param
    i_iter = 0
    ma_loss= 0
    ma_iou= 0
    # Train
    for i_epoch in range(start_epoch, max_epoch):
        logger.info(f'Epoch: {i_epoch}')
        logger.info(f'Learning rate: {optimizer.param_groups[0]["lr"]}')

        train_losses = []
        train_ious = []
        model.train()

        # Initialize randomized but balanced datasets
        train_dataset = Dataset(base_dir = dataset_dir,
                                affine_augmenter=affine_augmenter, image_augmenter=image_augmenter,
                                net_type=net_type, **data_config)
        valid_dataset = Dataset(base_dir = dataset_dir,
                                split='valid', net_type=net_type, **data_config)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4,
                                pin_memory=True, drop_last=True)
        valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

        with tqdm(train_loader) as _tqdm:
            for i, batched in enumerate(_tqdm):
                images, labels = batched
                if fp16:
                    images = images.half()
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                preds = model(images)
                if net_type == 'deeplab':
                    preds = F.interpolate(preds, size=labels.shape[1:], mode='bilinear', align_corners=True)
                if fp16:
                    loss = loss_fn(preds.float(), labels)
                else:
                    loss = loss_fn(preds, labels)

                preds_np = preds.detach().cpu().numpy()
                labels_np = labels.detach().cpu().numpy()
                iou = compute_iou_batch(np.argmax(preds_np, axis=1), labels_np, classes)

                _tqdm.set_postfix(OrderedDict(seg_loss=f'{loss.item():.5f}', iou=f'{iou:.3f}'))
                train_losses.append(loss.item())
                train_ious.append(iou)
                ma_loss= 0.01*loss.item() +  0.99 * ma_loss
                ma_iou= 0.01*iou +  0.99 * ma_iou
                plotter.plot('loss', 'train', 'iteration Loss', i_iter, loss.item())
                plotter.plot('iou', 'train', 'iteration iou', i_iter, iou)
                plotter.plot('loss', 'ma_loss', 'iteration Loss', i_iter, ma_loss)
                plotter.plot('iou', 'ma_iou', 'iteration iou', i_iter, ma_iou)
                if fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                optimizer.step()
                i_iter += 1
        scheduler.step()

        train_loss = np.mean(train_losses)
        train_iou = np.nanmean(train_ious)
        logger.info(f'train loss: {train_loss}')
        logger.info(f'train iou: {train_iou}')
        plotter.plot('loss-epoch', 'train', 'iteration Loss', i_epoch, train_loss)
        plotter.plot('iou-epoch', 'train', 'iteration iou', i_epoch, train_iou)
        torch.save(model.state_dict(), output_dir.joinpath('model_tmp.pth'))
        torch.save(optimizer.state_dict(), output_dir.joinpath('opt_tmp.pth'))

        valid_losses = []
        valid_ious = []
        model.eval()
        with torch.no_grad():
            with tqdm(valid_loader) as _tqdm:
                for batched in _tqdm:
                    images, labels = batched
                    if fp16:
                        images = images.half()
                    images, labels = images.to(device), labels.to(device)
                    preds = model.tta(images, net_type=net_type)
                    if fp16:
                        loss = loss_fn(preds.float(), labels)
                    else:
                        loss = loss_fn(preds, labels)

                    preds_np = preds.detach().cpu().numpy()
                    labels_np = labels.detach().cpu().numpy()

                    # I changed a parameter in the compute_iou method to prevent it from yielding nans
                    iou = compute_iou_batch(np.argmax(preds_np, axis=1), labels_np, classes)

                    _tqdm.set_postfix(OrderedDict(seg_loss=f'{loss.item():.5f}', iou=f'{iou:.3f}'))
                    valid_losses.append(loss.item())
                    valid_ious.append(iou)

        valid_loss = np.mean(valid_losses)
        valid_iou = np.mean(valid_ious)
        logger.info(f'valid seg loss: {valid_loss}')
        logger.info(f'valid iou: {valid_iou}')
        plotter.plot('loss-epoch', 'valid', 'iteration Loss', i_epoch, valid_loss)
        plotter.plot('iou-epoch', 'valid', 'iteration iou', i_epoch, valid_iou)
        if best_metrics < valid_iou:
            best_metrics = valid_iou
            logger.info('Best Model!')
            torch.save(model.state_dict(), output_dir.joinpath('model.pth'))
            torch.save(optimizer.state_dict(), output_dir.joinpath('opt.pth'))

        loss_history.append([train_loss, valid_loss])
        iou_history.append([train_iou, valid_iou])
        history_ploter(loss_history, log_dir.joinpath('loss.png'))
        history_ploter(iou_history, log_dir.joinpath('iou.png'))

        history_dict = {'loss': loss_history,
                        'iou': iou_history,
                        'best_metrics': best_metrics}
        with open(log_dir.joinpath('history.pkl'), 'wb') as f:
            pickle.dump(history_dict, f)