예제 #1
0
    def _build_network(self):
        #network model
        with tf.variable_scope('model'):
            net_args = {}
            net_args['left_img'] = self._left_input
            net_args['right_img'] = self._right_input
            net_args['split_layers'] = [None]
            net_args['sequence'] = True
            net_args['train_portion'] = 'BEGIN'
            net_args['bulkhead'] = True if self._mode == 'MAD' else False
            self._net = Nets.get_stereo_net(self._model_name, net_args)
            self._predictions = self._net.get_disparities()
            self._full_res_disp = self._predictions[-1]

            self._inputs = {
                'left':
                self._left_input,
                'right':
                self._right_input,
                'target':
                tf.zeros([1, self._image_shape[0], self._image_shape[1], 1],
                         dtype=tf.float32)
            }

            #full resolution loss between warped right image and original left image
            self._loss = loss_factory.get_reprojection_loss(
                'mean_SSIM_l1', reduced=True)(self._predictions, self._inputs)
예제 #2
0
 def _build_adaptation_loss(self, current_net, inputs):
     #compute adaptation loss and gradients
     reprojection_error = loss_factory.get_reprojection_loss(
         'ssim_l1', reduced=False)(current_net.get_disparities(), inputs)[0]
     weight, self._weighting_network_vars = sharedLayers.weighting_network(
         reprojection_error, reuse=self._reuse, training=True)
     return tf.reduce_sum(reprojection_error * weight)
예제 #3
0
def get_loss(name, unSupervised, masked=True):
    if unSupervised:
        return loss_factory.get_reprojection_loss(name, False, False)
    else:
        return loss_factory.get_supervised_loss(name,
                                                False,
                                                False,
                                                mask=masked)
예제 #4
0
    def _MAD_adaptation_ops(self):
        #build train ops for separate portions of the network
        self._load_block_config()

        #keep all predictions except full res
        predictions = self._predictions[:-1]

        inputs_modules = self._inputs

        assert (len(predictions) == len(self._train_config))
        for counter, p in enumerate(predictions):
            print('Build train ops for disparity {}'.format(counter))

            #rescale predictions to proper resolution
            multiplier = tf.cast(
                tf.shape(self._left_input)[1] // tf.shape(p)[1], tf.float32)
            p = preprocessing.resize_to_prediction(
                p, inputs_modules['left']) * multiplier

            #compute reprojection error
            with tf.variable_scope('reprojection_' + str(counter)):
                reconstruction_loss = loss_factory.get_reprojection_loss(
                    'mean_SSIM_l1', reduced=True)([p], inputs_modules)

            #build train op
            layer_to_train = self._train_config[counter]
            print('Going to train on {}'.format(layer_to_train))
            var_accumulator = []
            for name in layer_to_train:
                var_accumulator += self._net.get_variables(name)
            print('Number of variable to train: {}'.format(
                len(var_accumulator)))

            #add new training op
            self._train_ops.append(
                self._trainer.minimize(reconstruction_loss,
                                       var_list=var_accumulator))

            print('Done')
            print('=' * 50)

        #create Sampler to fetch portions to train
        self._sampler = sampler_factory.get_sampler('PROBABILITY', 1, 0)
예제 #5
0
def main(args):
    # setup input pipelines
    with tf.variable_scope('input_readers'):

        data_set = data_reader.dataset(args.sequence,
                                       batch_size=1,
                                       crop_shape=args.imageSize,
                                       num_epochs=1,
                                       augment=False,
                                       is_training=False,
                                       shuffle=False)
        left_img_batch, right_img_batch, gt_image_batch = data_set.get_batch()

    # build model
    with tf.variable_scope('model'):
        net_args = {}
        net_args['left_img'] = left_img_batch
        net_args['right_img'] = right_img_batch
        net_args['is_training'] = False
        stereo_net = Nets.factory.getStereoNet(args.modelName, net_args)
        print('Stereo Prediction Model:\n', stereo_net)

        # retrieve full resolution prediction and set its shape
        predictions = stereo_net.get_disparities()
        full_res_disp = predictions[-1]
        full_res_shape = left_img_batch.get_shape().as_list()
        full_res_shape[-1] = 1
        full_res_disp.set_shape(full_res_shape)

        # cast img batch to float32 for further elaboration
        right_input = tf.cast(right_img_batch, tf.float32)
        left_input = tf.cast(left_img_batch, tf.float32)
        gt_input = tf.cast(gt_image_batch, tf.float32)

        inputs = {}
        inputs['left'] = left_input
        inputs['right'] = right_input
        inputs['target'] = gt_input

        if args.mode != 'SAD':
            reprojection_error = loss_factory.get_reprojection_loss(
                'ssim_l1', reduced=False)([full_res_disp], inputs)[0]
            if args.mode == 'WAD':
                weight, _ = Nets.sharedLayers.weighting_network(
                    tf.stop_gradient(reprojection_error), reuse=False)
                adaptation_loss = tf.reduce_sum(reprojection_error * weight)
                if args.summary > 1:
                    masked_loss = reprojection_error * weight
                    tf.summary.image(
                        'weight',
                        preprocessing.colorize_img(weight, cmap='magma'))
                    tf.summary.image(
                        'reprojection_error',
                        preprocessing.colorize_img(reprojection_error,
                                                   cmap='magma'))
                    tf.summary.image(
                        'rescaled_error',
                        preprocessing.colorize_img(masked_loss, cmap='magma'))
            else:
                adaptation_loss = tf.reduce_mean(reprojection_error)
        else:
            adaptation_loss = loss_factory.get_supervised_loss('mean_l1')(
                [full_res_disp], inputs)

    with tf.variable_scope('validation_error'):
        # get the proper gt
        gt_input = tf.where(tf.is_finite(gt_input), gt_input,
                            tf.zeros_like(gt_input))

        # compute error against gt
        abs_err = tf.abs(full_res_disp - gt_input)
        valid_map = tf.cast(tf.logical_not(tf.equal(gt_input, 0)), tf.float32)
        filtered_error = abs_err * valid_map

        if args.summary > 1:
            tf.summary.image('filtered_error', filtered_error)

        abs_err = tf.reduce_sum(filtered_error) / tf.reduce_sum(valid_map)
        if args.kittiEval:
            error_pixels = tf.math.logical_and(
                tf.greater(filtered_error, args.badTH),
                tf.greater(filtered_error, gt_input * 0.05))
        else:
            error_pixels = tf.greater(filtered_error, args.badTH)
        bad_pixel_abs = tf.cast(error_pixels, tf.float32)
        bad_pixel_perc = tf.reduce_sum(bad_pixel_abs) / tf.reduce_sum(
            valid_map)

        # add summary for epe and bad3
        tf.summary.scalar('EPE', abs_err)
        tf.summary.scalar('bad{}'.format(args.badTH), bad_pixel_perc)

    # setup optimizer and trainign ops
    num_steps = len(data_set)
    with tf.variable_scope('trainer'):
        if args.mode == 'NONE':
            trainable_variables = []
        else:
            trainable_variables = stereo_net.get_trainable_variables()

        if len(trainable_variables) > 0:
            print('Going to train on {}'.format(len(trainable_variables)))
            optimizer = tf.train.AdamOptimizer(args.lr)
            train_op = optimizer.minimize(adaptation_loss,
                                          var_list=trainable_variables)
        else:
            print('Nothing to train, switching to pure forward')
            train_op = tf.no_op()

    # setup loggin info
    tf.summary.scalar("adaptation_loss", adaptation_loss)

    if args.summary > 1:
        tf.summary.image(
            'ground_truth',
            preprocessing.colorize_img(gt_image_batch, cmap='jet'))
        tf.summary.image('prediction',
                         preprocessing.colorize_img(full_res_disp, cmap='jet'))
        tf.summary.image('left', left_img_batch)

    summary_op = tf.summary.merge_all()

    # create saver and writer to save ckpt and log files
    logger = tf.summary.FileWriter(args.output)

    # adapt
    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        # init everything
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        # restore weights
        restored, _ = weights_utils.check_for_weights_or_restore_them(
            args.output,
            sess,
            initial_weights=args.weights,
            prefix=args.prefix,
            ignore_list=['train_model/'])
        print('Restored weights {}, initial step: {}'.format(restored, 0))

        bad3s = []
        epes = []
        global_start_time = time.time()
        start_time = time.time()
        step = 0
        try:
            if args.summary > 0:
                fetches = [
                    train_op, full_res_disp, adaptation_loss, abs_err,
                    bad_pixel_perc, summary_op
                ]
            else:
                fetches = [
                    train_op, full_res_disp, adaptation_loss, abs_err,
                    bad_pixel_perc, left_img_batch
                ]

            while True:
                # train
                if args.summary > 0:
                    _, dispy, lossy, current_epe, current_bad3, summary_string = sess.run(
                        fetches)
                else:
                    _, dispy, lossy, current_epe, current_bad3, lefty = sess.run(
                        fetches)

                epes.append(current_epe)
                bad3s.append(current_bad3)
                if step % 100 == 0:
                    end_time = time.time()
                    elapsed_time = end_time - start_time
                    missing_time = ((num_steps - step) // 100) * elapsed_time
                    missing_epochs = 1 - (step / num_steps)
                    print(
                        'Step:{}\tLoss:{:.2}\tf/b-time:{:.3}s\tmissing time: {}\tmissing epochs: {:.3}'
                        .format(step, lossy, elapsed_time / 100,
                                datetime.timedelta(seconds=missing_time),
                                missing_epochs))
                    if args.summary > 0:
                        logger.add_summary(summary_string, step)
                    start_time = time.time()

                if args.logDispStep != -1 and step % args.logDispStep == 0:
                    dispy_to_save = np.clip(dispy[0].astype(np.uint16), 0, 256)
                    cv2.imwrite(
                        os.path.join(
                            args.output,
                            'disparities/disparity_{}.png'.format(step)),
                        dispy_to_save * 256)
                    cv2.imwrite(
                        os.path.join(args.output,
                                     'rgbs/left_{}.png'.format(step)),
                        lefty[0, :, :, ::-1].astype(np.uint8))

                step += 1
        except tf.errors.OutOfRangeError:
            pass
        finally:
            global_end_time = time.time()
            avg_execution_time = (global_end_time - global_start_time) / step
            fps = 1.0 / avg_execution_time

            with open(os.path.join(args.output, 'stats.csv'), 'w+') as f_out:
                bad3_accumulator = np.sum(bad3s)
                epe_accumulator = np.sum(epes)
                # report series
                f_out.write('AVG_bad{},{}\n'.format(
                    args.badTH, bad3_accumulator / num_steps))
                f_out.write('AVG_EPE,{}\n'.format(epe_accumulator / num_steps))
                f_out.write(
                    'AVG Execution time,{}\n'.format(avg_execution_time))
                f_out.write('FPS,{}'.format(fps))

            files = [x[0] for x in data_set.get_couples()]
            with open(os.path.join(args.output, 'series.csv'), 'w+') as f_out:
                f_out.write('Iteration,file,EPE,bad{}\n'.format(args.badTH))
                for i, (f, e, b) in enumerate(zip(files, epes, bad3s)):
                    f_out.write('{},{},{},{}\n'.format(i, f, e, b))

            print('All done shutting down')
예제 #6
0
def main(args):
    #load json file config
    with open(args.blockConfig) as json_data:
        train_config = json.load(json_data)

    #read input data
    with tf.variable_scope('input_reader'):
        data_set = data_reader.dataset(args.list,
                                       batch_size=1,
                                       crop_shape=args.imageShape,
                                       num_epochs=1,
                                       augment=False,
                                       is_training=False,
                                       shuffle=False)
        left_img_batch, right_img_batch, gt_image_batch = data_set.get_batch()
        inputs = {
            'left': left_img_batch,
            'right': right_img_batch,
            'target': gt_image_batch
        }

    #build inference network
    with tf.variable_scope('model'):
        net_args = {}
        net_args['left_img'] = left_img_batch
        net_args['right_img'] = right_img_batch
        net_args['split_layers'] = [None]
        net_args['sequence'] = True
        net_args['train_portion'] = 'BEGIN'
        net_args['bulkhead'] = True if args.mode == 'MAD' else False
        stereo_net = Nets.get_stereo_net(args.modelName, net_args)
        print('Stereo Prediction Model:\n', stereo_net)
        predictions = stereo_net.get_disparities()
        full_res_disp = predictions[-1]

    #build real full resolution loss
    with tf.variable_scope('full_res_loss'):
        # reconstruction loss between warped right image and original left image
        full_reconstruction_loss = loss_factory.get_reprojection_loss(
            'mean_SSIM_l1', reduced=True)(predictions, inputs)

    #build validation ops
    with tf.variable_scope('validation_error'):
        # compute error against gt
        abs_err = tf.abs(full_res_disp - gt_image_batch)
        valid_map = tf.where(tf.equal(gt_image_batch, 0),
                             tf.zeros_like(gt_image_batch, dtype=tf.float32),
                             tf.ones_like(gt_image_batch, dtype=tf.float32))
        filtered_error = abs_err * valid_map

        abs_err = tf.reduce_sum(filtered_error) / tf.reduce_sum(valid_map)
        bad_pixel_abs = tf.where(
            tf.greater(filtered_error, PIXEL_TH),
            tf.ones_like(filtered_error, dtype=tf.float32),
            tf.zeros_like(filtered_error, dtype=tf.float32))
        bad_pixel_perc = tf.reduce_sum(bad_pixel_abs) / tf.reduce_sum(
            valid_map)

    #build train ops
    disparity_trainer = tf.train.MomentumOptimizer(args.lr, 0.9)
    train_ops = []
    if args.mode == 'MAD':
        #build train ops for separate portion of the network
        predictions = predictions[:-1]  #remove full res disp

        inputs_modules = {
            'left':
            scale_tensor(left_img_batch, args.reprojectionScale),
            'right':
            scale_tensor(right_img_batch, args.reprojectionScale),
            'target':
            scale_tensor(gt_image_batch, args.reprojectionScale) /
            args.reprojectionScale
        }

        assert (len(predictions) == len(train_config))
        for counter, p in enumerate(predictions):
            print('Build train ops for disparity {}'.format(counter))

            #rescale predictions to proper resolution
            multiplier = tf.cast(
                tf.shape(left_img_batch)[1] // tf.shape(p)[1], tf.float32)
            p = preprocessing.resize_to_prediction(
                p, inputs_modules['left']) * multiplier

            #compute reprojection error
            with tf.variable_scope('reprojection_' + str(counter)):
                reconstruction_loss = loss_factory.get_reprojection_loss(
                    'mean_SSIM_l1', reduced=True)([p], inputs_modules)

            #build train op
            layer_to_train = train_config[counter]
            print('Going to train on {}'.format(layer_to_train))
            var_accumulator = []
            for name in layer_to_train:
                var_accumulator += stereo_net.get_variables(name)
            print('Number of variable to train: {}'.format(
                len(var_accumulator)))

            #add new training op
            train_ops.append(
                disparity_trainer.minimize(reconstruction_loss,
                                           var_list=var_accumulator))

            print('Done')
            print('=' * 50)

        #create Sampler to fetch portions to train
        sampler = sampler_factory.get_sampler(args.sampleMode, args.numBlocks,
                                              args.fixedID)

    elif args.mode == 'FULL':
        #build single train op for the full network
        train_ops.append(disparity_trainer.minimize(full_reconstruction_loss))

    if args.summary:
        #add summaries
        tf.summary.scalar('EPE', abs_err)
        tf.summary.scalar('bad3', bad_pixel_perc)
        tf.summary.image('full_res_disp',
                         preprocessing.colorize_img(full_res_disp, cmap='jet'),
                         max_outputs=1)
        tf.summary.image('gt_disp',
                         preprocessing.colorize_img(gt_image_batch,
                                                    cmap='jet'),
                         max_outputs=1)

        #create summary logger
        summary_op = tf.summary.merge_all()
        logger = tf.summary.FileWriter(args.output)

    #start session
    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        #init stuff
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        #start queue runners
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(sess=sess, coord=coord)

        #restore disparity inference weights
        var_to_restore = weights_utils.get_var_to_restore_list(
            args.weights, [])
        assert (len(var_to_restore) > 0)
        restorer = tf.train.Saver(var_list=var_to_restore)
        restorer.restore(sess, args.weights)
        print('Disparity Net Restored?: {}, number of restored variables: {}'.
              format(True, len(var_to_restore)))

        num_actions = len(train_ops)
        if args.mode == 'FULL':
            selected_train_ops = train_ops
        else:
            selected_train_ops = [tf.no_op()]

        epe_accumulator = []
        bad3_accumulator = []
        time_accumulator = []
        exec_time = 0
        fetch_counter = [0] * num_actions
        sample_distribution = np.zeros(shape=[num_actions])
        temp_score = np.zeros(shape=[num_actions])
        loss_t_2 = 0
        loss_t_1 = 0
        expected_loss = 0
        last_trained_blocks = []
        reset_counter = 0
        step = 0
        max_steps = data_set.get_size()
        try:
            start_time = time.time()
            while True:
                #fetch new network portion to train
                if step % args.sampleFrequency == 0 and args.mode == 'MAD':
                    #Sample
                    distribution = softmax(sample_distribution)
                    blocks_to_train = sampler.sample(distribution)
                    selected_train_ops = [
                        train_ops[i] for i in blocks_to_train
                    ]

                    #accumulate sampling statistics
                    for l in blocks_to_train:
                        fetch_counter[l] += 1

                #build list of tensorflow operations that needs to be executed

                #errors and full resolution loss
                tf_fetches = [
                    abs_err, bad_pixel_perc, full_reconstruction_loss
                ]

                if args.summary and step % 100 == 0:
                    #summaries
                    tf_fetches = tf_fetches + [summary_op]

                #update ops
                tf_fetches = tf_fetches + selected_train_ops

                if args.logDispStep != -1 and step % args.logDispStep == 0:
                    #prediction for serialization to disk
                    tf_fetches = tf_fetches + [full_res_disp]

                #run network
                fetches = sess.run(tf_fetches)
                new_loss = fetches[2]

                if args.mode == 'MAD':
                    #update sampling probabilities
                    if step == 0:
                        loss_t_2 = new_loss
                        loss_t_1 = new_loss
                    expected_loss = 2 * loss_t_1 - loss_t_2
                    gain_loss = expected_loss - new_loss
                    sample_distribution = 0.99 * sample_distribution
                    for i in last_trained_blocks:
                        sample_distribution[i] += 0.01 * gain_loss

                    last_trained_blocks = blocks_to_train
                    loss_t_2 = loss_t_1
                    loss_t_1 = new_loss

                #accumulate performance metrics
                epe_accumulator.append(fetches[0])
                bad3_accumulator.append(fetches[1])

                if step % 100 == 0:
                    #log on terminal
                    fbTime = (time.time() - start_time)
                    exec_time += fbTime
                    fbTime = fbTime / 100
                    if args.summary:
                        logger.add_summary(fetches[3], global_step=step)
                    missing_time = (max_steps - step) * fbTime
                    print(
                        'Step:{:4d}\tbad3:{:.2f}\tEPE:{:.2f}\tSSIM:{:.2f}\tf/b time:{:3f}\tMissing time:{}'
                        .format(step, fetches[1], fetches[0], new_loss, fbTime,
                                datetime.timedelta(seconds=missing_time)))
                    start_time = time.time()

                #reset network if necessary
                if new_loss > args.SSIMTh:
                    restorer.restore(sess, args.weights)
                    reset_counter += 1

                #save disparity if requested
                if args.logDispStep != -1 and step % args.logDispStep == 0:
                    dispy = fetches[-1]
                    dispy_to_save = np.clip(dispy[0].astype(np.uint16), 0,
                                            MAX_DISP)
                    cv2.imwrite(
                        os.path.join(
                            args.output,
                            'disparities/disparity_{}.png'.format(step)),
                        dispy_to_save * 256)

                step += 1

        except Exception as e:
            print('Exception catched {}'.format(e))
            #raise(e)
        finally:
            epe_array = epe_accumulator
            bad3_array = bad3_accumulator
            epe_accumulator = np.sum(epe_accumulator)
            bad3_accumulator = np.sum(bad3_accumulator)
            with open(os.path.join(args.output, 'stats.csv'), 'w+') as f_out:
                # report series
                f_out.write('Metrics,cumulative,average\n')
                f_out.write('EPE,{},{}\n'.format(epe_accumulator,
                                                 epe_accumulator / step))
                f_out.write('bad3,{},{}\n'.format(bad3_accumulator,
                                                  bad3_accumulator / step))
                f_out.write('time,{},{}\n'.format(exec_time, exec_time / step))
                f_out.write('FPS,{}\n'.format(1 / (exec_time / step)))
                f_out.write('#resets,{}\n'.format(reset_counter))
                f_out.write('Blocks')
                for n in range(len(predictions)):
                    f_out.write(',{}'.format(n))
                f_out.write(',final\n')
                f_out.write('fetch_counter')
                for c in fetch_counter:
                    f_out.write(',{}'.format(c))
                f_out.write('\n')
                for c in sample_distribution:
                    f_out.write(',{}'.format(c))
                f_out.write('\n')

            step_time = exec_time / step
            time_array = [str(x * step_time) for x in range(len(epe_array))]

            with open(os.path.join(args.output, 'series.csv'), 'w+') as f_out:
                f_out.write('Iteration,Time,EPE,bad3\n')
                for i, (t, e,
                        b) in enumerate(zip(time_array, epe_array,
                                            bad3_array)):
                    f_out.write('{},{},{},{}\n'.format(i, t, e, b))

            print('All Done, Bye Bye!')
            coord.request_stop()
            coord.join()