def train_VQVAE(data_dir): # strategy = get_distribution_strategy(use_cpus=False, logical_per_physical_factor=1, memory_limit=10000) # lists containing tfrecord files train_dataset = build_dataset(os.path.join(data_dir, 'train')) test_dataset = build_dataset(os.path.join(data_dir, 'test')) train_dataset = train_dataset.map(lambda graph, img, c: (img, )).batch( batch_size=32) test_dataset = test_dataset.map(lambda graph, img, c: (img, )).batch( batch_size=32) # with strategy.scope(): model = VectorQuantizerVariationalAutoEncoder(embedding_dim=32, num_embeddings=1024, kernel_size=4) learning_rate = 1e-4 opt = snt.optimizers.Adam(learning_rate) def loss(model_outputs, batch): (img, ) = batch vq_loss, decoded_img = model_outputs print('im shape', img.shape) print('dec im shape', decoded_img.shape) # reconstruction_loss = tf.reduce_mean(tf.reduce_sum( # keras.losses.binary_crossentropy(img, decoded_img), axis=(1, 2) # )) reconstruction_loss = tf.reduce_mean( (img - decoded_img[:, :, :, :])**2) tf.summary.scalar('reconstruction loss', reconstruction_loss, step=model.step) tf.summary.scalar('vq_loss', vq_loss, step=model.step) total_loss = reconstruction_loss + vq_loss return total_loss train_one_epoch = TrainOneEpoch(model, loss, opt, strategy=None) log_dir = 'VQVAE_log_dir_16_1024' checkpoint_dir = 'VQVAE_checkpointing_16_1024' model_dir = 'trained_VAE_model_16_1024' vanilla_training_loop(train_one_epoch=train_one_epoch, training_dataset=train_dataset, test_dataset=test_dataset, num_epochs=10000, early_stop_patience=10000, checkpoint_dir=checkpoint_dir, log_dir=log_dir, debug=False, save_model_dir=model_dir)
def train_autoencoder(data_dir): # strategy = get_distribution_strategy(use_cpus=False, logical_per_physical_factor=1, memory_limit=10000) # lists containing tfrecord files train_dataset = build_dataset(os.path.join(data_dir, 'train')) test_dataset = build_dataset(os.path.join(data_dir, 'test')) # print(f'Number of training tfrecord files : {len(train_tfrecords)}') # print(f'Number of test tfrecord files : {len(test_tfrecords)}') # print(f'Total : {len(train_tfrecords) + len(test_tfrecords)}') # # train_dataset = build_dataset(train_tfrecords) # test_dataset = build_dataset(test_tfrecords) train_dataset = train_dataset.map(lambda graph, img, c: (img, )).batch( batch_size=32) test_dataset = test_dataset.map(lambda graph, img, c: (img, )).batch( batch_size=32) # with strategy.scope(): model = AutoEncoder() learning_rate = 1.0e-5 opt = snt.optimizers.Adam(learning_rate) def loss(model_outputs, batch): (img, ) = batch decoded_img = model_outputs # return tf.reduce_mean((gaussian_filter2d(img, filter_shape=[6, 6]) - decoded_img[:, :, :, :]) ** 2) return 100 * tf.reduce_mean((img - decoded_img[:, :, :, :])**2) train_one_epoch = TrainOneEpoch(model, loss, opt, strategy=None) log_dir = 'autoencoder_log_dir' checkpoint_dir = 'autoencoder_checkpointing' vanilla_training_loop(train_one_epoch=train_one_epoch, training_dataset=train_dataset, test_dataset=test_dataset, num_epochs=1000, early_stop_patience=1000, checkpoint_dir=checkpoint_dir, log_dir=log_dir, debug=False)
def test_vanillia_training_loop(): import sonnet as snt class Model(AbstractModule): def __init__(self, name=None): super(Model, self).__init__(name=name) self.net = snt.nets.MLP([10, 1], activate_final=False) def _build(self, batch): (inputs, _) = batch return self.net(inputs) def loss(model_output, batch): (_, target) = batch return tf.reduce_mean((target - model_output)**2) dataset = tf.data.Dataset.from_tensor_slices((tf.random.normal( (100, 5)), tf.random.normal((100, 1)))).batch(10) training = TrainOneEpoch(Model(), loss, snt.optimizers.Adam(1e-4)) vanilla_training_loop(dataset, training, 100, debug=False)
def train_autoencoder(data_dir): train_tfrecords = glob.glob(os.path.join(data_dir, 'train', '*.tfrecords')) test_tfrecords = glob.glob(os.path.join(data_dir, 'test', '*.tfrecords')) print(f'Number of training tfrecord files : {len(train_tfrecords)}') print(f'Number of test tfrecord files : {len(test_tfrecords)}') print(f'Total : {len(train_tfrecords) + len(test_tfrecords)}') train_dataset = build_dataset(train_tfrecords) test_dataset = build_dataset(test_tfrecords) train_dataset = train_dataset.map(lambda graph, img, c: (img, )).batch( batch_size=32) test_dataset = test_dataset.map(lambda graph, img, c: (img, )).batch( batch_size=32) model = AutoEncoder(kernel_size=4) learning_rate = 1e-5 opt = snt.optimizers.Adam(learning_rate) def loss(model_outputs, batch): (img, ) = batch decoded_img = model_outputs return tf.reduce_mean((gaussian_filter2d(img, filter_shape=[6, 6]) - decoded_img[:, 12:-12, 12:-12, :])**2) train_one_epoch = TrainOneEpoch(model, loss, opt, strategy=None) log_dir = 'autoencoder_log_dir' checkpoint_dir = 'autoencoder_checkpointing' vanilla_training_loop(train_one_epoch=train_one_epoch, training_dataset=train_dataset, test_dataset=test_dataset, num_epochs=50, early_stop_patience=5, checkpoint_dir=checkpoint_dir, log_dir=log_dir, debug=False)
def build_training(model_type, model_parameters, optimizer_parameters, loss_parameters, strategy=None, **kwargs) -> TrainOneEpoch: model_cls = MODEL_MAP[model_type] model = model_cls(**model_parameters, **kwargs) def build_opt(**kwargs): opt_type = kwargs.get('opt_type') if opt_type == 'adam': learning_rate = kwargs.get('learning_rate', 1e-4) opt = snt.optimizers.Adam(learning_rate, beta1=1 - 1 / 100, beta2=1 - 1 / 500) else: raise ValueError('Opt {} invalid'.format(opt_type)) return opt def build_loss(**loss_parameters): def loss(model_outputs, batch): graph = batch decoded_graph, nn_index = model_outputs print('shape', decoded_graph.nodes.shape) return tf.reduce_mean( (tf.gather(graph.nodes[:, 3:], nn_index) - decoded_graph.nodes) **2 * tf.constant([0, 0, 0, 1, 0, 0, 0], dtype=graph.nodes.dtype)) return loss loss = build_loss(**loss_parameters) opt = build_opt(**optimizer_parameters) training = TrainOneEpoch(model, loss, opt, strategy=strategy) return training
def build_training(model_type, model_parameters, optimizer_parameters, loss_parameters, strategy=None, **kwargs) -> TrainOneEpoch: model_cls = MODEL_MAP[model_type] model = model_cls(**model_parameters, **kwargs) def build_opt(**kwargs): opt_type = kwargs.get('opt_type') if opt_type == 'adam': learning_rate = kwargs.get('learning_rate', 1e-4) opt = snt.optimizers.Adam(learning_rate, beta1=1 - 1 / 100, beta2=1 - 1 / 500) else: raise ValueError('Opt {} invalid'.format(opt_type)) return opt def build_loss(**loss_parameters): def loss(model_outputs, batch): (encoded_graphs, decoded_graphs) = model_outputs (graph, positions) = batch # loss = mean(-sum_k^2 true[k] * log(pred[k]/true[k])) return tf.math.sqrt( tf.reduce_mean( tf.math.square(graph.nodes[:, 3:] - decoded_graphs[-1].nodes))) return loss loss = build_loss(**loss_parameters) opt = build_opt(**optimizer_parameters) training = TrainOneEpoch(model, loss, opt, strategy=strategy) return training