Пример #1
0
    def train(self):
        # Place tensors on the CPU
        with tf.device('/CPU:0'):
            dataset = Dataset()
            ds_train = dataset.get_train()
            ds_smpl = dataset.get_smpl()
            ds_val = dataset.get_val()

        start = 1
        if self.config.RESTORE_EPOCH:
            start = self.config.RESTORE_EPOCH

        for epoch in range(start, self.config.EPOCHS + 1):

            start = time.time()
            print('Start of Epoch {}'.format(epoch))

            dataset_train = ExceptionHandlingIterator(
                tf.data.Dataset.zip((ds_train, ds_smpl)))
            total = int(self.config.NUM_TRAINING_SAMPLES /
                        self.config.BATCH_SIZE)

            for image_data, theta in tqdm(dataset_train,
                                          total=total,
                                          position=0,
                                          desc='training'):
                images, kp2d, kp3d, has3d = image_data[0], image_data[
                    1], image_data[2], image_data[3]
                self._train_step(images, kp2d, kp3d, has3d, theta)

            self._log_train(epoch=epoch)

            total = int(self.config.NUM_VALIDATION_SAMPLES /
                        self.config.BATCH_SIZE)
            for image_data in tqdm(ds_val,
                                   total=total,
                                   position=0,
                                   desc='validate'):
                images, kp2d, kp3d, has3d = image_data[0], image_data[
                    1], image_data[2], image_data[3]
                self._val_step(images, kp2d, kp3d, has3d)

            self._log_val(epoch=epoch)

            print('Time taken for epoch {} is {} sec\n'.format(
                epoch,
                time.time() - start))

            # saving (checkpoint) the model every 5 epochs
            if epoch % 5 == 0:
                print('saving checkpoint\n')
                self.checkpoint_manager.save(epoch)

        self.summary_writer.flush()
        self.checkpoint_manager.save(self.config.EPOCHS + 1)
Пример #2
0
    class DastasetConfig(LocalConfig):
        # DATA_DIR = join('/', 'data', 'ssd1', 'russales', 'new_records')
        # DATASETS = ['coco'] #['lsp', 'lsp_ext', 'mpii', 'coco', 'mpii_3d', 'h36m']
        # SMPL_DATASETS = ['cmu', 'joint_lim']
        TRANS_MAX = 20

    # class Config is implemented as singleton, inizialize subclass first!
    config = DastasetConfig()

    import tensorflow as tf

    # Place tensors on the CPU
    with tf.device('/CPU:0'):
        dataset = Dataset()
        ds_train = dataset.get_train()
        ds_smpl = dataset.get_smpl()
        ds_val = dataset.get_val()

    import matplotlib.pyplot as plt

    for images, kp2d, kp3d, has3d in ds_train.take(1):
        fig = plt.figure(figsize=(9.6, 5.4))
        image_orig = tf.image.decode_jpeg(images[0], channels=3)
        image_orig = image_orig.numpy()
        kp2d = kp2d[0].numpy()
        ax0 = fig.add_subplot(111)
        image_2d = draw_2d_on_image(image_orig, kp2d[:, :2], vis=kp2d[:, 2])
        ax0.imshow(image_2d)

        fig2 = plt.figure(figsize=(9.6, 5.4))
        kp3d = kp3d[0].numpy()