def train_VQVAE(data_dir):
    # strategy = get_distribution_strategy(use_cpus=False, logical_per_physical_factor=1, memory_limit=10000)

    # lists containing tfrecord files
    train_dataset = build_dataset(os.path.join(data_dir, 'train'))
    test_dataset = build_dataset(os.path.join(data_dir, 'test'))

    train_dataset = train_dataset.map(lambda graph, img, c: (img, )).batch(
        batch_size=32)
    test_dataset = test_dataset.map(lambda graph, img, c: (img, )).batch(
        batch_size=32)

    # with strategy.scope():
    model = VectorQuantizerVariationalAutoEncoder(embedding_dim=32,
                                                  num_embeddings=1024,
                                                  kernel_size=4)

    learning_rate = 1e-4
    opt = snt.optimizers.Adam(learning_rate)

    def loss(model_outputs, batch):
        (img, ) = batch
        vq_loss, decoded_img = model_outputs
        print('im shape', img.shape)
        print('dec im shape', decoded_img.shape)
        # reconstruction_loss = tf.reduce_mean(tf.reduce_sum(
        #     keras.losses.binary_crossentropy(img, decoded_img), axis=(1, 2)
        #         ))
        reconstruction_loss = tf.reduce_mean(
            (img - decoded_img[:, :, :, :])**2)
        tf.summary.scalar('reconstruction loss',
                          reconstruction_loss,
                          step=model.step)
        tf.summary.scalar('vq_loss', vq_loss, step=model.step)
        total_loss = reconstruction_loss + vq_loss
        return total_loss

    train_one_epoch = TrainOneEpoch(model, loss, opt, strategy=None)

    log_dir = 'VQVAE_log_dir_16_1024'
    checkpoint_dir = 'VQVAE_checkpointing_16_1024'
    model_dir = 'trained_VAE_model_16_1024'

    vanilla_training_loop(train_one_epoch=train_one_epoch,
                          training_dataset=train_dataset,
                          test_dataset=test_dataset,
                          num_epochs=10000,
                          early_stop_patience=10000,
                          checkpoint_dir=checkpoint_dir,
                          log_dir=log_dir,
                          debug=False,
                          save_model_dir=model_dir)
def train_autoencoder(data_dir):
    # strategy = get_distribution_strategy(use_cpus=False, logical_per_physical_factor=1, memory_limit=10000)

    # lists containing tfrecord files
    train_dataset = build_dataset(os.path.join(data_dir, 'train'))
    test_dataset = build_dataset(os.path.join(data_dir, 'test'))

    # print(f'Number of training tfrecord files : {len(train_tfrecords)}')
    # print(f'Number of test tfrecord files : {len(test_tfrecords)}')
    # print(f'Total : {len(train_tfrecords) + len(test_tfrecords)}')
    #
    # train_dataset = build_dataset(train_tfrecords)
    # test_dataset = build_dataset(test_tfrecords)

    train_dataset = train_dataset.map(lambda graph, img, c: (img, )).batch(
        batch_size=32)
    test_dataset = test_dataset.map(lambda graph, img, c: (img, )).batch(
        batch_size=32)

    # with strategy.scope():
    model = AutoEncoder()

    learning_rate = 1.0e-5

    opt = snt.optimizers.Adam(learning_rate)

    def loss(model_outputs, batch):
        (img, ) = batch
        decoded_img = model_outputs
        # return tf.reduce_mean((gaussian_filter2d(img, filter_shape=[6, 6]) - decoded_img[:, :, :, :]) ** 2)
        return 100 * tf.reduce_mean((img - decoded_img[:, :, :, :])**2)

    train_one_epoch = TrainOneEpoch(model, loss, opt, strategy=None)

    log_dir = 'autoencoder_log_dir'
    checkpoint_dir = 'autoencoder_checkpointing'

    vanilla_training_loop(train_one_epoch=train_one_epoch,
                          training_dataset=train_dataset,
                          test_dataset=test_dataset,
                          num_epochs=1000,
                          early_stop_patience=1000,
                          checkpoint_dir=checkpoint_dir,
                          log_dir=log_dir,
                          debug=False)
Beispiel #3
0
def test_vanillia_training_loop():
    import sonnet as snt

    class Model(AbstractModule):
        def __init__(self, name=None):
            super(Model, self).__init__(name=name)
            self.net = snt.nets.MLP([10, 1], activate_final=False)

        def _build(self, batch):
            (inputs, _) = batch
            return self.net(inputs)

    def loss(model_output, batch):
        (_, target) = batch
        return tf.reduce_mean((target - model_output)**2)

    dataset = tf.data.Dataset.from_tensor_slices((tf.random.normal(
        (100, 5)), tf.random.normal((100, 1)))).batch(10)

    training = TrainOneEpoch(Model(), loss, snt.optimizers.Adam(1e-4))
    vanilla_training_loop(dataset, training, 100, debug=False)
Beispiel #4
0
def train_autoencoder(data_dir):
    train_tfrecords = glob.glob(os.path.join(data_dir, 'train', '*.tfrecords'))
    test_tfrecords = glob.glob(os.path.join(data_dir, 'test', '*.tfrecords'))

    print(f'Number of training tfrecord files : {len(train_tfrecords)}')
    print(f'Number of test tfrecord files : {len(test_tfrecords)}')
    print(f'Total : {len(train_tfrecords) + len(test_tfrecords)}')

    train_dataset = build_dataset(train_tfrecords)
    test_dataset = build_dataset(test_tfrecords)

    train_dataset = train_dataset.map(lambda graph, img, c: (img, )).batch(
        batch_size=32)
    test_dataset = test_dataset.map(lambda graph, img, c: (img, )).batch(
        batch_size=32)

    model = AutoEncoder(kernel_size=4)

    learning_rate = 1e-5
    opt = snt.optimizers.Adam(learning_rate)

    def loss(model_outputs, batch):
        (img, ) = batch
        decoded_img = model_outputs
        return tf.reduce_mean((gaussian_filter2d(img, filter_shape=[6, 6]) -
                               decoded_img[:, 12:-12, 12:-12, :])**2)

    train_one_epoch = TrainOneEpoch(model, loss, opt, strategy=None)

    log_dir = 'autoencoder_log_dir'
    checkpoint_dir = 'autoencoder_checkpointing'

    vanilla_training_loop(train_one_epoch=train_one_epoch,
                          training_dataset=train_dataset,
                          test_dataset=test_dataset,
                          num_epochs=50,
                          early_stop_patience=5,
                          checkpoint_dir=checkpoint_dir,
                          log_dir=log_dir,
                          debug=False)
Beispiel #5
0
def build_training(model_type,
                   model_parameters,
                   optimizer_parameters,
                   loss_parameters,
                   strategy=None,
                   **kwargs) -> TrainOneEpoch:
    model_cls = MODEL_MAP[model_type]
    model = model_cls(**model_parameters, **kwargs)

    def build_opt(**kwargs):
        opt_type = kwargs.get('opt_type')
        if opt_type == 'adam':
            learning_rate = kwargs.get('learning_rate', 1e-4)
            opt = snt.optimizers.Adam(learning_rate,
                                      beta1=1 - 1 / 100,
                                      beta2=1 - 1 / 500)
        else:
            raise ValueError('Opt {} invalid'.format(opt_type))
        return opt

    def build_loss(**loss_parameters):
        def loss(model_outputs, batch):
            graph = batch
            decoded_graph, nn_index = model_outputs
            print('shape', decoded_graph.nodes.shape)
            return tf.reduce_mean(
                (tf.gather(graph.nodes[:, 3:], nn_index) - decoded_graph.nodes)
                **2 *
                tf.constant([0, 0, 0, 1, 0, 0, 0], dtype=graph.nodes.dtype))

        return loss

    loss = build_loss(**loss_parameters)
    opt = build_opt(**optimizer_parameters)

    training = TrainOneEpoch(model, loss, opt, strategy=strategy)

    return training
Beispiel #6
0
def build_training(model_type,
                   model_parameters,
                   optimizer_parameters,
                   loss_parameters,
                   strategy=None,
                   **kwargs) -> TrainOneEpoch:
    model_cls = MODEL_MAP[model_type]
    model = model_cls(**model_parameters, **kwargs)

    def build_opt(**kwargs):
        opt_type = kwargs.get('opt_type')
        if opt_type == 'adam':
            learning_rate = kwargs.get('learning_rate', 1e-4)
            opt = snt.optimizers.Adam(learning_rate,
                                      beta1=1 - 1 / 100,
                                      beta2=1 - 1 / 500)
        else:
            raise ValueError('Opt {} invalid'.format(opt_type))
        return opt

    def build_loss(**loss_parameters):
        def loss(model_outputs, batch):
            (encoded_graphs, decoded_graphs) = model_outputs
            (graph, positions) = batch
            # loss =  mean(-sum_k^2 true[k] * log(pred[k]/true[k]))
            return tf.math.sqrt(
                tf.reduce_mean(
                    tf.math.square(graph.nodes[:, 3:] -
                                   decoded_graphs[-1].nodes)))

        return loss

    loss = build_loss(**loss_parameters)
    opt = build_opt(**optimizer_parameters)

    training = TrainOneEpoch(model, loss, opt, strategy=strategy)

    return training