Example #1
0
    def load_shots(self,
                   shot_list,
                   is_inference=False,
                   as_list=False,
                   num_samples=np.Inf):
        X = []
        Y = []
        Disr = []
        print("loading...")
        pbar = Progbar(len(shot_list))

        sample_prob_d, sample_prob_nd = self.get_sample_probs(
            shot_list, num_samples)
        fn = partial(self.load_shot,
                     is_inference=is_inference,
                     sample_prob_d=sample_prob_d,
                     sample_prob_nd=sample_prob_nd)
        pool = mp.Pool()
        print('loading data in parallel on {} processes'.format(
            pool._processes))
        for x, y, disr in pool.imap(fn, shot_list):
            X.append(x)
            Y.append(y)
            Disr.append(disr)
            pbar.add(1.0)
        pool.close()
        pool.join()
        return X, Y, np.array(Disr)
Example #2
0
def make_predictions(conf, shot_list, loader, custom_path=None):
    feature_extractor = FeatureExtractor(loader)
    # save_prepath = feature_extractor.get_save_prepath()
    if custom_path is None:
        model_path = conf['paths']['model_save_path'] + \
            model_filename  # save_prepath + model_filename
    else:
        model_path = custom_path
    model = joblib.load(model_path)
    # shot_list = shot_list.random_sublist(10)

    y_prime = []
    y_gold = []
    disruptive = []

    pbar = Progbar(len(shot_list))
    fn = partial(predict_single_shot,
                 model=model,
                 feature_extractor=feature_extractor)
    pool = mp.Pool()
    print('predicting in parallel on {} processes'.format(pool._processes))
    # for (y_p, y, disr) in map(fn, shot_list):
    for (y_p, y, disr) in pool.imap(fn, shot_list):
        # y_p, y, disr = predict_single_shot(model, feature_extractor,shot)
        y_prime += [np.expand_dims(y_p, axis=1)]
        y_gold += [np.expand_dims(y, axis=1)]
        disruptive += [disr]
        pbar.add(1.0)

    pool.close()
    pool.join()
    return y_prime, y_gold, disruptive
Example #3
0
    def transform(self, texts, verbose=False):

        if type(texts) is str:
            texts = [texts]

        texts = list(map(self._preprocessor, texts))
        n_samples = len(texts)

        blank_idx = []
        for i, text in enumerate(texts):
            if len(text) == 0:
                texts[i] = self._space_escape
                blank_idx.append(i)

        bar = Progbar(n_samples)

        mats = []
        for bi, text_batch in enumerate(batch(texts, self.batch_size)):

            self._data_container.set(text_batch)
            features = next(self._predict_fn)['output']
            mats.append(features)

            if verbose:
                bar.add(len(text_batch))

        mat = np.vstack(mats)
        if len(blank_idx):
            blank_idx = np.array(blank_idx)
            mat[blank_idx] = 0.0

        return mat
Example #4
0
    def prepare(self):
        self._create_own_dir()
        unzip_dir = self.unzip_dir
        image_names = os.listdir(unzip_dir)
        labels_dict = dict()

        face_detector = FaceDetector()
        face_aligner = FaceAligner(padding=0.1)
        progbar = Progbar(target=len(image_names))

        for image_name in image_names:
            progbar.add(1)
            age = image_name.split('_')[3]
            labels_dict[image_name] = int(age)

            image = read_image_like_rgb(os.path.join(unzip_dir, image_name))
            image = img_as_ubyte(exposure.equalize_adapthist(image))
            face_bboxes = face_detector.safe_detect_face_bboxes(image, include_cnn=False).clip(min=0)

            if face_bboxes.shape[0] == 0:
                cropped_image = self._crop_center(image, image.shape[0]//2, image.shape[1]//2)

            else:
                cropped_image = face_aligner.align_and_crop(image, bboxes=face_bboxes, bbox_number=0)

            image_path = os.path.join(self.data_dir, self.dataset_name, 'images', image_name)
            imsave(image_path, cropped_image)

        labels_path = os.path.join(self.data_dir, self.dataset_name, 'labels_dict.npy')
        np.save(labels_path, labels_dict)
Example #5
0
    def prepare(self):
        self._create_own_dir()
        unzip_dir = os.path.join(self.unzip_dir, 'wiki')
        mat_file = loadmat(os.path.join(unzip_dir, 'wiki.mat'))
        labels_dict = self._parse_mat(mat_file)

        face_detector = FaceDetector()
        face_aligner = FaceAligner(padding=0.1)
        progbar = Progbar(target=len(labels_dict))

        for image_subpath in list(labels_dict.keys()):
            progbar.add(1)
            image = read_image_like_rgb(os.path.join(unzip_dir, image_subpath))
            image = img_as_ubyte(exposure.equalize_adapthist(image))
            face_bboxes = face_detector.safe_detect_face_bboxes(image, include_cnn=False).clip(min=0)

            if face_bboxes.shape[0] == 0:
                continue

            else:
                cropped_image = face_aligner.align_and_crop(image, bboxes=face_bboxes, bbox_number=0)
                image_name = image_subpath.split('/')[1]
                image_path = os.path.join(self.data_dir, self.dataset_name, 'images', image_name)
                imsave(image_path, cropped_image)

        labels_path = os.path.join(self.data_dir, self.dataset_name, 'labels_dict.npy')
        labels_dict = {key.split('/')[1]: value for key, value in labels_dict.items()}
        np.save(labels_path, labels_dict)
    def predict_depth_generator(self, data_iterator, depth, steps, progbar: Progbar = None):
        """
        Args:
            data_iterator: should provide data in the form of a dict containing keys
                IterativeARTResNet.imgs_input_name and IterativeARTResNet.sinos_input_name
        """
        if progbar is not None:
            progbar.add(1)

        if depth == 0:
            return data_iterator

        new_actual = []
        new_sino = []
        new_good_reco = []

        # Lots of optimisation needed. Outputting sinograms and good_reconstructions could be optimized.
        nr_steps = steps or len(data_iterator)
        progbar_sublevel = Progbar(target=nr_steps, verbose=1)
        for i in range(nr_steps):
            data_batch = next(data_iterator)
            reconstructions_output, bad_sinograms, good_reconstructions = \
                self._predict_depth_generator_step(data_batch)
            new_actual.append(reconstructions_output.numpy())
            new_sino.append(bad_sinograms.numpy())
            new_good_reco.append(good_reconstructions.numpy())

            progbar_sublevel.update(i + 1)

        new_data_iterator = RecSinoArrayIterator(new_actual, new_sino, new_good_reco)
        return self.predict_depth_generator(new_data_iterator, depth - 1, steps=None, progbar=progbar)
Example #7
0
    def pre_fit(self, batches, epochs=100):
        """Pre-trains the model.

        Args:
            batches (Dataset): Pre-training batches containing samples.
            epochs (int): The maximum number of pre-training epochs.

        """

        logger.info('Pre-fitting generator ...')

        # Gathering the amount of batches
        n_batches = tf.data.experimental.cardinality(batches).numpy()

        # Iterate through all generator epochs
        for e in range(epochs):
            logger.info('Epoch %d/%d', e + 1, epochs)

            # Resetting state to further append losses
            self.G_loss.reset_states()

            # Defining a customized progress bar
            b = Progbar(n_batches, stateful_metrics=['loss(G)'])

            # Iterate through all possible pre-training batches
            for x_batch, y_batch in batches:
                # Performs the optimization step over the generator
                self.G_pre_step(x_batch, y_batch)

                # Adding corresponding values to the progress bar
                b.add(1, values=[('loss(G)', self.G_loss.result())])

            logger.file('Loss(G): %s', self.G_loss.result().numpy())
Example #8
0
 def vectorize(text, verbose=False):
     x = []
     bar = Progbar(len(text))
     for text_batch in batch(text, batch_size):
         container.set(text_batch)
         x.append(next(predict_fn)['output'])
         if verbose:
             bar.add(len(text_batch))
     r = np.vstack(x)
     return r
Example #9
0
def lipschitz_lb(f, X1, X2, iterations=1000, verbose=True):

    optimizer = Adam(lr=0.0001)

    X1 = tf.Variable(X1, name='x1', dtype='float32')
    X2 = tf.Variable(X2, name='x2', dtype='float32')
    
    max_L = None

    if verbose:
        pb = Progbar(iterations, stateful_metrics=['LC'])
    
    for _ in range(iterations):
        with tf.GradientTape() as tape:
            y1 = f(X1)
            y2 = f(X2)
            
            # The definition of the margin is not entirely symmetric: the top
            # class must remain the same when measuring both points. We assume
            # X1 is the reference point for determining the top class.
            original_predictions = tf.cast(
                tf.equal(y1, tf.reduce_max(y1, axis=1, keepdims=True)), 
                'float32')
            
            # This takes the logit at the top class for both X1 and X2.
            y1_j = tf.reduce_sum(
                y1 * original_predictions, axis=1, keepdims=True)
            y2_j = tf.reduce_sum(
                y2 * original_predictions, axis=1, keepdims=True)
            
            margin1 = y1_j - y1
            margin2 = y2_j - y2

            axes = tuple((tf.range(len(X1.shape) - 1) + 1).numpy())
            
            L = tf.abs(margin1 - margin2) / (tf.sqrt(
                tf.reduce_sum((X1 - X2)**2, axis=axes)) + EPS)[:,None]

            loss = -tf.reduce_max(L, axis=1)
            
        grad = tape.gradient(loss, [X1, X2])

        optimizer.apply_gradients(zip(grad, [X1, X2]))
        
        if max_L is None:
            max_L = L
        else:
            max_L = tf.maximum(max_L, L)

        if verbose:
            pb.add(1, [('LC', tf.reduce_max(max_L))])
        
    return tf.reduce_max(max_L)
Example #10
0
def causality_checking(model_path, dtype='float32'):
    # Building model
    depth = 64
    height = 64
    width = 64
    n_channel = 1
    output_channel = 2
    box = np.random.randint(0, 2, (1, depth, height, width, n_channel))
    box = box.astype(dtype)
    voxelDNN = VoxelDNN(depth, height, width, n_channel, output_channel)
    #     voxel_DNN = voxelDNN.build_voxelDNN_model()
    voxel_DNN = voxelDNN.restore_voxelDNN(model_path)
    predicted_box1 = voxel_DNN(box)
    predicted_box1 = np.asarray(predicted_box1, dtype=dtype)
    probs1 = tf.nn.softmax(predicted_box1[0, :, :, :, :], axis=-1)
    predicted_box2 = voxel_DNN(box)
    predicted_box2 = np.asarray(predicted_box2, dtype=dtype)
    err = predicted_box2 - predicted_box1
    print(err.max(), err.min())
    i = 0
    predicted_box2 = np.zeros((1, depth, height, width, output_channel),
                              dtype=dtype)
    probs2 = np.zeros((1, depth, height, width, output_channel), dtype=dtype)
    progbar = Progbar(depth * height * width)
    for d in range(depth):
        for h in range(height):
            for w in range(width):
                if i > 9:
                    break
                tmp_box = np.random.randint(
                    0, 2, (1, depth, height, width, n_channel)
                )  # np.zeros((1, depth, height, width, n_channel), dtype='float32')
                tmp_box = tmp_box.astype(dtype=dtype)
                tmp_box[:, :d, :, :, :] = box[:, :d, :, :, :]
                tmp_box[:, d, :h, :, :] = box[:, d, :h, :, :]
                tmp_box[:, d, h, :w, :] = box[:, d, h, :w, :]
                predicted = voxel_DNN(tmp_box)
                predicted_box2[:, d, h, w, :] = predicted[:, d, h, w, :]
                probs2[0, d, h, w, :] = tf.nn.softmax(predicted_box2[0, d, h,
                                                                     w, :],
                                                      axis=-1)
                i += 1
                progbar.add(1)
    predicted_box2 = np.asarray(predicted_box2, dtype=dtype)
    compare = predicted_box2 == predicted_box1
    print('Check 4: ', np.count_nonzero(compare), compare.all())
    print(probs2[0, 0, 0, 0, :])
    print(probs1[0, 0, 0, :].numpy())
    err = predicted_box2 - predicted_box1
    print(err.max(), err.min())
    def fit(self, batches, good_batches, epochs=100):
        """Trains the model.

        Args:
            batches (Dataset): Training batches containing samples.
            epochs (int): The maximum number of training epochs.

        """

        logger.info('Fitting model ...')

        # Gathering the amount of batches
        n_batches = tf.data.experimental.cardinality(batches).numpy()
        print(n_batches)

        good_batches = list(good_batches.as_numpy_iterator())

        # Iterate through all epochs
        for e in range(epochs):
            logger.info('Epoch %d/%d', e + 1, epochs)

            # Resetting states to further append losses
            self.G_loss.reset_states()
            self.D_loss.reset_states()

            # Defining a customized progress bar
            b = Progbar(n_batches, stateful_metrics=['loss(G)', 'loss(D)'])

            i = 0
            # Iterate through all possible training batches
            for batch, tar in enumerate(batches):
                # Performs the optimization step
                self.step(tar, good_batches[i])

                # Adding corresponding values to the progress bar
                b.add(1,
                      values=[('loss(G)', self.G_loss.result()),
                              ('loss(D)', self.D_loss.result())])
                i += 1

            # Exponentially annealing the Gumbel-Softmax temperature
            self.G.tau = self.init_tau**((epochs - e) / epochs)

            # Dumps the losses to history
            self.history['G_loss'].append(self.G_loss.result().numpy())
            self.history['D_loss'].append(self.D_loss.result().numpy())

            logger.to_file('Loss(G): %s | Loss(D): %s',
                           self.G_loss.result().numpy(),
                           self.D_loss.result().numpy())
Example #12
0
def pre_train(model, optimizer, loops=10000, batch_size=300, seed=None):
    '''Pre-train the model to learn an simpler value function

  The value function is computed by the random board generator
  and is the (normalized) column that the ball is in.

  Inputs
  ------
  model: a model to train that supports get_all_values=True
       when called on the forward pass in order to return
       the computed values (instead of the best move).

  loops: number of batches to do

  batch_size: number of randomly generated boards

  seed: None or int, a random seed for the board generator.
        If None, no seed is given.

  lr: the learning rate (or alpha)

  device: e.g. torch.device('cpu'), the device on which to
          do computations

  '''

    if seed is not None:
        np.random.seed(seed)

    device = next(model.parameters()).device

    min_density = 0
    max_density = 0.3

    bar = Progbar(loops)
    for _ in range(loops):
        boards, targets = random_board_batch(min_density, max_density,
                                             batch_size, device)

        targets.mul_(1 / config.cols)

        predictions = model(boards, get_all_values=True)
        loss = F.mse_loss(predictions, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        bar.add(1, values=[('loss', loss.item())])
    def fit(self, batches, epochs=100, critic_steps=5):
        """Trains the model.

        Args:
            batches (Dataset): Training batches containing samples.
            epochs (int): The maximum number of training epochs.
            critic_steps (int): Amount of discriminator epochs per training epoch.

        """

        logger.info('Fitting model ...')

        # Gathering the amount of batches
        n_batches = tf.data.experimental.cardinality(batches).numpy()

        # Iterate through all epochs
        for e in range(epochs):
            logger.info('Epoch %d/%d', e + 1, epochs)

            # Resetting states to further append losses
            self.G_loss.reset_states()
            self.D_loss.reset_states()

            # Defining a customized progress bar
            b = Progbar(n_batches, stateful_metrics=['loss(G)', 'loss(D)'])

            # Iterate through all possible training batches
            for batch in batches:
                # Iterate through all possible critic steps
                for _ in range(critic_steps):
                    # Performs the optimization step over the discriminator
                    self.D_step(batch)

                # Performs the optimization step over the generator
                self.G_step(batch)

                # Adding corresponding values to the progress bar
                b.add(1,
                      values=[('loss(G)', self.G_loss.result()),
                              ('loss(D)', self.D_loss.result())])

            # Dumps the losses to history
            self.history['G_loss'].append(self.G_loss.result().numpy())
            self.history['D_loss'].append(self.D_loss.result().numpy())

            logger.to_file('Loss(G): %s | Loss(D): %s',
                           self.G_loss.result().numpy(),
                           self.D_loss.result().numpy())
    def run_offline_transformations(
            self,
            offline_transformation: DicomOfflineTransformation,
            array_stream: ArrayStream,
            verbose=True):

        if verbose:
            print("We are starting offline transformation:")
            #progress: ProgressNumber = utility.ProgressNumber(max_value=len(self))
            progbar = Progbar(len(self))

        for scanw_batch in self:
            patient_ids = [scanw.patient_id for scanw in scanw_batch]
            data_batch = [scanw.load_data() for scanw in scanw_batch]

            nrs_imgs = [data['len'] for data in data_batch]
            volumes = np.concatenate([data['volume'] for data in data_batch],
                                     axis=0)
            intercepts = np.concatenate(
                [data['intercepts'] for data in data_batch], axis=0)
            slopes = np.concatenate([data['slopes'] for data in data_batch],
                                    axis=0)

            imgs_boundaries = [
                sum(nrs_imgs[:i + 1]) for i in range(len(nrs_imgs) - 1)
            ]

            output_data_batch = offline_transformation(volumes,
                                                       intercepts=intercepts,
                                                       slopes=slopes)
            patient_data_batch = [
                output_data_batch[i:j]
                for i, j in zip([0] + imgs_boundaries, imgs_boundaries +
                                [None])
            ]

            for patient_id, patient_data in zip(patient_ids,
                                                patient_data_batch):
                array_stream.switch_dir(patient_id)
                for idex, img_data in enumerate(patient_data):
                    arrname = '{:04}'.format(idex)
                    array_stream.save_arrays(arrname, img_data)

            if verbose:
                #progress.update_add(len(scanw_batch))
                progbar.add(len(scanw_batch))
Example #15
0
    def predict(self):

        prediction_dict = dict()
        progbar = Progbar(target=len(self.dataset.image_names))

        for image_name in self.dataset.image_names:
            progbar.add(1)
            image = read_by_pyvips(self.dataset.get_absolute_path(image_name))
            image = imresize(image, (self.model_wrapper.input_shape[0],
                                     self.model_wrapper.input_shape[1]))
            image = image / 255.

            prediction = self.model_wrapper.model.predict(
                np.expand_dims(image, axis=0))[0][0]
            prediction_dict[image_name] = prediction

        return prediction_dict
Example #16
0
    def compute_emb(data_iter, n_branches):
        embs = []
        progbar = Progbar(len(data_iter.files_arr))

        if n_branches == 1:
            test_init_op = 'test_init_op'
            output = 'model/output/output:0'
            # output = 'output/BiasAdd:0'
        else:
            test_init_op = []
            output = []
            for i in range(n_branches):
                test_init_op.append('test_init_op_' + str(i + 1))
                output.append('model/output/vb{}/output:0'.format(i + 1))

        for i, batch in enumerate(data_iter):
            if isinstance(batch, np.ndarray):
                sess.run(test_init_op,
                         feed_dict={
                             'x:0': batch,
                             'batch_size:0': len(batch)
                         })
                e = sess.run(output)
            else:
                # Test augmentation
                aug_outputs = []
                for i, aug_batch in enumerate(batch):
                    sess.run(test_init_op,
                             feed_dict={
                                 'x:0': aug_batch,
                                 'batch_size:0': len(aug_batch)
                             })
                    aug_outputs.append(sess.run(output))
                # Mean has better performance vs. concatenate
                e = np.mean(aug_outputs, axis=0)

            # if n_branches > 1:
            #     e = np.concatenate(e, axis=-1)

            # print(e.shape)
            embs.append(e)
            progbar.add(e.shape[-2])

        return embs
Example #17
0
def eval(model, eval_dataset):
    print("Evaluating model..")
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')
    test_accuracy.reset_states()
    num_class = config.num_classes

    start = time.time()
    batches_per_epoch = tf.data.experimental.cardinality(eval_dataset).numpy()
    pb = Progbar(batches_per_epoch, width=30)
    for x_context, x_face, y in eval_dataset:
        scores = model(x_face, x_context, training=False)
        test_accuracy(y, scores)  # update the metric
        y_pred = tf.argmax(scores, axis=1)
        pb.add(1)

    end = time.time()
    print("Evaluating time: %d seconds" % ((end - start)))

    val_acc = test_accuracy.result().numpy()
    print("Evaluate accuracy: {:.4}".format(test_accuracy.result()))
Example #18
0
class ProgressBar:
    """ tf keras progress bar to print the evaluations results """
    def __init__(self, metrics):
        self.metrics_train = ['train_' + metric for metric in metrics]
        self.metrics_val = ['val_' + metric for metric in metrics]
        self.metrics = self.metrics_train + self.metrics_val

    def prepare(self, n_samples, verbose=1):
        self.n_samples = n_samples
        self.prog_bar = Progbar(n_samples,
                                stateful_metrics=self.metrics,
                                verbose=verbose)

    def update_bar(self, d_train, d_val=[]):
        values = []
        for key in d_train:
            values.append(('train_' + key, d_train[key]))
        for key in d_val:
            values.append(('val_' + key, d_val[key]))
        self.prog_bar.add(n=1, values=values)
Example #19
0
    def fit(self, batches, epochs=100):
        """Trains the model.

        Args:
            batches (Dataset): Training batches containing samples.
            epochs (int): The maximum number of training epochs.

        """

        logger.info('Fitting model ...')

        # Gathering the amount of batches
        n_batches = tf.data.experimental.cardinality(batches).numpy()

        # Iterate through all epochs
        for e in range(epochs):
            logger.info('Epoch %d/%d', e + 1, epochs)

            # Resetting states to further append losses
            self.G_loss.reset_states()
            self.D_loss.reset_states()

            # Defining a customized progress bar
            b = Progbar(n_batches, stateful_metrics=['loss(G)', 'loss(D)'])

            # Iterate through all possible training batches
            for x_batch, y_batch in batches:
                # Performs the optimization step
                self.step(x_batch, y_batch)

                # Adding corresponding values to the progress bar
                b.add(1,
                      values=[('loss(G)', self.G_loss.result()),
                              ('loss(D)', self.D_loss.result())])

            # Exponentially annealing the Gumbel-Softmax temperature
            self.G.tau = 5**((epochs - e) / epochs)

            logger.file('Loss(G): %s | Loss(D): %s',
                        self.G_loss.result().numpy(),
                        self.D_loss.result().numpy())
Example #20
0
def test_model(ds_name, encoder, paths, categorical=False):
    """The main function for executing network testing. It loads the specified
       dataset iterator and optimized saliency model. By default, when no model
       checkpoint is found locally, the pretrained weights will be downloaded.

    Args:
        ds_name (str): Denotes the dataset that was used during training.
        encoder (str): the name of the encoder want to be used to predict.
        paths (dict, str): A dictionary with all path elements.
    """

    w_filename_template = "/%s_%s_%s_weights.h5" # [encoder]_[ds_name]_weights.h5

    (test_ds, n_test) = data.load_test_dataset(ds_name, paths["data"], categorical)
    
    print(">> Preparing model with encoder %s..." % encoder)

    model = MyModel(encoder, ds_name, "test")

    weights_path = paths["weights"] + w_filename_template % (encoder, ds_name, loss_fn_name)
    if os.path.exists(weights_path):
        print("Weights are loaded!\n    %s"%weights_path)
    else:
        download.download_pretrained_weights(paths["weights"], encoder, ds_name, loss_fn_name)
    model.load_weights(weights_path)
    del weights_path

    print(">> Start predicting using model trained on %s..." % ds_name.upper())
    results_path = paths["results"] + "%s/%s/%s/" % (ds_name, encoder, loss_fn_name)

    # Preparing progbar
    test_progbar = Progbar(n_test)
    for test_images, test_ori_sizes, test_filenames in test_ds:
        pred = test_step(test_images, model)
        for pred, filename, ori_size in zip(pred, test_filenames.numpy(), test_ori_sizes):
            img = data.postprocess_saliency_map(pred, ori_size, as_image=True)
            tf.io.write_file(results_path + filename.decode("utf-8"), img)
        test_progbar.add(test_images.shape[0])
Example #21
0
def train_loop(epochs):
    for epoch in range(epochs):
        progbar = Progbar(train_ds_len, unit_name='batch')
        for images, labels in train_ds:
            train_step(images, labels)
            progbar.add(1)

        for test_images, test_labels in test_ds:
            test_step(test_images, test_labels)

        template = 'Epoch {}, Loss: {}, Accuracy: {}, \
          Test Loss: {}, Test Accuracy: {}'

        print(
            template.format(epoch + 1, train_loss.result(),
                            train_accuracy.result() * 100, test_loss.result(),
                            test_accuracy.result() * 100))

        # Reset the metrics for the next epoch
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()
Example #22
0
def eval_verification(m,
                      N,
                      x_test,
                      y_test,
                      epsilon=0.25,
                      batch_size=100,
                      timeout=60,
                      seed=0):

    np.random.seed(0)

    n_robust = 0
    n_nonrobust = 0
    n_timeout = 0
    n_unknown = 0
    n_robust_correct = 0
    times = []
    n_visited = []
    pb = Progbar(N, stateful_metrics=['vra'])

    for j in range(N):
        start_time = time()

        r = check(m,
                  x_test[j],
                  epsilon,
                  recap=False,
                  timeout=timeout,
                  keepgoing=True,
                  return_num_visited=True,
                  batch_size=batch_size,
                  lowerbound=False)

        end_time = time()

        n_visited.append(r[1])

        if r[0] is ROBUST:
            n_robust += 1
            if m.predict(x_test[j:j +
                                1]).argmax(axis=1)[0] == y_test[j].argmax():
                n_robust_correct += 1

        elif r[0] is INCONCLUSIVE:
            n_unknown += 1

        elif r[0] is NOT_ROBUST:
            n_nonrobust += 1

        if r[0] is TIMED_OUT:
            n_timeout += 1

        else:
            times.append(end_time - start_time)

        pb.add(1, [('vra', float(n_robust_correct) / float(j + 1)),
                   ('ro', 1 if r[0] is ROBUST else 0),
                   ('adv', 1 if r[0] is NOT_ROBUST else 0),
                   ('unk', 1 if r[0] is INCONCLUSIVE else 0),
                   ('to', 1 if r[0] is TIMED_OUT else 0),
                   ('rt', end_time - start_time), ('vis', r[1])])

    print('# robust: {}'.format(n_robust))
    print('# non-robust: {}'.format(n_nonrobust))
    print('# timeout: {}'.format(n_timeout))
    print('# unknown: {}'.format(n_unknown))
    print('med time: {:.2f}'.format(
        np.sort(np.array(times))[int(len(times) / 2)]))
    print('mean time: {:.2f}'.format(np.array(times).mean()))
    print('mean # visited: {:.1f}'.format(np.array(n_visited).mean()))
    print('median # regions visited: {:.1f}'.format(
        np.sort(np.array(n_visited))[int(len(n_visited) / 2)]))
    print('verified robust accuracy: {:.2}'.format(
        float(n_robust_correct) / float(N)))
Example #23
0
    def train_voxelDNN(self, batch, epochs, model_path, saved_model, dataset,
                       portion_data):
        #log directory
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_log_dir = model_path + 'log' + current_time + '/train'
        test_log_dir = model_path + 'log' + current_time + '/test'
        train_summary_writer = tf.summary.create_file_writer(train_log_dir)
        test_summary_writer = tf.summary.create_file_writer(test_log_dir)
        #initialize model and optimizer, loss
        voxelDNN = self.build_voxelDNN_model()
        [train_dataset, test_dataset, number_training_data
         ] = self.calling_dataset(training_dirs=dataset,
                                  batch_size=batch,
                                  portion_data=portion_data)
        learning_rate = 1e-3
        optimizer = tf.optimizers.Adam(lr=learning_rate)
        compute_loss = keras.losses.CategoricalCrossentropy(from_logits=True, )
        n_epochs = epochs
        n_iter = int(number_training_data / batch)
        #early stopping setting
        best_val_loss, best_val_epoch = None, None
        max_patience = 10
        early_stop = False
        #Load lastest checkpoint
        vars_to_load = {
            "Weight_biases": voxelDNN.trainable_variables,
            "optimizer": optimizer
        }
        checkpoint = tf.train.Checkpoint(**vars_to_load)
        latest_ckpt = tf.train.latest_checkpoint(saved_model)
        if latest_ckpt is not None:
            checkpoint.restore(latest_ckpt)
            print('Loaded last checkpoint')
        else:
            print('Training from scratch')
        ckpt_manager = tf.train.CheckpointManager(checkpoint,
                                                  checkpoint_name='ckpt_',
                                                  directory=model_path,
                                                  max_to_keep=40)
        losses = []
        #training
        for epoch in range(n_epochs):
            progbar = Progbar(n_iter)
            print('Epoch {:}/{:}'.format(epoch + 1, n_epochs))
            loss_per_epochs = []
            for i_iter, batch_x in enumerate(train_dataset):
                batch_y = tf.cast(batch_x, tf.int32)
                with tf.GradientTape() as ae_tape:

                    logits = voxelDNN(batch_x, training=True)
                    y_true = tf.one_hot(batch_y, self.output_channel)
                    y_true = tf.reshape(
                        y_true, (y_true.shape[0], self.depth, self.height,
                                 self.width, self.output_channel))
                    loss = compute_loss(y_true, logits)

                    metrics = compute_acc(y_true, logits, loss,
                                          train_summary_writer,
                                          int(epoch * n_iter + i_iter))
                gradients = ae_tape.gradient(loss,
                                             voxelDNN.trainable_variables)
                gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
                optimizer.apply_gradients(
                    zip(gradients, voxelDNN.trainable_variables))
                loss_per_epochs.append(loss / batch_x.shape[0])
                progbar.add(1, values=[('loss', loss), ('f1', metrics[8])])
            avg_train_loss = np.average(loss_per_epochs)
            losses.append(avg_train_loss)

            # Validation dataset
            test_losses = []
            test_metrics = []
            for i_iter, batch_x in enumerate(test_dataset):
                batch_y = tf.cast(batch_x, tf.int32)
                logits = voxelDNN(batch_x, training=True)
                y_true = tf.one_hot(batch_y, self.output_channel)
                y_true = tf.reshape(y_true,
                                    (y_true.shape[0], self.depth, self.height,
                                     self.width, self.output_channel))

                loss = compute_loss(y_true, logits)
                metrics = compute_acc(y_true, logits, loss,
                                      test_summary_writer, i_iter)
                test_losses.append(loss / batch_x.shape[0])
                test_metrics.append(metrics)

            test_metrics = np.asarray(test_metrics)
            avg_metrics = np.average(test_metrics, axis=0)
            avg_test_loss = np.average(test_losses)

            print("Testing result on epoch: %i, test loss: %f " %
                  (epoch, avg_test_loss))
            tf.print(' tp: ',
                     avg_metrics[0],
                     ' tn: ',
                     avg_metrics[1],
                     ' fp: ',
                     avg_metrics[2],
                     ' fn: ',
                     avg_metrics[3],
                     ' precision: ',
                     avg_metrics[4],
                     ' recall: ',
                     avg_metrics[5],
                     ' accuracy: ',
                     avg_metrics[6],
                     ' specificity ',
                     avg_metrics[7],
                     ' f1 ',
                     avg_metrics[8],
                     output_stream=sys.stdout)

            if best_val_loss is None or best_val_loss > avg_test_loss:
                best_val_loss, best_val_epoch = avg_test_loss, epoch
                ckpt_manager.save()
                print('Saved model')
            if best_val_epoch < epoch - max_patience:
                print('Early stopping')
                break
Example #24
0
def local_lipschitz_lb(f, X1, X2, eps, iterations=1000, verbose=True):

    optimizer = Adam(lr=0.0001)

    X0 = tf.constant(X1, dtype='float32')
    y0 = f(X0)

    X1 = tf.Variable(X1, name='x1', dtype='float32')
    X2 = tf.Variable(X2, name='x2', dtype='float32')
    
    max_L = None

    if verbose:
        pb = Progbar(iterations, stateful_metrics=['max_LC', 'mean_LC'])
    
    for i in range(iterations):
        with tf.GradientTape() as tape:

            axes = tuple((tf.range(len(X1.shape) - 1) + 1).numpy())

            # Project so that X1 and X2 are at distance at most `eps` from X0.
            delta1 = X1 - X0
            dist1 = tf.sqrt(tf.reduce_sum(
                delta1 * delta1, axis=axes, keepdims=True))
            
            delta2 = X2 - X0
            dist2 = tf.sqrt(tf.reduce_sum(
                delta2 * delta2, axis=axes, keepdims=True))
            
            # Only project if `dist` > `eps`.
            where_dist_gt_eps1 = tf.cast(dist1 > eps, 'float32')
            where_dist_gt_eps2 = tf.cast(dist2 > eps, 'float32')

            X1.assign(
                (X0 + eps * delta1 / (dist1 + EPS)) * where_dist_gt_eps1 + 
                X1 * (1 - where_dist_gt_eps1))
            X2.assign(
                (X0 + eps * delta2 / (dist2 + EPS)) * where_dist_gt_eps2 + 
                X2 * (1 - where_dist_gt_eps2))

            y1 = f(X1)
            y2 = f(X2)
            
            # The definition of the margin is not entirely symmetric: the top
            # class must remain the same when measuring both points. We assume
            # X0 is the reference point for determining the top class.
            original_predictions = tf.cast(
                tf.equal(y0, tf.reduce_max(y0, axis=1, keepdims=True)), 
                'float32')
            
            # This takes the logit at the top class for both X1 and X2.
            y1_j = tf.reduce_sum(
                y1 * original_predictions, axis=1, keepdims=True)
            y2_j = tf.reduce_sum(
                y2 * original_predictions, axis=1, keepdims=True)
            
            margin1 = y1_j - y1
            margin2 = y2_j - y2

            L = tf.abs(margin1 - margin2) / (tf.sqrt(
                tf.reduce_sum((X1 - X2)**2, axis=axes)) + EPS)[:,None]

            loss = -tf.reduce_max(L, axis=1)
           
        if i < iterations - 1: 
            grad = tape.gradient(loss, [X1, X2])

            optimizer.apply_gradients(zip(grad, [X1, X2]))
        
        if max_L is None:
            max_L = L
        else:
            max_L = tf.maximum(max_L, L)

        if verbose:
            metrics = [
                ('max_LC', tf.reduce_max(max_L)), 
                ('mean_LC', tf.reduce_mean(max_L))
            ]
            
            pb.add(1, metrics)
        
    return tf.reduce_max(max_L), tf.reduce_mean(max_L)
Example #25
0
def extract_faces(kind):
    """
    :param kind: "train", "val" or "test
    Extract cropped faces using dlib face detector
    """
    if kind == 'train':
        image_path = config.train_images
        crop_path = config.train_crop
    elif kind == 'val':
        image_path = config.val_images
        crop_path = config.val_crop
    elif kind == 'test':
        image_path = config.test_images
        crop_path = config.test_crop
    else:
        raise ValueError(
            'Wrong type of dataset ("train", "val" or "test" is acceptable)')
    if not os.path.exists(crop_path):
        os.makedirs(crop_path)
    dnnFaceDetector = dlib.cnn_face_detection_model_v1(
        "detector/mmod_human_face_detector.dat")
    imgs = []

    valid_images = [".jpg", ".gif", ".png", ".tga"]

    print(f"Extracting faces in {kind} set...")

    imgs_cnt = len([
        1 for category in os.listdir(image_path)
        for f in os.listdir(os.path.join(image_path, category))
    ])
    pb = Progbar(target=imgs_cnt, width=30)
    for category in os.listdir(image_path):
        if not os.path.exists(os.path.join(crop_path, category)):
            os.makedirs(os.path.join(crop_path, category))
        for f in os.listdir(os.path.join(image_path, category)):
            fname, ext = os.path.splitext(f)
            # print(os.path.join(path,category,f))
            if ext.lower() not in valid_images:
                continue
            img = cv2.imread(os.path.join(image_path, category, f))
            h, w, _ = img.shape
            # if (h>400):
            #     print(img.shape)
            #     print(os.path.join(path,category,f))
            # continue
            result = dnnFaceDetector(img, 1)
            max_confidence = -100
            final_rect = None
            for rect in result:
                if rect.confidence > max_confidence:
                    max_confidence = rect.confidence
                    final_rect = rect.rect

            # print(max_confidence)
            if (max_confidence == -100):
                fo = open(os.path.join(crop_path, category, fname + '.txt'),
                          "w")
                fo.write(",".join([str(0), str(0), str(0), str(0)]))
                fo.close()
                continue
            x1 = final_rect.left()
            y1 = final_rect.top()
            x2 = final_rect.right()
            y2 = final_rect.bottom()
            fo = open(os.path.join(crop_path, category, fname + '.txt'), "w")
            fo.write(",".join([str(x1), str(y1), str(x2), str(y2)]))
            fo.close()
            pb.add(1)
            imgs_cnt += 1
    print(f"Sucessfully cropped {imgs_cnt} images!")
Example #26
0
    def represent(self, molecules):
        """
        provides bag of bonds representation for input molecules.

        Parameters
        ----------
        molecules : chemml.chem.Molecule object or array
            If list, it must be a list of chemml.chem.Molecule objects, otherwise we raise a ValueError.
            In addition, all the molecule objects must provide the XYZ information. Please make sure the XYZ geometry has been
            stored or optimized in advance.

        Returns
        -------
        features : pandas data frame, shape: (n_molecules, max_length_of_combinations)
            The bag of bond features.

        """
        if isinstance(molecules, (list, np.ndarray)):
            molecules = np.array(molecules)
        elif isinstance(molecules, Molecule):
            molecules = np.array([molecules])
        else:
            msg = "The input molecules must be a chemml.chem.Molecule object or a list of objects."
            raise ValueError(msg)

        if molecules.ndim > 1:
            msg = "The molecule must be a chemml.chem.Molecule object or a list of objects."
            raise ValueError(msg)

        # pool of processes
        if self.n_jobs == -1:
            self.n_jobs = cpu_count()
        pool = Pool(processes=self.n_jobs)

        # Create an iterator
        # http://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks
        def chunks(l, n):
            """Yield successive n-sized chunks from l."""
            for i in range(0, len(l), n):
                yield l[i:i + n]

        # find size of each batch
        batch_size = int(len(molecules) / self.n_jobs)
        if batch_size == 0:
            batch_size = 1

        molecule_chunks = chunks(molecules, batch_size)

        # MAP: CM in parallel
        map_function = partial(self._represent)
        if self.verbose:
            print('featurizing molecules in batches of %i ...' % batch_size)
            pbar = Progbar(len(molecules), width=50)
            bbs_info = []
            for tensors in pool.imap(map_function, molecule_chunks):
                pbar.add(len(tensors[0]))
                bbs_info.append(tensors)
            print('Merging batch features ...    ', end='')
        else:
            bbs_info = pool.map(map_function, molecule_chunks)
        if self.verbose:
            print('[DONE]')

        # REDUCE: Concatenate the obtained tensors
        pool.close()
        pool.join()
        return self.concat_mol_features(bbs_info)
Example #27
0
    def represent(self, molecules):
        """
        provides coulomb matrix representation for input molecules.

        Parameters
        ----------
        molecules : chemml.chem.Molecule object or array
            If list, it must be a list of chemml.chem.Molecule objects, otherwise we raise a ValueError.
            In addition, all the molecule objects must provide the XYZ information. Please make sure the XYZ geometry has been
            stored or optimized in advance.

        Returns
        -------
        features : Pandas DataFrame
            A data frame with same number of rows as number of molecules will be returned.
            The exact shape of the dataframe depends on the type of CM as follows:
                - shape of Unsorted_Matrix (UM): (n_molecules, max_n_atoms**2)
                - shape of Unsorted_Triangular (UT): (n_molecules, max_n_atoms*(max_n_atoms+1)/2)
                - shape of eigenspectrums (E): (n_molecules, max_n_atoms)
                - shape of Sorted_Coulomb (SC): (n_molecules, max_n_atoms*(max_n_atoms+1)/2)
                - shape of Random_Coulomb (RC): (n_molecules, nPerm * max_n_atoms * (max_n_atoms+1)/2)
        """
        # check input molecules
        if isinstance(molecules, (list, np.ndarray)):
            molecules = np.array(molecules)
        elif isinstance(molecules, Molecule):
            molecules = np.array([molecules])
        else:
            msg = "The molecule must be a chemml.chem.Molecule object or a list of objects."
            raise ValueError(msg)

        if molecules.ndim > 1:
            msg = "The molecule must be a chemml.chem.Molecule object or a list of objects."
            raise ValueError(msg)

        self.n_molecules_ = molecules.shape[0]

        # max number of atoms based on the list of molecules
        if self.max_n_atoms_ == 'auto':
            try:
                self.max_n_atoms_ = max(
                    [m.xyz.atomic_numbers.shape[0] for m in molecules])
            except:
                msg = "The xyz representation of molecules is not available."
                raise ValueError(msg)

        # pool of processes
        if self.n_jobs == -1:
            self.n_jobs = cpu_count()
        pool = Pool(processes=self.n_jobs)

        # Create an iterator
        # http://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks
        def chunks(l, n):
            """Yield successive n-sized chunks from l."""
            for i in range(0, len(l), n):
                yield l[i:i + n]

        # find size of each batch
        batch_size = int(len(molecules) / self.n_jobs)
        if batch_size == 0:
            batch_size = 1

        molecule_chunks = chunks(molecules, batch_size)

        # MAP: CM in parallel
        map_function = partial(self._represent)
        if self.verbose:
            print('featurizing molecules in batches of %i ...' % batch_size)
            pbar = Progbar(len(molecules), width=50)
            tensor_list = []
            for tensors in pool.imap(map_function, molecule_chunks):
                pbar.add(len(tensors[0]))
                tensor_list.append(tensors)
            print('Merging batch features ...    ', end='')
        else:
            tensor_list = pool.map(map_function, molecule_chunks)
        if self.verbose:
            print('[DONE]')

        # REDUCE: Concatenate the obtained tensors
        pool.close()
        pool.join()
        return pd.concat(tensor_list, axis=0, ignore_index=True)
Example #28
0
    def train(self,
              train_dataset,
              valid_dataset,
              perturbation_generator,
              epochs,
              batch_size,
              learning_rate,
              probability_train_perturbation,
              train_perturbation_mode='fgsm',
              **kwargs):
        """
        all training happens here
        Args:
            train_dataset: tf.data.Dataset object of (x_data, y_labels), prepared with
            batch_size, shuffle etc.
            valid_dataset: tf.data.Dataset object of (x_data, y_labels), prepared with
            batch_size etc.
            perturbation_generator: object from class PerturbationGenerator
            epochs: int
            batch_size: int
            learning_rate: float
            probability_train_perturbation: float, proportion of training with
                adversarial calibration loss
            train_perturbation_mode: string, 'fgsm' for Falcon model
        """

        self.epochs = epochs
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.probability_train_perturbation = probability_train_perturbation
        self.train_perturbation_mode = train_perturbation_mode

        self.set_save_dir()
        self.log_dir_train = self.save_path + "train/"
        self.log_dir_test = self.save_path + "test/"
        self.model.build_optimizers(learning_rate)

        # create log metrics
        log_dict_train = {}
        log_dict_train["loss"] = tf.keras.metrics.Mean(name="train_loss")
        log_dict_train["accuracy"] = tf.keras.metrics.CategoricalAccuracy(
            name="train_accuracy")
        log_dict_train["entropy"] = tf.keras.metrics.Mean(name="train_entropy")
        log_dict_train["loss_sub_models"] = Log_extend_mean_array(
            name="loss_sub_models")
        log_dict_valid = {}
        log_dict_valid["loss"] = tf.keras.metrics.Mean(name="valid_test_loss")
        log_dict_valid["accuracy"] = tf.keras.metrics.CategoricalAccuracy(
            name="valid_test_accuracy")

        # set up summary writeres for Tensorboard
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        tb_train_log_dir = self.log_dir_train + current_time + "/train"
        tb_valid_log_dir = self.log_dir_train + current_time + "/valid"
        train_summary_writer = tf.summary.create_file_writer(tb_train_log_dir)
        with train_summary_writer.as_default():
            tf.summary.trace_on(graph=True, profiler=True)

        # create log history over training
        log_history_dict_train = {}
        log_history_dict_train["loss"] = []
        log_history_dict_train["accuracy"] = []
        log_history_dict_train["entropy"] = []
        log_history_dict_train["loss_sub_models"] = []
        log_history_dict_valid = {}
        log_history_dict_valid["loss"] = []
        log_history_dict_valid["accuracy"] = []

        ##Training##
        global_step = 0
        for epoch in range(self.epochs):
            global_step += 1
            tf.summary.trace_off()

            num_training_samples = self.dataset.train_steps_per_epoch
            num_valid_samples = self.dataset.valid_steps_per_epoch
            pb_i = Progbar(num_training_samples,
                           stateful_metrics=['acc', 'loss', 'ent'])
            pb_t = Progbar(num_valid_samples, stateful_metrics=['acc', 'loss'])
            for x_data, y_data in train_dataset:
                if np.random.uniform() < probability_train_perturbation:
                    log_dict_train = self.train_step(x_data, y_data,
                                                     log_dict_train,
                                                     global_step)
                    if perturbation_generator is not None:
                        n_eps = len(
                            perturbation_generator.possible_epsilons(
                                perturb_type=train_perturbation_mode))
                        epsilon = random.randint(0, n_eps - 1)
                        x_data = perturbation_generator.perturb_batch(
                            x_data,
                            perturb_type=train_perturbation_mode,
                            epsilon=epsilon,
                            model=self,
                            label=y_data,
                        )
                    self.train_step_advcalib(x_data, y_data, log_dict_train,
                                             global_step)
                else:
                    log_dict_train = self.train_step(x_data, y_data,
                                                     log_dict_train,
                                                     global_step)
                #progressbar
                pb_i.add(1,
                         values=[('acc', log_dict_train["accuracy"].result()),
                                 ('loss', log_dict_train["loss"].result()),
                                 ('ent', log_dict_train["entropy"].result())])
            # test on validation set
            if valid_dataset != None:
                for x_data, y_data in valid_dataset:
                    log_dict_valid = self.test_step(x_data,
                                                    y_data,
                                                    log_dict_valid,
                                                    n_test_iter_now=1)
            tf.summary.trace_on()
            print("Epoch " + str(epoch) + ", Minibatch Loss= " +
                  "{:.4f}".format(log_dict_train["loss"].result()) +
                  ", Training Accuracy= " +
                  "{:.3f}".format(log_dict_train["accuracy"].result()) +
                  ", Valid Accuracy= " +
                  "{:.3f}".format(log_dict_valid["accuracy"].result()))

            # save to summary for Tensorboard
            with train_summary_writer.as_default():
                tf.summary.scalar("loss_train",
                                  log_dict_train["loss"].result(),
                                  step=epoch)
                tf.summary.scalar("accuracy_train",
                                  log_dict_train["accuracy"].result(),
                                  step=epoch)
                tf.summary.scalar("entropy_train",
                                  log_dict_train["entropy"].result(),
                                  step=epoch)
                tf.summary.scalar("loss_valid",
                                  log_dict_valid["loss"].result(),
                                  step=epoch)
                tf.summary.scalar("accuracy_valid",
                                  log_dict_valid["accuracy"].result(),
                                  step=epoch)

            # add certain metrics to history per epoch
            log_history_dict_train["loss"].append(
                log_dict_train["loss"].result().numpy())
            log_history_dict_train["accuracy"].append(
                log_dict_train["accuracy"].result().numpy())
            log_history_dict_train["entropy"].append(
                log_dict_train["entropy"].result().numpy())
            log_history_dict_valid["loss"].append(
                log_dict_valid["loss"].result().numpy())
            log_history_dict_valid["accuracy"].append(
                log_dict_valid["accuracy"].result().numpy())

            # reset log metrics
            log_dict_train["loss"].reset_states()
            log_dict_train["accuracy"].reset_states()
            log_dict_train["entropy"].reset_states()
            log_dict_valid["loss"].reset_states()
            log_dict_valid["accuracy"].reset_states()

        # plot accuracy and loss lists (log_histories)
        save_line_chart(
            self.log_dir_train,
            "accuracy_train_history",
            log_history_dict_train["accuracy"],
            diag_title="Train Accuracy",
            xlabel="epoch",
            ylabel="accuracy",
            pyplotBool=True,
            write_to_txtfile_Bool=True,
        )
        save_line_chart(
            self.log_dir_train,
            "loss_train_history",
            log_history_dict_train["loss"],
            diag_title="Train Loss",
            xlabel="epoch",
            ylabel="loss",
            pyplotBool=True,
            write_to_txtfile_Bool=True,
        )
        save_line_chart(
            self.log_dir_train,
            "entropy_train_history",
            log_history_dict_train["entropy"],
            diag_title="Train Entropy",
            xlabel="epoch",
            ylabel="loss",
            pyplotBool=True,
            write_to_txtfile_Bool=True,
        )
        save_line_chart(
            self.log_dir_train,
            "accuracy_test_history",
            log_history_dict_valid["accuracy"],
            diag_title="Valid Accuracy",
            xlabel="epoch",
            ylabel="accuracy",
            pyplotBool=True,
            write_to_txtfile_Bool=True,
        )
        save_line_chart(
            self.log_dir_train,
            "loss_test_history",
            log_history_dict_valid["loss"],
            diag_title="Valid Loss",
            xlabel="epoch",
            ylabel="loss",
            pyplotBool=True,
            write_to_txtfile_Bool=True,
        )

        # plot results
        keys_to_plot_train = {"loss", "accuracy"}
        keys_to_plot_valid = {"loss", "accuracy"}
        log_dict_test_plot_train = {
            key + "_train": value[-1]
            for key, value in log_history_dict_train.items()
            if key in keys_to_plot_train
        }
        log_dict_test_plot_test = {
            key + "_valid": value[-1]
            for key, value in log_history_dict_valid.items()
            if key in keys_to_plot_valid
        }
        log_dict_test_plot = {
            **log_dict_test_plot_train,
            **log_dict_test_plot_test
        }
        save_dict_to_structured_txt(self.log_dir_train,
                                    log_dict_test_plot,
                                    filename="results_train")

        print("Training Finished!")
    optimizer.apply_gradients(zip(gradients, gated_pixelcnn.trainable_variables))

    return loss

# Training loop
n_epochs = 50
n_iter = int(np.ceil(x_train_quantised.shape[0] / batch_size))
for epoch in range(n_epochs):
    progbar = Progbar(n_iter)
    print('Epoch {:}/{:}'.format(epoch + 1, n_epochs))

    for i_iter, (batch_x, batch_y) in enumerate(train_dataset):
        optimizer.lr = optimizer.lr * lr_decay
        loss = train_step(batch_x, batch_y)

        progbar.add(1, values=[('loss', loss)])

# Test set performance
test_loss = []
for batch_x, batch_y in test_dataset:
    logits = gated_pixelcnn(batch_x, training=False)

    # Calculate cross-entropy (= negative log-likelihood)
    loss = compute_loss(tf.squeeze(tf.one_hot(batch_y, q_levels)), logits)

    test_loss.append(loss)
print('nll : {:} nats'.format(np.array(test_loss).mean()))
print('bits/dim : {:}'.format(np.array(test_loss).mean() / np.log(2)))

# Generating new images
samples = np.zeros((100, height, width, n_channel), dtype='float32')
    start_time = time.time()
    train(train_dist_dataset, args.beta)

    if t % valid_inc == 0:
        test(valid_dist_dataset, args.beta)
        step_time = round(time.time() - start_time, 1)

        metrics = {
            metric_name: log(metric)
            for metric_name, metric in trainer.metrics.items()
        }
        metrics['step_time'] = step_time

        # validation plotting
        progbar.add(valid_inc, [('Train Loss', metrics['train_loss']),
                                ('Validation Loss', metrics['valid_loss']),
                                ('Time (s)', step_time)])
        #Plot on Comet
        experiment.log_metrics(metrics, step=t)
        # Plot on WandB
        wandb.log(metrics, step=t)

    if (t + 1) % save_inc == 0:
        trainer.save_weights(model_path,
                             run_id=wandb.run.id,
                             experiment_key=experiment.get_key())
        if not args.gcbc and not args.images:
            z_enc, z_plan = produce_cluster_fig(next(plotting_dataset),
                                                encoder,
                                                planner,
                                                TEST_DATA_PATHS[0],