示例#1
0
            net = Road_Seg(n_channels=3, n_classes=3, bilinear=True)

    else:
        print(
            "Error: {} is not supported, please chooes a model in [unet/gcn]".
            format(args.model))
        exit()

    logging.info(
        f'Network:\n'
        f'\t{net.n_channels} input channels\n'
        f'\t{net.n_classes} output channels (classes)\n'
        f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    try:
        train_net(net=net,
                  dataset=args.dataset,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100)
示例#2
0
)

ARGS = parser.parse_args()

log = Logger("PREDICT", ARGS.debug, ARGS.log_dir)
"""
Main function to compute prediction by using a trained model together with the given input
"""
if __name__ == "__main__":

    if ARGS.checkpoint_path is not None:
        log.info("Restoring checkpoint from {} instead of using a model file.".
                 format(ARGS.checkpoint_path))
        checkpoint = torch.load(ARGS.checkpoint_path)
        model = UNet(1, 1, bilinear=False)
        model.load_state_dict(checkpoint["modelState"])
        log.warning(
            "Using default preprocessing options. Provide Model file if they are changed"
        )
        dataOpts = DefaultSpecDatasetOps
    else:
        if ARGS.jit_load:
            extra_files = {}
            extra_files['dataOpts'] = ''
            model = torch.jit.load(ARGS.model_path, _extra_files=extra_files)
            unetState = model.state_dict()
            dataOpts = eval(extra_files['dataOpts'])
            log.debug("Model successfully load via torch jit: " +
                      str(ARGS.model_path))
        else:
            model_dict = torch.load(ARGS.model_path)
示例#3
0
# Make everything parallelisable
unet = nn.DataParallel(unet)
segmenter = nn.DataParallel(segmenter)
domain_pred = nn.DataParallel(domain_pred)

if LOAD_PATH_UNET:
    print('Loading Weights')
    encoder_dict = unet.state_dict()
    pretrained_dict = torch.load(LOAD_PATH_UNET)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in encoder_dict
    }
    print('weights loaded encoder = ', len(pretrained_dict), '/',
          len(encoder_dict))
    unet.load_state_dict(torch.load(LOAD_PATH_UNET))

if LOAD_PATH_SEGMENTER:
    regressor_dict = segmenter.state_dict()
    pretrained_dict = torch.load(LOAD_PATH_SEGMENTER)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in regressor_dict
    }
    print('weights loaded regressor = ', len(pretrained_dict), '/',
          len(regressor_dict))
    segmenter.load_state_dict(torch.load(LOAD_PATH_SEGMENTER))

if LOAD_PATH_DOMAIN:
    domain_dict = domain_pred.state_dict()
    pretrained_dict = torch.load(LOAD_PATH_DOMAIN)