def q_deblurGAN(img, theta, use_gpu=False, weights="./weights/fpn_inception.h5"): blurred_field = blurred_field_cal(img, theta) img_input = np.concatenate([img, blurred_field], axis=2) img_input = normal_img(img_input, mean=0.5, std=0.5) img_input = np.transpose(img_input, (2, 0, 1)) img_input = torch.tensor(img_input, dtype=torch.float32) img_input = img_input.unsqueeze(dim=0) norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) model = FPNInception(norm_layer=norm_layer) model.train() if use_gpu: img_input = img_input.cuda() model.cuda() model.load_state_dict({ k.replace('module.', ''): v for k, v in torch.load(weights)["model"].items() }) out = model(img_input) out = post_process(out) return out
def get_generator(model_config): generator_name = model_config['g_name'] if generator_name == 'resnet': model_g = ResnetGenerator( norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), use_dropout=model_config['dropout'], n_blocks=model_config['blocks'], learn_residual=model_config['learn_residual']) elif generator_name == 'fpn_mobilenet': model_g = FPNMobileNet(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_inception': model_g = FPNInception(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_inception_simple': model_g = FPNInceptionSimple(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_dense': model_g = FPNDense() elif generator_name == 'unet_seresnext': model_g = UNetSEResNext( norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), pretrained=model_config['pretrained']) else: raise ValueError("Generator Network [%s] not recognized." % generator_name) return nn.DataParallel(model_g)
def get_generator(model_config): generator_name = model_config['g_name'] if generator_name == 'resnet': model_g = ResnetGenerator( norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), use_dropout=model_config['dropout'], n_blocks=model_config['blocks'], learn_residual=model_config['learn_residual']) elif generator_name == 'fpn_mobilenet': model_g = FPNMobileNet(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_inception': model_g = FPNInception(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_inception_simple': model_g = FPNInceptionSimple(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_dense': model_g = FPNDense() elif generator_name == 'unet_seresnext': model_g = UNetSEResNext( norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), pretrained=model_config['pretrained']) elif generator_name == 'mirnet': model_g = MIRNet(in_channels=3, out_channels=3, n_feat=32, kernel_size=3, stride=2, n_RRG=3, n_MSRB=2, height=3, width=2, bias=False) else: raise ValueError("Generator Network [%s] not recognized." % generator_name) return nn.DataParallel(model_g)