Exemplo n.º 1
0
def q_deblurGAN(img,
                theta,
                use_gpu=False,
                weights="./weights/fpn_inception.h5"):
    blurred_field = blurred_field_cal(img, theta)
    img_input = np.concatenate([img, blurred_field], axis=2)
    img_input = normal_img(img_input, mean=0.5, std=0.5)
    img_input = np.transpose(img_input, (2, 0, 1))
    img_input = torch.tensor(img_input, dtype=torch.float32)
    img_input = img_input.unsqueeze(dim=0)
    norm_layer = functools.partial(nn.InstanceNorm2d,
                                   affine=False,
                                   track_running_stats=True)
    model = FPNInception(norm_layer=norm_layer)
    model.train()
    if use_gpu:
        img_input = img_input.cuda()
        model.cuda()
    model.load_state_dict({
        k.replace('module.', ''): v
        for k, v in torch.load(weights)["model"].items()
    })
    out = model(img_input)
    out = post_process(out)
    return out
Exemplo n.º 2
0
def get_generator(model_config):
    generator_name = model_config['g_name']
    if generator_name == 'resnet':
        model_g = ResnetGenerator(
            norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
            use_dropout=model_config['dropout'],
            n_blocks=model_config['blocks'],
            learn_residual=model_config['learn_residual'])
    elif generator_name == 'fpn_mobilenet':
        model_g = FPNMobileNet(norm_layer=get_norm_layer(
            norm_type=model_config['norm_layer']))
    elif generator_name == 'fpn_inception':
        model_g = FPNInception(norm_layer=get_norm_layer(
            norm_type=model_config['norm_layer']))
    elif generator_name == 'fpn_inception_simple':
        model_g = FPNInceptionSimple(norm_layer=get_norm_layer(
            norm_type=model_config['norm_layer']))
    elif generator_name == 'fpn_dense':
        model_g = FPNDense()
    elif generator_name == 'unet_seresnext':
        model_g = UNetSEResNext(
            norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
            pretrained=model_config['pretrained'])
    else:
        raise ValueError("Generator Network [%s] not recognized." %
                         generator_name)

    return nn.DataParallel(model_g)
def get_generator(model_config):
    generator_name = model_config['g_name']
    if generator_name == 'resnet':
        model_g = ResnetGenerator(
            norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
            use_dropout=model_config['dropout'],
            n_blocks=model_config['blocks'],
            learn_residual=model_config['learn_residual'])
    elif generator_name == 'fpn_mobilenet':
        model_g = FPNMobileNet(norm_layer=get_norm_layer(
            norm_type=model_config['norm_layer']))
    elif generator_name == 'fpn_inception':
        model_g = FPNInception(norm_layer=get_norm_layer(
            norm_type=model_config['norm_layer']))
    elif generator_name == 'fpn_inception_simple':
        model_g = FPNInceptionSimple(norm_layer=get_norm_layer(
            norm_type=model_config['norm_layer']))
    elif generator_name == 'fpn_dense':
        model_g = FPNDense()
    elif generator_name == 'unet_seresnext':
        model_g = UNetSEResNext(
            norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
            pretrained=model_config['pretrained'])
    elif generator_name == 'mirnet':
        model_g = MIRNet(in_channels=3,
                         out_channels=3,
                         n_feat=32,
                         kernel_size=3,
                         stride=2,
                         n_RRG=3,
                         n_MSRB=2,
                         height=3,
                         width=2,
                         bias=False)
    else:
        raise ValueError("Generator Network [%s] not recognized." %
                         generator_name)

    return nn.DataParallel(model_g)