def __init__(self, config): self.config = config # model self.dtn = DTN(32, config) # model optimizer self.dtn_op = tf.compat.v1.train.AdamOptimizer(config.LEARNING_RATE, beta1=0.5) # model losses self.depth_map_loss = Error() self.class_loss = Error() self.route_loss = Error() self.uniq_loss = Error() # model saving setting self.last_epoch = 0 self.checkpoint_manager = []
def main(argv=None): # Configurations config = Config(gpu='1', root_dir='./data/train/', root_dir_val='./data/val/', mode='training') # Create data feeding pipeline. dataset_train = Dataset(config, 'train') dataset_val = Dataset(config, 'val') # Train Graph losses, g_op, d_op, fig = _step(config, dataset_train, training_nn=True) losses_val, _, _, fig_val = _step(config, dataset_val, training_nn=False) # Add ops to save and restore all the variables. saver = tf.train.Saver(max_to_keep=50, ) with tf.Session(config=config.GPU_CONFIG) as sess: # Restore the model ckpt = tf.train.get_checkpoint_state(config.LOG_DIR) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) last_epoch = ckpt.model_checkpoint_path.split('/')[-1].split( '-')[-1] print('**********************************************************') print('Restore from Epoch ' + str(last_epoch)) print('**********************************************************') else: init = tf.initializers.global_variables() last_epoch = 0 sess.run(init) print('**********************************************************') print('Train from scratch.') print('**********************************************************') avg_loss = Error() print_list = {} for epoch in range(int(last_epoch), config.MAX_EPOCH): start = time.time() # Train one epoch for step in range(config.STEPS_PER_EPOCH): if step % config.G_D_RATIO == 0: _losses = sess.run(losses + [g_op, d_op, fig]) else: _losses = sess.run(losses + [g_op, fig]) # Logging print_list['g_loss'] = _losses[0] print_list['d_loss'] = _losses[1] print_list['a_loss'] = _losses[2] display_list = ['Epoch '+str(epoch+1)+'-'+str(step+1)+'/'+ str(config.STEPS_PER_EPOCH)+':'] +\ [avg_loss(x) for x in print_list.items()] print(*display_list + [' '], end='\r') # Visualization if step % config.LOG_FR_TRAIN == 0: fname = config.LOG_DIR + '/Epoch-' + str( epoch + 1) + '-' + str(step + 1) + '.png' cv2.imwrite(fname, _losses[-1]) # Model saving saver.save(sess, config.LOG_DIR + '/ckpt', global_step=epoch + 1) print('\n', end='\r') # Validate one epoch for step in range(config.STEPS_PER_EPOCH_VAL): _losses = sess.run(losses_val + [fig_val]) # Logging print_list['g_loss'] = _losses[0] print_list['d_loss'] = _losses[1] print_list['a_loss'] = _losses[2] display_list = ['Epoch '+str(epoch+1)+'-Val-'+str(step+1)+'/'+ str(config.STEPS_PER_EPOCH_VAL)+':'] +\ [avg_loss(x, val=1) for x in print_list.items()] print(*display_list + [' '], end='\r') # Visualization if step % config.LOG_FR_TEST == 0: fname = config.LOG_DIR + '/Epoch-' + str( epoch + 1) + '-Val-' + str(step + 1) + '.png' cv2.imwrite(fname, _losses[-1]) # time of one epoch print('\n Time taken for epoch {} is {:3g} sec'.format( epoch + 1, time.time() - start)) avg_loss.reset()