def create_base_transform(i): if args.base_transform_type == 'rq': return transforms.PiecewiseRationalQuadraticCouplingTransform( mask=utils.create_alternating_binary_mask(features=dim, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=args.hidden_features, num_blocks=args.num_transform_blocks, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm), num_bins=args.num_bins, apply_unconditional_transform=False, ) elif args.base_transform_type == 'affine': return transforms.AffineCouplingTransform( mask=utils.create_alternating_binary_mask(features=dim, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=args.hidden_features, num_blocks=args.num_transform_blocks, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm)) else: raise ValueError
def create_base_transform(i): if args.base_transform_type == 'affine': return transforms.AffineCouplingTransform( mask=utils.create_alternating_binary_mask(features=dim, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=32, num_blocks=2, use_batch_norm=True)) else: return transforms.PiecewiseRationalQuadraticCouplingTransform( mask=utils.create_alternating_binary_mask(features=dim, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=32, num_blocks=2, use_batch_norm=True), tails='linear', tail_bound=5, num_bins=args.num_bins, apply_unconditional_transform=False)
def create_base_transform(i, context_features=None): if args.prior_type == 'affine-coupling': return transforms.AffineCouplingTransform( mask=utils.create_alternating_binary_mask( features=args.latent_features, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm)) elif args.prior_type == 'rq-coupling': return transforms.PiecewiseRationalQuadraticCouplingTransform( mask=utils.create_alternating_binary_mask( features=args.latent_features, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm), num_bins=args.num_bins, tails='linear', tail_bound=args.tail_bound, apply_unconditional_transform=args. apply_unconditional_transform, ) elif args.prior_type == 'affine-autoregressive': return transforms.MaskedAffineAutoregressiveTransform( features=args.latent_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm) elif args.prior_type == 'rq-autoregressive': return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform( features=args.latent_features, hidden_features=args.hidden_features, context_features=context_features, num_bins=args.num_bins, tails='linear', tail_bound=args.tail_bound, num_blocks=args.num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm) else: raise ValueError
def create_transform_step(num_channels, hidden_channels, actnorm, coupling_layer_type, spline_params, use_resnet, num_res_blocks, resnet_batchnorm, dropout_prob): if use_resnet: def create_convnet(in_channels, out_channels): net = nn_.ConvResidualNet(in_channels=in_channels, out_channels=out_channels, hidden_channels=hidden_channels, num_blocks=num_res_blocks, use_batch_norm=resnet_batchnorm, dropout_probability=dropout_prob) return net else: if dropout_prob != 0.: raise ValueError() def create_convnet(in_channels, out_channels): return ConvNet(in_channels, hidden_channels, out_channels) mask = utils.create_mid_split_binary_mask(num_channels) if coupling_layer_type == 'cubic_spline': coupling_layer = transforms.PiecewiseCubicCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, tails='linear', tail_bound=spline_params['tail_bound'], num_bins=spline_params['num_bins'], apply_unconditional_transform=spline_params[ 'apply_unconditional_transform'], min_bin_width=spline_params['min_bin_width'], min_bin_height=spline_params['min_bin_height']) elif coupling_layer_type == 'quadratic_spline': coupling_layer = transforms.PiecewiseQuadraticCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, tails='linear', tail_bound=spline_params['tail_bound'], num_bins=spline_params['num_bins'], apply_unconditional_transform=spline_params[ 'apply_unconditional_transform'], min_bin_width=spline_params['min_bin_width'], min_bin_height=spline_params['min_bin_height']) elif coupling_layer_type == 'rational_quadratic_spline': coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform( mask=mask, transform_net_create_fn=create_convnet, tails='linear', tail_bound=spline_params['tail_bound'], num_bins=spline_params['num_bins'], apply_unconditional_transform=spline_params[ 'apply_unconditional_transform'], min_bin_width=spline_params['min_bin_width'], min_bin_height=spline_params['min_bin_height'], min_derivative=spline_params['min_derivative']) elif coupling_layer_type == 'affine': coupling_layer = transforms.AffineCouplingTransform( mask=mask, transform_net_create_fn=create_convnet) elif coupling_layer_type == 'additive': coupling_layer = transforms.AdditiveCouplingTransform( mask=mask, transform_net_create_fn=create_convnet) else: raise RuntimeError('Unknown coupling_layer_type') step_transforms = [] if actnorm: step_transforms.append(transforms.ActNorm(num_channels)) step_transforms.extend( [transforms.OneByOneConvolution(num_channels), coupling_layer]) return transforms.CompositeTransform(step_transforms)