def CapsNet(input_shape, n_class, routings): """ A Capsule Network on MNIST. :param input_shape: data shape, 3d, [width, height, channels] :param n_class: number of classes :param routings: number of routing iterations :return: Two Keras Models, the first one used for training, and the second one for evaluation. `eval_model` can also be used for training. """ x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x) # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule] primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid') # Layer 3: Capsule layer. Routing algorithm works here. digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=digitcap_dim, routings=routings, name='digitcaps')(primarycaps) # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. # If using tensorflow, this will not be necessary. :) out_caps = Length(name='capsnet')(digitcaps) # Decoder network. y = layers.Input(shape=(n_class,)) masked_by_y = Mask()([digitcaps, y]) # The true label is used to mask the output of capsule layer. For training masked = Mask()(digitcaps) # Mask using the capsule with maximal length. For prediction # Shared Decoder model in training and prediction decoder = models.Sequential(name='decoder') decoder.add(layers.Dense(512, activation='relu', input_dim=digitcap_dim*n_class)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon')) # Models for training and evaluation (prediction) train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)]) eval_model = models.Model(x, [out_caps, decoder(masked)]) digit_caps_model = models.Model(x, digitcaps) return train_model, eval_model, digit_caps_model
def CapsNetR3(input_shape, n_class=2): x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer conv1 = layers.Conv2D(filters=16, kernel_size=5, strides=1, padding='same', activation='relu', name='conv1')(x) # Reshape layer to be 1 capsule x [filters] atoms _, H, W, C = conv1.get_shape() # print("conv1 params",conv1.get_shape()) conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1) # Layer 1: Primary Capsule: Conv cap with routing 1 primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=2, num_atoms=16, strides=2, padding='same', routings=1, name='primarycaps')(conv1_reshaped) # Layer 2: Convolutional Capsule conv_cap_2_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1, padding='same', routings=3, name='conv_cap_2_1')(primary_caps) # Layer 2: Convolutional Capsule conv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=2, padding='same', routings=3, name='conv_cap_2_2')(conv_cap_2_1) # Layer 3: Convolutional Capsule conv_cap_3_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same', routings=3, name='conv_cap_3_1')(conv_cap_2_2) # Layer 3: Convolutional Capsule conv_cap_3_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=64, strides=2, padding='same', routings=3, name='conv_cap_3_2')(conv_cap_3_1) # Layer 4: Convolutional Capsule conv_cap_4_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same', routings=3, name='conv_cap_4_1')(conv_cap_3_2) # Layer 1 Up: Deconvolutional Capsule deconv_cap_1_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=8, num_atoms=32, upsamp_type='deconv', scaling=2, padding='same', routings=3, name='deconv_cap_1_1')(conv_cap_4_1) # Skip connection up_1 = layers.Concatenate(axis=-2, name='up_1')([deconv_cap_1_1, conv_cap_3_1]) # Layer 1 Up: Deconvolutional Capsule deconv_cap_1_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=1, padding='same', routings=3, name='deconv_cap_1_2')(up_1) # Layer 2 Up: Deconvolutional Capsule deconv_cap_2_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=4, num_atoms=16, upsamp_type='deconv', scaling=2, padding='same', routings=3, name='deconv_cap_2_1')(deconv_cap_1_2) # Skip connection up_2 = layers.Concatenate(axis=-2, name='up_2')([deconv_cap_2_1, conv_cap_2_1]) # Layer 2 Up: Deconvolutional Capsule deconv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1, padding='same', routings=3, name='deconv_cap_2_2')(up_2) # Layer 3 Up: Deconvolutional Capsule deconv_cap_3_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=2, num_atoms=16, upsamp_type='deconv', scaling=2, padding='same', routings=3, name='deconv_cap_3_1')(deconv_cap_2_2) # Skip connection up_3 = layers.Concatenate(axis=-2, name='up_3')([deconv_cap_3_1, conv1_reshaped]) # Layer 4: Convolutional Capsule: 1x1 seg_caps = ConvCapsuleLayer(kernel_size=1, num_capsule=1, num_atoms=16, strides=1, padding='same', routings=3, name='seg_caps')(up_3) # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps) # Decoder network. _, H, W, C, A = seg_caps.get_shape() y = layers.Input(shape=input_shape[:-1] + (1, )) # print(y) masked_by_y = Mask()( [seg_caps, y] ) # The true label is used to mask the output of capsule layer. For training masked = Mask()( seg_caps) # Mask using the capsule with maximal length. For prediction # print(masked_by_y, masked) def shared_decoder(mask_layer): recon_remove_dim = layers.Reshape( (H.value, W.value, A.value))(mask_layer) recon_1 = layers.Conv2D(filters=64, kernel_size=1, padding='same', kernel_initializer='he_normal', activation='relu', name='recon_1')(recon_remove_dim) recon_2 = layers.Conv2D(filters=128, kernel_size=1, padding='same', kernel_initializer='he_normal', activation='relu', name='recon_2')(recon_1) out_recon = layers.Conv2D(filters=1, kernel_size=1, padding='same', kernel_initializer='he_normal', activation='sigmoid', name='out_recon')(recon_2) return out_recon # Models for training and evaluation (prediction) train_model = models.Model(inputs=[x, y], outputs=[out_seg, shared_decoder(masked_by_y)]) eval_model = models.Model(inputs=x, outputs=[out_seg, shared_decoder(masked)]) # manipulate model noise = layers.Input(shape=((H.value, W.value, C.value, A.value))) noised_seg_caps = layers.Add()([seg_caps, noise]) masked_noised_y = Mask()([noised_seg_caps, y]) manipulate_model = models.Model(inputs=[x, y, noise], outputs=shared_decoder(masked_noised_y)) train_model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=[mean_iou]) return train_model
def DiagnosisCapsules(input_shape, n_class=2, k_size=5, output_atoms=16, routings1=3, routings2=3): """ A Capsule Network on Medical Image Diagnosis. :param input_shape: data shape :param n_class: number of classes :param k_size: kernel size for convolutional capsules :param output_atoms: number of atoms in D-Caps layer :param routings1: number of routing iterations when stride is 1 :param routings2: number of routing iterations when stride is > 1 :return: Two Keras Models, the first one used for training, and the second one for evaluation. `eval_model` can also be used for training. """ if n_class == 2: n_class = 1 # binary output x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer conv1 = layers.Conv2D(filters=16, kernel_size=k_size, strides=2, padding='same', activation='relu', name='conv1')(x) # Reshape layer to be 1 capsule x [filters] atoms conv1_reshaped = ExpandDim(name='expand_dim')(conv1) # Layer 1: Primary Capsule: Conv cap with routing 1 primary_caps = ConvCapsuleLayer(kernel_size=k_size, num_capsule=2, num_atoms=16, strides=2, padding='same', routings=1, name='primarycaps')(conv1_reshaped) # Layer 2: Convolutional Capsule conv_cap_2_1 = ConvCapsuleLayer(kernel_size=k_size, num_capsule=4, num_atoms=16, strides=1, padding='same', routings=routings1, name='conv_cap_2_1')(primary_caps) # Layer 2: Convolutional Capsule conv_cap_2_2 = ConvCapsuleLayer(kernel_size=k_size, num_capsule=4, num_atoms=32, strides=2, padding='same', routings=routings2, name='conv_cap_2_2')(conv_cap_2_1) # Layer 3: Convolutional Capsule conv_cap_3_1 = ConvCapsuleLayer(kernel_size=k_size, num_capsule=8, num_atoms=32, strides=1, padding='same', routings=routings1, name='conv_cap_3_1')(conv_cap_2_2) # Layer 3: Convolutional Capsule conv_cap_3_2 = ConvCapsuleLayer(kernel_size=k_size, num_capsule=8, num_atoms=64, strides=2, padding='same', routings=routings2, name='conv_cap_3_2')(conv_cap_3_1) # Layer 4: Convolutional Capsule conv_cap_4_1 = ConvCapsuleLayer(kernel_size=k_size, num_capsule=8, num_atoms=32, strides=1, padding='same', routings=routings1, name='conv_cap_4_1')(conv_cap_3_2) # Layer 3: Convolutional Capsule conv_cap_4_2 = ConvCapsuleLayer(kernel_size=k_size, num_capsule=n_class, num_atoms=output_atoms, strides=2, padding='same', routings=routings2, name='conv_cap_4_2')(conv_cap_4_1) if n_class > 1: # Perform GAP on each capsule type. class_caps_list = [] for i in range(n_class): in_shape = conv_cap_4_2.get_shape().as_list() one_class_capsule = layers.Lambda(lambda x: x[:, :, :, i, :], output_shape=in_shape[1:3] + in_shape[4:])(conv_cap_4_2) gap = layers.GlobalAveragePooling2D( name='gap_{}'.format(i))(one_class_capsule) # Put capsule dimension back for length and recon class_caps_list.append( ExpandDim(name='expand_gap_{}'.format(i))(gap)) class_caps = layers.Concatenate(axis=-2, name='class_caps')(class_caps_list) else: # Remove capsule dim, perform GAP, put capsule dim back conv_cap_4_2_reshaped = RemoveDim( name='conv_cap_4_2_reshaped')(conv_cap_4_2) gap = layers.GlobalAveragePooling2D(name='gap')(conv_cap_4_2_reshaped) class_caps = ExpandDim(name='expand_gap')(gap) # Output layer which predicts classes out_caps = Length(num_classes=n_class, name='out_caps')(class_caps) # Decoder network. _, C, A = class_caps.get_shape() y = layers.Input(shape=(n_class, )) masked_by_y = Mask()( [class_caps, y] ) # The true label is used to mask the output of capsule layer. For training masked = Mask( )(class_caps) # Mask using the capsule with maximal length. For prediction def shared_reconstructor(mask_layer): recon_1 = layers.Dense(input_shape[0] // (2**6) * input_shape[1] // (2**6), kernel_initializer='he_normal', activation='relu', name='recon_1', input_shape=(A.value, ))(mask_layer) recon_1a = layers.Reshape( (input_shape[0] // (2**6), input_shape[1] // (2**6), 1), name='recon_1a')(recon_1) recon_2 = layers.Conv2DTranspose(filters=128, kernel_size=5, strides=(8, 8), padding='same', kernel_initializer='he_normal', activation='relu', name='recon_2')(recon_1a) recon_3 = layers.Conv2DTranspose(filters=64, kernel_size=5, strides=(8, 8), padding='same', kernel_initializer='he_normal', activation='relu', name='recon_3')(recon_2) out_recon = layers.Conv2D(filters=3, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='tanh', name='out_recon')(recon_3) return out_recon # Models for training and evaluation (prediction) train_model = models.Model( inputs=[x, y], outputs=[out_caps, shared_reconstructor(masked_by_y)]) eval_model = models.Model(inputs=x, outputs=[out_caps, shared_reconstructor(masked)]) # manipulate model noise = layers.Input(shape=((C.value, A.value))) noised_class_caps = layers.Add()([class_caps, noise]) masked_noised_y = Mask()([noised_class_caps, y]) manipulate_model = models.Model( inputs=[x, y, noise], outputs=shared_reconstructor(masked_noised_y)) return train_model, eval_model, manipulate_model
def CapsNetBasic(input_shape, n_class=2): x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer conv1 = layers.Conv2D(filters=256, kernel_size=5, strides=1, padding='same', activation='relu', name='conv1')(x) # Reshape layer to be 1 capsule x [filters] atoms _, H, W, C = conv1.get_shape() conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1) # Layer 1: Primary Capsule: Conv cap with routing 1 primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same', routings=1, name='primarycaps')(conv1_reshaped) # Layer 4: Convolutional Capsule: 1x1 seg_caps = ConvCapsuleLayer(kernel_size=1, num_capsule=1, num_atoms=16, strides=1, padding='same', routings=3, name='seg_caps')(primary_caps) # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps) # Decoder network. _, H, W, C, A = seg_caps.get_shape() y = layers.Input(shape=input_shape[:-1] + (1, )) masked_by_y = Mask()( [seg_caps, y] ) # The true label is used to mask the output of capsule layer. For training masked = Mask()( seg_caps) # Mask using the capsule with maximal length. For prediction def shared_decoder(mask_layer): recon_remove_dim = layers.Reshape( (H.value, W.value, A.value))(mask_layer) recon_1 = layers.Conv2D(filters=64, kernel_size=1, padding='same', kernel_initializer='he_normal', activation='relu', name='recon_1')(recon_remove_dim) recon_2 = layers.Conv2D(filters=128, kernel_size=1, padding='same', kernel_initializer='he_normal', activation='relu', name='recon_2')(recon_1) out_recon = layers.Conv2D(filters=1, kernel_size=1, padding='same', kernel_initializer='he_normal', activation='sigmoid', name='out_recon')(recon_2) return out_recon # Models for training and evaluation (prediction) train_model = models.Model(inputs=[x, y], outputs=[out_seg, shared_decoder(masked_by_y)]) eval_model = models.Model(inputs=x, outputs=[out_seg, shared_decoder(masked)]) # manipulate model noise = layers.Input(shape=((H.value, W.value, C.value, A.value))) noised_seg_caps = layers.Add()([seg_caps, noise]) masked_noised_y = Mask()([noised_seg_caps, y]) manipulate_model = models.Model(inputs=[x, y, noise], outputs=shared_decoder(masked_noised_y)) return train_model, eval_model, manipulate_model
def CapsNet(input_shape, n_class=5, routings=3, noactiv=False): x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x) # Reshape layer to be 1 capsule x [filters] atoms conv1_reshaped = ExpandDim(name='expand_dim')(conv1) # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule] primary_caps = ConvCapsuleLayer(kernel_size=9, num_capsule=32, num_atoms=8, strides=2, padding='same', routings=1, name='primary_caps')(conv1_reshaped) # Layer 3: Capsule layer. Routing algorithm works here. malcaps = FullCapsuleLayer(num_capsule=n_class, num_atoms=16, routings=routings, name='malcaps')(primary_caps) # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. # If using tensorflow, this will not be necessary. :) if noactiv: out_mal = Length(num_classes=n_class, name='out_mal')(malcaps) else: mal_mag = Length(num_classes=n_class, name='mal_mag')(malcaps) out_mal = layers.Activation('softmax', name='out_mal')(mal_mag) # Decoder network. y = layers.Input(shape=(n_class, )) masked_by_y = Mask(n_class)( [malcaps, y] ) # The true label is used to mask the output of capsule layer. For training masked = Mask(n_class)( malcaps) # Mask using the capsule with maximal length. For prediction # Shared Decoder model in training and prediction decoder = models.Sequential(name='out_recon') decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon')) # Models for training and evaluation (prediction) train_model = models.Model([x, y], [out_mal, decoder(masked_by_y)]) eval_model = models.Model(x, [out_mal, decoder(masked)]) # manipulate model noise = layers.Input(shape=(n_class, 16)) noised_malcaps = layers.Add()([malcaps, noise]) masked_noised_y = Mask(n_class)([noised_malcaps, y]) manipulate_model = models.Model( [x, y, noise], [out_mal, decoder(masked_noised_y)]) return train_model, eval_model, manipulate_model
def CapsNetR3(input_shape, modalities=1, n_class=2): capsules_base = 2 filter_multiplier = 1 atoms_base = 16 * filter_multiplier x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer conv1 = layers.Conv2D(filters=16 * filter_multiplier, kernel_size=5, strides=1, padding='same', activation='relu', name='conv1')(x) # Reshape layer to be 1 capsule x [filters] atoms _, H, W, C = conv1.get_shape() print(conv1.shape) conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1) print(conv1_reshaped.shape) # Layer 1: Primary Capsule: Conv cap with routing 1 primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=capsules_base, num_atoms=atoms_base, strides=2, padding='same', routings=1, name='primarycaps')(conv1_reshaped) # Layer 2: Convolutional Capsule conv_cap_2_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=capsules_base * 2, num_atoms=atoms_base, strides=1, padding='same', routings=3, name='conv_cap_2_1')(primary_caps) # Layer 2: Convolutional Capsule conv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=capsules_base * 2, num_atoms=atoms_base * 2, strides=2, padding='same', routings=3, name='conv_cap_2_2')(conv_cap_2_1) # Layer 3: Convolutional Capsule conv_cap_3_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=capsules_base * 4, num_atoms=atoms_base * 2, strides=1, padding='same', routings=3, name='conv_cap_3_1')(conv_cap_2_2) # Layer 3: Convolutional Capsule conv_cap_3_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=capsules_base * 4, num_atoms=atoms_base * 4, strides=2, padding='same', routings=3, name='conv_cap_3_2')(conv_cap_3_1) # Layer 4: Convolutional Capsule conv_cap_4_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=capsules_base * 4, num_atoms=atoms_base * 2, strides=1, padding='same', routings=3, name='conv_cap_4_1')(conv_cap_3_2) # Layer 1 Up: Deconvolutional Capsule deconv_cap_1_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=capsules_base * 4, num_atoms=atoms_base * 2, upsamp_type='deconv', scaling=2, padding='same', routings=3, name='deconv_cap_1_1')(conv_cap_4_1) # Skip connection up_1 = layers.Concatenate(axis=-2, name='up_1')([deconv_cap_1_1, conv_cap_3_1]) # Layer 1 Up: Deconvolutional Capsule deconv_cap_1_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=capsules_base * 2, num_atoms=atoms_base * 2, strides=1, padding='same', routings=3, name='deconv_cap_1_2')(up_1) # Layer 2 Up: Deconvolutional Capsule deconv_cap_2_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=capsules_base * 2, num_atoms=atoms_base, upsamp_type='deconv', scaling=2, padding='same', routings=3, name='deconv_cap_2_1')(deconv_cap_1_2) # Skip connection up_2 = layers.Concatenate(axis=-2, name='up_2')([deconv_cap_2_1, conv_cap_2_1]) # Layer 2 Up: Deconvolutional Capsule deconv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=capsules_base * 2, num_atoms=atoms_base, strides=1, padding='same', routings=3, name='deconv_cap_2_2')(up_2) # Layer 3 Up: Deconvolutional Capsule deconv_cap_3_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=capsules_base * 2, num_atoms=atoms_base, upsamp_type='deconv', scaling=2, padding='same', routings=3, name='deconv_cap_3_1')(deconv_cap_2_2) # Skip connection up_3 = layers.Concatenate(axis=-2, name='up_3')([deconv_cap_3_1, conv1_reshaped]) seg_caps = ConvCapsuleLayer(kernel_size=1, num_capsule=n_class, num_atoms=atoms_base, strides=1, padding='same', routings=3, name='seg_caps')(up_3) # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps) print(out_seg.shape) #assert False, "Out seg shape" #out_seg = seg_caps # Decoder network. _, H, W, C, A = seg_caps.get_shape() print(H.value, W.value, C.value, A.value) y = layers.Input(shape=input_shape[:-1] + (1, ), name='recon_input') masked_by_y = Mask()( [seg_caps, y] ) # The true label is used to mask the output of capsule layer. For training print('masked by y ' + str(masked_by_y.shape)) print('Y ' + str(y.shape)) masked = Mask()( seg_caps) # Mask using the capsule with maximal length. For prediction def shared_decoder(mask_layer): recon_remove_dim = layers.Reshape( (input_shape[0], input_shape[1], n_class * atoms_base))(mask_layer) recon_1 = layers.Conv2D(filters=64, kernel_size=1, padding='same', kernel_initializer='he_normal', activation='relu', name='recon_1')(recon_remove_dim) recon_2 = layers.Conv2D(filters=128, kernel_size=1, padding='same', kernel_initializer='he_normal', activation='relu', name='recon_2')(recon_1) out_recon = layers.Conv2D(filters=modalities, kernel_size=1, padding='same', kernel_initializer='he_normal', activation='sigmoid', name='out_recon')(recon_2) return out_recon # Models for training and evaluation (prediction) train_outputs = [out_seg, shared_decoder(masked_by_y)] print("Train outputs") print(train_outputs[0].shape) print(train_outputs[1].shape) #assert False train_model = models.Model(inputs=[x, y], outputs=train_outputs) eval_model = models.Model(inputs=x, outputs=[ out_seg, shared_decoder(masked) ]) #TODO: Check masked by y for testing! # manipulate model noise = layers.Input(shape=((H.value, W.value, C.value, A.value))) noised_seg_caps = layers.Add()([seg_caps, noise]) masked_noised_y = Mask()([noised_seg_caps, y]) manipulate_model = models.Model(inputs=[x, y, noise], outputs=shared_decoder(masked_noised_y)) return train_model, eval_model, manipulate_model