Пример #1
0
    def f(x, chosen, emotion):
        time_steps = int(x.get_shape()[1])

        # Shift target one note to the left.
        shift_chosen = Lambda(lambda x: tf.pad(
            x[:, :, :-1, :], tf.constant([[0, 0], [0, 0], [1, 0], [0, 0]])))(
                chosen)

        x = Concatenate(axis=3)([x, shift_chosen])

        for l in range(NOTE_AXIS_LAYERS):
            # Integrate emotion
            if l not in dense_layer_cache:
                dense_layer_cache[l] = Dense(int(x.get_shape()[3]))

            emotion_proj = dense_layer_cache[l](emotion)
            emotion_proj = TimeDistributed(
                RepeatVector(NUM_NOTES))(emotion_proj)
            emotion_proj = Activation('tanh')(emotion_proj)
            emotion_proj = Dropout(dropout)(emotion_proj)
            x = Add()([x, emotion_proj])

            if l not in lstm_layer_cache:
                lstm_layer_cache[l] = LSTM(NOTE_AXIS_UNITS,
                                           return_sequences=True)

            x = TimeDistributed(lstm_layer_cache[l])(x)
            x = Dropout(dropout)(x)

        return Concatenate()([note_dense(x), volume_dense(x)])
Пример #2
0
    def f(x, chosen, style):
        time_steps = int(x.get_shape()[1])

        # Shift target one note to the left.
        shift_chosen = Lambda(lambda x: tf.pad(x[:, :, :-1, :], [[0, 0], [0, 0], [1, 0], [0, 0]]))(chosen)

        # [batch, time, notes, 1]
        shift_chosen = Reshape((time_steps, NUM_NOTES, -1))(shift_chosen)
        # [batch, time, notes, features + 1]
        x = Concatenate(axis=3)([x, shift_chosen])
        
        for l in range(NOTE_AXIS_LAYERS):
            # Integrate style
            if l not in dense_layer_cache:
                dense_layer_cache[l] = Dense(int(x.get_shape()[3]))

            style_proj = dense_layer_cache[l](style)
            style_proj = TimeDistributed(RepeatVector(NUM_NOTES))(style_proj)
            style_proj = Activation('tanh')(style_proj)
            style_proj = Dropout(dropout)(style_proj)
            x = Add()([x, style_proj])

            if l not in lstm_layer_cache:
                lstm_layer_cache[l] = LSTM(NOTE_AXIS_UNITS, return_sequences=True)

            x = TimeDistributed(lstm_layer_cache[l])(x)
            x = Dropout(dropout)(x)

        return Concatenate()([note_dense(x), volume_dense(x)])
Пример #3
0
    def generate_critic(self, state_shape, action_size, optimizer, LEARNING_RATE):
        assert(len(state_shape)==3),"shape mismatch"
        nr_day, nr_feature, nr_seller = state_shape
        assert(action_size == nr_seller)
        inp = Input(shape=(nr_day, nr_feature, nr_seller))
        print("inp shape=",inp.shape)
        action =  Input(shape=(nr_seller,))

        #Similarly, first background part, ASSUME seller data ordered
        dsf_inp = Permute((1,3,2))(inp)
        reshape_inp = Reshape((nr_day,nr_seller * nr_feature))(dsf_inp)
        background_feat = Bidirectional(GRU(self.bh))(reshape_inp) # (batch, bh)
        repeated_background_feat = RepeatVector(nr_seller)(background_feat) #(batch, nr_seller, bh)
        #individual part
        sdf_inp = Permute((3,1,2))(inp)
        ind_model = self.generate_individual_model(nr_day, nr_feature)
        individual_feat = TimeDistributed(ind_model)(sdf_inp) #(batch, nr_seller, ih)

        eval_model  = self.generate_eval_model(2*(self.ih+self.bh)+1)
        concat_feat = Concatenate(axis=-1)([repeated_background_feat, individual_feat, Reshape((-1,1))(action)]) #(batch, nr_seller, ih+bh)
        print("concat_feat shape = ",concat_feat.get_shape())
        values = TimeDistributed(eval_model)(concat_feat) #(batch, nr_seller ,1)
        print("values shape = ",values.get_shape())
        flat_values = Reshape((-1,))(values)
        #then reduce to one value by addition
        value = ReduceSum(axis = 1)(flat_values) #(batch,)
        #value  = Dense(1, activation="linear")(flat_values)
        print("value shape = ",value.get_shape())
        model = Model(inputs=[inp,action], outputs=[value])
        opt = optimizer(LEARNING_RATE)
        model.compile(loss="mse", optimizer=opt)
        return model, action, inp
Пример #4
0
    def f(x, chosen, style):
        time_steps = int(x.get_shape()[1])

        # Shift target one note to the left.
        shift_chosen = Lambda(lambda x: tf.pad(x[:, :, :-1, :], [[0, 0], [0, 0], [1, 0], [0, 0]]))(chosen)

        # [batch, time, notes, 1]
        shift_chosen = Reshape((time_steps, NUM_NOTES, -1))(shift_chosen)
        # [batch, time, notes, features + 1]
        x = Concatenate(axis=3)([x, shift_chosen])

        for l in range(NOTE_AXIS_LAYERS):
            # Integrate style
            if l not in dense_layer_cache:
                dense_layer_cache[l] = Dense(int(x.get_shape()[3]))

            style_proj = dense_layer_cache[l](style)
            style_proj = TimeDistributed(RepeatVector(NUM_NOTES))(style_proj)
            style_proj = Activation('tanh')(style_proj)
            style_proj = Dropout(dropout)(style_proj)
            x = Add()([x, style_proj])

            if l not in lstm_layer_cache:
                lstm_layer_cache[l] = LSTM(NOTE_AXIS_UNITS, return_sequences=True)

            x = TimeDistributed(lstm_layer_cache[l])(x)
            x = Dropout(dropout)(x)

        return Concatenate()([note_dense(x), volume_dense(x)])
Пример #5
0
def _collect_inputs(ae_input_shapes, ae_input_names,
        conditioning_input_shapes, conditioning_input_names,):

    if not isinstance(ae_input_shapes, list):
        ae_input_shapes = [ae_input_shapes]

    if ae_input_names is None:
        ae_input_names = ['input_{}'.format(ii) for ii in range(len(ae_input_names))]

    ae_inputs = []
    for ii, input_shape in enumerate(ae_input_shapes):
        ae_inputs.append(Input(input_shape, name='input_{}'.format(ae_input_names[ii])))

    ae_stack = Concatenate(name='concat_inputs', axis=-1)(ae_inputs)
    ae_stack_shape = ae_stack.get_shape().as_list()[1:]

    # collect conditioning inputs, and concatentate them into a stack
    if not isinstance(conditioning_input_shapes, list):
        conditioning_input_shapes = [conditioning_input_shapes]
    if conditioning_input_names is None:
        conditioning_input_names = ['cond_input_{}'.format(ii) for ii in range(len(conditioning_input_shapes))]

    conditioning_inputs = []
    for ii, input_shape in enumerate(conditioning_input_shapes):
        conditioning_inputs.append(Input(input_shape, name=conditioning_input_names[ii]))

    cond_stack = Concatenate(name='concat_cond_inputs', axis=-1)(conditioning_inputs)
    return ae_inputs, ae_stack, conditioning_inputs, cond_stack
Пример #6
0
def transform_dense_encoder_model(
    input_shapes,
    input_names=None,
    latent_shape=(50, ),
    model_name='VTE_transform_encoder',
    enc_params=None,
):
    '''
    Generic encoder for a stack of inputs

    :param input_shape:
    :param latent_shape:
    :param model_name:
    :param enc_params:
    :return:
    '''
    if not isinstance(input_shapes, list):
        input_shapes = [input_shapes]

    if input_names is None:
        input_names = [
            'input_{}'.format(ii) for ii in range(len(input_shapes))
        ]

    inputs = []
    for ii, input_shape in enumerate(input_shapes):
        inputs.append(
            Input(input_shape, name='input_{}'.format(input_names[ii])))

    if len(inputs) > 1:
        inputs_stacked = Concatenate(name='concat_inputs', axis=-1)(inputs)
    else:
        inputs_stacked = inputs[0]
    input_stack_shape = inputs_stacked.get_shape().as_list()[1:]
    n_dims = len(input_stack_shape) - 1

    assert len(latent_shape) == 1
    latent_size = latent_shape[0]

    z = Dense(latent_shape * 2, name='dense_1')(inputs_stacked)
    z = LeakyReLU(0.2)(z)
    z = Dense(latent_shape * 2)(z)

    # the last layer in the basic encoder will be a convolution, so we should activate after it
    x_transform_enc = LeakyReLU(0.2)(z)
    x_transform_enc = Flatten()(x_transform_enc)

    z_mean = Dense(latent_size,
                   name='latent_mean',
                   kernel_initializer=keras_init.RandomNormal(
                       mean=0., stddev=0.00001))(x_transform_enc)
    z_logvar = Dense(
        latent_size,
        name='latent_logvar',
        bias_initializer=keras_init.RandomNormal(mean=-2., stddev=1e-10),
        kernel_initializer=keras_init.RandomNormal(mean=0., stddev=1e-10),
    )(x_transform_enc)

    return Model(inputs=inputs, outputs=[z_mean, z_logvar], name=model_name)
Пример #7
0
def _collect_inputs(
    ae_input_shapes,
    ae_input_names,
    conditioning_input_shapes,
    conditioning_input_names,
):

    ae_inputs = []
    ae_stack = None
    if ae_input_shapes is not None:
        if not isinstance(ae_input_shapes, list):
            ae_input_shapes = [ae_input_shapes]

        if ae_input_names is None:
            ae_input_names = [
                'input_{}'.format(ii) for ii in range(len(ae_input_names))
            ]

        for ii, input_shape in enumerate(ae_input_shapes):
            ae_inputs.append(
                Input(input_shape, name='input_{}'.format(ae_input_names[ii])))

        ae_stack = Concatenate(name='concat_inputs', axis=-1)(ae_inputs)
        ae_stack_shape = ae_stack.get_shape().as_list()[1:]

    # collect conditioning inputs, and concatentate them into a stack
    if not isinstance(conditioning_input_shapes, list):
        conditioning_input_shapes = [conditioning_input_shapes]
    if conditioning_input_names is None:
        conditioning_input_names = [
            'cond_input_{}'.format(ii)
            for ii in range(len(conditioning_input_shapes))
        ]

    conditioning_inputs = []
    for ii, input_shape in enumerate(conditioning_input_shapes):
        conditioning_inputs.append(
            Input(input_shape, name=conditioning_input_names[ii]))

    if len(conditioning_inputs) > 1:
        cond_stack = Concatenate(name='concat_cond_inputs',
                                 axis=-1)(conditioning_inputs)
    else:
        cond_stack = conditioning_inputs[0]
    return ae_inputs, ae_stack, conditioning_inputs, cond_stack
Пример #8
0
def transformer_concat_model(
    conditioning_input_shapes,
    conditioning_input_names=None,
    output_shape=None,
    model_name='CVAE_transformer',
    transform_latent_shape=(100, ),
    enc_params=None,
    condition_on_image=True,
    n_concat_scales=3,
    transform_activation=None,
    clip_output_range=None,
    source_input_idx=None,
):
    # collect conditioning inputs, and concatentate them into a stack
    if not isinstance(conditioning_input_shapes, list):
        conditioning_input_shapes = [conditioning_input_shapes]
    if conditioning_input_names is None:
        conditioning_input_names = [
            'cond_input_{}'.format(ii)
            for ii in range(len(conditioning_input_shapes))
        ]

    conditioning_inputs = []
    for ii, input_shape in enumerate(conditioning_input_shapes):
        conditioning_inputs.append(
            Input(input_shape, name=conditioning_input_names[ii]))

    if len(conditioning_inputs) > 1:
        conditioning_input_stack = Concatenate(name='concat_cond_inputs',
                                               axis=-1)(conditioning_inputs)
    else:
        conditioning_input_stack = conditioning_inputs[0]

    conditioning_input_shape = tuple(
        conditioning_input_stack.get_shape().as_list()[1:])
    n_dims = len(conditioning_input_shape) - 1

    # we will always give z as a flattened vector
    z_input = Input((np.prod(transform_latent_shape), ), name='z_input')

    # determine what we should apply the transformation to
    if source_input_idx is None:
        # the image we want to transform is exactly the input of the conditioning branch
        x_source = conditioning_input_stack
        source_input_shape = conditioning_input_shape
    else:
        # slice conditioning input to get the single source im that we will apply the transform to
        source_input_shape = conditioning_input_shapes[source_input_idx]
        x_source = conditioning_inputs[source_input_idx]

    # assume the output is going to be the transformed source input, so it should be the same shape
    if output_shape is None:
        output_shape = source_input_shape

    layer_prefix = 'color'
    decoder_output_shape = output_shape

    if condition_on_image:  # assume we always condition by concat since it's better than other forms
        # simply concatenate the conditioning stack (at various scales) with the decoder volumes
        include_fullres = True

        concat_decoder_outputs_with = [None] * len(enc_params['nf_dec'])
        concat_skip_sizes = [None] * len(enc_params['nf_dec'])

        # make sure x_I is the same shape as the output, including in the channels dimension
        if not np.all(output_shape <= conditioning_input_shape):
            tile_factor = [
                int(round(output_shape[i] / conditioning_input_shape[i]))
                for i in range(len(output_shape))
            ]
            print('Tile factor: {}'.format(tile_factor))
            conditioning_input_stack = Lambda(
                lambda x: tf.tile(x, [1] + tile_factor),
                name='lambda_tile_cond_input')(conditioning_input_stack)

        # downscale the conditioning inputs by the specified number of times
        xs_downscaled = [conditioning_input_stack]
        for si in range(n_concat_scales):
            curr_x_scaled = network_utils.Blur_Downsample(
                n_chans=conditioning_input_shape[-1],
                n_dims=n_dims,
                do_blur=True,
                name='downsample_scale-1/{}'.format(2**(si + 1)))(
                    xs_downscaled[-1])
            xs_downscaled.append(curr_x_scaled)

        if not include_fullres:
            xs_downscaled = xs_downscaled[1:]  # exclude the full-res volume

        print('Including downsampled input sizes {}'.format(
            [x.get_shape().as_list() for x in xs_downscaled]))

        # the smallest decoder volume will be the same as the smallest encoder volume, so we need to make sure we match volume sizes appropriately
        n_enc_scales = len(enc_params['nf_enc'])
        n_ds = len(xs_downscaled)
        concat_decoder_outputs_with[n_enc_scales - n_ds +
                                    1:n_enc_scales] = list(
                                        reversed(xs_downscaled))
        concat_skip_sizes[n_enc_scales - n_ds + 1:n_enc_scales] = list(
            reversed([
                np.asarray(x.get_shape().as_list()[1:-1])
                for x in xs_downscaled if x is not None
            ]))

    else:
        # just ignore the conditioning input
        concat_decoder_outputs_with = None
        concat_skip_sizes = None

    if 'ks' not in enc_params:
        enc_params['ks'] = 3

    # determine what size to reshape the latent vector to
    reshape_encoding_to = get_encoded_shape(
        img_shape=conditioning_input_shape,
        conv_chans=enc_params['nf_enc'],
    )

    x_enc = Dense(np.prod(reshape_encoding_to),
                  name='dense_encoding_to_vol')(z_input)
    x_enc = LeakyReLU(0.2)(x_enc)

    x_enc = Reshape(reshape_encoding_to)(x_enc)
    print('Decoder starting shape: {}'.format(reshape_encoding_to))

    x_transformation = decoder(
        x_enc,
        decoder_output_shape,
        encoded_shape=reshape_encoding_to,
        prefix='{}_dec'.format(layer_prefix),
        conv_chans=enc_params['nf_dec'],
        ks=enc_params['ks'] if 'ks' in enc_params else 3,
        n_convs_per_stage=enc_params['n_convs_per_stage']
        if 'n_convs_per_stage' in enc_params else 1,
        use_upsample=enc_params['use_upsample']
        if 'use_upsample' in enc_params else False,
        kernel_initializer=enc_params['kernel_initializer']
        if 'kernel_initializer' in enc_params else None,
        bias_initializer=enc_params['bias_initializer']
        if 'bias_initializer' in enc_params else None,
        include_skips=concat_decoder_outputs_with,
        target_vol_sizes=concat_skip_sizes)

    if transform_activation is not None:
        x_transformation = Activation(
            transform_activation,
            name='activation_transform_{}'.format(transform_activation))(
                x_transformation)

        if transform_activation == 'tanh':
            # TODO: maybe move this logic
            # if we are learning a colro delta with a tanh, make sure to multiply it by 2
            x_transformation = Lambda(
                lambda x: x * 2, name='lambda_scale_tanh')(x_transformation)

    im_out = Add()([x_source, x_transformation])

    if clip_output_range is not None:
        im_out = Lambda(lambda x: tf.clip_by_value(x, clip_output_range[0],
                                                   clip_output_range[1]),
                        name='lambda_clip_output_{}-{}'.format(
                            clip_output_range[0],
                            clip_output_range[1]))(im_out)

    return Model(inputs=conditioning_inputs + [z_input],
                 outputs=[im_out, x_transformation],
                 name=model_name)
Пример #9
0
def transform_encoder_model(
        input_shapes,
        input_names=None,
        latent_shape=(50, ),
        model_name='encoder',
        enc_params=None,
):
    '''
    Generic encoder for a stack of inputs

    :param input_shape:
    :param latent_shape:
    :param model_name:
    :param enc_params:
    :return:
    '''
    if not isinstance(input_shapes, list):
        input_shapes = [input_shapes]

    if input_names is None:
        input_names = [
            'input_{}'.format(ii) for ii in range(len(input_shapes))
        ]

    inputs = []
    for ii, input_shape in enumerate(input_shapes):
        inputs.append(
            Input(input_shape, name='input_{}'.format(input_names[ii])))
    if len(inputs) > 1:
        inputs_stacked = Concatenate(name='concat_inputs', axis=-1)(inputs)
    else:
        inputs_stacked = inputs[0]
    input_stack_shape = inputs_stacked.get_shape().as_list()[1:]
    n_dims = len(input_stack_shape) - 1

    x_transform_enc = encoder(
        x=inputs_stacked,
        img_shape=input_stack_shape,
        conv_chans=enc_params['nf_enc'],
        n_convs_per_stage=enc_params['n_convs_per_stage']
        if 'n_convs_per_stage' in enc_params else 1,
        use_residuals=enc_params['use_residuals']
        if 'use_residuals' in enc_params else False,
        use_maxpool=enc_params['use_maxpool']
        if 'use_maxpool' in enc_params else False,
        kernel_initializer=enc_params['kernel_initializer']
        if 'kernel_initializer' in enc_params else None,
        bias_initializer=enc_params['bias_initializer']
        if 'bias_initializer' in enc_params else None,
        prefix="cvae")

    latent_size = np.prod(latent_shape)

    # the last layer in the basic encoder will be a convolution, so we should activate after it
    x_transform_enc = LeakyReLU(0.2)(x_transform_enc)
    x_transform_enc = Flatten()(x_transform_enc)

    z_mean = Dense(latent_size,
                   name='latent_mean',
                   kernel_initializer=keras_init.RandomNormal(
                       mean=0., stddev=0.00001))(x_transform_enc)
    z_logvar = Dense(
        latent_size,
        name='latent_logvar',
        bias_initializer=keras_init.RandomNormal(mean=-2., stddev=1e-10),
        kernel_initializer=keras_init.RandomNormal(mean=0., stddev=1e-10),
    )(x_transform_enc)

    return Model(inputs=inputs, outputs=[z_mean, z_logvar], name=model_name)
Пример #10
0
def transformer_selector_model(
    conditioning_input_shapes,
    conditioning_input_names=None,
    output_shape=None,
    model_name='CVAE_transformer',
    transform_latent_shape=(100, ),
    n_segs=64,
    latent_distribution='normal',
    transform_type=None,
    color_transform_type=None,
    enc_params=None,
    condition_on_image=True,
    n_concat_scales=3,
    transform_activation=None,
    clip_output_range=None,
    source_input_idx=None,
    mask_by_conditioning_input_idx=None,
):

    # collect conditioning inputs, and concatentate them into a stack
    if not isinstance(conditioning_input_shapes, list):
        conditioning_input_shapes = [conditioning_input_shapes]
    if conditioning_input_names is None:
        conditioning_input_names = [
            'cond_input_{}'.format(ii)
            for ii in range(len(conditioning_input_shapes))
        ]

    conditioning_inputs = []
    for ii, input_shape in enumerate(conditioning_input_shapes):
        conditioning_inputs.append(
            Input(input_shape, name=conditioning_input_names[ii]))

    conditioning_input_stack = Concatenate(name='concat_cond_inputs',
                                           axis=-1)(conditioning_inputs)
    conditioning_input_shape = tuple(
        conditioning_input_stack.get_shape().as_list()[1:])

    n_dims = len(conditioning_input_shape) - 1

    segs = conditioning_inputs[-1]

    if latent_distribution == 'normal':
        # we will always give z as a flattened vector
        z_input = Input((np.prod(transform_latent_shape), ), name='z_input')
        z = Dense(int(n_segs / 2), name='dense_z_1')(z_input)
        z = LeakyReLU(0.2)(z)
        z = Dense(n_segs, name='dense_z_2')(z)
        z = LeakyReLU(0.2)(z)
        z = Reshape((1, 1, n_segs), name='reshape_z')(z)
    else:
        # we will always give z as a flattened vector
        z_input = Input((np.prod(transform_latent_shape), ), name='z_input')
        # z = Dense(int(n_segs / 2), name='dense_z_1')(z_input)
        # z = LeakyReLU(0.2)(z)
        # z = Dense(n_segs, name='dense_z_2')(z)
        # z = LeakyReLU(0.2)(z)
        z = Reshape((1, 1, n_segs), name='reshape_z')(z_input)

    out_map = Multiply(name='mult_segs_z')([segs, z])
    out_map = Lambda(lambda x: tf.reduce_sum(x, axis=-1, keepdims=True),
                     name='lamda_sum_chans')(out_map)
    return Model(inputs=conditioning_inputs + [z_input],
                 outputs=[out_map],
                 name=model_name)
Пример #11
0
def transform_categorical_encoder_model(
    input_shapes,
    input_names=None,
    latent_shape=(50, ),
    model_name='VTE_transform_encoder',
    enc_params=None,
):
    '''
    Generic encoder for a stack of inputs

    :param input_shape:
    :param latent_shape:
    :param model_name:
    :param enc_params:
    :return:
    '''
    if not isinstance(input_shapes, list):
        input_shapes = [input_shapes]

    if input_names is None:
        input_names = [
            'input_{}'.format(ii) for ii in range(len(input_shapes))
        ]

    inputs = []
    for ii, input_shape in enumerate(input_shapes):
        inputs.append(
            Input(input_shape, name='input_{}'.format(input_names[ii])))

    inputs_stacked = Concatenate(name='concat_inputs', axis=-1)(inputs)
    input_stack_shape = inputs_stacked.get_shape().as_list()[1:]
    n_dims = len(input_stack_shape) - 1

    x_transform_enc = basic_networks.encoder(
        x=inputs_stacked,
        img_shape=input_stack_shape,
        conv_chans=enc_params['nf_enc'],
        min_h=None,
        min_c=None,
        n_convs_per_stage=enc_params['n_convs_per_stage'],
        use_residuals=enc_params['use_residuals'],
        use_maxpool=enc_params['use_maxpool'],
        prefix='vte')

    latent_size = np.prod(latent_shape)

    if not enc_params['fully_conv']:
        # the last layer in the basic encoder will be a convolution, so we should activate after it
        x_transform_enc = LeakyReLU(0.2)(x_transform_enc)
        x_transform_enc = Flatten()(x_transform_enc)

        z = Dense(latent_size, name='dense_1')(x_transform_enc)
        z = LeakyReLU(0.2)(z)
        z = Dense(latent_size, name='dense_2')(z)
        z = LeakyReLU(0.2)(z)
        z_logits = Dense(
            latent_size,
            bias_initializer=keras_init.RandomNormal(mean=1., stddev=1e-10),
            name='dense_latent')(z)  # not yet softmaxed (normalized to [0, 1])
#        z_categorical = Activation('softmax', name='softmax_latent_categorical')(z)

#tau = 0.5
#z = Lambda(lambda x:gumbel_softmax(x, tau, hard=False))(z)

# z_mean = Dense(latent_size, name='latent_mean',
#     kernel_initializer=keras_init.RandomNormal(mean=0., stddev=0.00001))(x_transform_enc)
# z_logvar = Dense(latent_size, name='latent_logvar',
#                 bias_initializer=keras_init.RandomNormal(mean=-2., stddev=1e-10),
#                 kernel_initializer=keras_init.RandomNormal(mean=0., stddev=1e-10),
#             )(x_transform_enc)

    return Model(inputs=inputs, outputs=[z_logits], name=model_name)
Пример #12
0
def transform_encoder_model(
    input_shapes,
    input_names=None,
    latent_shape=(50, ),
    model_name='VTE_transform_encoder',
    enc_params=None,
):
    '''
    Generic encoder for a stack of inputs

    :param input_shape:
    :param latent_shape:
    :param model_name:
    :param enc_params:
    :return:
    '''
    if not isinstance(input_shapes, list):
        input_shapes = [input_shapes]

    if input_names is None:
        input_names = [
            'input_{}'.format(ii) for ii in range(len(input_shapes))
        ]

    inputs = []
    for ii, input_shape in enumerate(input_shapes):
        inputs.append(
            Input(input_shape, name='input_{}'.format(input_names[ii])))
    if len(inputs) > 1:
        inputs_stacked = Concatenate(name='concat_inputs', axis=-1)(inputs)
    else:
        inputs_stacked = inputs[0]
    input_stack_shape = inputs_stacked.get_shape().as_list()[1:]
    n_dims = len(input_stack_shape) - 1

    x_transform_enc = basic_networks.encoder(
        x=inputs_stacked,
        img_shape=input_stack_shape,
        conv_chans=enc_params['nf_enc'],
        min_h=None,
        min_c=None,
        n_convs_per_stage=enc_params['n_convs_per_stage']
        if 'n_convs_per_stage' in enc_params else 1,
        use_residuals=enc_params['use_residuals']
        if 'use_residuals' in enc_params else False,
        use_maxpool=enc_params['use_maxpool']
        if 'use_maxpool' in enc_params else False,
        kernel_initializer=enc_params['kernel_initializer']
        if 'kernel_initializer' in enc_params else None,
        bias_initializer=enc_params['bias_initializer']
        if 'bias_initializer' in enc_params else None,
        prefix='vte')

    latent_size = np.prod(latent_shape)

    if not enc_params['fully_conv']:
        # the last layer in the basic encoder will be a convolution, so we should activate after it
        x_transform_enc = LeakyReLU(0.2)(x_transform_enc)
        x_transform_enc = Flatten()(x_transform_enc)

        z_mean = Dense(latent_size,
                       name='latent_mean',
                       kernel_initializer=keras_init.RandomNormal(
                           mean=0., stddev=0.00001))(x_transform_enc)
        z_logvar = Dense(
            latent_size,
            name='latent_logvar',
            bias_initializer=keras_init.RandomNormal(mean=-2., stddev=1e-10),
            kernel_initializer=keras_init.RandomNormal(mean=0., stddev=1e-10),
        )(x_transform_enc)
    else:
        emb_shape = basic_networks.get_encoded_shape(
            input_stack_shape, conv_chans=enc_params['nf_enc'])
        n_chans = emb_shape[-1]

        if n_dims == 3:
            # convolve rather than Lambda since we want to set the initialization
            z_mean = Conv3D(latent_shape[-1],
                            kernel_size=2,
                            strides=2,
                            padding='same',
                            kernel_initializer=keras_init.RandomNormal(
                                mean=0., stddev=0.001))(x_transform_enc)
            z_mean = Flatten(name='latent_mean')(z_mean)

            z_logvar = Conv3D(
                latent_shape[-1],
                kernel_size=2,
                strides=2,
                padding='same',
                bias_initializer=keras_init.RandomNormal(mean=-2.,
                                                         stddev=1e-10),
                kernel_initializer=keras_init.RandomNormal(mean=0.,
                                                           stddev=1e-10),
            )(x_transform_enc)
            z_logvar = Flatten(name='latent_logvar')(z_logvar)
        else:
            # TODO: also convolve to latent mean and logvar for 2D?
            z_mean = Lambda(lambda x: x[:, :, :, :n_chans / 2],
                            output_shape=emb_shape[:-1] +
                            (n_chans / 2, ))(x_transform_enc)
            z_mean = Flatten(name='latent_mean')(z_mean)

            z_logvar = Lambda(lambda x: x[:, :, :, n_chans / 2:],
                              output_shape=emb_shape[:-1] +
                              (n_chans / 2, ))(x_transform_enc)
            z_logvar = Flatten(name='latent_logvar')(z_logvar)

    return Model(inputs=inputs, outputs=[z_mean, z_logvar], name=model_name)
Пример #13
0
def transform_encoder_model(input_shapes, input_names=None,
                            latent_shape=(50,),
                            model_name='VTE_transform_encoder',
                            enc_params=None,
                            ):
    '''
    Generic encoder for a stack of inputs

    :param input_shape:
    :param latent_shape:
    :param model_name:
    :param enc_params:
    :return:
    '''
    if not isinstance(input_shapes, list):
        input_shapes = [input_shapes]

    if input_names is None:
        input_names = ['input_{}'.format(ii) for ii in range(len(input_shapes))]

    inputs = []
    for ii, input_shape in enumerate(input_shapes):
        inputs.append(Input(input_shape, name='input_{}'.format(input_names[ii])))

    inputs_stacked = Concatenate(name='concat_inputs', axis=-1)(inputs)
    input_stack_shape = inputs_stacked.get_shape().as_list()[1:]
    n_dims = len(input_stack_shape) - 1

    x_transform_enc = basic_networks.encoder(
        x=inputs_stacked,
        img_shape=input_stack_shape,
        conv_chans=enc_params['nf_enc'],
        min_h=None, min_c=None,
        n_convs_per_stage=enc_params['n_convs_per_stage'],
        use_residuals=enc_params['use_residuals'],
        use_maxpool=enc_params['use_maxpool'],
        prefix='vte'
    )

    latent_size = np.prod(latent_shape)

    if not enc_params['fully_conv']:
        # the last layer in the basic encoder will be a convolution, so we should activate after it
        x_transform_enc = LeakyReLU(0.2)(x_transform_enc)
        x_transform_enc = Flatten()(x_transform_enc)

        z_mean = Dense(latent_size, name='latent_mean',
            kernel_initializer=keras_init.RandomNormal(mean=0., stddev=0.00001))(x_transform_enc)
        z_logvar = Dense(latent_size, name='latent_logvar',
                        bias_initializer=keras_init.RandomNormal(mean=-2., stddev=1e-10),
                        kernel_initializer=keras_init.RandomNormal(mean=0., stddev=1e-10),
                    )(x_transform_enc)
    else:
        emb_shape = basic_networks.get_encoded_shape(input_stack_shape, conv_chans=enc_params['nf_enc'])
        n_chans = emb_shape[-1]

        if n_dims == 3:
            # convolve rather than Lambda since we want to set the initialization
            z_mean = Conv3D(latent_shape[-1], kernel_size=2, strides=2, padding='same',
                            kernel_initializer=keras_init.RandomNormal(mean=0., stddev=0.001))(x_transform_enc)
            z_mean = Flatten(name='latent_mean')(z_mean)

            z_logvar = Conv3D(latent_shape[-1], kernel_size=2,
                              strides=2, padding='same',
                              bias_initializer=keras_init.RandomNormal(mean=-2., stddev=1e-10),
                              kernel_initializer=keras_init.RandomNormal(mean=0., stddev=1e-10),
                              )(x_transform_enc)
            z_logvar = Flatten(name='latent_logvar')(z_logvar)
        else:
            # TODO: also convolve to latent mean and logvar for 2D?
            z_mean = Lambda(lambda x: x[:, :, :, :n_chans/2],
                            output_shape=emb_shape[:-1] + (n_chans/2,))(x_transform_enc)
            z_mean = Flatten(name='latent_mean')(z_mean)

            z_logvar = Lambda(lambda x: x[:, :, :, n_chans/2:],
                              output_shape=emb_shape[:-1] + (n_chans/2,))(x_transform_enc)
            z_logvar = Flatten(name='latent_logvar')(z_logvar)

    return Model(inputs=inputs, outputs=[z_mean, z_logvar], name=model_name)
Пример #14
0
def transformer_model(conditioning_input_shapes, conditioning_input_names=None,
                      output_shape=None,
                      model_name='CVAE_transformer',
                      transform_latent_shape=(100,),
                      transform_type=None,
                      color_transform_type=None,
                      enc_params=None,
                      condition_on_image=True,
                      n_concat_scales=3,
                      transform_activation=None, clip_output_range=None,
                      source_input_idx=None,
                      mask_by_conditioning_input_idx=None,
                      ):

    # collect conditioning inputs, and concatentate them into a stack
    if not isinstance(conditioning_input_shapes, list):
        conditioning_input_shapes = [conditioning_input_shapes]
    if conditioning_input_names is None:
        conditioning_input_names = ['cond_input_{}'.format(ii) for ii in range(len(conditioning_input_shapes))]

    conditioning_inputs = []
    for ii, input_shape in enumerate(conditioning_input_shapes):
        conditioning_inputs.append(Input(input_shape, name=conditioning_input_names[ii]))

    conditioning_input_stack = Concatenate(name='concat_cond_inputs', axis=-1)(conditioning_inputs)
    conditioning_input_shape = tuple(conditioning_input_stack.get_shape().as_list()[1:])

    n_dims = len(conditioning_input_shape) - 1

    # we will always give z as a flattened vector
    z_input = Input((np.prod(transform_latent_shape),), name='z_input')

    # determine what we should apply the transformation to
    if source_input_idx is None:
        # the image we want to transform is exactly the input of the conditioning branch
        x_source = conditioning_input_stack
        source_input_shape = conditioning_input_shape
    else:
        # slice conditioning input to get the single source im that we will apply the transform to
        source_input_shape = conditioning_input_shapes[source_input_idx]
        x_source = conditioning_inputs[source_input_idx]

    # assume the output is going to be the transformed source input, so it should be the same shape
    if output_shape is None:
        output_shape = source_input_shape

    if mask_by_conditioning_input_idx is None:
        x_mask = None
    else:
        print('Masking output by input {} with name {}'.format(
            mask_by_conditioning_input_idx,
            conditioning_input_names[mask_by_conditioning_input_idx]
        ))
        mask_shape = conditioning_input_shapes[mask_by_conditioning_input_idx]
        x_mask = conditioning_inputs[mask_by_conditioning_input_idx]

    if transform_type == 'flow':
        layer_prefix = 'flow'
    elif transform_type == 'color':
        layer_prefix = 'color'
    else:
        layer_prefix = 'synth'

    if condition_on_image: # assume we always condition by concat since it's better than other forms
        # simply concatenate the conditioning stack (at various scales) with the decoder volumes
        include_fullres = True

        concat_decoder_outputs_with = [None] * len(enc_params['nf_dec'])
        concat_skip_sizes = [None] * len(enc_params['nf_dec'])

        # make sure x_I is the same shape as the output, including in the channels dimension
        if not np.all(output_shape <= conditioning_input_shape):
            tile_factor = [int(round(output_shape[i] / conditioning_input_shape[i])) for i in
                           range(len(output_shape))]
            print('Tile factor: {}'.format(tile_factor))
            conditioning_input_stack = Lambda(lambda x: tf.tile(x, [1] + tile_factor), name='lambda_tile_cond_input')(conditioning_input_stack)

        # downscale the conditioning inputs by the specified number of times
        xs_downscaled = [conditioning_input_stack]
        for si in range(n_concat_scales):
            curr_x_scaled = network_layers.Blur_Downsample(
                n_chans=conditioning_input_shape[-1], n_dims=n_dims,
                do_blur=True,
                name='downsample_scale-1/{}'.format(2**(si + 1))
            )(xs_downscaled[-1])
            xs_downscaled.append(curr_x_scaled)

        if not include_fullres:
            xs_downscaled = xs_downscaled[1:]  # exclude the full-res volume

        print('Including downsampled input sizes {}'.format([x.get_shape().as_list() for x in xs_downscaled]))

        concat_decoder_outputs_with[:len(xs_downscaled)] = list(reversed(xs_downscaled))
        concat_skip_sizes[:len(xs_downscaled)] = list(reversed(
            [np.asarray(x.get_shape().as_list()[1:-1]) for x in xs_downscaled if
             x is not None]))

    else:
        # just ignore the conditioning input
        concat_decoder_outputs_with = None
        concat_skip_sizes = None


    if 'ks' not in enc_params:
        enc_params['ks'] = 3

    if not enc_params['fully_conv']:
        # determine what size to reshape the latent vector to
        reshape_encoding_to = basic_networks.get_encoded_shape(
            img_shape=conditioning_input_shape,
            conv_chans=enc_params['nf_enc'],
        )

        if np.all(reshape_encoding_to[:n_dims] > concat_skip_sizes[-1][:n_dims]):
            raise RuntimeWarning(
                'Attempting to concatenate reshaped latent vector of shape {} with downsampled input of shape {}!'.format(
                    reshape_encoding_to,
                    concat_skip_sizes[-1]
                ))

        x_enc = Dense(np.prod(reshape_encoding_to))(z_input)
    else:
        # latent representation is already in correct shape
        reshape_encoding_to = transform_latent_shape
        x_enc = z_input

    x_enc = Reshape(reshape_encoding_to)(x_enc)


    print('Decoder starting shape: {}'.format(reshape_encoding_to))

    x_transformation = basic_networks.decoder(
        x_enc, output_shape,
        encoded_shape=reshape_encoding_to,
        prefix='{}_dec'.format(layer_prefix),
        conv_chans=enc_params['nf_dec'], ks=enc_params['ks'],
        n_convs_per_stage=enc_params['n_convs_per_stage'],
        use_upsample=enc_params['use_upsample'],
        include_skips=concat_decoder_outputs_with,
        target_vol_sizes=concat_skip_sizes
    )

    if transform_activation is not None:
        x_transformation = Activation(
            transform_activation,
            name='activation_transform_{}'.format(transform_activation))(x_transformation)

        if transform_type == 'color' and 'delta' in color_transform_type and transform_activation=='tanh':
            # TODO: maybe move this logic
            # if we are learning a colro delta with a tanh, make sure to multiply it by 2
            x_transformation = Lambda(lambda x: x * 2, name='lambda_scale_tanh')(x_transformation)

    if mask_by_conditioning_input_idx is not None:
        x_transformation = Multiply(name='mult_mask_transformation')([x_transformation, x_mask])
    if transform_type is not None:
        im_out, transform_out = apply_transformation(x_source, x_transformation, 
            output_shape=source_input_shape, conditioning_input_shape=conditioning_input_shape, transform_name=transform_type,
            apply_flow_transform=transform_type=='flow',
            apply_color_transform=transform_type=='color',
            color_transform_type=color_transform_type
            )
    else:
        im_out = x_transformation

    if clip_output_range is not None:
        im_out = Lambda(lambda x: tf.clip_by_value(x, clip_output_range[0], clip_output_range[1]),
            name='lambda_clip_output_{}-{}'.format(clip_output_range[0], clip_output_range[1]))(im_out)

    if transform_type is not None:
        return Model(inputs=conditioning_inputs + [z_input], outputs=[im_out, transform_out], name=model_name)
    else:
        return Model(inputs=conditioning_inputs + [z_input], outputs=[im_out], name=model_name)