예제 #1
0
def main():
    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default=None,
                        help='path to the data')
    parser.add_argument('-e',
                        '--epochs',
                        dest='epochs',
                        default=20,
                        type=int,
                        help='number of epochs')
    parser.add_argument('-b',
                        '--batch_size',
                        dest='batch_size',
                        default=40,
                        type=int,
                        help='batch size')
    parser.add_argument('-s',
                        '--image_size',
                        dest='image_size',
                        default=256,
                        type=int,
                        help='input image size')
    parser.add_argument('-lr',
                        '--learning_rate',
                        dest='lr',
                        default=0.0001,
                        type=float,
                        help='learning rate')
    parser.add_argument('-wd',
                        '--weight_decay',
                        dest='weight_decay',
                        default=5e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('-lrs',
                        '--learning_rate_step',
                        dest='lr_step',
                        default=10,
                        type=int,
                        help='learning rate step')
    parser.add_argument('-lrg',
                        '--learning_rate_gamma',
                        dest='lr_gamma',
                        default=0.5,
                        type=float,
                        help='learning rate gamma')
    parser.add_argument(
        '-m',
        '--model',
        dest='model',
        default='fpn',
    )
    parser.add_argument('-w',
                        '--weight_bce',
                        default=0.5,
                        type=float,
                        help='weight BCE loss')
    parser.add_argument('-l',
                        '--load',
                        dest='load',
                        default=False,
                        help='load file model')
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        default=0.7,
                        help='train/val split')
    parser.add_argument('-o',
                        '--output_dir',
                        dest='output_dir',
                        default='./output',
                        help='dir to save log and models')
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    logger = get_logger(os.path.join(args.output_dir, 'train.log'))
    logger.info('Start training with params:')
    for arg, value in sorted(vars(args).items()):
        logger.info("Argument %s: %r", arg, value)


#     net = UNet() # TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size
#     net = smp.FPN('mobilenet_v2', encoder_weights='imagenet', classes=2)
    net = smp.FPN('se_resnet50', encoder_weights='imagenet', classes=2)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if args.load:
        net.load_state_dict(torch.load(args.load))
    logger.info('Model type: {}'.format(net.__class__.__name__))

    net.to(device)

    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    criterion = lambda x, y: (args.weight_bce * nn.BCELoss()(x, y),
                              (1. - args.weight_bce) * dice_loss(x, y))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) \
        if args.lr_step > 0 else None

    train_transforms = Compose([
        Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5),
        Flip(p=0.05),
        RandomRotate(),
        Pad(max_size=0.6, p=0.25),
        Resize(size=(args.image_size, args.image_size), keep_aspect=True),
        ScaleToZeroOne(),
    ])
    val_transforms = Compose([
        Resize(size=(args.image_size, args.image_size)),
        ScaleToZeroOne(),
    ])

    train_dataset = DetectionDataset(args.data_path,
                                     os.path.join(args.data_path,
                                                  'train_mask.json'),
                                     transforms=train_transforms)
    val_dataset = DetectionDataset(args.data_path,
                                   None,
                                   transforms=val_transforms)

    train_size = int(len(train_dataset) * args.val_split)
    val_dataset.image_names = train_dataset.image_names[train_size:]
    val_dataset.mask_names = train_dataset.mask_names[train_size:]
    train_dataset.image_names = train_dataset.image_names[:train_size]
    train_dataset.mask_names = train_dataset.mask_names[:train_size]
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True,
                                  drop_last=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                num_workers=4,
                                shuffle=False,
                                drop_last=False)
    logger.info('Number of batches of train/val=%d/%d', len(train_dataloader),
                len(val_dataloader))

    try:
        train(net,
              optimizer,
              criterion,
              scheduler,
              train_dataloader,
              val_dataloader,
              logger=logger,
              args=args,
              device=device)
    except KeyboardInterrupt:
        torch.save(
            net.state_dict(),
            os.path.join(args.output_dir, f'{args.model}_INTERRUPTED.pth'))
        logger.info('Saved interrupt')
        sys.exit(0)
예제 #2
0
파일: train.py 프로젝트: veegaaa/MADE
def main():
    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default=None,
                        help='path to the data')
    parser.add_argument('-e',
                        '--epochs',
                        dest='epochs',
                        default=20,
                        type=int,
                        help='number of epochs')
    parser.add_argument('-b',
                        '--batch_size',
                        dest='batch_size',
                        default=40,
                        type=int,
                        help='batch size')
    parser.add_argument('-s',
                        '--image_size',
                        dest='image_size',
                        default=256,
                        type=int,
                        help='input image size')
    parser.add_argument('-lr',
                        '--learning_rate',
                        dest='lr',
                        default=0.0001,
                        type=float,
                        help='learning rate')
    parser.add_argument('-wd',
                        '--weight_decay',
                        dest='weight_decay',
                        default=5e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('-lrs',
                        '--learning_rate_step',
                        dest='lr_step',
                        default=10,
                        type=int,
                        help='learning rate step')
    parser.add_argument('-lrg',
                        '--learning_rate_gamma',
                        dest='lr_gamma',
                        default=0.5,
                        type=float,
                        help='learning rate gamma')
    parser.add_argument('-m',
                        '--model',
                        dest='model',
                        default='unet',
                        choices=('unet', ))
    parser.add_argument('-w',
                        '--weight_bce',
                        default=0.5,
                        type=float,
                        help='weight BCE loss')
    parser.add_argument('-l',
                        '--load',
                        dest='load',
                        default=False,
                        help='load file model')
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        default=0.8,
                        help='train/val split')
    parser.add_argument('-o',
                        '--output_dir',
                        dest='output_dir',
                        default='/tmp/logs/',
                        help='dir to save log and models')
    args = parser.parse_args()
    #
    os.makedirs(args.output_dir, exist_ok=True)
    logger = get_logger(os.path.join(args.output_dir, 'train.log'))
    logger.info('Start training with params:')
    for arg, value in sorted(vars(args).items()):
        logger.info("Argument %s: %r", arg, value)
    #
    net = UNet(
    )  # TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size
    # TODO: img_size=256 is rather mediocre, try to optimize network for at least 512
    logger.info('Model type: {}'.format(net.__class__.__name__))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if args.load:
        net.load_state_dict(torch.load(args.load))
    net.to(device)
    # net = nn.DataParallel(net)

    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    # TODO: loss experimentation, fight class imbalance, there're many ways you can tackle this challenge
    criterion = lambda x, y: (args.weight_bce * nn.BCELoss()(x, y),
                              (1. - args.weight_bce) * dice_loss(x, y))
    # TODO: you can always try on plateau scheduler as a default option
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) \
        if args.lr_step > 0 else None

    # dataset
    # TODO: to work on transformations a lot, look at albumentations package for inspiration
    train_transforms = Compose([
        Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5),
        Flip(p=0.05),
        Pad(max_size=0.6, p=0.25),
        Resize(size=(args.image_size, args.image_size), keep_aspect=True)
    ])
    # TODO: don't forget to work class imbalance and data cleansing
    val_transforms = Resize(size=(args.image_size, args.image_size))

    train_dataset = DetectionDataset(args.data_path,
                                     os.path.join(args.data_path,
                                                  'train_mask.json'),
                                     transforms=train_transforms)
    val_dataset = DetectionDataset(args.data_path,
                                   None,
                                   transforms=val_transforms)

    # split dataset into train/val, don't try to do this at home ;)
    train_size = int(len(train_dataset) * args.val_split)
    val_dataset.image_names = train_dataset.image_names[train_size:]
    val_dataset.mask_names = train_dataset.mask_names[train_size:]
    train_dataset.image_names = train_dataset.image_names[:train_size]
    train_dataset.mask_names = train_dataset.mask_names[:train_size]

    # TODO: always work with the data: cleaning, sampling
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True,
                                  drop_last=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                num_workers=4,
                                shuffle=False,
                                drop_last=False)
    logger.info('Length of train/val=%d/%d', len(train_dataset),
                len(val_dataset))
    logger.info('Number of batches of train/val=%d/%d', len(train_dataloader),
                len(val_dataloader))

    try:
        train(net,
              optimizer,
              criterion,
              scheduler,
              train_dataloader,
              val_dataloader,
              logger=logger,
              args=args,
              device=device)
    except KeyboardInterrupt:
        torch.save(net.state_dict(),
                   os.path.join(args.output_dir, 'INTERRUPTED.pth'))
        logger.info('Saved interrupt')
        sys.exit(0)