Ejemplo n.º 1
0
    if net_name == 'unet':
        model = UNet(n_channels=3, n_classes=2)
    else:
        model = HNNNet(pretrained=True, class_number=2)

    if config.D_MULTIPLY:
        dnet = DNet(input_dim=3, output_dim=1, input_size=config.PATCH_SIZE)
    else:
        dnet = DNet(input_dim=4, output_dim=1, input_size=config.PATCH_SIZE)

    g_optimizer = optim.SGD(model.parameters(),
                            lr=config.G_LEARNING_RATE,
                            momentum=0.9,
                            weight_decay=0.0005)
    d_optimizer = optim.SGD(dnet.parameters(),
                            lr=config.D_LEARNING_RATE,
                            momentum=0.9,
                            weight_decay=0.0005)
    resume = config.RESUME_MODEL
    if resume:
        if os.path.isfile(resume):
            print("=> loading checkpoint '{}'".format(resume))
            checkpoint = torch.load(resume)
            start_epoch = checkpoint['epoch'] + 1
            start_step = checkpoint['step']
            try:
                model.load_state_dict(checkpoint['state_dict'])
                g_optimizer.load_state_dict(checkpoint['optimizer'])
            except:
                model.load_state_dict(checkpoint['g_state_dict'])
Ejemplo n.º 2
0
                RandomRotation(rotation_angle),
                RandomCrop(image_size),
                #ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
            ]))
        eval_dataset = IDRIDDataset(eval_image_paths,
                                    eval_mask_paths,
                                    config.CLASS_ID,
                                    transform=Compose([
                                        RandomCrop(image_size),
                                        Normalize(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225]),
                                    ]))
    train_loader = DataLoader(train_dataset, batchsize, shuffle=True)
    eval_loader = DataLoader(eval_dataset, batchsize, shuffle=False)

    params = list(model.parameters())
    if dnet:
        params += list(dnet.parameters())
    optimizer = optim.SGD(params,
                          lr=config.LEARNING_RATE,
                          momentum=0.9,
                          weight_decay=0.0005)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.9)
    criterion = nn.CrossEntropyLoss(
        weight=torch.FloatTensor(config.CROSSENTROPY_WEIGHTS).to(device))

    train_model(model, train_loader, eval_loader, criterion, optimizer, scheduler, batchsize, \
            num_epochs=config.EPOCHES, start_epoch=start_epoch, start_step=start_step, dnet=dnet)