class TrainerAE(BaseModel): ''' ------------------------------------------------------------------------------ SET ARGUMENTS ------------------------------------------------------------------------------- ''' def __init__(self, **kwrds): self.config = copy.deepcopy(config()) for key in kwrds.keys(): assert key in self.config.keys(), '{} is not a keyword, \n acceptable keywords: {}'. \ format(key, self.config.keys()) self.config[key] = kwrds[key] self.latent_data = None self.experiments_root_dir = 'experiments' file_utils.create_dirs([self.experiments_root_dir]) self.config.model_name = get_model_name(self.config.graph_type, self.config) if self.config.colab: self.google2colab() self.config.checkpoint_dir = os.path.join(self.experiments_root_dir + '/' + self.config.checkpoint_dir + '/', self.config.model_name) self.config.config_dir = os.path.join(self.experiments_root_dir + '/' + self.config.config_dir + '/', self.config.model_name) self.config.log_dir = os.path.join(self.experiments_root_dir + '/' + self.config.log_dir + '/', self.config.model_name) log_dir_subfolders = ['reconst', 'latent2d', 'latent3d', 'samples', 'total_random', 'pretoss_random', 'interpolate'] config_dir_subfolders = ['extra'] file_utils.create_dirs([self.config.checkpoint_dir, self.config.config_dir, self.config.log_dir]) file_utils.create_dirs([self.config.log_dir + '/' + dir_ + '/' for dir_ in log_dir_subfolders]) file_utils.create_dirs([self.config.config_dir + '/' + dir_ + '/' for dir_ in config_dir_subfolders]) load_config = {} try: load_config = file_utils.load_args(self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std', 'samples', 'y_uniqs']) self.config.update(load_config) print('Loading previous configuration ...') except: print('Unable to load previous configuration ...') file_utils.save_args(self.config.dict(), self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std', 'samples', 'y_uniqs']) if hasattr(self.config, 'height'): try: self.config.restore = True self.build_model(self.config.height, self.config.width, self.config.num_channels) except: self.config.isBuilt = False else: self.config.isBuilt = False ''' ------------------------------------------------------------------------------ EPOCH FUNCTIONS ------------------------------------------------------------------------------- ''' def _train(self, data_train, session, logger, num_batches): losses = list() iterator = data_train.make_one_shot_iterator() for t in tqdm(range(num_batches)): batch = session.run(iterator.get_next()) loss_curr = self.model_graph.train_epoch(session, da.from_array(batch['image']/255, chunks=100)) losses.append(loss_curr) cur_it = self.model_graph.global_step_tensor.eval(session) summaries_dict = dict(zip(self.model_graph.losses, np.mean(np.vstack(losses), axis=0))) logger.summarize(cur_it, summarizer='iter_train', log_dict=summaries_dict) losses = np.mean(np.vstack(losses), axis=0) cur_it = self.model_graph.global_step_tensor.eval(session) summaries_dict = dict(zip(self.model_graph.losses, losses)) logger.summarize(cur_it, summarizer='epoch_train', log_dict=summaries_dict) return losses def _test(self, data_test, session, logger, num_batches): losses = list() iterator = data_test.make_one_shot_iterator() for t in tqdm(range(num_batches)): batch = session.run(iterator.get_next()) loss_curr = self.model_graph.test_epoch(session, da.from_array(batch['image']/255, chunks=100)) losses.append(loss_curr) cur_it = self.model_graph.global_step_tensor.eval(session) summaries_dict = dict(zip(self.model_graph.losses, np.mean(np.vstack(losses), axis=0))) logger.summarize(cur_it, summarizer='iter_test', log_dict=summaries_dict) losses = np.mean(np.vstack(losses), axis=0) cur_it = self.model_graph.global_step_tensor.eval(session) summaries_dict = dict(zip(self.model_graph.losses, losses)) logger.summarize(cur_it, summarizer='epoch_test', log_dict=summaries_dict) return losses ''' ------------------------------------------------------------------------------ EPOCH FUNCTIONS ------------------------------------------------------------------------------- ''' def fit(self, dataset): assert str(dataset.__class__).split('.')[0].replace("<class '", '') + '.' + str(dataset.__class__).split('.')[1] \ == "tensorflow_datasets.image", 'The dataset type is not image tensorflow_datasets' self.data_train = dataset.as_dataset(split=tfds.Split.TRAIN, shuffle_files=True, batch_size=self.config.batch_size) try: self.data_test = dataset.as_dataset(split=tfds.Split.TEST, shuffle_files=True, batch_size=self.config.batch_size) except: self.data_test = dataset.as_dataset(split=tfds.Split.TRAIN, shuffle_files=True, batch_size=self.config.batch_size) width = dataset.info.features['image'].shape[0] height = dataset.info.features['image'].shape[1] num_channels = dataset.info.features['image'].shape[2] self.config.ntrain_batches = dataset.info.splits['train'].num_examples // self.config.batch_size self.config.ntest_batches = dataset.info.splits['test'].num_examples // self.config.batch_size if not self.config.isBuilt: self.config.restore=True self.build_model(height, width, num_channels) else: assert (self.config.height == height) and (self.config.width == width) and \ (num_channels == num_channels), \ 'Wrong dimension of data. Expected shape {}, and got {}'.\ format((self.config.height, self.config.width, self.config.num_channels), (height, width, num_channels)) ''' ------------------------------------------------------------------------------- TRAIN THE MODEL ------------------------------------------------------------------------------------- ''' print('\nTraining a model...') with tf.Session(graph=self.graph) as session: tf.set_random_seed(self.config.seeds) self.session = session logger = Logger(self.session, self.config.log_dir) self.saver = tf.train.Saver() early_stopper = EarlyStopping(name='total loss', decay_fn=self.decay_fn) if(self.config.restore and self.load(self.session, self.saver) ): load_config = file_utils.load_args(self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std', 'samples', 'y_uniqs']) self.config.update(load_config) num_epochs_trained = self.model_graph.cur_epoch_tensor.eval(self.session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initializing Variables ...') tf.global_variables_initializer().run() if self.config.plot: if self.config.y_uniqs is None: print('\nFinding the unique categories...') y_uniqs = list() iterator = self.data_train.make_one_shot_iterator() for t in tqdm(range(self.config.ntrain_batches)): batch = session.run(iterator.get_next()) y, _ = tf.unique(batch[self.config.y_index]) y_uniqs += y.eval().tolist() self.config.y_uniqs = np.unique(y_uniqs) if self.config.samples is None: y_uniqs = self.config.y_uniqs[:10] y_uniqs = np.array(list(itertools.repeat(y_uniqs, 10))).flatten()[:10] print('\nSampling from the unique categories...') samples = dict(zip(y_uniqs, itertools.repeat(list(), len(y_uniqs)))) iterator = self.data_train.make_one_shot_iterator() for t in tqdm(range(self.config.ntrain_batches)): batch = session.run(iterator.get_next()) for yi in y_uniqs: if len(samples[yi]) <= 10: samples[yi] = samples[yi] + da.from_array( tf.boolean_mask(mask=batch[self.config.y_index]==yi, tensor=batch['image']).eval(), chunks=10).compute().tolist() samples[yi] = samples[yi][:10] self.config.samples = da.from_array(da.vstack(samples.values()), chunks=10).compute() if not self.config.isTrained: for cur_epoch in range(self.model_graph.cur_epoch_tensor.eval(self.session), self.config.epochs+1, 1): print('EPOCH: ', cur_epoch) self.current_epoch = cur_epoch losses_tr = self._train(self.data_train, self.session, logger, self.config.ntrain_batches) if np.isnan(losses_tr[0]): print('Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.') for lname, lval in zip(self.model_graph.losses, losses_tr): print(lname, lval) sys.exit() losses_test = self._test(self.data_test, self.session, logger, self.config.ntest_batches) train_msg = 'TRAIN: \n' for lname, lval in zip(self.model_graph.losses, losses_tr): train_msg += str(lname) + ': ' + str(lval) + ' | ' eval_msg = 'TEST: \n' for lname, lval in zip(self.model_graph.losses, losses_test): eval_msg += str(lname) + ': ' + str(lval) + ' | ' print(train_msg) print(eval_msg) print() if (cur_epoch == 1) or ((cur_epoch % self.config.save_epoch == 0) and (cur_epoch != 0)): self.save_model() if self.config.plot: self.plot_latent(cur_epoch) self.plot_reconst(cur_epoch) self.session.run(self.model_graph.increment_cur_epoch_tensor) # Early stopping if (self.config.early_stopping and early_stopper.stop(losses_test[0])): print('Early Stopping!') break if cur_epoch % self.config.colab_save == 0: if self.config.colab: self.push_colab() self.config.isTrained = True self.save_model() if self.config.plot: self.plot_latent(cur_epoch) self.plot_reconst(cur_epoch) if self.config.colab: self.push_colab() def save_model(self): self.save(self.session, self.saver, self.model_graph.global_step_tensor.eval(self.session)) self.compute_distribution(self.data_train, self.session, self.config.ntrain_batches) file_utils.save_args(self.config.dict(), self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std', 'samples', 'y_uniqs']) gc.collect() def compute_distribution(self, images, session, num_batches): self.generate_latent(images, session, num_batches) print("Computing the latent's distribution ... ") self.model_graph.config.latent_mean = self.latent_data['latent'].mean(axis=0).compute() self.model_graph.config.latent_std = self.latent_data['latent'].std(axis=0).compute() def generate_latent(self, images, session, num_batches): print("Generating latent space ... ") latents = list() labels = list() iterator = images.make_one_shot_iterator() for t in tqdm(range(num_batches)): batch = session.run(iterator.get_next()) latents_batch = self.model_graph.encode(session, da.from_array(batch['image']/255, chunks=100)) y_index = list(batch.keys()).index(self.config.y_index)-1 label_batch = da.from_array(np.array([batch[k] for k in batch.keys() if k !='image']), chunks=100) latents.append(latents_batch) labels.append(label_batch) latents = da.from_array(delayed(np.vstack(latents)).compute(), chunks=100) labels = da.from_array(delayed(np.vstack(labels)).compute(), chunks=100) self.latent_data = {'latent': latents.reshape((-1, self.config.latent_dim)), 'label': labels.reshape((-1, 1)), 'y_index': y_index} ''' ------------------------------------------------------------------------------ SET NETWORK PARAMS ------------------------------------------------------------------------------ ''' def build_model(self, height, width, num_channels): self.config.height = height self.config.width = width self.config.num_channels = num_channels self.graph = tf.Graph() with self.graph.as_default(): self.model_graph = Factory(self.config) print(self.model_graph) self.trainable_count = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) print('\nNumber of trainable paramters', self.trainable_count) self.test_graph() self.config.isBuilt=True def test_graph(self): with tf.Session(graph=self.graph) as session: tf.set_random_seed(self.config.seeds) self.session = session self.saver = tf.train.Saver() if (self.config.restore and self.load(self.session, self.saver)): load_config = file_utils.load_args(self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std', 'samples', 'y_uniqs']) self.config.update(load_config) num_epochs_trained = self.model_graph.cur_epoch_tensor.eval(self.session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initializing Variables ...') tf.global_variables_initializer().run() print('random latent batch ...') samples = self.model_graph._sampling_reconst(session, std_scales=np.ones(self.config.latent_dim))[0] print('random latent shape {}'.format(samples.shape)) def _sampling_reconst(self, std_scales, random_latent=None): def aux_fun(session, rand_samp): return self.model_graph._sampling_reconst(session=session, std_scales=std_scales, random_latent=rand_samp) with tf.Session(graph=self.graph) as session: tf.set_random_seed(self.config.seeds) self.session = session self.saver = tf.train.Saver() if (self.config.restore and self.load(self.session, self.saver)): load_config = file_utils.load_args(self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std', 'samples', 'y_uniqs']) self.config.update(load_config) num_epochs_trained = self.model_graph.cur_epoch_tensor.eval(self.session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initializing Variables ...') tf.global_variables_initializer().run() samples = list() if random_latent is None: while True: samples.append(self.model_graph._sampling_reconst(session=session, std_scales=std_scales)[0]) if len(samples) >= (100//self.config.batch_size)+1: samples = da.vstack(samples) samples = samples[:100] break else: samples = self.batch_function(aux_fun, random_latent) scaler = MinMaxScaler() return scaler.fit_transform(samples.flatten().reshape(-1, 1).astype(np.float32)).reshape(samples.shape)
class AE(BaseModel): ''' ------------------------------------------------------------------------------ SET ARGUMENTS ------------------------------------------------------------------------------- ''' def __init__(self, **kwrds): self.config = copy.deepcopy(config()) for key in kwrds.keys(): assert key in self.config.keys(), '{} is not a keyword, \n acceptable keywords: {}'. \ format(key, self.config.keys()) self.config[key] = kwrds[key] self.experiments_root_dir = 'experiments' file_utils.create_dirs([self.experiments_root_dir]) self.config.model_name = get_model_name(self.config.graph_type, self.config) self.config.checkpoint_dir = os.path.join( self.experiments_root_dir + '/' + self.config.checkpoint_dir + '/', self.config.model_name) self.config.config_dir = os.path.join( self.experiments_root_dir + '/' + self.config.config_dir + '/', self.config.model_name) self.config.log_dir = os.path.join( self.experiments_root_dir + '/' + self.config.log_dir + '/', self.config.model_name) log_dir_subfolders = [ 'reconst', 'latent2d', 'latent3d', 'samples', 'total_random', 'pretoss_random', 'interpolate' ] config_dir_subfolders = ['extra'] file_utils.create_dirs([ self.config.checkpoint_dir, self.config.config_dir, self.config.log_dir ]) file_utils.create_dirs([ self.config.log_dir + '/' + dir_ + '/' for dir_ in log_dir_subfolders ]) file_utils.create_dirs([ self.config.config_dir + '/' + dir_ + '/' for dir_ in config_dir_subfolders ]) load_config = {} try: load_config = file_utils.load_args(self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std']) self.config.update(load_config) self.config.update({ key: config[key] for key in ['kinit', 'bias_init', 'act_out', 'transfer_fct'] }) print('Loading previous configuration ...') except: print('Unable to load previous configuration ...') file_utils.save_args(self.config.dict(), self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std']) if self.config.plot: self.latent_space_files = list() self.latent_space3d_files = list() self.recons_files = list() if hasattr(self.config, 'height'): try: self.config.restore = True self.build_model(self.config.height, self.config.width, self.config.num_channels) except: self.config.isBuild = False else: self.config.isBuild = False ''' ------------------------------------------------------------------------------ EPOCH FUNCTIONS ------------------------------------------------------------------------------- ''' def _train(self, data_train, session, logger): losses = list() for _ in tqdm(range(data_train.num_batches(self.config.batch_size))): batch_x = next(data_train.next_batch(self.config.batch_size)) loss_curr = self.model_graph.train_epoch(session, batch_x) losses.append(loss_curr) cur_it = self.model_graph.global_step_tensor.eval(session) summaries_dict = dict( zip(self.model_graph.losses, np.mean(np.vstack(losses), axis=0))) logger.summarize(cur_it, summarizer='iter_train', log_dict=summaries_dict) losses = np.mean(np.vstack(losses), axis=0) cur_it = self.model_graph.global_step_tensor.eval(session) summaries_dict = dict(zip(self.model_graph.losses, losses)) logger.summarize(cur_it, summarizer='epoch_train', log_dict=summaries_dict) return losses def _evaluate(self, data_eval, session, logger): losses = list() for _ in tqdm(range(data_eval.num_batches(self.config.batch_size))): batch_x = next(data_eval.next_batch(self.config.batch_size)) loss_curr = self.model_graph.evaluate_epoch(session, batch_x) losses.append(loss_curr) cur_it = self.model_graph.global_step_tensor.eval(session) summaries_dict = dict( zip(self.model_graph.losses, np.mean(np.vstack(losses), axis=0))) logger.summarize(cur_it, summarizer='iter_evaluate', log_dict=summaries_dict) losses = np.mean(np.vstack(losses), axis=0) cur_it = self.model_graph.global_step_tensor.eval(session) summaries_dict = dict(zip(self.model_graph.losses, losses)) logger.summarize(cur_it, summarizer='epoch_evaluate', log_dict=summaries_dict) return losses ''' ------------------------------------------------------------------------------ EPOCH FUNCTIONS ------------------------------------------------------------------------------- ''' def fit(self, X, y=None): print('\nProcessing data...') self.data_train, self.data_eval = data_utils.process_data(X, y) if self.config.plot: self.data_plot = self.data_train self.config.num_batches = self.data_train.num_batches( self.config.batch_size) if not self.config.isBuild: self.config.restore = True self.build_model(self.data_train.height, self.data_train.width, self.data_train.num_channels) else: assert (self.config.height == self.data_train.height) and (self.config.width == self.data_train.width) and \ (self.config.num_channels == self.data_train.num_channels), \ 'Wrong dimension of data. Expected shape {}, and got {}'.format((self.config.height,self.config.width, \ self.config.num_channels), \ (self.data_train.height, self.data_train.width, \ self.data_train.num_channels) \ ) ''' ------------------------------------------------------------------------------- TRAIN THE MODEL ------------------------------------------------------------------------------------- ''' print('\nTraining a model...') with tf.Session(graph=self.graph) as session: tf.set_random_seed(self.config.seeds) self.session = session logger = Logger(self.session, self.config.log_dir) self.saver = tf.train.Saver() early_stopper = EarlyStopping(name='total loss', decay_fn=self.decay_fn) if (self.config.restore and self.load(self.session, self.saver)): load_config = file_utils.load_args( self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std']) self.config.update(load_config) num_epochs_trained = self.model_graph.cur_epoch_tensor.eval( self.session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initializing Variables ...') tf.global_variables_initializer().run() for cur_epoch in range( self.model_graph.cur_epoch_tensor.eval(self.session), self.config.epochs + 1, 1): print('EPOCH: ', cur_epoch) self.current_epoch = cur_epoch losses_tr = self._train(self.data_train, self.session, logger) if np.isnan(losses_tr[0]): print( 'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.' ) for lname, lval in zip(self.model_graph.losses, losses_tr): print(lname, lval) sys.exit() losses_eval = self._evaluate(self.data_eval, self.session, logger) train_msg = 'TRAIN: \n' for lname, lval in zip(self.model_graph.losses, losses_tr): train_msg += str(lname) + ': ' + str(lval) + ' | ' eval_msg = 'EVALUATE: \n' for lname, lval in zip(self.model_graph.losses, losses_eval): eval_msg += str(lname) + ': ' + str(lval) + ' | ' print(train_msg) print(eval_msg) print() if (cur_epoch == 1) or ((cur_epoch % self.config.save_epoch == 0) and (cur_epoch != 0)): self.save_model() if self.config.plot: self.plot_latent(cur_epoch) self.plot_reconst(cur_epoch) self.session.run(self.model_graph.increment_cur_epoch_tensor) # Early stopping if (self.config.early_stopping and early_stopper.stop(losses_eval[0])): print('Early Stopping!') break if cur_epoch % self.config.colab_save == 0: if self.config.colab: self.push_colab() self.save_model() if self.config.plot: self.plot_latent(cur_epoch) self.plot_reconst(cur_epoch) if self.config.colab: self.push_colab() return def save_model(self): self.save(self.session, self.saver, self.model_graph.global_step_tensor.eval(self.session)) self.compute_distribution(self.data_train.x) file_utils.save_args(self.config.dict(), self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std']) gc.collect() def compute_distribution(self, x): z = self.encode(x) self.model_graph.config.latent_mean = z.mean(axis=0).compute() self.model_graph.config.latent_std = z.std(axis=0).compute() del z ''' ------------------------------------------------------------------------------ SET NETWORK PARAMS ------------------------------------------------------------------------------ ''' def build_model(self, height, width, num_channels): self.config.height = height self.config.width = width self.config.num_channels = num_channels self.graph = tf.Graph() with self.graph.as_default(): self.model_graph = Factory(self.config) print(self.model_graph) self.trainable_count = np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ]) print('\nNumber of trainable paramters', self.trainable_count) self.test_graph() ''' ------------------------------------------------------------------------------- GOOGLE COLAB ------------------------------------------------------------------------------------- ''' if self.config.colab: self.push_colab() self.config.push_colab = self.push_colab self.config.isBuild = True file_utils.save_args(self.config.dict(), self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std']) def test_graph(self): with tf.Session(graph=self.graph) as session: tf.set_random_seed(self.config.seeds) self.session = session logger = Logger(self.session, self.config.log_dir) self.saver = tf.train.Saver() if (self.config.restore and self.load(self.session, self.saver)): load_config = file_utils.load_args( self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std']) self.config.update(load_config) num_epochs_trained = self.model_graph.cur_epoch_tensor.eval( self.session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initializing Variables ...') tf.global_variables_initializer().run() print('random latent batch ...') samples = self.model_graph._sampling_reconst( session, std_scales=np.ones(self.config.latent_dim))[0] print('random latent shape {}'.format(samples.shape)) def _sampling_reconst(self, std_scales, random_latent=None): def aux_fun(session, rand_samp): return self.model_graph._sampling_reconst(session=session, std_scales=std_scales, random_latent=rand_samp) with tf.Session(graph=self.graph) as session: tf.set_random_seed(self.config.seeds) self.session = session logger = Logger(self.session, self.config.log_dir) self.saver = tf.train.Saver() if (self.config.restore and self.load(self.session, self.saver)): load_config = file_utils.load_args( self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std']) self.config.update(load_config) num_epochs_trained = self.model_graph.cur_epoch_tensor.eval( self.session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initializing Variables ...') tf.global_variables_initializer().run() samples = list() if random_latent is None: while True: samples.append( self.model_graph._sampling_reconst( session=session, std_scales=std_scales)[0]) if len(samples) >= (100 // self.config.batch_size) + 1: samples = da.vstack(samples) samples = samples[:100] break else: samples = self.batch_function(aux_fun, random_latent) scaler = MinMaxScaler() return scaler.fit_transform(samples.flatten().reshape(-1, 1).astype( np.float32)).reshape(samples.shape)