Example #1
0
def main():
    model = UNet().to(device)
    criterion = BinaryDiceLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_dataset = ImageDataset(split='train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=train_batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True)
    valid_dataset = ImageDataset(split='valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=valid_batch_size,
                                               num_workers=num_workers,
                                               shuffle=False,
                                               pin_memory=True)

    for epoch in range(epochs):
        train(epoch, model, train_loader, criterion, optimizer)
        valid(epoch, model, valid_loader, criterion)

    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        },
        save_file='pretrained/unet.pth.tar',
        is_best=False)
def main():
    global args, logger
    args = parser.parse_args()
    # logger = Logger(add_prefix(args.prefix, 'logs'))
    set_prefix(args.prefix, __file__)
    model = UNet(3, depth=5, in_channels=3)
    print(model)
    print('load unet with depth=5')
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        raise RuntimeError('there is no gpu')
    criterion = nn.L1Loss(reduce=False).cuda()
    print('use l1_loss')
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # accelerate the speed of training
    cudnn.benchmark = True

    data_loader = get_dataloader()
    # class_names=['LESION', 'NORMAL']
    # class_names = data_loader.dataset.class_names
    # print(class_names)

    since = time.time()
    print('-' * 10)
    for epoch in range(1, args.epochs + 1):
        train(data_loader, model, optimizer, criterion, epoch)
        if epoch % 40 == 0:
            validate(model, epoch, data_loader)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    validate(model, args.epochs, data_loader)
    # save model parameter
    torch.save(model.state_dict(),
               add_prefix(args.prefix, 'identical_mapping.pkl'))
    # save running parameter setting to json
    write(vars(args), add_prefix(args.prefix, 'paras.txt'))