def get_generator(model_config): generator_name = model_config['g_name'] if generator_name == 'resnet': model_g = ResnetGenerator( norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), use_dropout=model_config['dropout'], n_blocks=model_config['blocks'], learn_residual=model_config['learn_residual']) elif generator_name == 'fpn_mobilenet': model_g = FPNMobileNet(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_inception': model_g = FPNInception(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_inception_simple': model_g = FPNInceptionSimple(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_dense': model_g = FPNDense() elif generator_name == 'unet_seresnext': model_g = UNetSEResNext( norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), pretrained=model_config['pretrained']) else: raise ValueError("Generator Network [%s] not recognized." % generator_name) return nn.DataParallel(model_g)
def get_generator(model_config): generator_name = model_config['g_name'] if generator_name == 'resnet': model_g = ResnetGenerator( norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), use_dropout=model_config['dropout'], n_blocks=model_config['blocks'], learn_residual=model_config['learn_residual']) elif generator_name == 'fpn_mobilenet': model_g = FPNMobileNet(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_inception': model_g = FPNInception(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_inception_simple': model_g = FPNInceptionSimple(norm_layer=get_norm_layer( norm_type=model_config['norm_layer'])) elif generator_name == 'fpn_dense': model_g = FPNDense() elif generator_name == 'unet_seresnext': model_g = UNetSEResNext( norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), pretrained=model_config['pretrained']) elif generator_name == 'mirnet': model_g = MIRNet(in_channels=3, out_channels=3, n_feat=32, kernel_size=3, stride=2, n_RRG=3, n_MSRB=2, height=3, width=2, bias=False) else: raise ValueError("Generator Network [%s] not recognized." % generator_name) return nn.DataParallel(model_g)