Exemple #1
0
    #     net.cuda()
    #     import pdb
    #     from torchsummary import summary
    #     summary(net, (3,1000,1000))
    #     pdb.set_trace()
    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    if args.gpu:
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
        net.cuda()

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  gpu=args.gpu,
                  img_scale=args.scale)
        torch.save(net.state_dict(), 'model_fin.pth')

    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'interrupt.pth')
        print('saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Exemple #2
0
        f'Network:\n'
        f'\t{net.n_channels} input channels\n'
        f'\t{net.n_classes} output channels (classes)\n'
        f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling')

    if model_path is not None:
        net.load_state_dict(torch.load(model_path, map_location=device))
        logging.info(f'Model loaded from {model_path}')

    if not os.path.exists(dir_checkpoint):
        os.mkdir(dir_checkpoint)

    net.to(device=device)

    try:
        train_net(net=net,
                  start_epoch=start_epoch,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            sys.exit(0)
Exemple #3
0
def main(args):
    presentParameters(vars(args))
    results_path = args.results
    if not os.path.exists(results_path):
        os.makedirs(results_path)

    save_args(args, modelpath=results_path)

    device = torch.device(args.device)
    if args.model == 'u-net':
        from unet.model import UNet
        model = UNet(in_channels=3, n_classes=1).to(device)
    elif args.model == 'fcd-net':
        from tiramisu.model import FCDenseNet
        # select model archictecture so it can be trained in 16gb ram GPU
        model = FCDenseNet(in_channels=3,
                           n_classes=1,
                           n_filter_first_conv=48,
                           n_pool=4,
                           growth_rate=8,
                           n_layers_per_block=3,
                           dropout_p=0.2).to(device)
    else:
        print(
            'Parsed model argument "{}" invalid. Possible choices are "u-net" or "fcd-net"'
            .format(args.model))

    # Init weights for model
    model = model.apply(weights_init)

    transforms = my_transforms(scale=args.aug_scale,
                               angle=args.aug_angle,
                               flip_prob=args.aug_flip)
    print('Trainable parameters for model {}: {}'.format(
        args.model, get_number_params(model)))

    # create pytorch dataset
    dataset = DataSetfromNumpy(
        image_npy_path='data/train_img_{}x{}.npy'.format(
            args.image_size, args.image_size),
        mask_npy_path='data/train_mask_{}x{}.npy'.format(
            args.image_size, args.image_size),
        transform=transforms)

    # create training and validation set
    n_val = int(len(dataset) * args.val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])

    ## hacky solution: only add CustomToTensor transform in validation
    from utils.transform import CustomToTensor
    val.dataset.transform = CustomToTensor()

    print('Training the model with n_train: {} and n_val: {} images/masks'.
          format(n_train, n_val))
    train_loader = DataLoader(train,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers)
    val_loader = DataLoader(val,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    dc_loss = DiceLoss()
    writer = SummaryWriter(log_dir=os.path.join(args.logs, args.model))
    optimizer = Adam(params=model.parameters(), lr=args.lr)
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.9,
                                                           patience=5)

    loss_train = []
    loss_valid = []

    # training loop:
    global_step = 0
    for epoch in range(args.epochs):
        eval_count = 0
        epoch_start_time = datetime.datetime.now().replace(microsecond=0)
        # set model into train mode
        model = model.train()
        train_epoch_loss = 0
        valid_epoch_loss = 0
        # tqdm progress bar
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{args.epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                # retrieve images and masks and send to pytorch device
                imgs = batch['image'].to(device=device, dtype=torch.float32)
                true_masks = batch['mask'].to(
                    device=device,
                    dtype=torch.float32
                    if model.n_classes == 1 else torch.long)

                # compute prediction masks
                predicted_masks = model(imgs)
                if model.n_classes == 1:
                    predicted_masks = torch.sigmoid(predicted_masks)
                elif model.n_classes > 1:
                    predicted_masks = F.softmax(predicted_masks, dim=1)

                # compute dice loss
                loss = dc_loss(y_true=true_masks, y_pred=predicted_masks)
                train_epoch_loss += loss.item()
                # update model network weights
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # logging
                writer.add_scalar('Loss/train', loss.item(), global_step)
                # update progress bar
                pbar.update(imgs.shape[0])
                # Do evaluation every 25 training steps
                if global_step % 25 == 0:
                    eval_count += 1
                    val_loss = np.mean(
                        eval_net(model, val_loader, device, dc_loss))
                    valid_epoch_loss += val_loss
                    writer.add_scalar('Loss/validation', val_loss, global_step)
                    if model.n_classes > 1:
                        pbar.set_postfix(
                            **{
                                'Training CE loss (batch)': loss.item(),
                                'Validation CE (val set)': val_loss
                            })
                    else:
                        pbar.set_postfix(
                            **{
                                'Training dice loss (batch)': loss.item(),
                                'Validation dice loss (val set)': val_loss
                            })

                global_step += 1
                # save images as well as true + predicted masks into writer
                if global_step % args.vis_images == 0:
                    writer.add_images('images', imgs, global_step)
                    if model.n_classes == 1:
                        writer.add_images('masks/true', true_masks,
                                          global_step)
                        writer.add_images('masks/pred', predicted_masks > 0.5,
                                          global_step)

            # Get estimation of training and validation loss for entire epoch
            valid_epoch_loss /= eval_count
            train_epoch_loss /= len(train_loader)

            # Apply learning rate scheduler per epoch
            scheduler.step(valid_epoch_loss)
            # Only save the model in case the validation metric is best. For the first epoch, directly save
            if epoch > 0:
                best_model_bool = [valid_epoch_loss < l for l in loss_valid]
                best_model_bool = np.all(best_model_bool)
            else:
                best_model_bool = True

            # append
            loss_train.append(train_epoch_loss)
            loss_valid.append(valid_epoch_loss)

            if best_model_bool:
                print(
                    '\nSaving model and optimizers at epoch {} with best validation loss of {}'
                    .format(epoch, valid_epoch_loss))
                torch.save(obj={
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lr_scheduler': scheduler.state_dict(),
                },
                           f=results_path +
                           '/model_epoch-{}_val_loss-{}.pth'.format(
                               epoch, np.round(valid_epoch_loss, 4)))
                epoch_time_difference = datetime.datetime.now().replace(
                    microsecond=0) - epoch_start_time
                print('Epoch: {:3d} time execution: {}'.format(
                    epoch, epoch_time_difference))

    print(
        'Finished training the segmentation model.\nAll results can be found at: {}'
        .format(results_path))
    # save scalars dictionary as json file
    scalars = {'loss_train': loss_train, 'loss_valid': loss_valid}
    with open('{}/all_scalars.json'.format(results_path), 'w') as fp:
        json.dump(scalars, fp)

    print('Logging file for tensorboard is stored at {}'.format(args.logs))
    writer.close()