Пример #1
0
    def train(self, hparams_string):
        """ Run training of the network
        Args:
    
        Returns:
        """
        args_train = hparams_parser_train(hparams_string)

        self.batch_size = args_train.batch_size
        self.epoch_max = args_train.epoch_max

        utils.save_model_configuration(args_train, self.dir_base)

        # Use dataset for loading in datasamples from .tfrecord (https://www.tensorflow.org/programmers_guide/datasets#consuming_tfrecord_data)
        # The iterator will get a new batch from the dataset each time a sess.run() is executed on the graph.
        dataset = tf.data.TFRecordDataset(self.dateset_filenames)
        dataset = dataset.map(util_data.decode_image)  # decoding the tfrecord
        dataset = dataset.map(
            self._preProcessData)  # potential local preprocessing of data
        dataset = dataset.shuffle(buffer_size=10000, seed=None)
        dataset = dataset.batch(batch_size=self.batch_size)
        iterator = dataset.make_initializable_iterator()
        inputs = iterator.get_next()

        # depends on self._preProcessData
        [in_image, in_label] = inputs

        # show network architecture
        utils.show_all_variables()

        # define model, loss, optimizer and summaries.
        outputs = self._create_inference(in_image)
        loss = self._create_losses(outputs, in_label)
        optimizer_op = self._create_optimizer(loss)
        summary_op = self._create_summaries(loss)

        with tf.Session() as sess:

            # Initialize all model Variables.
            sess.run(tf.global_variables_initializer())

            # Create Saver object for loading and storing checkpoints
            saver = tf.train.Saver()

            # Create Writer object for storing graph and summaries for TensorBoard
            writer = tf.summary.FileWriter(self.dir_logs, sess.graph)

            # Reload Tensor values from latest checkpoint
            ckpt = tf.train.get_checkpoint_state(self.dir_checkpoints)
            epoch_start = 0
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
                epoch_start = int(ckpt_name.split('-')[-1])

            interationCnt = 0
            # Do training loops
            for epoch_n in range(epoch_start, self.epoch_max):

                # Initiate or Re-initiate iterator
                sess.run(iterator.initializer)

                # Test model output before any training
                if epoch_n == 0:
                    summary = sess.run(summary_op)
                    writer.add_summary(summary, global_step=-1)

                utils.show_message(
                    'Running training epoch no: {0}'.format(epoch_n))
                while True:
                    try:
                        _, summary = sess.run([optimizer_op, summary_op])

                        writer.add_summary(summary, global_step=interationCnt)
                        counter = +1

                    except tf.errors.OutOfRangeError:
                        # Do some evaluation after each Epoch
                        break

                if epoch_n % 1 == 0:
                    saver.save(sess,
                               os.path.join(self.dir_checkpoints,
                                            self.model + '.model'),
                               global_step=epoch_n)
Пример #2
0
    def train(self, hparams_string):
        """ Run training of the network
        Args:
    
        Returns:
        """

        args_train = hparams_parser_train(hparams_string)
        self.batch_size = args_train.batch_size
        self.epoch_max = args_train.epoch_max
        self.use_imagenet = args_train.use_imagenet
        self.model_version = args_train.model_version

        utils.save_model_configuration(args_train, self.dir_base)

        # Use dataset for loading in datasamples from .tfrecord (https://www.tensorflow.org/programmers_guide/datasets#consuming_tfrecord_data)
        # The iterator will get a new batch from the dataset each time a sess.run() is executed on the graph.
        dataset = tf.data.TFRecordDataset(self.dateset_filenames)
        dataset = dataset.map(util_data.decode_image)  # decoding the tfrecord
        dataset = dataset.map(
            self._preProcessData)  # potential local preprocessing of data
        dataset = dataset.shuffle(buffer_size=10000, seed=None)
        dataset = dataset.batch(batch_size=self.batch_size)
        iterator = dataset.make_initializable_iterator()
        input_getBatch = iterator.get_next()

        input_images = tf.placeholder(dtype=tf.float32,
                                      shape=[None] + self.image_dims,
                                      name='input_images')
        input_lbls = tf.placeholder(dtype=tf.float32,
                                    shape=[None, self.lbls_dim],
                                    name='input_lbls')

        # define model, loss, optimizer and summaries.
        output_logits = self._create_inference(input_images)
        loss = self._create_losses(output_logits, input_lbls)
        optimizer_op = self._create_optimizer(loss)
        summary_op = self._create_summaries(loss)

        # show network architecture
        utils.show_all_variables()

        if self.use_imagenet:
            if self.model_version == 'VGG16':
                path_imagenet_ckpt = os.path.join(self.dir_checkpoints,
                                                  'vgg_16.ckpt')
                if not tf.gfile.Exists(path_imagenet_ckpt):
                    url_imagenet_model = "http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz"
                    utils.download_and_uncompress_tarball(
                        url_imagenet_model, self.dir_checkpoints)

                variables_to_restore = slim.get_model_variables('vgg_16')
                variables_to_restore = variables_to_restore[:
                                                            -6]  # ignore fc layers
                init_fn = slim.assign_from_checkpoint_fn(
                    path_imagenet_ckpt, variables_to_restore)

            elif self.model_version == 'VGG19':
                path_imagenet_ckpt = os.path.join(self.dir_checkpoints,
                                                  'vgg_19.ckpt')
                if not tf.gfile.Exists(path_imagenet_ckpt):
                    url_imagenet_model = "http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz"
                    utils.download_and_uncompress_tarball(
                        url_imagenet_model, self.dir_checkpoints)

                variables_to_restore = slim.get_model_variables('vgg_19')
                variables_to_restore = variables_to_restore[:
                                                            -6]  # ignore fc layers
                init_fn = slim.assign_from_checkpoint_fn(
                    path_imagenet_ckpt, variables_to_restore)

        with tf.Session() as sess:

            # Initialize all model Variables.
            sess.run(tf.global_variables_initializer())

            if self.use_imagenet:
                init_fn(sess)

            # Create Saver object for loading and storing checkpoints
            saver = tf.train.Saver()

            # Create Writer object for storing graph and summaries for TensorBoard
            writer = tf.summary.FileWriter(self.dir_logs, sess.graph)

            # Reload Tensor values from latest checkpoint
            ckpt = tf.train.get_checkpoint_state(self.dir_checkpoints)
            epoch_start = 0
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
                epoch_start = int(ckpt_name.split('-')[-1])

            interationCnt = 0
            # Do training loops
            for epoch_n in range(epoch_start, self.epoch_max):

                # Initiate or Re-initiate iterator
                sess.run(iterator.initializer)

                # Test model output before any training
                # if epoch_n == 0:
                #     summary_loss = sess.run(summary_op)
                #     writer.add_summary(summary_loss, global_step=-1)

                utils.show_message(
                    'Running training epoch no: {0}'.format(epoch_n))
                while True:
                    try:
                        image_batch, lbl_batch = sess.run(input_getBatch)
                        _, summary_loss = sess.run([optimizer_op, summary_op],
                                                   feed_dict={
                                                       input_images:
                                                       image_batch,
                                                       input_lbls: lbl_batch
                                                   })

                        writer.add_summary(summary_loss,
                                           global_step=interationCnt)
                        counter = +1

                    except tf.errors.OutOfRangeError:
                        # Do some evaluation after each Epoch
                        break

                if epoch_n % 1 == 0:
                    saver.save(sess,
                               os.path.join(self.dir_checkpoints,
                                            self.model + '.model'),
                               global_step=epoch_n)
Пример #3
0
    def train(self, hparams_string):
        """ Run training of the network
        Args:
    
        Returns:
        """
        args_train = hparams_parser_train(hparams_string)

        self.batch_size = args_train.batch_size
        self.epoch_max = args_train.epoch_max
        self.unstructured_noise_dim = args_train.unstructured_noise_dim

        self.d_learning_rate = args_train.lr_discriminator
        self.g_learning_rate = args_train.lr_generator

        self.d_iter = args_train.d_iter
        self.n_testsamples = args_train.n_testsamples

        self.class_scale_d = args_train.class_scale_d
        self.class_scale_g = args_train.class_scale_g

        self.backup_frequency = args_train.backup_frequency

        utils.save_model_configuration(args_train, self.dir_base)
        

        # Use dataset for loading in datasamples from .tfrecord (https://www.tensorflow.org/programmers_guide/datasets#consuming_tfrecord_data)
        # The iterator will get a new batch from the dataset each time a sess.run() is executed on the graph.
        dataset = tf.data.TFRecordDataset(self.dateset_filenames)
        dataset = dataset.map(util_data.decode_image)      # decoding the tfrecord
        dataset = dataset.map(self._genLatentCodes)
        dataset = dataset.shuffle(buffer_size = 10000, seed = None)
        dataset = dataset.batch(batch_size = self.batch_size)
        iterator = dataset.make_initializable_iterator()
        input_getBatch = iterator.get_next()

        # Create input placeholders
        input_images = tf.placeholder(
            dtype = tf.float32, 
            shape = [None] + self.image_dims, 
            name = 'input_images')
        input_lbls = tf.placeholder(
            dtype = tf.float32,   
            shape = [None, self.lbls_dim], 
            name = 'input_lbls')
        input_unstructured_noise = tf.placeholder(
            dtype = tf.float32, 
            shape = [None, self.unstructured_noise_dim], 
            name = 'input_unstructured_noise')
        input_test_lbls = tf.placeholder(
            dtype = tf.float32, 
            shape = [self.n_testsamples * self.lbls_dim, self.lbls_dim], 
            name = 'input_test_lbls')
        input_test_noise = tf.placeholder(
            dtype = tf.float32, 
            shape = [self.n_testsamples * self.lbls_dim, self.unstructured_noise_dim], 
            name = 'input_test_noise')
               
        
        # Define model, loss, optimizer and summaries.
        logits_source, logits_class, _ = self._create_inference(input_images, input_lbls, input_unstructured_noise)
        loss_discriminator, loss_generator = self._create_losses(logits_source, logits_class, input_lbls)
        train_op_discriminator, train_op_generator = self._create_optimizer(loss_discriminator, loss_generator)
        summary_op_dloss, summary_op_gloss, summary_op_img, summary_img = self._create_summaries(loss_discriminator, loss_generator, input_test_noise, input_test_lbls)

        # show network architecture
        utils.show_all_variables()

        # create constant test variable to inspect changes in the model
        test_noise, test_lbls = self._genTestInput(self.lbls_dim, n_samples = self.n_testsamples)

        dir_results_train = os.path.join(self.dir_results, 'Training')
        utils.checkfolder(dir_results_train)

        with tf.Session() as sess:
            # Initialize all model Variables.
            sess.run(tf.global_variables_initializer())
            
            # Create Saver object for loading and storing checkpoints
            saver = tf.train.Saver()
            
            # Create Writer object for storing graph and summaries for TensorBoard
            writer = tf.summary.FileWriter(self.dir_logs, sess.graph)

            # Reload Tensor values from latest checkpoint
            ckpt = tf.train.get_checkpoint_state(self.dir_checkpoints)
            epoch_start = 0
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
                epoch_start = int(ckpt_name.split('-')[-1]) + 1
            
            interationCnt = 0
            for epoch_n in range(epoch_start, self.epoch_max):

                # Test model output before any training
                if epoch_n == 0:
                    summaryImg_tb, summaryImg = sess.run(
                        [summary_op_img, summary_img],
                        feed_dict={input_test_noise:    test_noise,
                                   input_test_lbls:     test_lbls})

                    writer.add_summary(summaryImg_tb, global_step=-1)
                    utils.save_image_local(summaryImg, dir_results_train, 'Epoch_' + str(-1))

                # Initiate or Re-initiate iterator
                sess.run(iterator.initializer)
                
                ### ----------------------------------------------------------
                ### Update model
                print(datetime.datetime.now(),'- Running training epoch no:', epoch_n)
                while True:
                # for idx in range(0, num_batches):
                    try:
                        for _ in range(self.d_iter):
                            image_batch, lbl_batch, unst_noise_batch = sess.run(input_getBatch)

                            _, summary_dloss, _ = sess.run(
                                [train_op_discriminator, summary_op_dloss],
                                feed_dict={input_images:             image_batch,
                                        input_lbls:               lbl_batch,
                                        input_unstructured_noise: unst_noise_batch})
                                        
                        writer.add_summary(summary_dloss, global_step=interationCnt)

                        _, summary_gloss = sess.run(
                            [train_op_generator, summary_op_gloss],
                            feed_dict={input_images:             image_batch,
                                       input_lbls:               lbl_batch,
                                       input_unstructured_noise: unst_noise_batch})

                        writer.add_summary(summary_gloss, global_step=interationCnt)
                        interationCnt += 1

                    except tf.errors.OutOfRangeError:
                        # Test current model
                        summaryImg_tb, summaryImg = sess.run(
                            [summary_op_img, summary_img],
                            feed_dict={input_test_noise:    test_noise,
                                        input_test_lbls:     test_lbls})

                        writer.add_summary(summaryImg_tb, global_step=epoch_n)
                        utils.save_image_local(summaryImg, dir_results_train, 'Epoch_' + str(epoch_n))

                        break
                
                # Save model variables to checkpoint
                if (epoch_n +1) % self.backup_frequency == 0:
                    saver.save(sess,os.path.join(self.dir_checkpoints, self.model + '.model'), global_step=epoch_n)
Пример #4
0
    def train(self, hparams_string):
        """ Run training of the network
        Args:
    
        Returns:
        """
        args_train = hparams_parser_train(hparams_string)

        self.batch_size = args_train.batch_size
        self.epoch_max = args_train.epoch_max

        self.unstructured_noise_dim = args_train.unstructured_noise_dim
        self.info_var_dim = args_train.info_var_dim
        self.n_testsamples = args_train.n_testsamples

        self.d_learning_rate = args_train.lr_discriminator
        self.g_learning_rate = args_train.lr_generator
        self.d_iter = args_train.d_iter

        self.gp_lambda = args_train.gp_lambda
        self.class_scale_d = args_train.class_scale_d
        self.class_scale_g = args_train.class_scale_g

        self.info_scale_d = args_train.info_scale_d
        self.info_scale_g = args_train.info_scale_g

        self.backup_frequency = args_train.backup_frequency

        self.shards_idx_test = args_train.shards_idx_test

        utils.save_model_configuration(args_train, self.dir_base)

        # Create folder for saving training results
        dir_results_train = os.path.join(self.dir_results, 'Training')
        utils.checkfolder(dir_results_train)

        for class_n in range(self.lbls_dim):
            dir_result_train_class = dir_results_train + '/' + str(
                class_n).zfill(2)
            utils.checkfolder(dir_result_train_class)

        if 0 in self.shards_idx_test:
            dataset_filenames = self.dataset_filenames
        else:
            self.shards_idx_test = np.subtract(self.shards_idx_test, 1)
            shards_idx_training = np.delete(range(len(self.dataset_filenames)),
                                            self.shards_idx_test)
            dataset_filenames = [
                self.dataset_filenames[i] for i in shards_idx_training
            ]

            utils.show_message('Training Data:')
            print(dataset_filenames)

        # Setup preprocessing pipeline
        preprocessing = preprocess_factory.preprocess_factory()

        # Dataset specific preprocessing
        if self.dataset == 'MNIST':
            pass

        elif self.dataset == 'PSD_Nonsegmented':
            pass

        elif self.dataset == 'PSD_Segmented':
            preprocessing.prep_pipe_from_string(
                "pad_to_size;{'height': 566, 'width': 566, 'constant': -1.0};random_rotation;{};crop_to_size;{'height': 400, 'width': 400};resize;{'height': 128, 'width': 128}"
            )

        # Use dataset for loading in datasamples from .tfrecord (https://www.tensorflow.org/programmers_guide/datasets#consuming_tfrecord_data)
        # The iterator will get a new batch from the dataset each time a sess.run() is executed on the graph.
        dataset = tf.data.TFRecordDataset(dataset_filenames)
        dataset = dataset.shuffle(buffer_size=10000, seed=None)
        dataset = dataset.map(util_data.decode_image)  # decoding the tfrecord
        dataset = dataset.map(
            self._genLatentCodes)  # preprocess data and perform augmentation
        dataset = dataset.map(preprocessing.pipe)
        dataset = dataset.batch(batch_size=self.batch_size)
        iterator = dataset.make_initializable_iterator()
        input_getBatch = iterator.get_next()

        # Create input placeholders
        input_images = tf.placeholder(dtype=tf.float32,
                                      shape=[self.batch_size] +
                                      self.image_dims,
                                      name='input_images')
        input_lbls = tf.placeholder(dtype=tf.float32,
                                    shape=[None, self.lbls_dim],
                                    name='input_lbls')
        input_unstructured_noise = tf.placeholder(
            dtype=tf.float32,
            shape=[None, self.unstructured_noise_dim],
            name='input_unstructured_noise')
        input_info_noise = tf.placeholder(dtype=tf.float32,
                                          shape=[None, self.info_var_dim],
                                          name='input_info_noise')
        input_test_lbls = tf.placeholder(dtype=tf.float32,
                                         shape=[
                                             self.n_testsamples**np.minimum(
                                                 2, self.info_var_dim),
                                             self.lbls_dim
                                         ],
                                         name='input_test_lbls')
        input_test_noise = tf.placeholder(dtype=tf.float32,
                                          shape=[
                                              self.n_testsamples**np.minimum(
                                                  2, self.info_var_dim),
                                              self.unstructured_noise_dim
                                          ],
                                          name='input_test_noise')
        input_test_info_noise = tf.placeholder(
            dtype=tf.float32,
            shape=[
                self.n_testsamples**np.minimum(2, self.info_var_dim),
                self.info_var_dim
            ],
            name='input_test_info_noise')

        # Define model, loss, optimizer and summaries.
        logits_source, logits_class, logits_info, artificial_images = self._create_inference(
            input_images, input_lbls, input_unstructured_noise,
            input_info_noise)
        loss_discriminator, loss_generator = self._create_losses(
            logits_source, logits_class, logits_info, artificial_images,
            input_lbls, input_info_noise)
        train_op_discriminator, train_op_generator = self._create_optimizer(
            loss_discriminator, loss_generator)
        summary_op_dloss, summary_op_gloss, summary_op_img, summary_img = self._create_summaries(
            loss_discriminator, loss_generator, input_test_noise,
            input_test_lbls, input_test_info_noise)

        # show network architecture
        utils.show_all_variables()

        # create constant test variable to inspect changes in the model
        self.combinations_info_var = itertools.combinations(
            range(self.info_var_dim), 2)
        self.combinations_info_var = list(self.combinations_info_var)

        test_noise, test_info = self._genTestInput()

        with tf.Session() as sess:
            # Initialize all model Variables.
            sess.run(tf.global_variables_initializer())

            # Create Saver object for loading and storing checkpoints
            saver = tf.train.Saver(max_to_keep=500)

            # Create Writer object for storing graph and summaries for TensorBoard
            writer = tf.summary.FileWriter(self.dir_logs, sess.graph)

            # Reload Tensor values from latest checkpoint
            ckpt = tf.train.get_checkpoint_state(self.dir_checkpoints)
            epoch_start = 0
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
                epoch_start = int(ckpt_name.split('-')[-1]) + 1

            interationCnt = 0
            for epoch_n in range(epoch_start, self.epoch_max):

                # Test model output before any training
                if epoch_n == 0:
                    for class_n in range(self.lbls_dim):
                        test_lbls = np.zeros([
                            self.n_testsamples**np.minimum(
                                2, self.info_var_dim), self.lbls_dim
                        ])
                        test_lbls[:, class_n] = 1

                        for i in range(len(test_info)):
                            test_info_combi = test_info[i]

                            _, summaryImg = sess.run(
                                [summary_op_img, summary_img],
                                feed_dict={
                                    input_test_noise: test_noise,
                                    input_test_lbls: test_lbls,
                                    input_test_info_noise: test_info_combi
                                })

                            dir_result_train_class = dir_results_train + '/' + str(
                                class_n).zfill(2)
                            if self.info_var_dim < 2:
                                filename_temp = 'Epoch_{0}_LatentVar_1'.format(
                                    epoch_n)
                            else:
                                filename_temp = 'Epoch_{0}_LatentCombi_{1}_{2}'.format(
                                    epoch_n, self.combinations_info_var[i][0],
                                    self.combinations_info_var[i][1])

                            # writer.add_summary(summaryImg_tb, global_step=epoch_n)
                            utils.save_image_local(summaryImg,
                                                   dir_result_train_class,
                                                   filename_temp)

                # Initiate or Re-initiate iterator
                sess.run(iterator.initializer)

                ### ----------------------------------------------------------
                ### Update model
                if (np.mod(epoch_n, 100) == 0) or epoch_n < 25:
                    utils.show_message(
                        'Running training epoch no: {0}'.format(epoch_n))

                while True:
                    # for idx in range(0, num_batches):
                    try:
                        for _ in range(self.d_iter):
                            image_batch, lbl_batch, unst_noise_batch, info_noise_batch = sess.run(
                                input_getBatch)

                            if (image_batch.shape[0] != self.batch_size):
                                raise OutOfRangeError

                            _, summary_dloss = sess.run(
                                [train_op_discriminator, summary_op_dloss],
                                feed_dict={
                                    input_images: image_batch,
                                    input_lbls: lbl_batch,
                                    input_unstructured_noise: unst_noise_batch,
                                    input_info_noise: info_noise_batch
                                })

                        writer.add_summary(summary_dloss,
                                           global_step=interationCnt)

                        _, summary_gloss = sess.run(
                            [train_op_generator, summary_op_gloss],
                            feed_dict={
                                input_images: image_batch,
                                input_lbls: lbl_batch,
                                input_unstructured_noise: unst_noise_batch,
                                input_info_noise: info_noise_batch
                            })

                        writer.add_summary(summary_gloss,
                                           global_step=interationCnt)
                        interationCnt += 1

                    except (tf.errors.OutOfRangeError, OutOfRangeError):
                        # Test current model
                        for class_n in range(self.lbls_dim):
                            test_lbls = np.zeros([
                                self.n_testsamples**np.minimum(
                                    2, self.info_var_dim), self.lbls_dim
                            ])
                            test_lbls[:, class_n] = 1

                            for i in range(len(test_info)):
                                test_info_combi = test_info[i]

                                _, summaryImg = sess.run(
                                    [summary_op_img, summary_img],
                                    feed_dict={
                                        input_test_noise: test_noise,
                                        input_test_lbls: test_lbls,
                                        input_test_info_noise: test_info_combi
                                    })

                                dir_result_train_class = dir_results_train + '/' + str(
                                    class_n).zfill(2)
                                if self.info_var_dim < 2:
                                    filename_temp = 'Epoch_{0}_LatentVar_1'.format(
                                        epoch_n)
                                else:
                                    filename_temp = 'Epoch_{0}_LatentCombi_{1}_{2}'.format(
                                        epoch_n,
                                        self.combinations_info_var[i][0],
                                        self.combinations_info_var[i][1])

                                # writer.add_summary(summaryImg_tb, global_step=epoch_n)
                                utils.save_image_local(summaryImg,
                                                       dir_result_train_class,
                                                       filename_temp)

                        break

                # Save model variables to checkpoint
                if (epoch_n + 1) % self.backup_frequency == 0:
                    saver.save(sess,
                               os.path.join(self.dir_checkpoints,
                                            self.model + '.model'),
                               global_step=epoch_n)