Ejemplo n.º 1
0
    def post_init(self):
        """Load VAE model"""
        super().post_init()
        from cvae import cvae
        import cvae.lib.model_iaf as model
        import tensorflow as tf

        params_path = os.path.join(
            self.model_path,
            'params.json') if self.model_path and os.path.exists(
                self.model_path) else None

        if params_path and os.path.exists(params_path):

            config = tf.ConfigProto(log_device_placement=False)
            config.gpu_options.allow_growth = True

            with tf.Graph().as_default():

                # Load parameter file.
                with open(params_path, 'r') as f:
                    param = json.load(f)

                net = model.VAEModel(param,
                                     None,
                                     input_dim=param['dim_feature'],
                                     keep_prob=tf.placeholder_with_default(
                                         input=tf.cast(1.0, dtype=tf.float32),
                                         shape=(),
                                         name="KeepProb"),
                                     initializer='orthogonal')
                # Placeholder for data features
                self.data_feature_placeholder = tf.placeholder_with_default(
                    input=tf.zeros([64, param['dim_feature']],
                                   dtype=tf.float32),
                    shape=[None, param['dim_feature']])

                self.embeddings = net.embed(self.data_feature_placeholder)

                self.sess = tf.Session(config=config)
                init = tf.global_variables_initializer()
                self.sess.run(init)

                # Saver for loading checkpoints of the model.
                saver = tf.train.Saver(var_list=tf.trainable_variables())
                cvae.load(saver, self.sess, self.model_path)

                self.to_device()
        else:
            raise PretrainedModelFileDoesNotExist()
Ejemplo n.º 2
0
    def __init__(self,
                 X=None,
                 X_valid=None,
                 train_valid_split=0.9,
                 dim_latent=2,
                 iaf_flow_length=5,
                 cells_encoder=None,
                 initializer='orthogonal',
                 batch_size=64,
                 batch_size_test=64,
                 logdir='temp',
                 feature_normalization=True,
                 tb_logging=False):

        self.dim_latent = dim_latent
        self.iaf_flow_length = iaf_flow_length
        self.cells_encoder = cells_encoder
        self.initializer = initializer
        self.batch_size = batch_size
        self.batch_size_test = batch_size_test
        self.logdir = os.path.abspath(logdir)
        self.feature_normalization = feature_normalization
        self.tb_logging = tb_logging

        self.trained_once_this_session = False

        # --- Check for existing model ---

        # Set flag to indicate that the model has not been trained yet
        self.is_trained = False

        # If using temporary model directory (default), delete any previously stored models
        if logdir == 'temp' and os.path.exists(self.logdir):
            shutil.rmtree(self.logdir)

        # Check if a model with the same name already exists
        # If no, create directory
        if not os.path.exists(self.logdir):
            os.makedirs(self.logdir)

        # Do checkpoint, parameter, and norm files exist?
        self.has_checkpoint = os.path.exists(f'{self.logdir}/checkpoint')
        self.has_params = os.path.exists(f'{self.logdir}/params.json')
        self.has_norm = os.path.exists(f'{self.logdir}/norm.pkl')
        self.has_dataset_file = os.path.exists(f'{self.logdir}/data_train.pkl')
        self.has_data = False

        # --- Prepare data ---

        # Check if data is provided as array or as directory.
        self.dataset_type = None

        if type(X) == str:
            self.dataset_type = 'string'

            if self.has_dataset_file:
                print('This model has already been associated with a dataset from a directory. To create a new '
                      'dataset, delete the data_train.pkl and data_valid.pkl files in the model directory.')
                _, _, self.dim_feature = dr.load_dataset_file(f'{self.logdir}/data_train.pkl')
            else:
                print(f'Preparing train and validation datasets from feature directory {X}.')
                self.dim_feature = fun.prepare_dataset(data_dir=os.path.abspath(X),
                                                       logdir=self.logdir,
                                                       train_ratio=train_valid_split)
                self.has_dataset_file = True

            self.X = None
            self.X_valid = None
            self.has_data = True

        elif type(X) == np.ndarray:
            self.dataset_type = 'array'
            self.dim_feature = X.shape[1]
            # Split data into train and validation or use provided validation data
            if X_valid is not None:
                assert X_valid.shape[1] == self.dim_feature, "Train and validation data has different feature dimensions!"
                self.X = X.astype(np.float32)
                self.X_valid = X_valid.astype(np.float32)
            else:
                # Randomize data
                num_data = len(X)
                indices = list(range(num_data))
                random.shuffle(indices)
                # Split data (and ensure it's float)
                split_index = int(train_valid_split * num_data)
                train_indices = indices[:split_index]
                valid_indices = indices[split_index:]
                self.X = X[train_indices].astype(np.float32)
                self.X_valid = X[valid_indices].astype(np.float32)
            self.has_data = True

        # elif X is None:
        #     self.X = None
        #     self.X_valid = None
        #     if self.has_dataset_file:
        #         print(f'Reloading dataset file {self.logdir}/data_train.pkl from previous instance of this model.')
        #         _, _, self.dim_feature = dr.load_dataset_file(f'{self.logdir}/data_train.pkl')
        #         self.has_data = True
        #     else:
        #         if self.has_checkpoint:
        #             self.has_data = False
        #         else:
        #             raise Exception(
        #                 'Model needs to be initialised with X provided as numpy array or path to directory.')
        else:
            raise Exception('Unsupported input type for X. Needs to be numpy array or string with path to directory '
                            'containing npy files. ')

        # --- Prepare parameter file ---

        # If parameter file for this model already exists, load it. Otherwise create one.
        if self.has_params:
            print(f'Existing parameter file found for model {self.logdir}.\n'
                  f'Loading stored parameters. Some input parameters might be ignored.')
            with open(f'{self.logdir}/params.json', 'r') as f:
                self.param = json.load(f)
            # Set dim_feature if not previously known
            if X is None and not self.has_dataset_file:
                self.dim_feature = self.param['dim_feature']
        else:
            # If not given, determine model structure
            # NOTE: The reasoning here is a bit arbitrary/dodgy, should probably put some more thought into this
            # and improve it.
            # TODO: This does not give very good results yet...
            if cells_encoder is None:
                # Get all the powers of two between latent dim and feature dim
                smallest_power = int(2 ** (self.dim_latent - 1).bit_length())
                largest_power = int(2 ** self.dim_feature.bit_length() / 2)
                powers_of_two = [smallest_power]
                while powers_of_two[-1] <= largest_power:
                    powers_of_two.append(powers_of_two[-1]*2)

                # By default, use two layers, one with largest power of two, the second roughly half-way between
                # input and output dimension
                l2_index = int(len(powers_of_two) / 2)
                try:
                    model_layers = [largest_power,
                                    powers_of_two[l2_index+1]]
                except:
                    model_layers = [largest_power,
                                    int(largest_power/2)]

            else:
                model_layers = cells_encoder

            # Number of hidden cells is smaller of the last layer size or 64
            cells_hidden = min(model_layers[-1], 64)

            if self.dataset_type == 'string':
                dataset_file = f'{self.logdir}/data_train.pkl'
                dataset_file_valid = f'{self.logdir}/data_valid.pkl'
            else:
                dataset_file = None
                dataset_file_valid = None

            self.param = {
                "dataset_file": dataset_file,
                "dataset_file_valid": dataset_file_valid,
                "dim_latent": self.dim_latent,
                "dim_feature": self.dim_feature,
                "cells_encoder": model_layers,
                "cells_hidden": cells_hidden,
                "iaf_flow_length": self.iaf_flow_length,
                "dim_autoregressive_nl": cells_hidden,
                "initial_s_offset": 1.0,
                "feature_normalization": self.feature_normalization
            }

            # Write to json for future re-use of this model
            with open(f'{self.logdir}/params.json', 'w') as outfile:
                json.dump(self.param, outfile, indent=2)

        # --- Set up VAE model ---
        self.graph = tf.Graph()

        with self.graph.as_default():

            # Create coordinator.
            self.coord = tf.train.Coordinator()

            # Set up batchers.
            with tf.name_scope('create_inputs'):
                if self.dataset_type == 'string':
                    self.reader = dr.DataReader(self.param['dataset_file'],
                                                self.param,
                                                f'{self.logdir}/params.json',
                                                self.coord,
                                                self.logdir)
                    self.test_batcher = dr.Batcher(self.param['dataset_file_valid'],
                                                   self.param,
                                                   f'{self.logdir}/params.json',
                                                   self.logdir)
                else:
                    self.reader = dra.DataReader(self.X, self.feature_normalization, self.coord, self.logdir)
                    self.test_batcher = dra.Batcher(self.X_valid, self.feature_normalization, self.logdir)
                self.train_batch = self.reader.dequeue_feature(self.batch_size)

            # Get normalisation data
            if self.feature_normalization:
                self.mean = self.test_batcher.mean
                self.norm = self.test_batcher.norm

            num_test_data = self.test_batcher.num_data
            self.test_batches_full = int(self.test_batcher.num_data / self.batch_size_test)
            self.test_batch_last = num_test_data - (self.test_batches_full * self.batch_size_test)

            # Placeholder for test features
            self.test_feature_placeholder = tf.placeholder_with_default(
                input=tf.zeros([self.batch_size, self.dim_feature], dtype=tf.float32),
                shape=[None, self.dim_feature])

            # Placeholder for dropout
            self.dropout_placeholder = tf.placeholder_with_default(input=tf.cast(1.0, dtype=tf.float32), shape=(),
                                                                   name="KeepProb")

            # Placeholder for learning rate
            self.lr_placeholder = tf.placeholder_with_default(input=tf.cast(1e-4, dtype=tf.float32), shape=(),
                                                                  name="LearningRate")

            print('Creating model.')
            self.net = model.VAEModel(self.param,
                                      self.batch_size,
                                      input_dim=self.dim_feature,
                                      keep_prob=self.dropout_placeholder,
                                      initializer=self.initializer)
            print('Model created.')

            self.embeddings = self.net.embed(self.test_feature_placeholder)

            print('Setting up loss.')
            self.loss = self.net.loss(self.train_batch)
            self.loss_test = self.net.loss(self.test_feature_placeholder, test=True)
            print('Loss set up.')

            optimizer = tf.train.AdamOptimizer(learning_rate=self.lr_placeholder,
                                               epsilon=1e-4)
            trainable = tf.trainable_variables()
            # for var in trainable:
            #     print(var)
            self.optim = optimizer.minimize(self.loss, var_list=trainable)

            # Set up logging for TensorBoard.
            if self.tb_logging:
                self.writer = tf.summary.FileWriter(self.logdir)
                self.writer.add_graph(tf.get_default_graph())
                run_metadata = tf.RunMetadata()
                self.summaries = tf.summary.merge_all()

            # Set up session
            print('Setting up session.')
            config = tf.ConfigProto(log_device_placement=False)
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(config=config)
            init = tf.global_variables_initializer()
            self.sess.run(init)
            print('Session set up.')

            # Saver for storing checkpoints of the model.
            self.saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=2)

            # Try to load model
            try:
                self.saved_global_step = load(self.saver, self.sess, self.logdir)
                if self.saved_global_step is None:
                    # The first training step will be saved_global_step + 1,
                    # therefore we put -1 here for new or overwritten trainings.
                    self.saved_global_step = -1
                    print(f'No model found to restore. Initialising new model.')
                else:
                    print(f'Restored trained model from step {self.saved_global_step}.')
            except:
                print("Something went wrong while restoring checkpoint.")
                raise