Exemplo n.º 1
0
def sanity_check(model, test_loader, use_gpu=True):
    """
    TODO: add documentation here
    """
    device = get_device(use_gpu)
    model.to(device)
    model.eval()

    original_loss_list = []
    final_loss_list = []
    for features, labels in test_loader:
        # move tensors to GPU if CUDA is available
        features, labels = features.to(device).float(), labels.to(
            device).float()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(features)
        original_loss = [
            utils.angular_error(
                features.data[0, index * 3:index * 3 + 3].detach().numpy(),
                np.transpose(labels.detach().numpy()))[0]
            for index in range(len(features[0]) // 3)
        ]

        final_loss = utils.angular_error(output.detach().numpy(),
                                         np.transpose(
                                             labels.detach().numpy()))[0][0]

        original_loss_list.append(original_loss[-1].item())
        final_loss_list.append(final_loss)

    print("Original:")
    utils.print_stats(original_loss_list)
    print("Final:")
    utils.print_stats(final_loss_list)
Exemplo n.º 2
0
Arquivo: fcn.py Projeto: zyqgmzyq/fc4
    def test_naive():
        t = time.time()

        import scipy.io
        std = scipy.io.loadmat('/home/yuanming/colorchecker_shi_greyworld.mat')
        names = map(lambda x: x[0].encode('utf8'), std['all_image_names'][0])
        #print(names)
        records = load_data(TEST_FOLDS)

        errors = []
        for r in records:
            est = np.mean(r.img, axis=(0, 1))[::-1]
            est /= np.linalg.norm(est)
            #print(r.fn, est)
            #est=np.array((1, 1, 1))
            #est2= std['estimated_illuminants'][names.index(r.fn[:-4])]
            gt2 = std['groundtruth_illuminants'][names.index(r.fn[:-4])]
            #print(est2)
            error = math.degrees(angular_error(est, gt2))
            errors.append(error)

        print("Full Image:")
        ret = print_angular_errors(errors)
        print('Test time:',
              time.time() - t, 'per image:', (time.time() - t) / len(records))

        return errors
Exemplo n.º 3
0
Arquivo: fcn.py Projeto: zyqgmzyq/fc4
    def test_patch_based(self, scale, patches, pooling='median'):
        records = load_data(TEST_FOLDS)
        avg_errors = []
        median_errors = []
        t = time.time()

        def sample_patch(img):
            s = FCN_INPUT_SIZE
            x = random.randrange(0, img.shape[0] - s + 1)
            y = random.randrange(0, img.shape[1] - s + 1)
            return img[x:x + s, y:y + s]

        for r in records:
            img = cv2.resize(r.img, (0, 0), fx=scale, fy=scale)
            img = [sample_patch(img) for i in range(patches)]
            illum_est = []
            batch_size = 4
            for j in range((len(img) + batch_size - 1) // batch_size):
                illum_est.append(
                    self.sess.run(self.illum_normalized,
                                  feed_dict={
                                      self.images:
                                      img[j * batch_size:(j + 1) * batch_size],
                                      self.dropout:
                                      1.0
                                  }))
            illum_est = np.vstack(illum_est)
            med = len(illum_est) // 2
            illum_est_median = np.array(
                [sorted(list(illum_est[:, i]))[med] for i in range(3)])
            illum_est_avg = np.mean(illum_est, axis=0)
            avg_error = math.degrees(angular_error(illum_est_avg, r.illum))
            median_error = math.degrees(
                angular_error(illum_est_median, r.illum))
            avg_errors.append(avg_error)
            median_errors.append(median_error)
        print("Avg pooling:")
        print_angular_errors(avg_errors)
        print("Median pooling:")
        print_angular_errors(median_errors)
        ppt = (time.time() - t) / len(records)
        print('Test time:', time.time() - t, 'per image:', ppt)
        if pooling == 'median':
            errors = median_errors
        else:
            errors = avg_errors
        return errors, ppt
Exemplo n.º 4
0
Arquivo: fcn.py Projeto: zjudzl/fc4
  def test_resize(self):
    records = load_data(TEST_FOLDS)
    t = time.time()

    errors = []
    for r in records:
      img = cv2.resize(r.img, (FCN_INPUT_SIZE, FCN_INPUT_SIZE))
      illum_est = self.sess.run(
          self.illum_normalized,
          feed_dict={self.images: [img],
                     self.dropout: 1.0})
      avg_error = math.degrees(angular_error(illum_est, r.illum))
      errors.append(avg_error)
    print_angular_errors(errors)
    ppt = (time.time() - t) / len(records)
    print('Test time:', time.time() - t, 'per image:', ppt)
    return errors, ppt
Exemplo n.º 5
0
def get_votes(images_path, correct_illums, debug=False):
    # TODO: add doc here

    # constants
    GREY_WORLD = 'grey_world'
    MAX_RGB = 'max_rgb'
    GREY_EDGE = 'grey_edge'

    # keys lists
    keys = [GREY_WORLD, MAX_RGB, GREY_EDGE]
    # voters lists
    voters = {
        GREY_WORLD: lambda x: grey_edge(x, njet=0, mink_norm=1, sigma=0),
        MAX_RGB: lambda x: grey_edge(x, njet=0, mink_norm=-1, sigma=0),
        GREY_EDGE: lambda x: grey_edge(x, njet=1, mink_norm=5, sigma=2)
    }
    # illum lists
    illums = {GREY_WORLD: [], MAX_RGB: [], GREY_EDGE: []}
    # error lists
    errors = {GREY_WORLD: [], MAX_RGB: [], GREY_EDGE: []}
    # estimate illuminations and calculate error
    for index, (image_path,
                correct_illum) in enumerate(zip(images_path, correct_illums)):
        image = skio.imread(image_path)
        if debug:
            print(image_path)
            print('illumination: ' + str(correct_illum))

        for key in keys:
            estim_illum = voters[key](image)
            illums[key].append(estim_illum)
            errors[key].append(utils.angular_error(estim_illum, correct_illum))
            if debug:
                print(key + ": " + str(estim_illum) + " error: " +
                      str(errors[key][-1]))
        if debug:
            print()
        else:
            utils.clear_screen()
            print(str(index + 1) + ' / ' + str(len(images_path)))

    for key in keys:
        print(key + ":")
        utils.print_stats(errors[key])

    return illums, errors
Exemplo n.º 6
0
            global_estimate = local_estimates_aggregation(
                local_estimates, confidences)
        else:
            global_estimate = local_estimates_aggregation_naive(
                local_estimates)

        end_time = timer()
        inference_times.append(end_time - start_time)

        local_rgb_estimates = 1. / local_estimates  # convert gains into rgb triplet
        local_rgb_estimates /= local_rgb_estimates.sum(axis=1, keepdims=True)
        global_rgb_estimate = 1. / global_estimate  # convert gain into rgb triplet
        global_rgb_estimate /= global_rgb_estimate.sum()

        if ground_truth_mode:
            local_angular_errors = angular_error(ground_truth,
                                                 local_rgb_estimates)
            global_angular_error = angular_error(ground_truth,
                                                 global_rgb_estimate)
            angular_errors_statistics.append(global_angular_error)
        else:
            local_angular_errors = global_angular_error = None

        # Save the white balanced image
        if args.save in [1, 2, 3, 4]:
            img = read_image(img_path=img_path,
                             input_bits=INPUT_BITS,
                             valid_bits=VALID_BITS,
                             darkness=DARKNESS,
                             gamma=GAMMA)
            wb_imgs_path = os.path.join(img_dir, 'white_balanced_images')
            if not os.path.exists(wb_imgs_path):
Exemplo n.º 7
0
def test(FLAGS):
    batch_size = FLAGS.batch_size
    height = width = FLAGS.patch_size
    final_W = FLAGS.final_W
    final_K = FLAGS.final_K
    dataset_dir = os.path.join(FLAGS.dataset_dir)
    dataset_file_name = FLAGS.dataset_file_name
    if FLAGS.use_ms:
        input_image, gt_image = data_provider.load_batch(dataset_dir, dataset_file_name,
                                                         batch_size, height, width, channel = final_W,
                                                         shuffle = False, use_ms = True, is_train = False)
    else:
        input_image, gt_image, label, file_name = data_provider.load_batch(dataset_dir, dataset_file_name,
                                                     batch_size, height, width, channel = final_W,
                                                     shuffle = False, use_ms = False, with_file_name_gain = True, is_train = False)

    with tf.variable_scope('generator'):
        if FLAGS.patch_size == 128:
            N_size = 3
        else:
            N_size = 2
        filters = net.convolve_net(input_image, final_K, final_W, ch0=64,
                                   N=N_size, D=3,
                      scope='get_filted', separable=False, bonus=False)
    predict_image = net.convolve(input_image, filters, final_K, final_W)

    # summaies
    # filters_sum = tf.summary.image('filters', filters)
    # input_image_sum = tf.summary.image('input_image', input_image)
    # gt_image_sum = tf.summary.image('gt_image', gt_image)
    # predict_image_sum = tf.summary.image('predict_image', predict_image)

    sum_total = tf.summary.merge_all()

    config = tf.ConfigProto()
    with tf.Session(config=config) as sess:

        print ('Initializers variables')
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        if FLAGS.write_sum:
            writer = tf.summary.FileWriter(FLAGS.save_dir, sess.graph)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        restorer = tf.train.Saver(max_to_keep=None)

        ckpt_path = tf.train.latest_checkpoint(FLAGS.ckpt_path)
        if ckpt_path is not None:
            print ('Restoring from', ckpt_path)
            restorer.restore(sess, ckpt_path)

        errors = []

        max_steps = FLAGS.total_test_num // batch_size
        for i_step in range(max_steps):
            if FLAGS.use_ms:
                input_image_, gt_image_, predict_image_, filters_, sum_total_ = \
                    sess.run([input_image, gt_image, predict_image, filters, sum_total])
            else:
                input_image_, gt_image_, predict_image_, filters_, label_, file_name_ , sum_total_ = \
                    sess.run([input_image, gt_image, predict_image, filters, label, file_name, sum_total])

            batch_confidence_r = utils.compute_rate_confidence(filters_, input_image_, final_K, final_W, sel_ch = 0, ref_ch = [2])
            batch_confidence_b = utils.compute_rate_confidence(filters_, input_image_, final_K, final_W, sel_ch = 2, ref_ch = [0])

            concat = utils.get_concat(input_image_, gt_image_, predict_image_)
            for batch_i in range(batch_size):
                est = utils.solve_gain(input_image_[batch_i], np.clip(predict_image_[batch_i], 0, 500))
                print ('confidence_r: ', batch_confidence_r[batch_i])
                print ('confidence_b: ', batch_confidence_b[batch_i])

                if FLAGS.use_ms:
                    save_file_name = '%03d_%02d.png'%(i_step,batch_i)
                else:
                    current_file_name = file_name_[batch_i][0].decode('utf-8').split('/')[-1]
                    print (' {} saved once'.format(current_file_name))
                    gt = label_[batch_i]
                    error = utils.angular_error(est, gt)
                    print ('est is ; ', est)
                    print ('gt is ; ', gt)
                    print ('error is ; ', error)
                    errors.append(error)
                    save_file_name = current_file_name

                est_img_ = np.clip(input_image_[batch_i] * est, 0, 255.0) / 255.0
                all_concat = np.concatenate([concat[batch_i], est_img_], axis = 1)
                if FLAGS.save_dir is not None:
                    imsave(os.path.join(FLAGS.save_dir, save_file_name), all_concat*255.0 )

                # np.save(os.path.join(FLAGS.save_dir,'%03d_%02d.npy'%(i_step,batch_i)), predict_image_[batch_i])

            if FLAGS.write_sum and i_step % 20 == 0:
                writer.add_summary(sum_total_, i)
                print ('summary saved')

        coord.request_stop()
        coord.join(threads)
    if errors:
        utils.print_angular_errors(errors)
Exemplo n.º 8
0
def inference(model_level, model_dir, test_img_IDs):
    confidence_estimation_mode = False
    model = model_builder(level=model_level,
                          confidence=False,
                          input_shape=(*PATCH_SIZE, 3))
    model.load_weights(model_dir)
    ground_truth_dict = get_ground_truth_dict(
        r'train\RECommended\ground-truth.txt')
    masks_dict = get_masks_dict(r'train\RECommended\masks.txt')
    angular_errors_statistics = []
    for (counter, test_img_ID) in enumerate(test_img_IDs):
        print('Processing {}/{} images...'.format(counter + 1,
                                                  len(test_img_IDs)),
              end='\r')
        # data generator
        batch, boxes, remained_boxes_indices, ground_truth = img2batch(
            test_img_ID,
            patch_size=PATCH_SIZE,
            input_bits=INPUT_BITS,
            valid_bits=VALID_BITS,
            darkness=DARKNESS,
            ground_truth_dict=ground_truth_dict,
            masks_dict=masks_dict,
            gamma=GAMMA)
        nb_batch = int(np.ceil(PATCHES / BATCH_SIZE))
        batch_size = int(PATCHES / nb_batch)  # actual batch size
        local_estimates, confidences = np.empty(shape=(0, 3)), np.empty(
            shape=(0, ))

        # use batch(es) to feed into the network
        for b in range(nb_batch):
            batch_start_index, batch_end_index = b * batch_size, (
                b + 1) * batch_size
            batch_tmp = batch[batch_start_index:batch_end_index, ]
            if confidence_estimation_mode:
                # the model requires 2 inputs when confidence estimation mode is activated
                batch_tmp = [batch_tmp, np.zeros((batch_size, 3))]
            outputs = model.predict(batch_tmp)  # model inference
            if confidence_estimation_mode:
                # the model produces 6 outputs when confidence estimation mode is on. See model.py for more details
                # local_estimates is the gain instead of illuminant color!
                local_estimates = np.vstack((local_estimates, outputs[4]))
                confidences = np.hstack((confidences, outputs[5].squeeze()))
            else:
                # local_estimates is the gain instead of illuminant color!
                local_estimates = np.vstack((local_estimates, outputs))
                confidences = None

        if confidence_estimation_mode:
            global_estimate = local_estimates_aggregation(
                local_estimates, confidences)
        else:
            global_estimate = local_estimates_aggregation_naive(
                local_estimates)

        global_rgb_estimate = 1. / global_estimate  # convert gain into rgb triplet

        global_angular_error = angular_error(ground_truth, global_rgb_estimate)
        angular_errors_statistics.append(global_angular_error)

    return np.array(angular_errors_statistics)
Exemplo n.º 9
0
def test(FLAGS):
    batch_size = FLAGS.batch_size
    height = width = FLAGS.patch_size
    final_W = FLAGS.final_W
    final_K = FLAGS.final_K
    dataset_dir = os.path.join(FLAGS.dataset_dir)
    dataset_file_name = FLAGS.dataset_file_name
    shuffle = FLAGS.shuffle
    input_image = tf.placeholder(tf.float32, shape=(None, height, width, 3))

    with tf.variable_scope('generator'):
        if FLAGS.patch_size == 128:
            N_size = 3
        else:
            N_size = 2
        filters = net.convolve_net(input_image,
                                   final_K,
                                   final_W,
                                   ch0=64,
                                   N=N_size,
                                   D=3,
                                   scope='get_filted',
                                   separable=False,
                                   bonus=False)
    predict_image = convolve(input_image, filters, final_K, final_W)

    config = tf.ConfigProto()
    with tf.Session(config=config) as sess:

        print('Initializers variables')
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        restorer = tf.train.Saver(max_to_keep=None)

        ckpt_path = tf.train.latest_checkpoint(FLAGS.ckpt_path)
        if ckpt_path is not None:
            print('Restoring from', ckpt_path)
            restorer.restore(sess, ckpt_path)

        errors = []

        max_steps = FLAGS.total_test_num // batch_size
        for i_step in range(max_steps):
            if FLAGS.use_ms:
                imgs, imgs_gt, labels, file_names, configs = utils.data_loader_np(
                    data_folder=dataset_dir,
                    data_txt=dataset_file_name,
                    patch_size=FLAGS.patch_size,
                    start_index=i_step * batch_size,
                    batch_size=batch_size,
                    use_ms=True)
            else:
                imgs, imgs_gt, labels, file_names = utils.data_loader_np(
                    data_folder=dataset_dir,
                    data_txt=dataset_file_name,
                    patch_size=FLAGS.patch_size,
                    start_index=i_step * batch_size,
                    batch_size=batch_size,
                    use_ms=False)
            input_image_ = utils.batch_stable_process(
                imgs,
                use_crop=FLAGS.use_crop,
                use_clip=FLAGS.use_clip,
                use_flip=FLAGS.use_flip,
                use_rotate=FLAGS.use_rotate,
                use_noise=FLAGS.use_noise)
            gt_image_ = imgs_gt
            predict_image_, filters_ = sess.run(
                [predict_image, filters],
                feed_dict={input_image: input_image_})
            # [batch, h ,w]
            batch_confidence_r = utils.compute_rate_confidence(filters_,
                                                               input_image_,
                                                               final_K,
                                                               final_W,
                                                               sel_ch=0,
                                                               ref_ch=[2],
                                                               is_spatial=True)
            batch_confidence_b = utils.compute_rate_confidence(filters_,
                                                               input_image_,
                                                               final_K,
                                                               final_W,
                                                               sel_ch=2,
                                                               ref_ch=[0],
                                                               is_spatial=True)

            concat = utils.get_concat(input_image_, gt_image_, predict_image_)
            num_filt = (FLAGS.final_K**2) * (FLAGS.final_W**2)
            for batch_i in range(batch_size):
                est_global = utils.solve_gain(
                    input_image_[batch_i],
                    np.clip(predict_image_[batch_i], 0, 500))

                if FLAGS.use_ms:
                    save_file_name = '%s_%s.png' % (
                        file_names[batch_i][0][:-4],
                        file_names[batch_i][1][:-4])
                else:
                    print('confidence_r: ',
                          np.mean(batch_confidence_r[batch_i]))
                    print('confidence_b: ',
                          np.mean(batch_confidence_b[batch_i]))
                    current_file_name = file_names[batch_i]
                    print(' {} saved once'.format(current_file_name))
                    gt = labels[batch_i]
                    error = utils.angular_error(est, gt)
                    print('est is ; ', est)
                    print('gt is ; ', gt)
                    print('error is ; ', error)
                    errors.append(error)
                    save_file_name = current_file_name

                est_global_img_ = np.clip(input_image_[batch_i] * est_global,
                                          0, 255.0) / 255.0
                all_concat = np.concatenate([concat[batch_i], est_global_img_],
                                            axis=1)
                if FLAGS.save_dir is not None:
                    imsave(os.path.join(FLAGS.save_dir, save_file_name),
                           all_concat * 255.0)
                    np_concat = np.concatenate(
                        [input_image_[batch_i], predict_image_[batch_i]],
                        axis=0)
                    file_name_np = os.path.join(FLAGS.save_dir,
                                                save_file_name[:-3] + 'npy')
                    np.save(file_name_np, np_concat)
                    if FLAGS.use_ms:
                        if FLAGS.save_clus:
                            print('local gain fitting', save_file_name)
                            gain_box, clus_img, clus_labels = utils.gain_fitting(
                                input_image_[batch_i],
                                predict_image_[batch_i],
                                is_local=True,
                                n_clusters=2,
                                gamma=4.0,
                                with_clus=True)
                            num_multi = len(set(clus_labels))
                            for index_ill in range(num_multi):
                                confi_multi_r = utils.get_confi_multi(
                                    clus_labels,
                                    batch_confidence_r[batch_i],
                                    label=index_ill)
                                confi_multi_b = utils.get_confi_multi(
                                    clus_labels,
                                    batch_confidence_b[batch_i],
                                    label=index_ill)
                                print('confidence_r for ill %d' % index_ill,
                                      confi_multi_r)
                                print('confidence_b for ill %d' % index_ill,
                                      confi_multi_b)
                            imsave(
                                os.path.join(
                                    FLAGS.save_dir,
                                    '%s_clus.png' % (save_file_name[:-4])),
                                clus_img)
                        if FLAGS.save_filt:
                            cur_filt = filters_[batch_i]
                            for filt_index in range(num_filt):
                                cur_ = cur_filt[..., filt_index]
                                imsave(
                                    os.path.join(
                                        FLAGS.save_dir, '%s_filt_%d.png' %
                                        (save_file_name[:-4], filt_index)),
                                    cur_)
                        file_name_json = os.path.join(
                            FLAGS.save_dir, save_file_name[:-3] + 'json')
                        save_dict = configs[batch_i]
                        with open(file_name_json, 'w') as fp:
                            json.dump(save_dict, fp, ensure_ascii=False)
                # np.save(os.path.join(FLAGS.save_dir,'%03d_%02d.npy'%(i_step,batch_i)), predict_image_[batch_i])
    if errors:
        utils.print_angular_errors(errors)
Exemplo n.º 10
0
Arquivo: fcn.py Projeto: zyqgmzyq/fc4
    def test(self,
             summary=False,
             scales=[1.0],
             weights=[],
             summary_key=0,
             data=None,
             eval_speed=False,
             visualize=False):
        if not TEST_FOLDS:
            return [0]
        if data is None:
            records = load_data(TEST_FOLDS)
        else:
            records = data
        avg_errors = []
        median_errors = []
        t = time.time()

        summaries = []
        if weights == []:
            weights = [1.0] * len(scales)

        outputs = []
        ground_truth = []
        avg_confidence = []

        errors = []
        for r in records:
            all_pixels = []
            for scale, weight in zip(scales, weights):
                img = r.img
                if scale != 1.0:
                    img = cv2.resize(img, (0, 0), fx=scale, fy=scale)
                shape = img.shape[:2]
                if shape not in self.test_nets:
                    aspect_ratio = 1.0 * shape[1] / shape[0]
                    if aspect_ratio < 1:
                        target_shape = (MERGED_IMAGE_SIZE,
                                        MERGED_IMAGE_SIZE * aspect_ratio)
                    else:
                        target_shape = (MERGED_IMAGE_SIZE / aspect_ratio,
                                        MERGED_IMAGE_SIZE)
                    target_shape = tuple(map(int, target_shape))

                    test_net = {}
                    test_net['illums'] = tf.placeholder(tf.float32,
                                                        shape=(None, 3),
                                                        name='test_illums')
                    test_net['images'] = tf.placeholder(tf.float32,
                                                        shape=(None, shape[0],
                                                               shape[1], 3),
                                                        name='test_images')
                    with tf.variable_scope("FCN", reuse=True):
                        test_net['pixels'] = FCN.build_branches(
                            test_net['images'], 1.0)
                        test_net['est'] = tf.reduce_sum(test_net['pixels'],
                                                        axis=(1, 2))
                    test_net['merged'] = get_visualization(
                        test_net['images'], test_net['pixels'],
                        test_net['est'], test_net['illums'], target_shape)
                    self.test_nets[shape] = test_net
                test_net = self.test_nets[shape]

                pixels, est, merged = self.sess.run(
                    [test_net['pixels'], test_net['est'], test_net['merged']],
                    feed_dict={
                        test_net['images']: img[None, :, :, :],
                        test_net['illums']: r.illum[None, :]
                    })

                if eval_speed:
                    eval_batch_size = 1
                    eval_packed_input = img[None, :, :, :].copy()
                    eval_packed_input = np.concatenate(
                        [eval_packed_input for i in range(eval_batch_size)],
                        axis=0)
                    eval_packed_input = np.ascontiguousarray(eval_packed_input)
                    eval_start_t = time.time()
                    print(eval_packed_input.shape)
                    eval_rounds = 100
                    images_variable = tf.Variable(
                        tf.random_normal(eval_packed_input.shape,
                                         dtype=tf.float32,
                                         stddev=1e-1))
                    print(images_variable)
                    for eval_t in range(eval_rounds):
                        print(eval_t)
                        pixels, est = self.sess.run(
                            [test_net['pixels'], test_net['est']],
                            feed_dict={
                                test_net['images']:  #images_variable,
                                eval_packed_input,
                            })
                    eval_elapsed_t = time.time() - eval_start_t
                    print('per image evaluation time',
                          eval_elapsed_t / (eval_rounds * eval_batch_size))

                pixels = pixels[0]
                #est = est[0]
                merged = merged[0]

                all_pixels.append(weight * pixels.reshape(-1, 3))

            all_pixels = np.sum(np.concatenate(all_pixels, axis=0), axis=0)
            est = all_pixels / (np.linalg.norm(all_pixels) + 1e-7)
            outputs.append(est)
            ground_truth.append(r.illum)
            error = math.degrees(angular_error(est, r.illum))
            errors.append(error)
            avg_confidence.append(np.mean(np.linalg.norm(all_pixels)))

            summaries.append((r.fn, error, merged))
        print("Full Image:")
        ret = print_angular_errors(errors)
        ppt = (time.time() - t) / len(records)
        print('Test time:', time.time() - t, 'per image:', ppt)

        if summary:
            for fn, error, merged in summaries:
                folder = self.get_ckpt_folder() + '/test%04dsummaries_%4f/' % (
                    summary_key, scale)
                try:
                    os.mkdir(folder)
                except:
                    pass
                summary_fn = '%s/%5.3f-%s.png' % (folder, error, fn)
                cv2.imwrite(summary_fn, merged[:, :, ::-1] * 255)

        if visualize:
            for fn, error, merged in summaries:
                cv2.imshow('Testing', merged[:, :, ::-1])
                cv2.waitKey(0)

        return errors, ppt, outputs, ground_truth, ret, avg_confidence
Exemplo n.º 11
0
    def train_one_epoch(self, epoch, data_loader, is_train=True):
        """
        Train the model for 1 epoch of the training set.
        """
        batch_time = AverageMeter()
        errors = AverageMeter()
        losses_gaze = AverageMeter()

        tic = time.time()
        for i, (input_img, target) in enumerate(data_loader):
            input_var = torch.autograd.Variable(input_img.float().cuda())
            target_var = torch.autograd.Variable(target.float().cuda())

            # train gaze net
            pred_gaze = self.model(input_var)

            gaze_error_batch = np.mean(
                angular_error(pred_gaze.cpu().data.numpy(),
                              target_var.cpu().data.numpy()))
            errors.update(gaze_error_batch.item(), input_var.size()[0])

            loss_gaze = F.l1_loss(pred_gaze, target_var)
            self.optimizer.zero_grad()
            loss_gaze.backward()
            self.optimizer.step()
            losses_gaze.update(loss_gaze.item(), input_var.size()[0])

            if i % self.print_freq == 0:
                self.writer.add_scalar('Loss/gaze', losses_gaze.avg,
                                       self.train_iter)

            # report information
            if i % self.print_freq == 0 and i is not 0:
                print(
                    '--------------------------------------------------------------------'
                )
                msg = "train error: {:.3f} - loss_gaze: {:.5f}"
                print(msg.format(errors.avg, losses_gaze.avg))

                # measure elapsed time
                print('iteration ', self.train_iter)
                toc = time.time()
                batch_time.update(toc - tic)
                # print('Current batch running time is ', np.round(batch_time.avg / 60.0), ' mins')
                tic = time.time()
                # estimate the finish time
                est_time = (self.epochs - epoch) * (
                    self.num_train / self.batch_size) * batch_time.avg / 60.0
                print('Estimated training time left: ', np.round(est_time),
                      ' mins')

                self.writer.add_scalar('Error/train', errors.avg,
                                       self.train_iter)

                errors.reset()
                losses_gaze.reset()

            self.train_iter = self.train_iter + 1

        toc = time.time()
        batch_time.update(toc - tic)

        print('running time is ', batch_time.avg)
        return errors.avg, losses_gaze.avg
Exemplo n.º 12
0
            # training phase
            while continue_train:
                b, l, continue_train = get_train_batch()
                logs = model.train_on_batch(b, l)
                train_mse.append(logs[0])
                train_angular_errors.append(logs[1])

            # validation phase
            while continue_val:
                b, l, continue_val = get_val_batch()
                if b.shape[0] > 4:  # only test on images with more than 4 crops
                    estimates = model.predict_on_batch(b)
                    estimates /= estimates[:, 1][:, np.newaxis]
                    estimates = np.median(estimates, axis=0)
                    val_angular_errors.append(angular_error(l[0, ], estimates))
                else:
                    pass

            mean_val_angular_error_current_epoch = np.mean(val_angular_errors)
            median_val_angular_error_current_epoch = np.median(
                val_angular_errors)
            tri_val_angular_error_current_epoch = (
                np.percentile(val_angular_errors, 25) +
                2 * np.median(val_angular_errors) +
                np.percentile(val_angular_errors, 75)) / 4.
            b25_val_angular_error_current_epoch = percentile_mean(
                np.array(val_angular_errors), 0, 25)
            w25_val_angular_error_current_epoch = percentile_mean(
                np.array(val_angular_errors), 75, 100)