def cb(self, src, rgb, d, caminfo, *more):
        # Ugly workaround because approximate sync sometimes jumps back in time.
        if rgb.header.stamp <= self.last_stamp:
            rospy.logwarn("Jump back in time detected and dropped like it's hot")
            return

        self.last_stamp = rgb.header.stamp

        detrects = get_rects(src)

        # Early-exit to minimize CPU usage if possible.
        #if len(detrects) == 0:
        #    return

        # If nobody's listening, why should we be computing?
        if 0 == sum(p.get_num_connections() for p in (self.pub, self.pub_vis, self.pub_pa, self.pub_tracks)):
            return

        header = rgb.header
        bridge = CvBridge()
        rgb = bridge.imgmsg_to_cv2(rgb)[:,:,::-1]  # Need to do BGR-RGB conversion manually.
        d = bridge.imgmsg_to_cv2(d)
        imgs = []
        for detrect in detrects:
            detrect = self.getrect(*detrect)
            det_rgb = cutout(rgb, *detrect)
            det_d = cutout(d, *detrect)

            # Preprocess and stick into the minibatch.
            im = subtractbg(det_rgb, det_d, 1.0, 0.5)
            im = self.preproc(im)
            imgs.append(im)
            sys.stderr.write("\r{}".format(self.counter)) ; sys.stderr.flush()
            self.counter += 1

        # TODO: We could further optimize by putting all augmentations in a
        #       single batch and doing only one forward pass. Should be easy.
        if len(detrects):
            bits = [self.net.forward(batch) for batch in self.aug.augbatch_pred(np.array(imgs), fast=True)]
            preds = bit2deg(ensemble_biternions(bits)) - 90  # Subtract 90 to correct for "my weird" origin.
            # print(preds)
        else:
            preds = []

        if 0 < self.pub.get_num_connections():
            self.pub.publish(HeadOrientations(
                header=header,
                angles=list(preds),
                confidences=[0.83] * len(imgs)
            ))

        # Visualization
        if 0 < self.pub_vis.get_num_connections():
            rgb_vis = rgb[:,:,::-1].copy()
            for detrect, alpha in zip(detrects, preds):
                l, t, w, h = self.getrect(*detrect)
                px =  int(round(np.cos(np.deg2rad(alpha))*w/2))
                py = -int(round(np.sin(np.deg2rad(alpha))*h/2))
                cv2.rectangle(rgb_vis, (detrect[0], detrect[1]), (detrect[0]+detrect[2],detrect[1]+detrect[3]), (0,255,255), 1)
                cv2.rectangle(rgb_vis, (l,t), (l+w,t+h), (0,255,0), 2)
                cv2.line(rgb_vis, (l+w//2, t+h//2), (l+w//2+px,t+h//2+py), (0,255,0), 2)
                # cv2.putText(rgb_vis, "{:.1f}".format(alpha), (l, t+25), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,0,255), 2)
            vismsg = bridge.cv2_to_imgmsg(rgb_vis, encoding='rgb8')
            vismsg.header = header  # TODO: Seems not to work!
            self.pub_vis.publish(vismsg)

        if 0 < self.pub_pa.get_num_connections():
            fx, cx = caminfo.K[0], caminfo.K[2]
            fy, cy = caminfo.K[4], caminfo.K[5]

            poseArray = PoseArray(header=header)

            for (dx, dy, dw, dh, dd), alpha in zip(get_rects(src, with_depth=True), preds):
                dx, dy, dw, dh = self.getrect(dx, dy, dw, dh)

                # PoseArray message for boundingbox centres
                poseArray.poses.append(Pose(
                    position=Point(
                        x=dd*((dx+dw/2.0-cx)/fx),
                        y=dd*((dy+dh/2.0-cy)/fy),
                        z=dd
                    ),
                    # TODO: Use global UP vector (0,0,1) and transform into frame used by this message.
                    orientation=Quaternion(*quaternion_about_axis(np.deg2rad(alpha), [0, -1, 0]))
                ))

            self.pub_pa.publish(poseArray)

        if len(more) == 1 and 0 < self.pub_tracks.get_num_connections():
            t3d = more[0]
            try:
                self.listener.waitForTransform(header.frame_id, t3d.header.frame_id, rospy.Time(), rospy.Duration(1))
                for track, alpha in zip(t3d.tracks, preds):
                    track.pose.pose.orientation = self.listener.transformQuaternion(t3d.header.frame_id, QuaternionStamped(
                        header=header,
                        # TODO: Same as above!
                        quaternion=Quaternion(*quaternion_about_axis(np.deg2rad(alpha), [0, -1, 0]))
                    )).quaternion
                self.pub_tracks.publish(t3d)
            except TFException:
                pass
  net = netlib.mknet()
  printnow('Network has {:.3f}M params in {} layers\n', df.utils.count_params(net)/1000.0/1000.0, len(net.modules))

  print(net[:21].forward(aug.augbatch_train(Xtr[:100])[0]).shape)

  costs = dotrain(net, crit, aug, Xtr, ytr, nepochs=args.epochs)
  print("Costs: {}".format(' ; '.join(map(str, costs))))

  dostats(net, aug, Xtr, batchsize=64)

  # Save the network.
  printnow("Saving the learned network to {}\n", args.output)
  np.save(args.output, net.__getstate__())

  # Prediction, TODO: Move to ROS node.
  s = np.argsort(nte)
  Xte,yte = Xte[s],yte[s]

  printnow("(TEMP) Doing predictions.\n", args.output)
  y_pred = dopred_bit(net, aug, Xte, batchsize=64)

  # Ensemble the flips!
  #res = maad_from_deg(bit2deg(yte), bit2deg(yte))
  res = maad_from_deg(bit2deg(y_pred), bit2deg(yte))
  printnow("MAE for test images              = {:.2f}\n", res.mean())

  #y_pred2 = ensemble_biternions([yte[::2], flipbiternions(yte[1::2])])
  y_pred2 = ensemble_biternions([y_pred[::2], flipbiternions(y_pred[1::2])])
  res = maad_from_deg(bit2deg(y_pred2), bit2deg(yte[::2]))
  printnow("MAE for flipped augmented images = {:.2f}\n", res.mean())