def train(self): # Place tensors on the CPU with tf.device('/CPU:0'): dataset = Dataset() ds_train = dataset.get_train() ds_smpl = dataset.get_smpl() ds_val = dataset.get_val() start = 1 if self.config.RESTORE_EPOCH: start = self.config.RESTORE_EPOCH for epoch in range(start, self.config.EPOCHS + 1): start = time.time() print('Start of Epoch {}'.format(epoch)) dataset_train = ExceptionHandlingIterator( tf.data.Dataset.zip((ds_train, ds_smpl))) total = int(self.config.NUM_TRAINING_SAMPLES / self.config.BATCH_SIZE) for image_data, theta in tqdm(dataset_train, total=total, position=0, desc='training'): images, kp2d, kp3d, has3d = image_data[0], image_data[ 1], image_data[2], image_data[3] self._train_step(images, kp2d, kp3d, has3d, theta) self._log_train(epoch=epoch) total = int(self.config.NUM_VALIDATION_SAMPLES / self.config.BATCH_SIZE) for image_data in tqdm(ds_val, total=total, position=0, desc='validate'): images, kp2d, kp3d, has3d = image_data[0], image_data[ 1], image_data[2], image_data[3] self._val_step(images, kp2d, kp3d, has3d) self._log_val(epoch=epoch) print('Time taken for epoch {} is {} sec\n'.format( epoch, time.time() - start)) # saving (checkpoint) the model every 5 epochs if epoch % 5 == 0: print('saving checkpoint\n') self.checkpoint_manager.save(epoch) self.summary_writer.flush() self.checkpoint_manager.save(self.config.EPOCHS + 1)
class DastasetConfig(LocalConfig): # DATA_DIR = join('/', 'data', 'ssd1', 'russales', 'new_records') # DATASETS = ['coco'] #['lsp', 'lsp_ext', 'mpii', 'coco', 'mpii_3d', 'h36m'] # SMPL_DATASETS = ['cmu', 'joint_lim'] TRANS_MAX = 20 # class Config is implemented as singleton, inizialize subclass first! config = DastasetConfig() import tensorflow as tf # Place tensors on the CPU with tf.device('/CPU:0'): dataset = Dataset() ds_train = dataset.get_train() ds_smpl = dataset.get_smpl() ds_val = dataset.get_val() import matplotlib.pyplot as plt for images, kp2d, kp3d, has3d in ds_train.take(1): fig = plt.figure(figsize=(9.6, 5.4)) image_orig = tf.image.decode_jpeg(images[0], channels=3) image_orig = image_orig.numpy() kp2d = kp2d[0].numpy() ax0 = fig.add_subplot(111) image_2d = draw_2d_on_image(image_orig, kp2d[:, :2], vis=kp2d[:, 2]) ax0.imshow(image_2d) fig2 = plt.figure(figsize=(9.6, 5.4)) kp3d = kp3d[0].numpy()