Esempio n. 1
0
def main(argv):
    ''' Prepare dataset '''
    data_loader = DataLoader(FLAGS.data_dir, val_list=FLAGS.val_list)
    _, val_dataset = data_loader.create_tf_dataset(flags=FLAGS)
    logging.info('Number of validation samples: {}'.format(
        data_loader.val_size))
    ''' Create metric and summary writers '''
    val_metric = tf.keras.metrics.Mean(name='val_average_end_point_error')
    ''' Initialize model '''
    pwcnet = PWCDCNet()

    restore(net=pwcnet, ckpt_path=FLAGS.ckpt_path)

    with tqdm(total=data_loader.val_size) as pbar:
        pbar.set_description('Evaluation progress: ')

        for im_pairs, flo_gt in val_dataset:
            EPE = eval_step(model=pwcnet, image_pairs=im_pairs, flo_gt=flo_gt)
            val_metric.update_state(EPE)

            pbar.update(1)

    logging.info('*****AEPE = {:.5f}*****'.format(val_metric.result()))
    val_metric.reset_states()
    ''' 
Esempio n. 2
0
def main(argv):
    ''' Prepare dataset '''
    data_loader = DataLoader(FLAGS.data_dir, FLAGS.train_list, FLAGS.val_list)
    train_dataset, val_dataset = data_loader.create_tf_dataset(flags=FLAGS)
    ''' Declare and setup optimizer '''
    num_steps = FLAGS.num_steps // (FLAGS.batch_size // 8) + 1
    lr_boundaries = [
        x // (FLAGS.batch_size // 8) for x in FLAGS.lr_boundaries
    ]  # Adjust the boundaries by batch size
    lr_values = [
        FLAGS.lr / (2**i) for i in range(len(FLAGS.lr_boundaries) + 1)
    ]
    lr_scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=lr_boundaries, values=lr_values)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_scheduler)

    logging.info('Learning rate boundaries: {}'.format(lr_boundaries))
    logging.info('Training steps: {}'.format(num_steps))
    ''' Create metric and summary writers '''
    train_metric = tf.keras.metrics.Mean(name='train_loss')
    val_metric = tf.keras.metrics.Mean(name='val_average_end_point_error')
    time_metric = tf.keras.metrics.Mean(name='elapsed_time_per_step')

    train_summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'summaries', 'train'))
    val_summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'summaries', 'val'))
    ''' Initialize model '''
    pwcnet = PWCDCNet()
    ''' Check if there exists the checkpoints '''
    ckpt_path = os.path.join(FLAGS.save_dir, 'tf_ckpt')
    ckpt = tf.train.Checkpoint(optimizer=optimizer, net=pwcnet)
    manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=20)

    if manager.latest_checkpoint:
        logging.info("Restored from {}".format(manager.latest_checkpoint))
    else:
        logging.info("Initializing from scratch.")

    status = ckpt.restore(manager.latest_checkpoint).expect_partial()
    ''' Start training '''
    step = optimizer.iterations.numpy()
    while step < num_steps:
        for im_pairs, flo_gt in train_dataset:
            # Model inference and use 'tf.GradientTape()' to trace gradients.
            start_time = time.time()
            with tf.GradientTape() as tape:
                total_losses, flo_preds = train_step(
                    model=pwcnet,
                    metric=train_metric,
                    summary_writer=train_summary_writer,
                    image_pairs=im_pairs,
                    flo_gt=flo_gt)

            # Update weights. Compute gradients and apply to the optimizersr.
            grads = tape.gradient(total_losses, pwcnet.trainable_variables)
            optimizer.apply_gradients(zip(grads, pwcnet.trainable_variables))

            elapsed_time = time.time() - start_time

            # Logging
            train_metric.update_state(total_losses)
            time_metric.update_state(elapsed_time)

            step = optimizer.iterations.numpy()
            if step % FLAGS.log_freq == 0:
                write_summary(summary_writer=train_summary_writer,
                              step=step,
                              metric=train_metric,
                              mode='training',
                              image_pairs=im_pairs,
                              flo_preds=flo_preds,
                              flo_gt=flo_gt)

                logging.info(
                    'Step {:>7}, Training Loss: {:.5f}, ({:.3f} sec/step)'.
                    format(step, train_metric.result(), time_metric.result()))
                train_metric.reset_states()
                time_metric.reset_states()

            # Evaluate
            if step % FLAGS.steps_per_eval == 0:
                for im_pairs, flo_gt in val_dataset:
                    EPE = eval_step(model=pwcnet,
                                    image_pairs=im_pairs,
                                    flo_gt=flo_gt)
                    val_metric.update_state(EPE)

                write_summary(summary_writer=val_summary_writer,
                              step=step,
                              metric=val_metric,
                              mode='validation')

                logging.info('*****Steps {:>7}, AEPE = {:.5f}*****'.format(
                    step, val_metric.result()))

                val_metric.reset_states()

            # Save checkpoints
            if step % FLAGS.steps_per_save == 0:
                manager.save(checkpoint_number=step)
                logging.info(
                    '*****Steps {:>7}, save checkpoints!*****'.format(step))
Esempio n. 3
0
def main(argv):
    ''' Prepare dataset '''
    data_loader = DataLoader(FLAGS.data_dir, FLAGS.data_dir2, FLAGS.train_list,
                             FLAGS.train_list2, FLAGS.val_list)
    train_dataset, val_dataset = data_loader.create_tf_dataset(flags=FLAGS)
    ''' Declare and setup optimizer '''
    num_steps = FLAGS.num_steps
    optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr)
    logging.info('Training steps: {}'.format(num_steps))
    ''' Create metric and summary writers '''
    train_metric = tf.keras.metrics.Mean(name='train_loss')
    SDR_metric = tf.keras.metrics.Mean(name='SDR')
    SIR_metric = tf.keras.metrics.Mean(name='SIR')
    SAR_metric = tf.keras.metrics.Mean(name='SAR')
    NSDR_metric = tf.keras.metrics.Mean(name='NSDR')
    time_metric = tf.keras.metrics.Mean(name='elapsed_time_per_step')
    train_summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'summaries', 'train'))
    val_summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'summaries', 'val'))
    ''' Initialize model '''
    unet = Unet()
    ''' Check if there exists the checkpoints '''
    ckpt_path = os.path.join(FLAGS.save_dir, 'tf_ckpt')
    ckpt = tf.train.Checkpoint(optimizer=optimizer, net=unet)
    manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=20)

    if manager.latest_checkpoint:
        logging.info("Restored from {}".format(manager.latest_checkpoint))
    else:
        logging.info("Initializing from scratch.")

    status = ckpt.restore(manager.latest_checkpoint).expect_partial()
    ''' Start training '''
    step = optimizer.iterations.numpy()
    while step < num_steps:
        for data in train_dataset:
            mix, vocal = tf.split(data, [1, 1], -1)  #16, 512, 128
            # Model inference and use 'tf.GradientTape()' to trace gradients.
            start_time = time.time()
            with tf.GradientTape() as tape:
                preds, loss = train_step(model=unet, inputs=mix, gt=vocal)
            # Update weights. Compute gradients and apply to the optimizersr.
            grads = tape.gradient(loss, unet.trainable_variables)
            optimizer.apply_gradients(zip(grads, unet.trainable_variables))
            elapsed_time = time.time() - start_time

            # Logging
            train_metric.update_state(loss)
            time_metric.update_state(elapsed_time)
            step = optimizer.iterations.numpy()
            if step % FLAGS.log_freq == 0:
                #if step % 1 == 0:
                write_summary(summary_writer=train_summary_writer,
                              step=step,
                              metric=train_metric,
                              mode='training',
                              input=mix,
                              preds=preds,
                              gt=vocal)

                logging.info(
                    'Step {:>7}, Training Loss: {:.5f}, ({:.3f} sec/step)'.
                    format(step, train_metric.result(), time_metric.result()))
                train_metric.reset_states()
                time_metric.reset_states()
            # Evaluate
            if step % FLAGS.steps_per_eval == 0:
                #if step % 1 == 0:
                for data in val_dataset:
                    mix, vocal = tf.split(data, [1, 1], -1)
                    SDR, SIR, SAR, NSDR = eval_step(model=unet,
                                                    inputs=mix,
                                                    gt=vocal)
                    SDR_metric.update_state(SDR)
                    SIR_metric.update_state(SIR)
                    SAR_metric.update_state(SAR)
                    NSDR_metric.update_state(NSDR)

                val_write_summary(summary_writer=val_summary_writer,
                                  step=step,
                                  SDR_metric=SDR_metric,
                                  SIR_metric=SIR_metric,
                                  SAR_metric=SAR_metric,
                                  NSDR_metric=NSDR_metric,
                                  mode='validation')

                logging.info(
                    '*****Steps {:>7}, SDR = {:.5f}, SIR = {:.5f}, SAR = {:.5f}, NSDR = {:.5f}*****'
                    .format(step, SDR_metric.result(), SIR_metric.result(),
                            SAR_metric.result(), NSDR_metric.result()))
                SDR_metric.reset_states()
                SIR_metric.reset_states()
                SAR_metric.reset_states()
                NSDR_metric.reset_states()

            # Save checkpoints
            if step % FLAGS.steps_per_save == 0:
                #if step % 1 == 0:
                manager.save(checkpoint_number=step)
                logging.info(
                    '*****Steps {:>7}, save checkpoints!*****'.format(step))