def main(argv): ''' Prepare dataset ''' data_loader = DataLoader(FLAGS.data_dir, val_list=FLAGS.val_list) _, val_dataset = data_loader.create_tf_dataset(flags=FLAGS) logging.info('Number of validation samples: {}'.format( data_loader.val_size)) ''' Create metric and summary writers ''' val_metric = tf.keras.metrics.Mean(name='val_average_end_point_error') ''' Initialize model ''' pwcnet = PWCDCNet() restore(net=pwcnet, ckpt_path=FLAGS.ckpt_path) with tqdm(total=data_loader.val_size) as pbar: pbar.set_description('Evaluation progress: ') for im_pairs, flo_gt in val_dataset: EPE = eval_step(model=pwcnet, image_pairs=im_pairs, flo_gt=flo_gt) val_metric.update_state(EPE) pbar.update(1) logging.info('*****AEPE = {:.5f}*****'.format(val_metric.result())) val_metric.reset_states() '''
def main(argv): ''' Prepare dataset ''' data_loader = DataLoader(FLAGS.data_dir, FLAGS.train_list, FLAGS.val_list) train_dataset, val_dataset = data_loader.create_tf_dataset(flags=FLAGS) ''' Declare and setup optimizer ''' num_steps = FLAGS.num_steps // (FLAGS.batch_size // 8) + 1 lr_boundaries = [ x // (FLAGS.batch_size // 8) for x in FLAGS.lr_boundaries ] # Adjust the boundaries by batch size lr_values = [ FLAGS.lr / (2**i) for i in range(len(FLAGS.lr_boundaries) + 1) ] lr_scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay( boundaries=lr_boundaries, values=lr_values) optimizer = tf.keras.optimizers.Adam(learning_rate=lr_scheduler) logging.info('Learning rate boundaries: {}'.format(lr_boundaries)) logging.info('Training steps: {}'.format(num_steps)) ''' Create metric and summary writers ''' train_metric = tf.keras.metrics.Mean(name='train_loss') val_metric = tf.keras.metrics.Mean(name='val_average_end_point_error') time_metric = tf.keras.metrics.Mean(name='elapsed_time_per_step') train_summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'summaries', 'train')) val_summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'summaries', 'val')) ''' Initialize model ''' pwcnet = PWCDCNet() ''' Check if there exists the checkpoints ''' ckpt_path = os.path.join(FLAGS.save_dir, 'tf_ckpt') ckpt = tf.train.Checkpoint(optimizer=optimizer, net=pwcnet) manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=20) if manager.latest_checkpoint: logging.info("Restored from {}".format(manager.latest_checkpoint)) else: logging.info("Initializing from scratch.") status = ckpt.restore(manager.latest_checkpoint).expect_partial() ''' Start training ''' step = optimizer.iterations.numpy() while step < num_steps: for im_pairs, flo_gt in train_dataset: # Model inference and use 'tf.GradientTape()' to trace gradients. start_time = time.time() with tf.GradientTape() as tape: total_losses, flo_preds = train_step( model=pwcnet, metric=train_metric, summary_writer=train_summary_writer, image_pairs=im_pairs, flo_gt=flo_gt) # Update weights. Compute gradients and apply to the optimizersr. grads = tape.gradient(total_losses, pwcnet.trainable_variables) optimizer.apply_gradients(zip(grads, pwcnet.trainable_variables)) elapsed_time = time.time() - start_time # Logging train_metric.update_state(total_losses) time_metric.update_state(elapsed_time) step = optimizer.iterations.numpy() if step % FLAGS.log_freq == 0: write_summary(summary_writer=train_summary_writer, step=step, metric=train_metric, mode='training', image_pairs=im_pairs, flo_preds=flo_preds, flo_gt=flo_gt) logging.info( 'Step {:>7}, Training Loss: {:.5f}, ({:.3f} sec/step)'. format(step, train_metric.result(), time_metric.result())) train_metric.reset_states() time_metric.reset_states() # Evaluate if step % FLAGS.steps_per_eval == 0: for im_pairs, flo_gt in val_dataset: EPE = eval_step(model=pwcnet, image_pairs=im_pairs, flo_gt=flo_gt) val_metric.update_state(EPE) write_summary(summary_writer=val_summary_writer, step=step, metric=val_metric, mode='validation') logging.info('*****Steps {:>7}, AEPE = {:.5f}*****'.format( step, val_metric.result())) val_metric.reset_states() # Save checkpoints if step % FLAGS.steps_per_save == 0: manager.save(checkpoint_number=step) logging.info( '*****Steps {:>7}, save checkpoints!*****'.format(step))
def main(argv): ''' Prepare dataset ''' data_loader = DataLoader(FLAGS.data_dir, FLAGS.data_dir2, FLAGS.train_list, FLAGS.train_list2, FLAGS.val_list) train_dataset, val_dataset = data_loader.create_tf_dataset(flags=FLAGS) ''' Declare and setup optimizer ''' num_steps = FLAGS.num_steps optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr) logging.info('Training steps: {}'.format(num_steps)) ''' Create metric and summary writers ''' train_metric = tf.keras.metrics.Mean(name='train_loss') SDR_metric = tf.keras.metrics.Mean(name='SDR') SIR_metric = tf.keras.metrics.Mean(name='SIR') SAR_metric = tf.keras.metrics.Mean(name='SAR') NSDR_metric = tf.keras.metrics.Mean(name='NSDR') time_metric = tf.keras.metrics.Mean(name='elapsed_time_per_step') train_summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'summaries', 'train')) val_summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'summaries', 'val')) ''' Initialize model ''' unet = Unet() ''' Check if there exists the checkpoints ''' ckpt_path = os.path.join(FLAGS.save_dir, 'tf_ckpt') ckpt = tf.train.Checkpoint(optimizer=optimizer, net=unet) manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=20) if manager.latest_checkpoint: logging.info("Restored from {}".format(manager.latest_checkpoint)) else: logging.info("Initializing from scratch.") status = ckpt.restore(manager.latest_checkpoint).expect_partial() ''' Start training ''' step = optimizer.iterations.numpy() while step < num_steps: for data in train_dataset: mix, vocal = tf.split(data, [1, 1], -1) #16, 512, 128 # Model inference and use 'tf.GradientTape()' to trace gradients. start_time = time.time() with tf.GradientTape() as tape: preds, loss = train_step(model=unet, inputs=mix, gt=vocal) # Update weights. Compute gradients and apply to the optimizersr. grads = tape.gradient(loss, unet.trainable_variables) optimizer.apply_gradients(zip(grads, unet.trainable_variables)) elapsed_time = time.time() - start_time # Logging train_metric.update_state(loss) time_metric.update_state(elapsed_time) step = optimizer.iterations.numpy() if step % FLAGS.log_freq == 0: #if step % 1 == 0: write_summary(summary_writer=train_summary_writer, step=step, metric=train_metric, mode='training', input=mix, preds=preds, gt=vocal) logging.info( 'Step {:>7}, Training Loss: {:.5f}, ({:.3f} sec/step)'. format(step, train_metric.result(), time_metric.result())) train_metric.reset_states() time_metric.reset_states() # Evaluate if step % FLAGS.steps_per_eval == 0: #if step % 1 == 0: for data in val_dataset: mix, vocal = tf.split(data, [1, 1], -1) SDR, SIR, SAR, NSDR = eval_step(model=unet, inputs=mix, gt=vocal) SDR_metric.update_state(SDR) SIR_metric.update_state(SIR) SAR_metric.update_state(SAR) NSDR_metric.update_state(NSDR) val_write_summary(summary_writer=val_summary_writer, step=step, SDR_metric=SDR_metric, SIR_metric=SIR_metric, SAR_metric=SAR_metric, NSDR_metric=NSDR_metric, mode='validation') logging.info( '*****Steps {:>7}, SDR = {:.5f}, SIR = {:.5f}, SAR = {:.5f}, NSDR = {:.5f}*****' .format(step, SDR_metric.result(), SIR_metric.result(), SAR_metric.result(), NSDR_metric.result())) SDR_metric.reset_states() SIR_metric.reset_states() SAR_metric.reset_states() NSDR_metric.reset_states() # Save checkpoints if step % FLAGS.steps_per_save == 0: #if step % 1 == 0: manager.save(checkpoint_number=step) logging.info( '*****Steps {:>7}, save checkpoints!*****'.format(step))