Ejemplo n.º 1
0
    def compute_loss(disparities, inputs):
        left = inputs['left']
        right = inputs['right']
        #normalize image to be between 0 and 1
        left = tf.cast(left, dtype=tf.float32) / 256.0
        right = tf.cast(right, dtype=tf.float32) / 256.0
        accumulator = []
        if multiScale:
            disp_to_test = len(disparities)
        else:
            disp_to_test = 1
        for i in range(disp_to_test):
            #rescale prediction to full resolution
            current_disp = disparities[-(i + 1)]
            disparity_scale_factor = tf.cast(
                tf.shape(current_disp)[2], tf.float32) / tf.cast(
                    tf.shape(left)[2], tf.float32)
            resized_disp = preprocessing.resize_to_prediction(
                current_disp, left) * disparity_scale_factor

            reprojected_left = preprocessing.warp_image(right, resized_disp)
            partial_loss = base_loss_function(reprojected_left, left)
            if logs:
                tf.summary.scalar('Loss_resolution_{}'.format(i), partial_loss)
            accumulator.append(weights[i] * partial_loss)
        if reduced:
            return tf.reduce_sum(accumulator)
        else:
            return accumulator
Ejemplo n.º 2
0
    def compute_loss(disparities, inputs):
        left = inputs['left']
        right = inputs['right']
        targets = inputs['target']
        accumulator = []
        if multiScale:
            disp_to_test = len(disparities)
        else:
            disp_to_test = 1

        valid_map = tf.where(tf.equal(targets, 0),
                             tf.zeros_like(targets, dtype=tf.float32),
                             tf.ones_like(targets, dtype=tf.float32))
        for i in range(0, disp_to_test):
            #upsample prediction
            current_disp = disparities[-(i + 1)]
            disparity_scale_factor = tf.cast(
                tf.shape(left)[2], tf.float32) / tf.cast(
                    tf.shape(current_disp)[2], tf.float32)
            resized_disp = preprocessing.resize_to_prediction(
                current_disp, targets) * disparity_scale_factor

            partial_loss = base_loss_function(resized_disp, targets, valid_map)
            #partial_loss = tf.Print(partial_loss,[disparity_scale_factor,tf.shape(valid_map),tf.reduce_sum(valid_map), tf.reduce_sum(valid_map*resized_disp)/tf.reduce_sum(valid_map), tf.reduce_sum(valid_map*targets)/tf.reduce_sum(valid_map)],summarize=10000)
            if logs:
                tf.summary.scalar('Loss_resolution_{}'.format(i), partial_loss)
            accumulator.append(weights[i] * partial_loss)
        if reduced:
            return tf.reduce_sum(accumulator)
        else:
            return accumulator
Ejemplo n.º 3
0
    def _MAD_adaptation_ops(self):
        #build train ops for separate portions of the network
        self._load_block_config()

        #keep all predictions except full res
        predictions = self._predictions[:-1]

        inputs_modules = self._inputs

        assert (len(predictions) == len(self._train_config))
        for counter, p in enumerate(predictions):
            print('Build train ops for disparity {}'.format(counter))

            #rescale predictions to proper resolution
            multiplier = tf.cast(
                tf.shape(self._left_input)[1] // tf.shape(p)[1], tf.float32)
            p = preprocessing.resize_to_prediction(
                p, inputs_modules['left']) * multiplier

            #compute reprojection error
            with tf.variable_scope('reprojection_' + str(counter)):
                reconstruction_loss = loss_factory.get_reprojection_loss(
                    'mean_SSIM_l1', reduced=True)([p], inputs_modules)

            #build train op
            layer_to_train = self._train_config[counter]
            print('Going to train on {}'.format(layer_to_train))
            var_accumulator = []
            for name in layer_to_train:
                var_accumulator += self._net.get_variables(name)
            print('Number of variable to train: {}'.format(
                len(var_accumulator)))

            #add new training op
            self._train_ops.append(
                self._trainer.minimize(reconstruction_loss,
                                       var_list=var_accumulator))

            print('Done')
            print('=' * 50)

        #create Sampler to fetch portions to train
        self._sampler = sampler_factory.get_sampler('PROBABILITY', 1, 0)
Ejemplo n.º 4
0
    def compute_loss(disparities, inputs):
        left = inputs['left']
        right = inputs['right']
        targets = inputs['target']
        accumulator = []
        if multiScale:
            disp_to_test = len(disparities)
        else:
            disp_to_test = 1

        if mask:
            valid_map = tf.cast(
                tf.logical_not(
                    tf.logical_or(tf.equal(targets, 0),
                                  tf.greater_equal(targets, max_disp))),
                tf.float32)
        else:
            valid_map = tf.ones_like(targets)

        for i in range(0, disp_to_test):
            #upsample prediction
            current_disp = disparities[-(i + 1)]
            disparity_scale_factor = tf.cast(
                tf.shape(left)[2], tf.float32) / tf.cast(
                    tf.shape(current_disp)[2], tf.float32)
            resized_disp = preprocessing.resize_to_prediction(
                current_disp, targets) * disparity_scale_factor

            partial_loss = base_loss_function(resized_disp, targets, valid_map)
            if logs:
                tf.summary.scalar('Loss_resolution_{}'.format(i), partial_loss)
            accumulator.append(weights[i] * partial_loss)
        if reduced:
            return tf.reduce_sum(accumulator)
        else:
            return accumulator
Ejemplo n.º 5
0
def main(args):
    #load json file config
    with open(args.blockConfig) as json_data:
        train_config = json.load(json_data)

    #read input data
    with tf.variable_scope('input_reader'):
        data_set = continual_data_reader.dataset(args.list,
                                                 batch_size=1,
                                                 crop_shape=args.imageShape,
                                                 num_epochs=1,
                                                 augment=False,
                                                 is_training=False,
                                                 shuffle=False)
        left_img_batch, right_img_batch, gt_image_batch, px_image_batch, real_width = data_set.get_batch(
        )
        inputs = {
            'left': left_img_batch,
            'right': right_img_batch,
            'target': gt_image_batch,
            'proxy': px_image_batch,
            'real_width': real_width
        }

    #build inference network
    with tf.variable_scope('model'):
        net_args = {}
        net_args['left_img'] = left_img_batch
        net_args['right_img'] = right_img_batch
        net_args['split_layers'] = [None]
        net_args['sequence'] = True
        net_args['train_portion'] = 'BEGIN'
        net_args['bulkhead'] = True if args.mode == 'MAD' else False
        stereo_net = Nets.get_stereo_net(args.modelName, net_args)
        print('Stereo Prediction Model:\n', stereo_net)
        predictions = stereo_net.get_disparities()
        full_res_disp = predictions[-1]

    #build real full resolution loss
    with tf.variable_scope('full_res_loss'):
        # loss with respect to proxy labels
        full_proxy_loss = loss_factory.get_proxy_loss('mean_l1',
                                                      max_disp=192,
                                                      weights=[0.01] * 10,
                                                      reduced=True)(
                                                          predictions, inputs)

    #build validation ops
    with tf.variable_scope('validation_error'):
        # compute error against gt
        abs_err = tf.abs(full_res_disp - gt_image_batch)
        valid_map = tf.where(tf.equal(gt_image_batch, 0),
                             tf.zeros_like(gt_image_batch, dtype=tf.float32),
                             tf.ones_like(gt_image_batch, dtype=tf.float32))
        filtered_error = abs_err * valid_map

        abs_err = tf.reduce_sum(filtered_error) / tf.reduce_sum(valid_map)
        bad_pixel_abs = tf.where(
            tf.greater(filtered_error, PIXEL_TH),
            tf.ones_like(filtered_error, dtype=tf.float32),
            tf.zeros_like(filtered_error, dtype=tf.float32))
        bad_pixel_perc = tf.reduce_sum(bad_pixel_abs) / tf.reduce_sum(
            valid_map)

    #build train ops
    disparity_trainer = tf.train.MomentumOptimizer(args.lr, 0.9)
    train_ops = []
    if args.mode == 'MAD':
        #build train ops for separate portion of the network
        predictions = predictions[:-1]  #remove full res disp

        inputs_modules = {
            'left':
            scale_tensor(left_img_batch, args.reprojectionScale),
            'right':
            scale_tensor(right_img_batch, args.reprojectionScale),
            'target':
            scale_tensor(gt_image_batch, args.reprojectionScale) /
            args.reprojectionScale,
            'proxy':
            scale_tensor(px_image_batch, args.reprojectionScale) /
            args.reprojectionScale,
        }

        assert (len(predictions) == len(train_config))
        for counter, p in enumerate(predictions):
            print('Build train ops for disparity {}'.format(counter))

            #rescale predictions to proper resolution
            multiplier = tf.cast(
                tf.shape(left_img_batch)[1] // tf.shape(p)[1], tf.float32)
            p = preprocessing.resize_to_prediction(
                p, inputs_modules['left']) * multiplier

            #compute proxy error
            with tf.variable_scope('proxy_' + str(counter)):
                proxy_loss = loss_factory.get_proxy_loss(
                    'mean_l1', max_disp=192, weights=[0.1] * 10,
                    reduced=True)([p], inputs_modules)

            #build train op
            layer_to_train = train_config[counter]
            print('Going to train on {}'.format(layer_to_train))
            var_accumulator = []
            for name in layer_to_train:
                var_accumulator += stereo_net.get_variables(name)
            print('Number of variable to train: {}'.format(
                len(var_accumulator)))

            #add new training op
            train_ops.append(
                disparity_trainer.minimize(proxy_loss,
                                           var_list=var_accumulator))

            print('Done')
            print('=' * 50)

        #create Sampler to fetch portions to train
        sampler = sampler_factory.get_sampler(args.sampleMode, args.numBlocks,
                                              args.fixedID)

    elif args.mode == 'FULL':
        #build single train op for the full network
        train_ops.append(disparity_trainer.minimize(full_proxy_loss))

    if args.summary:
        #add summaries
        tf.summary.scalar('EPE', abs_err)
        tf.summary.scalar('bad3', bad_pixel_perc)
        tf.summary.image('full_res_disp',
                         preprocessing.colorize_img(full_res_disp, cmap='jet'),
                         max_outputs=1)
        tf.summary.image('proxy_disp',
                         preprocessing.colorize_img(px_image_batch,
                                                    cmap='jet'),
                         max_outputs=1)
        tf.summary.image('gt_disp',
                         preprocessing.colorize_img(gt_image_batch,
                                                    cmap='jet'),
                         max_outputs=1)

        #create summary logger
        summary_op = tf.summary.merge_all()
        logger = tf.summary.FileWriter(args.output)

    #start session
    gpu_options = tf.GPUOptions(allow_growth=True)
    adaptation_saver = tf.train.Saver()

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        #init stuff
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        #restore disparity inference weights
        var_to_restore = weights_utils.get_var_to_restore_list(
            args.weights, [])
        assert (len(var_to_restore) > 0)
        restorer = tf.train.Saver(var_list=var_to_restore)
        restorer.restore(sess, args.weights)
        print('Disparity Net Restored?: {}, number of restored variables: {}'.
              format(True, len(var_to_restore)))

        num_actions = len(train_ops)
        if args.mode == 'FULL':
            selected_train_ops = train_ops
        else:
            selected_train_ops = [tf.no_op()]

        # accumulators
        avg_accumulator = []
        d1_accumulator = []

        time_accumulator = []
        exec_time = 0
        fetch_counter = [0] * num_actions
        sample_distribution = np.zeros(shape=[num_actions])
        temp_score = np.zeros(shape=[num_actions])
        loss_t_2 = 0
        loss_t_1 = 0
        expected_loss = 0
        last_trained_blocks = []
        reset_counter = 0
        step = 0
        max_steps = data_set.get_max_steps()
        with open(os.path.join(args.output, 'histogram.csv'), 'w') as f_out:
            f_out.write('Histogram\n')

        try:
            start_time = time.time()
            while True:

                #fetch new network portion to train
                if step % args.sampleFrequency == 0 and args.mode == 'MAD':
                    #Sample
                    distribution = softmax(sample_distribution)
                    blocks_to_train = sampler.sample(distribution)
                    selected_train_ops = [
                        train_ops[i] for i in blocks_to_train
                    ]

                    #accumulate sampling statistics
                    for l in blocks_to_train:
                        fetch_counter[l] += 1

                #build list of tensorflow operations that needs to be executed

                tf_fetches = [
                    full_proxy_loss, full_res_disp, inputs['target'],
                    inputs['real_width']
                ]

                if args.summary and step % 100 == 0:
                    #summaries
                    tf_fetches = tf_fetches + [summary_op]

                #update ops
                if step % args.dilation == 0:
                    tf_fetches = tf_fetches + selected_train_ops

                tf_fetches = tf_fetches + [abs_err, bad_pixel_perc]

                if args.logDispStep != -1 and step % args.logDispStep == 0:
                    #prediction for serialization to disk
                    tf_fetches = tf_fetches + [inputs['left']] + [
                        inputs['proxy']
                    ] + [full_res_disp]

                #run network
                fetches = sess.run(tf_fetches)
                new_loss = fetches[0]

                if args.mode == 'MAD':
                    #update sampling probabilities
                    if step == 0:
                        loss_t_2 = new_loss
                        loss_t_1 = new_loss
                    expected_loss = 2 * loss_t_1 - loss_t_2
                    gain_loss = expected_loss - new_loss
                    sample_distribution = args.decay * sample_distribution
                    for i in last_trained_blocks:
                        sample_distribution[i] += args.uf * gain_loss

                    last_trained_blocks = blocks_to_train
                    loss_t_2 = loss_t_1
                    loss_t_1 = new_loss

                disp = fetches[1][-1]
                gt = fetches[2][-1]
                real_width = fetches[3][-1]

                # compute errors
                val = gt > 0
                disp_diff = np.abs(gt[val] - disp[val])
                outliers = np.logical_and(disp_diff > 3,
                                          (disp_diff / gt[val]) >= 0.05)
                d1 = np.mean(outliers) * 100.
                epe = np.mean(disp_diff)

                d1_accumulator.append(d1)
                avg_accumulator.append(epe)

                if step % 100 == 0:
                    #log on terminal
                    fbTime = (time.time() - start_time)
                    exec_time += fbTime
                    fbTime = fbTime / 100
                    if args.summary:
                        logger.add_summary(fetches[4], global_step=step)
                    missing_time = (max_steps - step) * fbTime

                    with open(os.path.join(args.output, 'histogram.csv'),
                              'a') as f_out:
                        f_out.write('%s\n' % fetch_counter)

                    print('Step: %04d \tEPE:%.3f\tD1:%.3f\t' % (step, epe, d1))
                    start_time = time.time()

                #reset network if necessary
                if new_loss > args.SSIMTh:
                    restorer.restore(sess, args.weights)
                    reset_counter += 1

                #save disparity if requested
                if args.logDispStep != -1 and step % args.logDispStep == 0:
                    dispy = fetches[-1]
                    prox = fetches[-2]
                    l = fetches[-3]

                    dispy_to_save = np.clip(dispy[0].astype(np.uint16), 0,
                                            MAX_DISP)
                    cv2.imwrite(
                        os.path.join(
                            args.output,
                            'disparities/disparity_{}.png'.format(step)),
                        dispy_to_save * 256)
                step += 1

        except tf.errors.InvalidArgumentError:  #OutOfRangeError:
            pass
        finally:

            with open(os.path.join(args.output, 'overall.csv'), 'w+') as f_out:
                print(fetch_counter)

                # report series
                f_out.write('EPE\tD1\n')
                f_out.write('%.3f\t%.3f\n' %
                            (np.asarray(avg_accumulator).mean(),
                             np.asarray(d1_accumulator).mean()))

            with open(os.path.join(args.output, 'series.csv'), 'w+') as f_out:
                f_out.write('step\tEPE\tD1\n')
                for i, (a, b) in enumerate(zip(avg_accumulator,
                                               d1_accumulator)):
                    f_out.write('%d & %.3f & %.3f\n' % (i, a, b))

            if args.saveWeights:
                adaptation_saver.save(sess,
                                      args.output + '/weights/model',
                                      global_step=step)
                print('Checkpoint saved in {}/weights'.format(args.output))
            print('Result saved in {}'.format(args.output))
            print('All Done, Bye Bye!')
Ejemplo n.º 6
0
def main(args):
    #load json file config
    with open(args.blockConfig) as json_data:
        train_config = json.load(json_data)

    #read input data
    with tf.variable_scope('input_reader'):
        data_set = data_reader.dataset(args.list,
                                       batch_size=1,
                                       crop_shape=args.imageShape,
                                       num_epochs=1,
                                       augment=False,
                                       is_training=False,
                                       shuffle=False)
        left_img_batch, right_img_batch, gt_image_batch = data_set.get_batch()
        inputs = {
            'left': left_img_batch,
            'right': right_img_batch,
            'target': gt_image_batch
        }

    #build inference network
    with tf.variable_scope('model'):
        net_args = {}
        net_args['left_img'] = left_img_batch
        net_args['right_img'] = right_img_batch
        net_args['split_layers'] = [None]
        net_args['sequence'] = True
        net_args['train_portion'] = 'BEGIN'
        net_args['bulkhead'] = True if args.mode == 'MAD' else False
        stereo_net = Nets.get_stereo_net(args.modelName, net_args)
        print('Stereo Prediction Model:\n', stereo_net)
        predictions = stereo_net.get_disparities()
        full_res_disp = predictions[-1]

    #build real full resolution loss
    with tf.variable_scope('full_res_loss'):
        # reconstruction loss between warped right image and original left image
        full_reconstruction_loss = loss_factory.get_reprojection_loss(
            'mean_SSIM_l1', reduced=True)(predictions, inputs)

    #build validation ops
    with tf.variable_scope('validation_error'):
        # compute error against gt
        abs_err = tf.abs(full_res_disp - gt_image_batch)
        valid_map = tf.where(tf.equal(gt_image_batch, 0),
                             tf.zeros_like(gt_image_batch, dtype=tf.float32),
                             tf.ones_like(gt_image_batch, dtype=tf.float32))
        filtered_error = abs_err * valid_map

        abs_err = tf.reduce_sum(filtered_error) / tf.reduce_sum(valid_map)
        bad_pixel_abs = tf.where(
            tf.greater(filtered_error, PIXEL_TH),
            tf.ones_like(filtered_error, dtype=tf.float32),
            tf.zeros_like(filtered_error, dtype=tf.float32))
        bad_pixel_perc = tf.reduce_sum(bad_pixel_abs) / tf.reduce_sum(
            valid_map)

    #build train ops
    disparity_trainer = tf.train.MomentumOptimizer(args.lr, 0.9)
    train_ops = []
    if args.mode == 'MAD':
        #build train ops for separate portion of the network
        predictions = predictions[:-1]  #remove full res disp

        inputs_modules = {
            'left':
            scale_tensor(left_img_batch, args.reprojectionScale),
            'right':
            scale_tensor(right_img_batch, args.reprojectionScale),
            'target':
            scale_tensor(gt_image_batch, args.reprojectionScale) /
            args.reprojectionScale
        }

        assert (len(predictions) == len(train_config))
        for counter, p in enumerate(predictions):
            print('Build train ops for disparity {}'.format(counter))

            #rescale predictions to proper resolution
            multiplier = tf.cast(
                tf.shape(left_img_batch)[1] // tf.shape(p)[1], tf.float32)
            p = preprocessing.resize_to_prediction(
                p, inputs_modules['left']) * multiplier

            #compute reprojection error
            with tf.variable_scope('reprojection_' + str(counter)):
                reconstruction_loss = loss_factory.get_reprojection_loss(
                    'mean_SSIM_l1', reduced=True)([p], inputs_modules)

            #build train op
            layer_to_train = train_config[counter]
            print('Going to train on {}'.format(layer_to_train))
            var_accumulator = []
            for name in layer_to_train:
                var_accumulator += stereo_net.get_variables(name)
            print('Number of variable to train: {}'.format(
                len(var_accumulator)))

            #add new training op
            train_ops.append(
                disparity_trainer.minimize(reconstruction_loss,
                                           var_list=var_accumulator))

            print('Done')
            print('=' * 50)

        #create Sampler to fetch portions to train
        sampler = sampler_factory.get_sampler(args.sampleMode, args.numBlocks,
                                              args.fixedID)

    elif args.mode == 'FULL':
        #build single train op for the full network
        train_ops.append(disparity_trainer.minimize(full_reconstruction_loss))

    if args.summary:
        #add summaries
        tf.summary.scalar('EPE', abs_err)
        tf.summary.scalar('bad3', bad_pixel_perc)
        tf.summary.image('full_res_disp',
                         preprocessing.colorize_img(full_res_disp, cmap='jet'),
                         max_outputs=1)
        tf.summary.image('gt_disp',
                         preprocessing.colorize_img(gt_image_batch,
                                                    cmap='jet'),
                         max_outputs=1)

        #create summary logger
        summary_op = tf.summary.merge_all()
        logger = tf.summary.FileWriter(args.output)

    #start session
    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        #init stuff
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        #start queue runners
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(sess=sess, coord=coord)

        #restore disparity inference weights
        var_to_restore = weights_utils.get_var_to_restore_list(
            args.weights, [])
        assert (len(var_to_restore) > 0)
        restorer = tf.train.Saver(var_list=var_to_restore)
        restorer.restore(sess, args.weights)
        print('Disparity Net Restored?: {}, number of restored variables: {}'.
              format(True, len(var_to_restore)))

        num_actions = len(train_ops)
        if args.mode == 'FULL':
            selected_train_ops = train_ops
        else:
            selected_train_ops = [tf.no_op()]

        epe_accumulator = []
        bad3_accumulator = []
        time_accumulator = []
        exec_time = 0
        fetch_counter = [0] * num_actions
        sample_distribution = np.zeros(shape=[num_actions])
        temp_score = np.zeros(shape=[num_actions])
        loss_t_2 = 0
        loss_t_1 = 0
        expected_loss = 0
        last_trained_blocks = []
        reset_counter = 0
        step = 0
        max_steps = data_set.get_size()
        try:
            start_time = time.time()
            while True:
                #fetch new network portion to train
                if step % args.sampleFrequency == 0 and args.mode == 'MAD':
                    #Sample
                    distribution = softmax(sample_distribution)
                    blocks_to_train = sampler.sample(distribution)
                    selected_train_ops = [
                        train_ops[i] for i in blocks_to_train
                    ]

                    #accumulate sampling statistics
                    for l in blocks_to_train:
                        fetch_counter[l] += 1

                #build list of tensorflow operations that needs to be executed

                #errors and full resolution loss
                tf_fetches = [
                    abs_err, bad_pixel_perc, full_reconstruction_loss
                ]

                if args.summary and step % 100 == 0:
                    #summaries
                    tf_fetches = tf_fetches + [summary_op]

                #update ops
                tf_fetches = tf_fetches + selected_train_ops

                if args.logDispStep != -1 and step % args.logDispStep == 0:
                    #prediction for serialization to disk
                    tf_fetches = tf_fetches + [full_res_disp]

                #run network
                fetches = sess.run(tf_fetches)
                new_loss = fetches[2]

                if args.mode == 'MAD':
                    #update sampling probabilities
                    if step == 0:
                        loss_t_2 = new_loss
                        loss_t_1 = new_loss
                    expected_loss = 2 * loss_t_1 - loss_t_2
                    gain_loss = expected_loss - new_loss
                    sample_distribution = 0.99 * sample_distribution
                    for i in last_trained_blocks:
                        sample_distribution[i] += 0.01 * gain_loss

                    last_trained_blocks = blocks_to_train
                    loss_t_2 = loss_t_1
                    loss_t_1 = new_loss

                #accumulate performance metrics
                epe_accumulator.append(fetches[0])
                bad3_accumulator.append(fetches[1])

                if step % 100 == 0:
                    #log on terminal
                    fbTime = (time.time() - start_time)
                    exec_time += fbTime
                    fbTime = fbTime / 100
                    if args.summary:
                        logger.add_summary(fetches[3], global_step=step)
                    missing_time = (max_steps - step) * fbTime
                    print(
                        'Step:{:4d}\tbad3:{:.2f}\tEPE:{:.2f}\tSSIM:{:.2f}\tf/b time:{:3f}\tMissing time:{}'
                        .format(step, fetches[1], fetches[0], new_loss, fbTime,
                                datetime.timedelta(seconds=missing_time)))
                    start_time = time.time()

                #reset network if necessary
                if new_loss > args.SSIMTh:
                    restorer.restore(sess, args.weights)
                    reset_counter += 1

                #save disparity if requested
                if args.logDispStep != -1 and step % args.logDispStep == 0:
                    dispy = fetches[-1]
                    dispy_to_save = np.clip(dispy[0].astype(np.uint16), 0,
                                            MAX_DISP)
                    cv2.imwrite(
                        os.path.join(
                            args.output,
                            'disparities/disparity_{}.png'.format(step)),
                        dispy_to_save * 256)

                step += 1

        except Exception as e:
            print('Exception catched {}'.format(e))
            #raise(e)
        finally:
            epe_array = epe_accumulator
            bad3_array = bad3_accumulator
            epe_accumulator = np.sum(epe_accumulator)
            bad3_accumulator = np.sum(bad3_accumulator)
            with open(os.path.join(args.output, 'stats.csv'), 'w+') as f_out:
                # report series
                f_out.write('Metrics,cumulative,average\n')
                f_out.write('EPE,{},{}\n'.format(epe_accumulator,
                                                 epe_accumulator / step))
                f_out.write('bad3,{},{}\n'.format(bad3_accumulator,
                                                  bad3_accumulator / step))
                f_out.write('time,{},{}\n'.format(exec_time, exec_time / step))
                f_out.write('FPS,{}\n'.format(1 / (exec_time / step)))
                f_out.write('#resets,{}\n'.format(reset_counter))
                f_out.write('Blocks')
                for n in range(len(predictions)):
                    f_out.write(',{}'.format(n))
                f_out.write(',final\n')
                f_out.write('fetch_counter')
                for c in fetch_counter:
                    f_out.write(',{}'.format(c))
                f_out.write('\n')
                for c in sample_distribution:
                    f_out.write(',{}'.format(c))
                f_out.write('\n')

            step_time = exec_time / step
            time_array = [str(x * step_time) for x in range(len(epe_array))]

            with open(os.path.join(args.output, 'series.csv'), 'w+') as f_out:
                f_out.write('Iteration,Time,EPE,bad3\n')
                for i, (t, e,
                        b) in enumerate(zip(time_array, epe_array,
                                            bad3_array)):
                    f_out.write('{},{},{},{}\n'.format(i, t, e, b))

            print('All Done, Bye Bye!')
            coord.request_stop()
            coord.join()