Exemple #1
0
def test_dataset_works_with_keras_api(test_png_path):
    """
    Make sure we can use pw.loaders.dataset() with
    the high-level keras API
    """
    imfiles = [test_png_path] * 10
    labels = np.random.randint(0, 2, size=10)

    ds, ns = dataset(imfiles,
                     ys=labels,
                     imshape=(11, 17),
                     num_channels=3,
                     norm=255,
                     batch_size=5,
                     augment=False)

    inpt = tf.keras.layers.Input((None, None, 3))
    net = tf.keras.layers.Conv2D(2, 1)(inpt)
    net = tf.keras.layers.GlobalMaxPool2D()(net)
    net = tf.keras.layers.Dense(2, activation="softmax")(net)
    model = tf.keras.Model(inpt, net)

    model.compile("SGD", loss=tf.keras.losses.sparse_categorical_crossentropy)

    hist = model.fit(ds, steps_per_epoch=ns, epochs=1)
    assert isinstance(hist, tf.keras.callbacks.History)
    print(hist)
Exemple #2
0
def test_get_features_from_dataset(test_png_path, test_jpg_path):
    filepaths = [test_png_path, test_jpg_path] * 5
    ys = [0, 1] * 5
    ds = dataset(filepaths, ys=ys, imshape=(20, 20), batch_size=2)[0]
    features, labels = _get_features(fcn, ds)

    assert isinstance(features, np.ndarray)
    assert len(features.shape) == 2
    assert isinstance(labels, np.ndarray)
    assert features.shape[0] == len(labels)
def test_linear_classification_test_dataset_input(test_png_path,
                                                  test_jpg_path):
    filepaths = [test_png_path, test_jpg_path] * 5
    ys = [0, 1] * 5
    ds = dataset(filepaths, ys=ys, imshape=(20, 20), batch_size=2)[0]

    acc, cm = linear_classification_test(fcn, ds)

    assert isinstance(acc, float)
    assert acc <= 1
    assert acc >= 0
    assert isinstance(cm, np.ndarray)
Exemple #4
0
def linear_classification_test(fcn, downstream_labels, **input_config):
    """
    Train a linear classifier on a fully-convolutional network
    and return out-of-sample results.
    
    :fcn: Keras fully-convolutional network
    :downstream_labels: dictionary mapping image file paths to labels
    :input_config: kwargs for patchwork.loaders.dataset()
    
    Returns
    :acc: float; test accuracy
    :cm: 2D numpy array; confusion matrix
    """
    # do a 2:1 train:test split
    split = np.array([(i%3 == 0) for i in range(len(downstream_labels))])
    X = np.array(list(downstream_labels.keys()))
    Y = np.array(list(downstream_labels.values()))
    # get average-pooled training features
    ds, num_steps = dataset(X[~split], shuffle=False,
                                            **input_config)
    trainvecs = fcn.predict(ds, steps=num_steps).mean(axis=1).mean(axis=1)
    # get test features
    ds, num_steps = dataset(X[split], shuffle=False,
                                            **input_config)
    testvecs = fcn.predict(ds, steps=num_steps).mean(axis=1).mean(axis=1)
    # rescale train and test
    scaler = StandardScaler().fit(trainvecs)
    trainvecs = scaler.transform(trainvecs)
    testvecs = scaler.transform(testvecs)
    # train a multinomial classifier
    logreg = SGDClassifier(loss="log", max_iter=1000, n_jobs=-1, learning_rate="adaptive",
                           eta0=1e-2)
    logreg.fit(trainvecs, Y[~split])
    # make predictions on test set
    preds = logreg.predict(testvecs)
    # compute metrics and return
    acc = accuracy_score(Y[split], preds)
    cm = confusion_matrix(Y[split], preds)
    return acc, cm
Exemple #5
0
def test_dataset_with_augmentation(test_png_path):
    imfiles = [test_png_path]*10
    
    ds, ns = dataset(imfiles, ys=None, imshape=(11,17),
                     num_channels=3, norm=255,
                     batch_size=5, augment={})
    
    for x in ds:
        x = x.numpy()
        break
    
    assert isinstance(ds, tf.data.Dataset)
    assert ns == 2
    assert x.shape == (5, 11, 17, 3)
Exemple #6
0
def test_dataset_with_labels(test_png_path):
    imfiles = [test_png_path]*10
    labels = np.arange(10)
    
    ds, ns = dataset(imfiles, ys=labels, imshape=(11,17),
                     num_channels=3, norm=255,
                     batch_size=5, augment=False)
    
    for x,y in ds:
        x = x.numpy()
        y = y.numpy()
        break
    
    assert (y == np.arange(5)).all()
Exemple #7
0
def test_dataset_with_unlabeled_images(test_png_path, test_tif_path):
    imfiles = [test_png_path]*10
    unlab_files = [test_tif_path]*15
    
    ds, ns = dataset(imfiles, unlab_fps=unlab_files, 
                     imshape=(11,17),
                     num_channels=3, norm=255,
                     batch_size=5, augment={})
    
    for x,y in ds:
        x = x.numpy()
        y = y.numpy()
        break
    
    assert x.shape == (5,11,17,3)
    assert y.shape == (5,11,17,3)
Exemple #8
0
 def _pred_dataset(self, batch_size=32):
     """
     Build a dataset for predictions
     """
     num_steps = int(np.ceil(len(self.df) / batch_size))
     if self.feature_vecs is None:
         files = self.df["filepath"].values
         return dataset(files,
                        imshape=self._imshape,
                        num_channels=self._num_channels,
                        num_parallel_calls=self._num_parallel_calls,
                        batch_size=batch_size,
                        shuffle=False,
                        augment=False)  #, num_steps
     # PRE-EXTRACTED FEATURE CASE
     else:
         return tf.data.Dataset.from_tensor_slices(
             self.feature_vecs).batch(batch_size), num_steps
Exemple #9
0
    def _training_dataset(self, batch_size=32, num_samples=None):
        """
        Build a single-epoch training set
        """
        if num_samples is None:
            num_samples = len(self.df)
        # LIVE FEATURE EXTRACTOR CASE
        if self.feature_vecs is None:
            files, ys = stratified_sample(self.df, num_samples)
            unlab_fps = None
            if self._semi_supervised:
                unlabeled_filepaths = self.df.filepath.values[find_unlabeled(
                    self.df)]

                unlab_fps = np.random.choice(unlabeled_filepaths,
                                             replace=True,
                                             size=num_samples)
            return dataset(files,
                           ys,
                           imshape=self._imshape,
                           num_channels=self._num_channels,
                           num_parallel_calls=self._num_parallel_calls,
                           batch_size=batch_size,
                           augment=self._aug,
                           unlab_fps=unlab_fps)[0]
        # PRE-EXTRACTED FEATURE CASE
        else:
            inds, ys = stratified_sample(self.df,
                                         num_samples,
                                         return_indices=True)
            if self._semi_supervised:
                unlabeled_indices = np.arange(len(self.df))[find_unlabeled(
                    self.df)]
            else:
                unlabeled_indices = None

            return _build_in_memory_dataset(
                self.feature_vecs,
                inds,
                ys,
                batch_size=batch_size,
                unlabeled_indices=unlabeled_indices)
Exemple #10
0
def test_dataset_with_both(test_png_path, test_tif_path):
    imfiles = [test_png_path]*10
    labels = np.arange(10)
    unlab_files = [test_tif_path]*15
    
    ds, ns = dataset(imfiles, ys=labels,
                     unlab_fps=unlab_files, 
                     imshape=(11,17),
                     num_channels=3, norm=255,
                     batch_size=5, augment={})
    
    for (w,x),y in ds:
        w = w.numpy()
        x = x.numpy()
        y = y.numpy()
        break
    
    assert w.shape == (5,11,17,3)
    assert x.shape == (5,11,17,3)
    assert y.shape == (5,)
Exemple #11
0
def test_dataset_with_custom_dataset():
    rawdata = np.zeros((7, 11, 17, 3)).astype(np.float32)
    ds = tf.data.Dataset.from_tensor_slices(rawdata)

    ds, ns = dataset(ds,
                     ys=None,
                     imshape=(11, 17),
                     num_channels=3,
                     norm=255,
                     batch_size=5,
                     augment=False)

    for x in ds:
        x = x.numpy()
        break

    assert isinstance(ds, tf.data.Dataset)
    # for a custom dataset, can't precompute number of steps
    assert ns is None
    assert x.shape == (5, 11, 17, 3)
Exemple #12
0
    def _val_dataset(self, batch_size=32):
        """
        Build a dataset of just validation examples
        """
        val_df = self.df[self.df.validation]
        ys = val_df[self.classes].values.copy()
        ys[np.isnan(ys)] = -1

        if self.feature_vecs is None:
            files = val_df["filepath"].values
            return dataset(files,
                           ys=ys,
                           imshape=self._imshape,
                           num_channels=self._num_channels,
                           num_parallel_calls=self._num_parallel_calls,
                           batch_size=batch_size,
                           shuffle=False,
                           augment=False)
        # PRE-EXTRACTED FEATURE CASE
        else:
            vecs = self.feature_vecs[self.df.validation.values]
            return tf.data.Dataset.from_tensor_slices(
                (vecs, ys)).batch(batch_size),
Exemple #13
0
def _fixmatch_dataset(labeled_filepaths, labels, unlabeled_filepaths, imshape,
                      num_parallel_calls, norm, num_channels, single_channel,
                      batch_size, weak_aug, strong_aug, mu):
    """
    Build the training dataset. We're going to be zipping together two
    datasets:
        -a conventional supervised dataset generating weakly-augmented
        (image, label) pairs
        -a dataset generating unabeled pairs of the same image, one weakly-augmented
        and the other strongly-augmented, for the semisupervised training component
    """
    # dataset for supervised task
    sup_ds = dataset(labeled_filepaths,
                     ys=labels,
                     imshape=imshape,
                     num_parallel_calls=num_parallel_calls,
                     norm=norm,
                     num_channels=num_channels,
                     single_channel=single_channel,
                     batch_size=batch_size,
                     augment=weak_aug,
                     shuffle=True)[0]

    # dataset for unsupervised task
    unsup_ds = _fixmatch_unlab_dataset(unlabeled_filepaths,
                                       weak_aug,
                                       strong_aug,
                                       imshape=imshape,
                                       num_parallel_calls=num_parallel_calls,
                                       norm=norm,
                                       num_channels=num_channels,
                                       single_channel=single_channel,
                                       batch_size=mu * batch_size)
    # zipped dataset is ((x,y), (x_wk, x_str))
    ds = tf.data.Dataset.zip((sup_ds, unsup_ds))
    ds = ds.prefetch(1)
    return ds
Exemple #14
0
 def __init__(self, logdir, trainingdata, testdata=None, fcn=None, augment=True, 
              pca_dim=256, k=1000, dense=[4096], mult=1, 
              kmeans_max_iter=100, kmeans_batch_size=100, lr=0.05, 
              lr_decay=100000, decay_type="exponential", opt_type="momentum",
               imshape=(256,256), num_channels=3,
              norm=255, batch_size=64, num_parallel_calls=None,
              single_channel=False, notes="",
              downstream_labels=None):
     """
     :logdir: (string) path to log directory
     :trainingdata: (list) list of paths to training images
     :testdata: (list) filepaths of a batch of images to use for eval
     :fcn: (keras Model) fully-convolutional network to train as feature extractor
     :augment: (dict) dictionary of augmentation parameters, True for defaults or
             False to disable augmentation
     :pca_dim: (int) dimension to reduce FCN outputs to using principal component analysis
     :k: (int) number of clusters
     :dense: (list of ints) number of hidden units in dense layers between the
         FCN and the softmax layer. THERE NEEDS TO BE AT LEAST ONE DENSE LAYER. The
         DeepCluster paper used [4096, 4096].
     :mult: (int) not in paper; multiplication factor to increase
             number of steps/epoch. set to 1 to get paper algorithm
     :kmeans_max_iter: max iterations over dataset for minibatch k-means
     :kmeans_batch_size: batch size for minibatch k-means
     :lr: (float) initial learning rate
     :lr_decay: (int) number of steps for one decay period (0 to disable)
     :decay_type: (string) how to decay the learning rate- "exponential" (smooth exponential decay), "staircase" (non-smooth exponential decay), or "cosine"
     :opt_type: (str) which optimizer to use; "momentum" or "adam"
     :imshape: (tuple) image dimensions in H,W
     :num_channels: (int) number of image channels
     :norm: (int or float) normalization constant for images (for rescaling to
            unit interval)
     :batch_size: (int) batch size for training
     :num_parallel_calls: (int) number of threads for loader mapping
     :single_channel: if True, expect a single-channel input image and 
             stack it num_channels times.
     :notes: (string) any notes on the experiment that you want saved in the
             config.yml file
     :downstream_labels: dictionary mapping image file paths to labels
     """
     self.logdir = logdir
     self.trainingdata = trainingdata
     self._downstream_labels = downstream_labels
     
     self._file_writer = tf.summary.create_file_writer(logdir, flush_millis=10000)
     self._file_writer.set_as_default()
     
     # if no FCN is passed- build one
     if fcn is None:
         fcn = BNAlexNetFCN(num_channels)
     self.fcn = fcn
     self._models = {"fcn":fcn}    
     
     # build model for training    
     prediction_model, training_model, output_layer = _build_model(fcn, 
                             imshape=imshape, num_channels=num_channels, 
                             dense=dense, k=k)
     self._models["full"] = training_model
     self._pred_model = prediction_model
     self._output_layer = output_layer
     
     # create optimizer
     self._optimizer = self._build_optimizer(lr, lr_decay, opt_type=opt_type,
                                             decay_type=decay_type)
     
     # build evaluation dataset
     if testdata is not None:
         self._test_ds, self._test_steps = dataset(testdata,
                                  imshape=imshape,norm=norm,
                                  num_channels=num_channels,
                                  single_channel=single_channel)
         self._test = True
     else:
         self._test = False
     self._test_labels = None
     self._old_test_labels = None
     
     # build prediction dataset for clustering
     ds, num_steps = dataset(trainingdata, imshape=imshape, num_channels=num_channels, 
              num_parallel_calls=num_parallel_calls, batch_size=batch_size, 
              augment=False, single_channel=single_channel)
     self._pred_ds = ds
     self._pred_steps = num_steps
     
     # finally, build a Glorot initializer we'll use for resetting the last
     # layer of the network
     self._initializer = tf.initializers.glorot_uniform()
     self._old_cluster_assignments = None
     
     
     # build training step function- this step makes sure that this object
     # has its own @tf.function-decorated training function so if we have
     # multiple deepcluster objects (for example, for hyperparameter tuning)
     # they won't interfere with each other.
     self._training_step = build_deepcluster_training_step()
     self.step = 0
     
     # parse and write out config YAML
     self._parse_configs(augment=augment, k=k, pca_dim=pca_dim, lr=lr, 
                         lr_decay=lr_decay, opt_type=opt_type,
                         mult=mult, dense=dense,
                         kmeans_max_iter=kmeans_max_iter, 
                         kmeans_batch_size=kmeans_batch_size,
                         imshape=imshape, num_channels=num_channels,
                         norm=norm, batch_size=batch_size,
                         num_parallel_calls=num_parallel_calls, 
                         single_channel=single_channel, notes=notes,
                         trainer="deepcluster")
Exemple #15
0
    def __init__(self,
                 logdir,
                 trainingdata,
                 unlabeled_filepaths,
                 valdata,
                 model,
                 weak_aug=DEFAULT_WEAK_AUG,
                 strong_aug=DEFAULT_STRONG_AUG,
                 lam=1,
                 tau=0.95,
                 mu=4,
                 passes_per_epoch=1,
                 weight_decay=3e-4,
                 lr=1e-3,
                 lr_decay=0,
                 decay_type="exponential",
                 opt_type="momentum",
                 imshape=(256, 256),
                 num_channels=3,
                 norm=255,
                 batch_size=64,
                 num_parallel_calls=None,
                 single_channel=False,
                 notes="",
                 strategy=None):
        """
        :logdir: (string) path to log directory
        :trainingdata: pandas dataframe of training data
        :unlabeled_filepaths:
        :valdata: pandas dataframe of validation data
        :model: Keras model to be trained
        :weak_aug: dictionary of weak augmentation parameters- usually just flipping. This
            will be used for FixMatch pseudolabels as well as supervised training
        :strong_aug: dictionary of strong augmentation parameters, for FixMatch predictions
        :lam: FixMatch lambda parameter; semisupervised loss weight
        :tau: FixMatch threshold parameter
        :mu: FixMatch batch size multiplier
        :passes_per_epoch: for small labeled datasets- run through the data this many times 
            per epoch
        :weight_decay: (float) coefficient for L2-norm loss.
        :lr: learning rate
        :lr_decay: learning rate decay (set to 0 to disable)
        :decay_type: (str) how to decay learning rate; "exponential", "cosine", or "staircase"
        :opt_type: (str) optimizer type
        :imshape: (tuple) image dimensions in H,W
        :num_channels: (int) number of image channels
        :norm: (int or float) normalization constant for images (for rescaling to
               unit interval)
        :batch_size: (int) batch size for training
        :num_parallel_calls: (int) number of threads for loader mapping
        :single_channel: if True, expect a single-channel input image and 
            stack it num_channels times.
        :notes: any experimental notes you want recorded in the config.yml file
        :strategy: if distributing across multiple GPUs, pass a tf.distribute
            Strategy object here
        """
        self.logdir = logdir
        self._weak_aug = weak_aug
        self._strong_aug = strong_aug
        self.strategy = strategy
        self.model = model
        # find the columns of the dataframe that correspond to binary class labels
        self.categories = [
            c for c in trainingdata.columns if c not in PROTECTED_COLUMN_NAMES
        ]
        self.valdata = valdata
        self._val_labels = valdata[self.categories].values
        self._models = {"full": model}
        self._passes_per_epoch = passes_per_epoch

        # create optimizer
        self._optimizer = self._build_optimizer(lr,
                                                lr_decay,
                                                decay_type=decay_type,
                                                opt_type=opt_type)

        # build our training dataset
        self._ds = self._distribute_dataset(
            _fixmatch_dataset(trainingdata["filepath"].values,
                              trainingdata[self.categories].values.astype(
                                  np.float32),
                              unlabeled_filepaths,
                              imshape=imshape,
                              num_parallel_calls=num_parallel_calls,
                              norm=norm,
                              num_channels=num_channels,
                              single_channel=single_channel,
                              batch_size=batch_size,
                              weak_aug=weak_aug,
                              strong_aug=strong_aug,
                              mu=mu))
        # and validation dataset
        self._val_ds = dataset(valdata["filepath"].values,
                               imshape=imshape,
                               num_parallel_calls=num_parallel_calls,
                               norm=norm,
                               num_channels=num_channels,
                               single_channel=single_channel,
                               augment=False,
                               shuffle=False)[0]

        # build training step
        trainstep = _build_fixmatch_training_step(model,
                                                  self._optimizer,
                                                  lam=lam,
                                                  tau=tau,
                                                  weight_decay=weight_decay)
        self._training_step = self._distribute_training_function(trainstep)

        self._file_writer = tf.summary.create_file_writer(logdir,
                                                          flush_millis=10000)
        self._file_writer.set_as_default()
        self.step = 0

        self._parse_configs(augment=strong_aug,
                            lam=lam,
                            tau=tau,
                            mu=mu,
                            weight_decay=weight_decay,
                            lr=lr,
                            lr_decay=lr_decay,
                            decay_type=decay_type,
                            opt_type=opt_type,
                            imshape=imshape,
                            num_channels=num_channels,
                            norm=norm,
                            batch_size=batch_size,
                            num_parallel_calls=num_parallel_calls,
                            single_channel=single_channel,
                            notes=notes,
                            trainer="fixmatch")
Exemple #16
0
    def _training_dataset(self, batch_size=32, num_samples=None):
        """
        Build a single-epoch training set.
        
        Supervised case: returns tf.data.Dataset object with
            structure (x,y)
            
        Semi-supervised case: returns tf.data.Dataset object
            with structure ((x,y), x_unlab)
        """
        if num_samples is None:
            num_samples = len(self.df)
        # LIVE FEATURE EXTRACTOR CASE
        if self.feature_vecs is None:
            files, ys = stratified_sample(self.df, num_samples)
            # (x,y) dataset
            ds = dataset(files,
                         ys,
                         imshape=self._imshape,
                         num_channels=self._num_channels,
                         num_parallel_calls=self._num_parallel_calls,
                         batch_size=batch_size,
                         augment=self._aug)[0]

            # include unlabeled data as well if
            # we're doing semisupervised learning
            if self._semi_supervised:
                """
                # choose unlabeled files for this epoch
                unlabeled_filepaths = self.df.filepath.values[find_unlabeled(self.df)]
                unlab_fps = np.random.choice(unlabeled_filepaths,
                                             replace=True, size=num_samples)
                # construct a dataset to load the unlabeled files
                # and zip with the (x,y) dataset
                unlab_ds = dataset(unlab_fps, imshape=self._imshape, 
                       num_channels=self._num_channels,
                       num_parallel_calls=self._num_parallel_calls, 
                       batch_size=batch_size,
                       augment=self._aug)[0]
                ds = tf.data.Dataset.zip((ds, unlab_ds))"""
                # train on anything not specifically labeled validation
                all_filepaths = self.df.filepath[~self.df.validation].values
                # in the FixMatch paper they use a larger batch size for
                # unlabeled data
                qN = batch_size * self._model_params["fixmatch"]["mu"]
                # Make a tensorflow dataset that will return batches of
                # (weakly augmented image, strongly augmented image) pairs
                unlab_ds = _fixmatch_unlab_dataset(
                    all_filepaths,
                    self._aug,
                    self._fixmatch_aug,
                    imshape=self._imshape,
                    norm=self._norm,
                    num_channels=self._num_channels,
                    num_parallel_calls=self._num_parallel_calls,
                    batch_size=qN)
                # stitch together the labeled and unlabeled datasets
                ds = tf.data.Dataset.zip((ds, unlab_ds))

            return ds

        # PRE-EXTRACTED FEATURE CASE
        else:
            inds, ys = stratified_sample(self.df,
                                         num_samples,
                                         return_indices=True)
            if self._semi_supervised:
                unlabeled_indices = np.arange(len(self.df))[find_unlabeled(
                    self.df)]
            else:
                unlabeled_indices = None

            return _build_in_memory_dataset(
                self.feature_vecs,
                inds,
                ys,
                batch_size=batch_size,
                unlabeled_indices=unlabeled_indices)
Exemple #17
0
    def __init__(self,
                 logdir,
                 trainingdata,
                 testdata=None,
                 fcn=None,
                 full_model=None,
                 conv_layers=[32, 48, 64, 128],
                 dropout=0.5,
                 augment=True,
                 lr=1e-3,
                 lr_decay=100000,
                 imshape=(256, 256),
                 num_channels=3,
                 norm=255,
                 batch_size=64,
                 num_parallel_calls=None,
                 single_channel=False,
                 notes="",
                 downstream_labels=None):
        """
        :logdir: (string) path to log directory
        :trainingdata: (list) list of paths to training images
        :testdata: (list) filepaths of a batch of images to use for eval
        :fcn: (keras Model) fully-convolutional network to train as feature extractor
        :full_model: (keras model) full autoencoder
        :augment: (dict) dictionary of augmentation parameters, True for defaults or
                False to disable augmentation
        :lr: (float) initial learning rate
        :lr_decay: (int) steps for learning rate to decay by half (0 to disable)
        :imshape: (tuple) image dimensions in H,W
        :num_channels: (int) number of image channels
        :norm: (int or float) normalization constant for images (for rescaling to
               unit interval)
        :batch_size: (int) batch size for training
        :num_parallel_calls: (int) number of threads for loader mapping
        :single_channel: if True, expect a single-channel input image and 
                stack it num_channels times.
        :notes: (string) any notes on the experiment that you want saved in the
                config.yml file
        :downstream_labels: dictionary mapping image file paths to labels
        """
        self.logdir = logdir
        self.trainingdata = trainingdata
        self._downstream_labels = downstream_labels

        self._file_writer = tf.summary.create_file_writer(logdir,
                                                          flush_millis=10000)
        self._file_writer.set_as_default()

        # build models if necessary
        if fcn is None or full_model is None:
            print("building new autoencoder")
            fcn, full_model = _build_autoencoder(num_channels, conv_layers,
                                                 dropout)
        self.fcn = fcn
        self._models = {"fcn": fcn, "full": full_model}

        # create optimizer
        self._optimizer = self._build_optimizer(lr, lr_decay)

        # training dataset
        self._train_ds, _ = dataset(trainingdata,
                                    imshape=imshape,
                                    norm=norm,
                                    sobel=False,
                                    num_channels=num_channels,
                                    augment=augment,
                                    single_channel=single_channel,
                                    batch_size=batch_size,
                                    shuffle=True)
        # build evaluation dataset
        if testdata is not None:
            self._test_ds, self._test_steps = dataset(
                testdata,
                imshape=imshape,
                norm=norm,
                sobel=False,
                num_channels=num_channels,
                single_channel=single_channel,
                batch_size=batch_size,
                shuffle=False,
                trainer="autoencoder")
            self._test = True
        else:
            self._test = False

        # build training step function- this step makes sure that this object
        # has its own @tf.function-decorated training function so if we have
        # multiple deepcluster objects (for example, for hyperparameter tuning)
        # they won't interfere with each other.
        self._training_step = _build_training_step(full_model, self._optimizer)
        self.step = 0

        # parse and write out config YAML
        self._parse_configs(augment=augment,
                            conv_layers=conv_layers,
                            dropout=dropout,
                            lr=lr,
                            lr_decay=lr_decay,
                            imshape=imshape,
                            num_channels=num_channels,
                            norm=norm,
                            batch_size=batch_size,
                            num_parallel_calls=num_parallel_calls,
                            single_channel=single_channel,
                            notes=notes)