コード例 #1
0
ファイル: encoder_simple.py プロジェクト: yanzhicong/VAE-GAN
    def __call__(self, i, condition=None):

        output_dims = self.config.get("output dims", 3)
        output_act_fn = get_activation(
            self.config.get('output_activation', 'none'))

        x, end_points = self.network(i)

        x = tcl.flatten(x)
        if condition is not None:
            x = tf.concatenate([x, condition], axis=-1)

        with tf.variable_scope(self.name):
            if self.reuse:
                tf.get_variable_scope().reuse_variables()
            else:
                assert tf.get_variable_scope().reuse is False
                self.reuse = True

            if self.output_distribution == 'gaussian':
                mean = self.fc('fc_out_mean', x, output_dims,
                               **self.out_fc_args)
                log_var = self.fc('fc_out_log_var', x, output_dims,
                                  **self.out_fc_args)
                return mean, log_var

            elif self.output_distribution == 'mean':
                mean = self.fc('fc_out_mean', x, output_dims,
                               **self.out_fc_args)
                return mean

            elif self.output_distribution == 'none':
                out = self.fc('fc_out_mean', x, output_dims,
                              **self.out_fc_args)
                return out
            else:
                raise Exception("None output distribution named " +
                                self.output_distribution)
コード例 #2
0
    def __call__(self, i, reuse=False):

        act_fn = get_activation(self.config.get('activation', 'relu'),
                                self.config.get('activation_params', {}))

        norm_fn, norm_params = get_normalization(
            self.config.get('batch_norm', 'batch_norm'),
            self.config.get('batch_norm_params', self.normalizer_params))

        winit_fn = get_weightsinit(
            self.config.get('weightsinit', 'normal'),
            self.config.get('weightsinit_params', '0.00 0.02'))

        filters = self.config.get('nb_filters', 32)

        # fully connected parameters
        including_top = self.config.get("including top", True)
        nb_fc_nodes = self.config.get('nb_fc_nodes', [1024, 1024])

        # output stage parameters
        output_dims = self.config.get("output dims",
                                      0)  # zero for no output layer
        output_act_fn = get_activation(
            self.config.get('output_activation', 'none'),
            self.config.get('output_activation_params', ''))

        with tf.variable_scope(self.name):
            if reuse:
                tf.get_variable_scope().reuse_variables()
            else:
                assert tf.get_variable_scope().reuse is False

            end_points = {}

            # x : 299 * 299 * 3
            x = tcl.conv2d(i,
                           filters,
                           3,
                           stride=2,
                           activation_fn=act_fn,
                           normalizer_fn=norm_fn,
                           normalizer_params=norm_params,
                           padding='VALID',
                           weights_initializer=winit_fn,
                           scope='conv1')
            end_points['conv1'] = x

            # x : 149 * 149 * 32
            x = tcl.conv2d(x,
                           filters,
                           3,
                           stride=1,
                           activation_fn=act_fn,
                           normalizer_fn=norm_fn,
                           normalizer_params=norm_params,
                           padding='VALID',
                           weights_initializer=winit_fn,
                           scope='conv2')
            end_points['conv2'] = x

            # x : 147 * 147 * 32
            x = tcl.conv2d(x,
                           2 * filters,
                           3,
                           stride=1,
                           activation_fn=act_fn,
                           normalizer_fn=norm_fn,
                           normalizer_params=norm_params,
                           padding='SAME',
                           weights_initializer=winit_fn,
                           scope='conv3')
            end_points['conv3'] = x

            # x : 147 * 147 * 64
            x = tcl.max_pool2d(x, 3, stride=2, padding='VALID', scope='pool1')
            end_points['pool1'] = x

            # x : 73 * 73 * 64
            x = tcl.conv2d(x,
                           80,
                           3,
                           stride=1,
                           activation_fn=act_fn,
                           normalizer_fn=norm_fn,
                           normalizer_params=norm_params,
                           padding='VALID',
                           weights_initializer=winit_fn,
                           scope='conv4')
            end_points['conv4'] = x

            # x : 71 * 71 * 80
            x = tcl.conv2d(x,
                           192,
                           3,
                           stride=2,
                           activation_fn=act_fn,
                           normalizer_fn=norm_fn,
                           normalizer_params=norm_params,
                           padding='VALID',
                           weights_initializer=winit_fn,
                           scope='conv5')
            end_points['conv5'] = x

            # x : 35 * 35 * 192
            x = tcl.conv2d(x,
                           288,
                           3,
                           stride=1,
                           activation_fn=act_fn,
                           normalizer_fn=norm_fn,
                           normalizer_params=norm_params,
                           padding='SAME',
                           weights_initializer=winit_fn,
                           scope='conv6')
            end_points['conv6'] = x

            # x : 35 * 35 * 288
            x, end_points = inception_v3_figure5('inception1a',
                                                 x,
                                                 end_points,
                                                 act_fn=act_fn,
                                                 norm_fn=norm_fn,
                                                 norm_params=norm_params,
                                                 winit_fn=winit_fn)
            x, end_points = inception_v3_figure5('inception1b',
                                                 x,
                                                 end_points,
                                                 act_fn=act_fn,
                                                 norm_fn=norm_fn,
                                                 norm_params=norm_params,
                                                 winit_fn=winit_fn)
            x, end_points = inception_v3_figure5_downsample(
                'inception1c',
                x,
                end_points,
                act_fn=act_fn,
                norm_fn=norm_fn,
                norm_params=norm_params,
                winit_fn=winit_fn)

            # x : 17 * 17 * 768
            x, end_points = inception_v3_figure6('inception2a',
                                                 x,
                                                 end_points,
                                                 n=7,
                                                 act_fn=act_fn,
                                                 norm_fn=norm_fn,
                                                 norm_params=norm_params,
                                                 winit_fn=winit_fn)
            x, end_points = inception_v3_figure6('inception2b',
                                                 x,
                                                 end_points,
                                                 n=7,
                                                 act_fn=act_fn,
                                                 norm_fn=norm_fn,
                                                 norm_params=norm_params,
                                                 winit_fn=winit_fn)
            x, end_points = inception_v3_figure6('inception2c',
                                                 x,
                                                 end_points,
                                                 n=7,
                                                 act_fn=act_fn,
                                                 norm_fn=norm_fn,
                                                 norm_params=norm_params,
                                                 winit_fn=winit_fn)
            x, end_points = inception_v3_figure6('inception2d',
                                                 x,
                                                 end_points,
                                                 n=7,
                                                 act_fn=act_fn,
                                                 norm_fn=norm_fn,
                                                 norm_params=norm_params,
                                                 winit_fn=winit_fn)
            x, end_points = inception_v3_figure6_downsample(
                'inception2e',
                x,
                end_points,
                n=7,
                act_fn=act_fn,
                norm_fn=norm_fn,
                norm_params=norm_params,
                winit_fn=winit_fn)

            # x : 8 * 8 * 1280
            x, end_points = inception_v3_figure7('inception3a',
                                                 x,
                                                 end_points,
                                                 act_fn=act_fn,
                                                 norm_fn=norm_fn,
                                                 norm_params=norm_params,
                                                 winit_fn=winit_fn)
            x, end_points = inception_v3_figure7('inception3b',
                                                 x,
                                                 end_points,
                                                 act_fn=act_fn,
                                                 norm_fn=norm_fn,
                                                 norm_params=norm_params,
                                                 winit_fn=winit_fn)

            # construct top fully connected layer
            if including_top:
                with tf.variable_scope("logits"):
                    x = tcl.avg_pool2d(x,
                                       kernel_size=[8, 8],
                                       padding="VALID",
                                       scope="avgpool_1a_8x8")
                    x = tcl.dropout(x, keep_prob=0.5, scope="dropout_1b")
                    end_points["global_avg_pooling"] = x
                    x = tcl.flatten(x)

                    if output_dims != 0:
                        x = tcl.fully_connected(x,
                                                output_dims,
                                                activation_fn=output_act_fn,
                                                normalizer_fn=None,
                                                weights_initializer=winit_fn,
                                                scope='fc_out')
                        end_points['fc_out'] = x

            return x, end_points
コード例 #3
0
ファイル: base_network.py プロジェクト: yanzhicong/VAE-GAN
 def activation(self, x, act_fn='relu'):
     if not callable(act_fn):
         act_fn = get_activation(act_fn)
     return act_fn(x)
コード例 #4
0
ファイル: base_network.py プロジェクト: yanzhicong/VAE-GAN
    def fc(self,
           name,
           x,
           nb_nodes,
           *,
           norm_fn='none',
           norm_params=None,
           act_fn='none',
           winit_fn='xavier',
           binit_fn='zeros',
           has_bias=True,
           disp=True,
           collect_end_points=True):

        if callable(act_fn):
            act_fn_str = 'func'
            act_fn = act_fn
        else:
            act_fn_str = self.config.get(name + ' activation', act_fn)
            act_fn = get_activation(act_fn_str)

        if callable(norm_fn):
            norm_fn_str = 'func'
            norm_fn = norm_fn
        else:
            norm_fn_str = self.config.get(name + ' normalization', norm_fn)
            norm_fn = get_normalization(norm_fn_str)

        winit_fn_str = self.config.get(name + ' weightsinit', winit_fn)
        if 'special' in winit_fn_str:
            split = winit_fn_str.split()
            winit_name = split[0]
            if winit_name == 'glorot_uniform':
                input_nb_nodes = int(x.get_shape()[-1])
                filters_stdev = np.sqrt(2.0 / (input_nb_nodes + nb_nodes))
                winit_fn = self.uniform_initializer(filters_stdev)
            else:
                raise Exception('Error weights initializer function name : ' +
                                winit_fn_str)
        else:
            winit_fn = get_weightsinit(winit_fn_str)

        binit_fn_str = self.config.get(name + ' biasesinit', binit_fn)
        binit_fn = get_weightsinit(binit_fn_str)

        if self.using_tcl_library:
            x = tcl.fully_connected(x,
                                    nb_nodes,
                                    activation_fn=act_fn,
                                    normalizer_fn=norm_fn,
                                    normalizer_params=norm_params,
                                    weights_initializer=winit_fn,
                                    scope=name)
        else:
            x = tl.dense(x,
                         nb_nodes,
                         use_bias=has_bias,
                         kernel_initializer=winit_fn,
                         bias_initializer=binit_fn,
                         trainable=True,
                         name=name)

            with tf.variable_scope(name):
                if norm_fn is not None:
                    norm_params = norm_params or {}
                    x = norm_fn(x, **norm_params)
                if act_fn is not None:
                    x = act_fn(x)

        if disp:
            print('\t\tFC(' + str(name) + ') --> ', x.get_shape(), '  ',
                  (act_fn_str, norm_fn_str, winit_fn_str))
        if collect_end_points:
            self.end_points[name] = x
        return x
コード例 #5
0
ファイル: base_network.py プロジェクト: yanzhicong/VAE-GAN
    def deconv2d(self,
                 name,
                 x,
                 nb_filters,
                 ksize,
                 stride,
                 *,
                 norm_fn='none',
                 norm_params=None,
                 act_fn='relu',
                 winit_fn='xavier',
                 binit_fn='zeros',
                 padding='SAME',
                 has_bias=True,
                 disp=True,
                 collect_end_points=True):

        if callable(act_fn):
            act_fn_str = 'func'
            act_fn = act_fn
        else:
            act_fn_str = self.config.get(name + ' activation', act_fn)
            act_fn = get_activation(act_fn_str)

        if callable(norm_fn):
            norm_fn_str = 'func'
            norm_fn = norm_fn
        else:
            norm_fn_str = self.config.get(name + ' normalization', norm_fn)
            norm_fn = get_normalization(norm_fn_str)

        winit_fn_str = self.config.get(name + ' weightsinit', winit_fn)
        if 'special' in winit_fn_str:
            split = winit_fn_str.split()
            winit_name = split[0]
            if winit_name == 'he_uniform':
                input_nb_filters = int(x.get_shape()[-1])
                fan_in = input_nb_filters * (ksize**2) / (stride**2)
                fan_out = nb_filters * (ksize**2)
                filters_stdev = np.sqrt(4.0 / (fan_in + fan_out))
                winit_fn = self.uniform_initializer(filters_stdev)
            else:
                raise Exception('Error weights initializer function name : ' +
                                winit_fn_str)
        else:
            winit_fn = get_weightsinit(winit_fn_str)
        binit_fn_str = self.config.get(name + ' biasesinit', binit_fn)
        binit_fn = get_weightsinit(binit_fn_str)
        _padding = self.config.get(name + ' padding', padding)

        if self.using_tcl_library:
            x = tcl.conv2d_transpose(x,
                                     nb_filters,
                                     ksize,
                                     stride=stride,
                                     use_bias=True,
                                     activation_fn=act_fn,
                                     normalizer_fn=norm_fn,
                                     normalizer_params=norm_params,
                                     weights_initializer=winit_fn,
                                     padding=_padding,
                                     scope=name)
        else:
            x = tl.conv2d_transpose(x,
                                    nb_filters,
                                    ksize,
                                    strides=stride,
                                    padding=_padding,
                                    use_bias=has_bias,
                                    kernel_initializer=winit_fn,
                                    bias_initializer=binit_fn,
                                    trainable=True,
                                    name=name)
            with tf.variable_scope(name):
                if norm_fn is not None:
                    norm_params = norm_params or {}
                    x = norm_fn(x, **norm_params)
                if act_fn is not None:
                    x = act_fn(x)

        if disp:
            print('\t\tDeonv2D(' + str(name) + ') --> ', x.get_shape(), '  ',
                  (act_fn_str, norm_fn_str, winit_fn_str, _padding))
        if collect_end_points:
            self.end_points[name] = x
        return x
コード例 #6
0
ファイル: encoder_cifar10.py プロジェクト: yanzhicong/VAE-GAN
    def __call__(self, i, condition=None):

        act_fn = get_activation(self.config.get('activation', 'relu'))

        norm_fn, norm_params = get_normalization(
            self.config.get('batch_norm', 'batch_norm'),
            self.config.get('batch_norm_params', self.normalizer_params))

        winit_fn = get_weightsinit(
            self.config.get('weightsinit', 'normal 0.00 0.02'))

        nb_fc_nodes = self.config.get('nb_fc_nodes', [1024, 1024])

        output_dims = self.config.get("output dims", 3)
        output_act_fn = get_activation(
            self.config.get('output_activation', 'none'))

        x, end_points = self.network(i)

        x = tcl.flatten(x)
        if condition is not None:
            x = tf.concatenate([x, condition], axis=-1)

        with tf.variable_scope(self.name):
            if self.reuse:
                tf.get_variable_scope().reuse_variables()
            else:
                assert tf.get_variable_scope().reuse is False
                self.reuse = True

            for ind, nb_nodes in enumerate(nb_fc_nodes):
                x = tcl.fully_connected(x,
                                        nb_nodes,
                                        activation_fn=act_fn,
                                        normalizer_fn=norm_fn,
                                        normalizer_params=norm_params,
                                        weights_initializer=winit_fn,
                                        scope='fc%d' % ind)

            if self.output_distribution == 'gaussian':
                mean = tcl.fully_connected(x,
                                           output_dims,
                                           activation_fn=output_act_fn,
                                           weights_initializer=winit_fn,
                                           scope='fc_out_mean')
                log_var = tcl.fully_connected(x,
                                              output_dims,
                                              activation_fn=output_act_fn,
                                              weights_initializer=winit_fn,
                                              scope='fc_out_log_var')
                return mean, log_var

            elif self.output_distribution == 'mean':
                mean = tcl.fully_connected(x,
                                           output_dims,
                                           activation_fn=output_act_fn,
                                           weights_initializer=winit_fn,
                                           scope='fc_out_mean')
                return mean

            elif self.output_distribution == 'none':
                out = tcl.fully_connected(x,
                                          output_dims,
                                          activation_fn=output_act_fn,
                                          weights_initializer=winit_fn,
                                          scope='fc_out_mean')
                return out
            else:
                raise Exception("None output distribution named " +
                                self.output_distribution)
コード例 #7
0
ファイル: unet.py プロジェクト: yanzhicong/VAE-GAN
    def __call__(self, x):

        conv_nb_blocks = self.config.get("conv nb blocks", 4)
        conv_nb_layers = self.config.get("conv nb layers", [2, 2, 2, 2, 2])
        conv_nb_filters = self.config.get("conv nb filters",
                                          [64, 128, 256, 512, 1024])
        conv_ksize = self.config.get("conv ksize",
                                     [3 for i in range(conv_nb_blocks + 1)])
        no_maxpooling = self.config.get("no maxpooling", False)
        no_upsampling = self.config.get('no_upsampling', False)

        output_dims = self.config.get("output dims",
                                      0)  # zero for no output layer
        output_act_fn = get_activation(
            self.config.get('output_activation', 'none'))

        debug = self.config.get('debug', False)

        with tf.variable_scope(self.name):
            if self.reuse:
                tf.get_variable_scope().reuse_variables()
            else:
                assert tf.get_variable_scope().reuse is False
                self.reuse = True

            block_end = {}
            self.end_points = {}

            if debug:
                print('UNet : (' + str(self.name) + ')')
                print('downsapmle network : ')

            for block_ind in range(conv_nb_blocks):
                for layer_ind in range(conv_nb_layers[block_ind]):

                    conv_name = 'conv_ds%d_%d' % (block_ind + 1, layer_ind)
                    maxpool_name = 'maxpool_ds%d' % (block_ind + 1)

                    if layer_ind == conv_nb_layers[block_ind] - 1:
                        if no_maxpooling:
                            block_end[block_ind] = x
                            x = self.conv2d(conv_name,
                                            x,
                                            conv_nb_filters[block_ind],
                                            conv_ksize[block_ind],
                                            stride=2,
                                            **self.conv_args,
                                            disp=debug)
                        else:
                            x = self.conv2d(conv_name,
                                            x,
                                            conv_nb_filters[block_ind],
                                            conv_ksize[block_ind],
                                            stride=1,
                                            **self.conv_args,
                                            disp=debug)
                            block_end[block_ind] = x
                            x = self.maxpool2d(maxpool_name,
                                               x,
                                               2,
                                               stride=2,
                                               padding='SAME')
                    else:
                        x = self.conv2d(conv_name,
                                        x,
                                        conv_nb_filters[block_ind],
                                        conv_ksize[block_ind],
                                        stride=1,
                                        **self.conv_args,
                                        disp=debug)

            if debug:
                print('bottleneck network : ')
            for layer_ind in range(conv_nb_layers[conv_nb_blocks]):
                conv_name = 'conv_bn%d' % (layer_ind)
                x = self.conv2d(conv_name,
                                x,
                                conv_nb_filters[conv_nb_blocks],
                                conv_ksize[conv_nb_blocks],
                                stride=1,
                                **self.conv_args,
                                disp=debug)

            if debug:
                print('upsample network : ')

            for block_ind in range(conv_nb_blocks)[::-1]:
                for layer_ind in range(conv_nb_layers[block_ind]):
                    deconv_name = 'deconv_us%d' % block_ind
                    conv_name = 'conv_us%d_%d' % (block_ind + 1, layer_ind)
                    if layer_ind == 0:
                        x = self.deconv2d(deconv_name,
                                          x,
                                          conv_nb_filters[block_ind],
                                          conv_ksize[block_ind],
                                          stride=2,
                                          **self.conv_args,
                                          disp=debug)
                        x = self.concat(deconv_name + '_concat',
                                        [x, block_end[block_ind]])
                        x = self.conv2d(conv_name,
                                        x,
                                        conv_nb_filters[block_ind],
                                        conv_ksize[block_ind],
                                        stride=1,
                                        **self.conv_args,
                                        disp=debug)
                    else:
                        x = self.conv2d(conv_name,
                                        x,
                                        conv_nb_filters[block_ind],
                                        conv_ksize[block_ind],
                                        stride=1,
                                        **self.conv_args,
                                        disp=debug)

            if debug:
                print('output network : ')
            if output_dims != 0:
                x = self.conv2d('conv_out',
                                x,
                                output_dims,
                                1,
                                stride=1,
                                **self.out_conv_args,
                                disp=debug)

            if debug:
                print('')

            return x, self.end_points