Ejemplo n.º 1
0
 def _is_in(self, param_group, param_group_list):
     assert is_list_of(param_group_list, dict)
     param = set(param_group['params'])
     param_set = set()
     for group in param_group_list:
         param_set.update(set(group['params']))
     return (not param.isdisjoint(param_set))
Ejemplo n.º 2
0
    def __init__(self,
                 output_scale,
                 num_classes=0,
                 base_channels=64,
                 out_channels=3,
                 input_scale=4,
                 noise_size=128,
                 attention_cfg=dict(type='SelfAttentionBlock'),
                 attention_after_nth_block=0,
                 channels_cfg=None,
                 blocks_cfg=dict(type='SNGANGenResBlock'),
                 act_cfg=dict(type='ReLU'),
                 use_cbn=True,
                 auto_sync_bn=True,
                 with_spectral_norm=False,
                 with_embedding_spectral_norm=None,
                 norm_eps=1e-4,
                 sn_eps=1e-12,
                 init_cfg=dict(type='BigGAN'),
                 pretrained=None):

        super().__init__()

        self.input_scale = input_scale
        self.output_scale = output_scale
        self.noise_size = noise_size
        self.num_classes = num_classes
        self.init_type = init_cfg.get('type', None)

        self.blocks_cfg = deepcopy(blocks_cfg)

        self.blocks_cfg.setdefault('num_classes', num_classes)
        self.blocks_cfg.setdefault('act_cfg', act_cfg)
        self.blocks_cfg.setdefault('use_cbn', use_cbn)
        self.blocks_cfg.setdefault('auto_sync_bn', auto_sync_bn)
        self.blocks_cfg.setdefault('with_spectral_norm', with_spectral_norm)

        # set `norm_spectral_norm` as `with_spectral_norm` if not defined
        with_embedding_spectral_norm = with_embedding_spectral_norm \
            if with_embedding_spectral_norm is not None else with_spectral_norm
        self.blocks_cfg.setdefault('with_embedding_spectral_norm',
                                   with_embedding_spectral_norm)
        self.blocks_cfg.setdefault('init_cfg', init_cfg)
        self.blocks_cfg.setdefault('norm_eps', norm_eps)
        self.blocks_cfg.setdefault('sn_eps', sn_eps)

        channels_cfg = deepcopy(self._default_channels_cfg) \
            if channels_cfg is None else deepcopy(channels_cfg)
        if isinstance(channels_cfg, dict):
            if output_scale not in channels_cfg:
                raise KeyError(f'`output_scale={output_scale} is not found in '
                               '`channel_cfg`, only support configs for '
                               f'{[chn for chn in channels_cfg.keys()]}')
            self.channel_factor_list = channels_cfg[output_scale]
        elif isinstance(channels_cfg, list):
            self.channel_factor_list = channels_cfg
        else:
            raise ValueError('Only support list or dict for `channel_cfg`, '
                             f'receive {type(channels_cfg)}')

        self.noise2feat = nn.Linear(
            noise_size,
            input_scale**2 * base_channels * self.channel_factor_list[0])
        if with_spectral_norm:
            self.noise2feat = spectral_norm(self.noise2feat)

        # check `attention_after_nth_block`
        if not isinstance(attention_after_nth_block, list):
            attention_after_nth_block = [attention_after_nth_block]
        if not is_list_of(attention_after_nth_block, int):
            raise ValueError('`attention_after_nth_block` only support int or '
                             'a list of int. Please check your input type.')

        self.conv_blocks = nn.ModuleList()
        self.attention_block_idx = []
        for idx in range(len(self.channel_factor_list)):
            factor_input = self.channel_factor_list[idx]
            factor_output = self.channel_factor_list[idx+1] \
                if idx < len(self.channel_factor_list)-1 else 1

            # get block-specific config
            block_cfg_ = deepcopy(self.blocks_cfg)
            block_cfg_['in_channels'] = factor_input * base_channels
            block_cfg_['out_channels'] = factor_output * base_channels
            self.conv_blocks.append(build_module(block_cfg_))

            # build self-attention block
            # `idx` is start from 0, add 1 to get the index
            if idx + 1 in attention_after_nth_block:
                self.attention_block_idx.append(len(self.conv_blocks))
                attn_cfg_ = deepcopy(attention_cfg)
                attn_cfg_['in_channels'] = factor_output * base_channels
                self.conv_blocks.append(build_module(attn_cfg_))

        to_rgb_norm_cfg = dict(type='BN', eps=norm_eps)
        if check_dist_init() and auto_sync_bn:
            to_rgb_norm_cfg['type'] = 'SyncBN'

        self.to_rgb = ConvModule(factor_output * base_channels,
                                 out_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 bias=True,
                                 norm_cfg=to_rgb_norm_cfg,
                                 act_cfg=act_cfg,
                                 order=('norm', 'act', 'conv'),
                                 with_spectral_norm=with_spectral_norm)
        self.final_act = build_activation_layer(dict(type='Tanh'))

        self.init_weights(pretrained)
Ejemplo n.º 3
0
def sample_conditional_model(model,
                             num_samples=16,
                             num_batches=4,
                             sample_model='ema',
                             label=None,
                             **kwargs):
    """Sampling from conditional models.

    Args:
        model (nn.Module): Conditional models in MMGeneration.
        num_samples (int, optional): The total number of samples.
            Defaults to 16.
        num_batches (int, optional): The number of batch size for inference.
            Defaults to 4.
        sample_model (str, optional): Which model you want to use. ['ema',
            'orig']. Defaults to 'ema'.
        label (int | torch.Tensor | list[int], optional): Labels used to
            generate images. Default to None.,

    Returns:
        Tensor: Generated image tensor.
    """
    # set eval mode
    model.eval()
    # construct sampling list for batches
    n_repeat = num_samples // num_batches
    batches_list = [num_batches] * n_repeat

    # check and convert the input labels
    if isinstance(label, int):
        label = torch.LongTensor([label] * num_samples)
    elif isinstance(label, torch.Tensor):
        label = label.type(torch.int64)
        if label.numel() == 1:
            # repeat single tensor
            # call view(-1) to avoid nested tensor like [[[1]]]
            label = label.view(-1).repeat(num_samples)
        else:
            # flatten multi tensors
            label = label.view(-1)
    elif isinstance(label, list):
        if is_list_of(label, int):
            label = torch.LongTensor(label)
            # `nargs='+'` parse single integer as list
            if label.numel() == 1:
                # repeat single tensor
                label = label.repeat(num_samples)
        else:
            raise TypeError('Only support `int` for label list elements, '
                            f'but receive {type(label[0])}')
    elif label is None:
        pass
    else:
        raise TypeError('Only support `int`, `torch.Tensor`, `list[int]` or '
                        f'None as label, but receive {type(label)}.')

    # check the length of the (converted) label
    if label is not None and label.size(0) != num_samples:
        raise ValueError('Number of elements in the label list should be ONE '
                         'or the length of `num_samples`. Requires '
                         f'{num_samples}, but receive {label.size(0)}.')

    # make label list
    label_list = []
    for n in range(n_repeat):
        if label is None:
            label_list.append(None)
        else:
            label_list.append(label[n * num_batches:(n + 1) * num_batches])

    if num_samples % num_batches > 0:
        batches_list.append(num_samples % num_batches)
        if label is None:
            label_list.append(None)
        else:
            label_list.append(label[(n + 1) * num_batches:])

    res_list = []

    # inference
    for batches, labels in zip(batches_list, label_list):
        res = model.sample_from_noise(None,
                                      num_batches=batches,
                                      label=labels,
                                      sample_model=sample_model,
                                      **kwargs)
        res_list.append(res.cpu())

    results = torch.cat(res_list, dim=0)

    return results