Exemple #1
0
 def __init__(self, network_name, latent_dim, name='enc'):
     super(VAEEncoder, self).__init__(name=name)
     self.net = get_network_builder(network_name)()
     self.loc = tf.keras.layers.Dense(latent_dim, name='encoder_loc')
     self.scale = tf.keras.layers.Dense(latent_dim,
                                        name='encoder_scale_diag_raw')
     self.flatten = tf.keras.layers.Flatten()
Exemple #2
0
def build_classifier(network, num_classes, **network_kwargs):
    if isinstance(network, str):
        from pyroclast.common.models import get_network_builder
        network = get_network_builder(network)(**network_kwargs)

    return snt.Sequential([
        lambda x: network(x, **network_kwargs),
        snt.BatchFlatten(),
        snt.Linear(num_classes)
    ])
Exemple #3
0
 def __init__(self, network_name, output_channels, name='dec'):
     super(VAEDecoder, self).__init__(name=name)
     self.net = get_network_builder(network_name)()
     self.loc = tf.keras.layers.Conv2D(output_channels,
                                       3,
                                       padding="same",
                                       name='decoder_loc')
     self.log_scale = tf.keras.layers.Conv2D(output_channels,
                                             3,
                                             padding="same",
                                             name='decoder_log_scale')
Exemple #4
0
 def setUp(self):
     self.args = dict()
     self.args['batch_size'] = 8
     self.args['data_limit'] = 24
     self.args['num_prototypes'] = 100
     self.args['prototype_dim'] = 64
     self.ds = setup_tfds('mnist', self.args['batch_size'], None,
                          self.args['data_limit'])
     conv_stack = get_network_builder('mnist_conv')()
     self.model = ProtoPNet(conv_stack, self.args['num_prototypes'],
                            self.args['prototype_dim'],
                            self.ds['num_classes'])
Exemple #5
0
def build_entity_model(network, **network_kwargs):
    if isinstance(network, str):
        from pyroclast.common.models import get_network_builder
        network = get_network_builder(network)(**network_kwargs)

    def entity_model_builder(data, concept):
        broadcast_concept = tf.broadcast_to(concept, data.shape)
        broadcast_concept = tf.dtypes.cast(broadcast_concept, tf.float32)
        network_input = tf.concat([data, broadcast_concept], axis=-1)
        return network(network_input)

    return entity_model_builder
Exemple #6
0
def build_attention_model(network, **network_kwargs):
    if isinstance(network, str):
        from pyroclast.common.models import get_network_builder
        network = get_network_builder(network)(**network_kwargs)

    def attention_model_builder(data, attention, concept):
        attended_data = tf.nn.sigmoid(attention) * data
        broadcast_concept = tf.broadcast_to(concept, attended_data.shape)
        broadcast_concept = tf.dtypes.cast(broadcast_concept, tf.float32)
        network_input = tf.concat([attended_data, broadcast_concept], axis=-1)
        return tf.square(network(network_input))

    return attention_model_builder
Exemple #7
0
def build_encoder(network, latent_dim, **network_kwargs):
    if isinstance(network, str):
        from pyroclast.common.models import get_network_builder
        network = get_network_builder(network)(**network_kwargs)

    def encoder_builder(x, y):
        latent = network(x, **network_kwargs)
        latent = snt.BatchFlatten()(latent)
        latent = tf.concat(
            [tf.cast(latent, dtype=tf.float32),
             tf.cast(y, dtype=tf.float32)],
            axis=1)
        loc = snt.Linear(latent_dim)(latent)
        scale = snt.Linear(latent_dim)(latent)
        return loc, scale

    return encoder_builder
Exemple #8
0
def build_decoder(network, num_channels=3, **network_kwargs):
    if isinstance(network, str):
        from pyroclast.common.models import get_network_builder
        network = get_network_builder(network)(**network_kwargs)

    def decoder_builder(y, z):
        # TODO: remove the linear layer
        # `n` is local parameter for size of internal linear layer
        n = 32
        latent = snt.Linear(n)(tf.concat(
            [tf.cast(y, dtype=tf.float32),
             tf.cast(z, dtype=tf.float32)], 1))
        latent = tf.reshape(latent, [-1, 1, 1, n])
        output = network(latent, **network_kwargs)
        output = snt.Conv2D(num_channels, 1, padding=snt.SAME)(output)
        return output

    return decoder_builder
Exemple #9
0
def learn(data_dict,
          seed,
          output_dir,
          debug,
          conv_stack='vgg19_conv',
          epochs_phase_1=10,
          epochs_phase_3=10,
          learning_rate=3e-3,
          cluster_coeff=0.8,
          l1_coeff=1e-4,
          separation_coeff=0.08,
          clip_norm=None,
          num_prototypes=20,
          prototype_dim=128,
          is_class_specific=False,
          delay_conv_stack_training=False):
    writer = tf.summary.create_file_writer(output_dir)
    global_step = tf.compat.v1.train.get_or_create_global_step()
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
                                         beta_1=0.5,
                                         epsilon=0.01)

    # Using VGG19 here necessitates a 3 channel image input
    conv_stack = get_network_builder(conv_stack)()
    model = ProtoPNet(conv_stack,
                      num_prototypes,
                      prototype_dim,
                      data_dict['num_classes'],
                      class_specific=is_class_specific)

    # define minibatch fn
    def run_minibatch(epoch, batch, phase, is_train=True):
        """
        Args:
            epoch (int): Epoch of training for logging
            batch (dict): dict from dataset
            phase (int): Value in {1,3} which determines what objective and variables are used in training updates.
            is_train (bool): Optional, run backwards pass if True
        """
        x = tf.cast(batch['image'], tf.float32) / 255.
        labels = tf.cast(batch['label'], tf.int32)

        with tf.GradientTape() if is_train else dummy_context_mgr() as tape:
            global_step.assign_add(1)
            y_hat, minimum_distances, _ = model(x)
            classification_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=labels, logits=y_hat)
            loss_term_dict = model.conv_prototype_objective(minimum_distances,
                                                            label=labels)

            # build loss
            assert phase in [1, 3]
            loss = classification_loss
            if 'cluster' in loss_term_dict and phase == 1:
                loss += cluster_coeff * loss_term_dict['cluster']
            if 'l1' in loss_term_dict:
                loss += l1_coeff * loss_term_dict['l1']
            if 'separation' in loss_term_dict and phase == 1:
                loss += separation_coeff * loss_term_dict['separation']

            loss = tf.reduce_mean(loss)

        # calculate gradients for current loss
        if is_train:
            # choose which variables are being trained
            if phase == 1:
                train_vars = model.trainable_prototype_vars + model.trainable_conv_stack_vars
                if delay_conv_stack_training and epoch >= 5:
                    train_vars += model.trainable_conv_stack_vars
                else:
                    train_vars += model.trainable_conv_stack_vars
            elif phase == 3:
                train_vars = model.trainable_classifier_vars

            # calculate gradients and apply update
            gradients = tape.gradient(loss, train_vars)
            if clip_norm:
                clipped_gradients, pre_clip_global_norm = tf.clip_by_global_norm(
                    gradients, clip_norm)
            else:
                clipped_gradients = gradients
            optimizer.apply_gradients(zip(clipped_gradients, train_vars))

        # log to TensorBoard
        prefix = 'train_' if is_train else 'validate_'
        with writer.as_default():
            prediction = tf.math.argmax(y_hat, axis=1, output_type=tf.int32)
            classification_rate = tf.reduce_mean(
                tf.cast(tf.equal(prediction, labels), tf.float32))
            tf.summary.scalar(prefix + "classification_rate",
                              classification_rate,
                              step=global_step)
            tf.summary.scalar(prefix + "loss/mean classification",
                              tf.reduce_mean(classification_loss),
                              step=global_step)
            if 'cluster' in loss_term_dict:
                tf.summary.scalar(prefix + "loss/mean cluster",
                                  cluster_coeff *
                                  tf.reduce_mean(loss_term_dict['cluster']),
                                  step=global_step)
            if 'l1' in loss_term_dict:
                tf.summary.scalar(prefix + "loss/mean l1",
                                  l1_coeff *
                                  tf.reduce_mean(loss_term_dict['l1']),
                                  step=global_step)
            if 'separation' in loss_term_dict:
                tf.summary.scalar(prefix + "loss/mean separation",
                                  separation_coeff *
                                  tf.reduce_mean(loss_term_dict['separation']),
                                  step=global_step)
            tf.summary.scalar(prefix + "loss/total loss",
                              loss,
                              step=global_step)

    ### PHASE 1
    # run training loop
    if debug: print("PHASE 1 - TRAINING CONV STACK AND PROTOTYPES")
    for epoch in range(epochs_phase_1):
        # train
        train_batches = data_dict['train']
        if debug:
            print("Epoch", epoch)
            print("TRAIN")
            train_batches = tqdm(train_batches, total=data_dict['train_bpe'])
        for batch in train_batches:
            run_minibatch(epoch, batch, phase=1, is_train=True)

        # test
        test_batches = data_dict['test']
        if debug:
            print("TEST")
            test_batches = tqdm(test_batches, total=data_dict['test_bpe'])
        for batch in test_batches:
            run_minibatch(epoch, batch, phase=1, is_train=False)

    ### PHASE 2
    if debug: print("PHASE 2 - PUSHING PROTOTYPES")

    def classification_rate(ds):
        numerator = 0.
        denominator = 0.
        for batch in ds:
            x = tf.cast(batch['image'], tf.float32) / 255.
            labels = tf.cast(batch['label'], tf.int64)
            y_hat, _, _ = model(x)
            numerator += tf.reduce_sum(
                tf.cast(tf.equal(labels, tf.argmax(y_hat, axis=1)), tf.float32))
            denominator += labels.shape[0]
        return numerator / denominator

    def calculate_patches(image):
        x = tf.cast(image, tf.float32) / 255.
        _, _, patches = model(x)
        return patches

    print("Classification rate before prototype push: ",
          classification_rate(data_dict['train']))
    # calculate intput to prototype layer from each datum
    patches = list(data_dict['train'].map(lambda x: (calculate_patches(x[
        'image']), x['label'])).apply(lambda ds: ds.unbatch()))

    # group the data into classes
    class_patches = [[] for i in range(data_dict['num_classes'])]
    for img, label in patches:
        class_patches[label.numpy()].append(img)
    class_patches = [np.stack(cp) for cp in class_patches]

    # set new prototype values
    new_prototypes = model.prototypes.numpy()
    for i, proto in enumerate(model.prototypes.numpy()):
        # get the associated class of the prototype
        label = np.argmax(model.prototype_class_identity[i])
        # calculate distance between prototype and all patches of the associated class
        distances = np.reshape(
            l2_convolution(class_patches[label], np.expand_dims(proto, 0)),
            [-1])
        # set new value
        new_prototypes[i] = np.reshape(class_patches[label],
                                       [-1])[np.argmin(distances)]
    model.prototypes = new_prototypes
    print("Classification rate after prototype push: ",
          classification_rate(data_dict['train']))

    ### PHASE 3
    if debug: print("PHASE 3 - TRAINING CLASSIFIER")
    # run training loop
    for epoch in range(epochs_phase_3):
        # train
        train_batches = data_dict['train']
        if debug:
            print("Epoch", epoch)
            print("TRAIN")
            train_batches = tqdm(train_batches, total=data_dict['train_bpe'])
        for batch in train_batches:
            run_minibatch(epoch, batch, phase=3, is_train=True)

        # test
        test_batches = data_dict['test']
        if debug:
            print("TEST")
            test_batches = tqdm(test_batches, total=data_dict['test_bpe'])
        for batch in test_batches:
            run_minibatch(epoch, batch, phase=3, is_train=False)
Exemple #10
0
 def __init__(self, network_name, image_size, name='dec'):
     super(Decoder, self).__init__(name=name)
     self.net = get_network_builder(network_name)({'image_size': image_size})
     self.final = tf.keras.layers.Conv2D(3, 3, padding="same")