class Transporter6dAgent(TransporterAgent):
  """6D Transporter variant."""

  def __init__(self, name, task):
    super().__init__(name, task)

    self.attention_model = Attention(
        input_shape=self.input_shape,
        num_rotations=1,
        preprocess=self.preprocess)
    self.transport_model = Transport(
        input_shape=self.input_shape,
        num_rotations=self.num_rotations,
        crop_size=self.crop_size,
        preprocess=self.preprocess,
        per_pixel_loss=False,
        six_dof=False)

    self.rpz_model = Transport(
        input_shape=self.input_shape,
        num_rotations=self.num_rotations,
        crop_size=self.crop_size,
        preprocess=self.preprocess,
        per_pixel_loss=False,
        six_dof=True)

    self.transport_model.set_bounds_pixel_size(self.bounds, self.pixel_size)
    self.rpz_model.set_bounds_pixel_size(self.bounds, self.pixel_size)

    self.six_dof = True

    self.p0_pixel_error = tf.keras.metrics.Mean(name='p0_pixel_error')
    self.p1_pixel_error = tf.keras.metrics.Mean(name='p1_pixel_error')
    self.p0_theta_error = tf.keras.metrics.Mean(name='p0_theta_error')
    self.p1_theta_error = tf.keras.metrics.Mean(name='p1_theta_error')
    self.metrics = [
        self.p0_pixel_error, self.p1_pixel_error, self.p0_theta_error,
        self.p1_theta_error
    ]

  def get_six_dof(self,
                  transform_params,
                  heightmap,
                  pose0,
                  pose1,
                  augment=True):
    """Adjust SE(3) poses via the in-plane SE(2) augmentation transform."""
    debug_visualize = False

    p1_position, p1_rotation = pose1[0], pose1[1]
    p0_position, p0_rotation = pose0[0], pose0[1]

    if debug_visualize:
      self.vis = utils.create_visualizer()
      self.transport_model.vis = self.vis

    if augment:
      t_world_center, t_world_centernew = utils.get_se3_from_image_transform(
          *transform_params, heightmap, self.bounds, self.pixel_size)

      if debug_visualize:
        label = 't_world_center'
        utils.make_frame(self.vis, label, h=0.05, radius=0.0012, o=1.0)
        self.vis[label].set_transform(t_world_center)

        label = 't_world_centernew'
        utils.make_frame(self.vis, label, h=0.05, radius=0.0012, o=1.0)
        self.vis[label].set_transform(t_world_centernew)

      t_worldnew_world = t_world_centernew @ np.linalg.inv(t_world_center)
    else:
      t_worldnew_world = np.eye(4)

    p1_quat_wxyz = (p1_rotation[3], p1_rotation[0], p1_rotation[1],
                    p1_rotation[2])
    t_world_p1 = transformations.quaternion_matrix(p1_quat_wxyz)
    t_world_p1[0:3, 3] = np.array(p1_position)

    t_worldnew_p1 = t_worldnew_world @ t_world_p1

    p0_quat_wxyz = (p0_rotation[3], p0_rotation[0], p0_rotation[1],
                    p0_rotation[2])
    t_world_p0 = transformations.quaternion_matrix(p0_quat_wxyz)
    t_world_p0[0:3, 3] = np.array(p0_position)
    t_worldnew_p0 = t_worldnew_world @ t_world_p0

    if debug_visualize:
      label = 't_worldnew_p1'
      utils.make_frame(self.vis, label, h=0.05, radius=0.0012, o=1.0)
      self.vis[label].set_transform(t_worldnew_p1)

      label = 't_world_p1'
      utils.make_frame(self.vis, label, h=0.05, radius=0.0012, o=1.0)
      self.vis[label].set_transform(t_world_p1)

      label = 't_worldnew_p0-0thetaoriginally'
      utils.make_frame(self.vis, label, h=0.05, radius=0.0021, o=1.0)
      self.vis[label].set_transform(t_worldnew_p0)

    # PICK FRAME, using 0 rotation due to suction rotational symmetry
    t_worldnew_p0theta0 = t_worldnew_p0 * 1.0
    t_worldnew_p0theta0[0:3, 0:3] = np.eye(3)

    if debug_visualize:
      label = 'PICK'
      utils.make_frame(self.vis, label, h=0.05, radius=0.0021, o=1.0)
      self.vis[label].set_transform(t_worldnew_p0theta0)

    # PLACE FRAME, adjusted for this 0 rotation on pick
    t_p0_p0theta0 = np.linalg.inv(t_worldnew_p0) @ t_worldnew_p0theta0
    t_worldnew_p1theta0 = t_worldnew_p1 @ t_p0_p0theta0

    if debug_visualize:
      label = 'PLACE'
      utils.make_frame(self.vis, label, h=0.05, radius=0.0021, o=1.0)
      self.vis[label].set_transform(t_worldnew_p1theta0)

    # convert the above rotation to euler
    quatwxyz_worldnew_p1theta0 = transformations.quaternion_from_matrix(
        t_worldnew_p1theta0)
    q = quatwxyz_worldnew_p1theta0
    quatxyzw_worldnew_p1theta0 = (q[1], q[2], q[3], q[0])
    p1_rotation = quatxyzw_worldnew_p1theta0
    p1_euler = utils.quatXYZW_to_eulerXYZ(p1_rotation)
    roll = p1_euler[0]
    pitch = p1_euler[1]
    p1_theta = -p1_euler[2]

    p0_theta = 0
    z = p1_position[2]
    return p0_theta, p1_theta, z, roll, pitch

  def get_sample(self, dataset, augment=True):
    (obs, act, _, _), _ = dataset.sample()
    img = self.get_image(obs)

    # Get training labels from data sample.
    p0_xyz, p0_xyzw = act['pose0']
    p1_xyz, p1_xyzw = act['pose1']
    p0 = utils.xyz_to_pix(p0_xyz, self.bounds, self.pix_size)
    p0_theta = -np.float32(utils.quatXYZW_to_eulerXYZ(p0_xyzw)[2])
    p1 = utils.xyz_to_pix(p1_xyz, self.bounds, self.pix_size)
    p1_theta = -np.float32(utils.quatXYZW_to_eulerXYZ(p1_xyzw)[2])
    p1_theta = p1_theta - p0_theta
    p0_theta = 0

    if augment:
      img, _, (p0, p1), transforms = utils.perturb(img, [p0, p1])
      p0_theta, p1_theta, z, roll, pitch = self.get_six_dof(
          transforms, img[:, :, 3], (p0_xyz, p0_xyzw), (p1_xyz, p1_xyzw))

    return img, p0, p0_theta, p1, p1_theta, z, roll, pitch

  def train(self, dataset, num_iter, writer, validation_dataset):
    """Train on dataset for a specific number of iterations.

    Args:
      dataset: a ravens.Dataset.
      num_iter: int, number of iterations to train.
      writer: a TF summary writer (for tensorboard).
      validation_dataset: a ravens.Dataset.
    """

    validation_rate = 200

    for i in range(num_iter):

      tf.keras.backend.set_learning_phase(1)
      _, p0, p0_theta, p1, p1_theta, z, roll, pitch = self.get_sample(
          dataset)

      # Compute training losses.
      loss0 = self.attention_model.train(input_image, p0, p0_theta)

      loss1 = self.transport_model.train(input_image, p0, p1, p1_theta, z, roll,
                                         pitch)

      loss2 = self.rpz_model.train(input_image, p0, p1, p1_theta, z, roll,
                                   pitch)
      del loss2

      with writer.as_default():
        tf.summary.scalar(
            'attention_loss',
            self.attention_model.metric.result(),
            step=self.total_iter + i)

        tf.summary.scalar(
            'transport_loss',
            self.transport_model.metric.result(),
            step=self.total_iter + i)
        tf.summary.scalar(
            'z_loss',
            self.rpz_model.z_metric.result(),
            step=self.total_iter + i)
        tf.summary.scalar(
            'roll_loss',
            self.rpz_model.roll_metric.result(),
            step=self.total_iter + i)
        tf.summary.scalar(
            'pitch_loss',
            self.rpz_model.pitch_metric.result(),
            step=self.total_iter + i)

      print(f'Train Iter: {self.total_iter + i} Loss: {loss0:.4f} {loss1:.4f}')

      if (self.total_iter + i) % validation_rate == 0:
        print('Validating!')
        tf.keras.backend.set_learning_phase(0)
        input_image, p0, p0_theta, p1, p1_theta, z, roll, pitch = self.get_data_batch(
            validation_dataset, augment=False)

        loss1 = self.transport_model.train(
            input_image, p0, p1, p1_theta, z, roll, pitch, validate=True)

        loss2 = self.rpz_model.train(
            input_image, p0, p1, p1_theta, z, roll, pitch, validate=True)

        # compute pixel/theta metrics
        # [metric.reset_states() for metric in self.metrics]
        # for _ in range(30):
        #     obs, act, info = validation_dataset.random_sample()
        #     self.act(obs, act, info, compute_error=True)

        with writer.as_default():
          tf.summary.scalar(
              'val_transport_loss',
              self.transport_model.metric.result(),
              step=self.total_iter + i)
          tf.summary.scalar(
              'val_z_loss',
              self.rpz_model.z_metric.result(),
              step=self.total_iter + i)
          tf.summary.scalar(
              'val_roll_loss',
              self.rpz_model.roll_metric.result(),
              step=self.total_iter + i)
          tf.summary.scalar(
              'val_pitch_loss',
              self.rpz_model.pitch_metric.result(),
              step=self.total_iter + i)

          tf.summary.scalar(
              'p0_pixel_error',
              self.p0_pixel_error.result(),
              step=self.total_iter + i)
          tf.summary.scalar(
              'p1_pixel_error',
              self.p1_pixel_error.result(),
              step=self.total_iter + i)
          tf.summary.scalar(
              'p0_theta_error',
              self.p0_theta_error.result(),
              step=self.total_iter + i)
          tf.summary.scalar(
              'p1_theta_error',
              self.p1_theta_error.result(),
              step=self.total_iter + i)

        tf.keras.backend.set_learning_phase(1)

    self.total_iter += num_iter
    self.save()

  def act(self, obs, info, compute_error=False, gt_act=None):
    """Run inference and return best action given visual observations."""

    # Get heightmap from RGB-D images.
    colormap, heightmap = self.get_heightmap(obs, self.camera_config)

    # Concatenate color with depth images.
    input_image = np.concatenate(
        (colormap, heightmap[Ellipsis, None], heightmap[Ellipsis, None], heightmap[Ellipsis,
                                                                         None]),
        axis=2)

    # Attention model forward pass.
    attention = self.attention_model.forward(input_image)
    argmax = np.argmax(attention)
    argmax = np.unravel_index(argmax, shape=attention.shape)
    p0_pixel = argmax[:2]
    p0_theta = argmax[2] * (2 * np.pi / attention.shape[2])

    # Transport model forward pass.
    transport = self.transport_model.forward(input_image, p0_pixel)
    _, z, roll, pitch = self.rpz_model.forward(input_image, p0_pixel)

    argmax = np.argmax(transport)
    argmax = np.unravel_index(argmax, shape=transport.shape)

    # Index into 3D discrete tensor, grab z, roll, pitch activations
    z_best = z[:, argmax[0], argmax[1], argmax[2]][Ellipsis, None]
    roll_best = roll[:, argmax[0], argmax[1], argmax[2]][Ellipsis, None]
    pitch_best = pitch[:, argmax[0], argmax[1], argmax[2]][Ellipsis, None]

    # Send through regressors for each of z, roll, pitch
    z_best = self.rpz_model.z_regressor(z_best)[0, 0]
    roll_best = self.rpz_model.roll_regressor(roll_best)[0, 0]
    pitch_best = self.rpz_model.pitch_regressor(pitch_best)[0, 0]

    p1_pixel = argmax[:2]
    p1_theta = argmax[2] * (2 * np.pi / transport.shape[2])

    # Pixels to end effector poses.
    p0_position = utils.pix_to_xyz(p0_pixel, heightmap, self.bounds,
                                   self.pixel_size)
    p1_position = utils.pix_to_xyz(p1_pixel, heightmap, self.bounds,
                                   self.pixel_size)

    p1_position = (p1_position[0], p1_position[1], z_best)

    p0_rotation = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
    p1_rotation = utils.eulerXYZ_to_quatXYZW(
        (roll_best, pitch_best, -p1_theta))

    if compute_error:
      gt_p0_position, gt_p0_rotation = gt_act['params']['pose0']
      gt_p1_position, gt_p1_rotation = gt_act['params']['pose1']

      gt_p0_pixel = np.array(
          utils.xyz_to_pix(gt_p0_position, self.bounds, self.pixel_size))
      gt_p1_pixel = np.array(
          utils.xyz_to_pix(gt_p1_position, self.bounds, self.pixel_size))

      self.p0_pixel_error(np.linalg.norm(gt_p0_pixel - np.array(p0_pixel)))
      self.p1_pixel_error(np.linalg.norm(gt_p1_pixel - np.array(p1_pixel)))

      gt_p0_theta = -np.float32(
          utils.quatXYZW_to_eulerXYZ(gt_p0_rotation)[2])
      gt_p1_theta = -np.float32(
          utils.quatXYZW_to_eulerXYZ(gt_p1_rotation)[2])

      self.p0_theta_error(
          abs((np.rad2deg(gt_p0_theta - p0_theta) + 180) % 360 - 180))
      self.p1_theta_error(
          abs((np.rad2deg(gt_p1_theta - p1_theta) + 180) % 360 - 180))

      return None

    return {'pose0': (p0_position, p0_rotation),
            'pose1': (p1_position, p1_rotation)}
class Form2FitAgent:
  """Form-2-fit Agent (https://form2fit.github.io/)."""

  def __init__(self, name, task):
    self.name = name
    self.task = task
    self.total_iter = 0
    self.num_rotations = 24
    self.descriptor_dim = 16
    self.pixel_size = 0.003125
    self.input_shape = (320, 160, 6)
    self.camera_config = cameras.RealSenseD415.CONFIG
    self.models_dir = os.path.join('checkpoints', self.name)
    self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]])

    self.pick_model = Attention(
        input_shape=self.input_shape,
        num_rotations=1,
        preprocess=self.preprocess,
        lite=True)
    self.place_model = Attention(
        input_shape=self.input_shape,
        num_rotations=1,
        preprocess=self.preprocess,
        lite=True)
    self.match_model = Matching(
        input_shape=self.input_shape,
        descriptor_dim=self.descriptor_dim,
        num_rotations=self.num_rotations,
        preprocess=self.preprocess,
        lite=True)

  def train(self, dataset, num_iter, writer, validation_dataset=None):
    """Train on dataset for a specific number of iterations."""
    del validation_dataset

    for i in range(num_iter):
      obs, act, _ = dataset.random_sample()

      # Get heightmap from RGB-D images.
      configs = act['camera_config']
      colormap, heightmap = self.get_heightmap(obs, configs)

      # Get training labels from data sample.
      pose0, pose1 = act['params']['pose0'], act['params']['pose1']
      p0_position, p0_rotation = pose0[0], pose0[1]
      p0 = utils.xyz_to_pix(p0_position, self.bounds, self.pixel_size)
      p0_theta = -np.float32(
          utils.quatXYZW_to_eulerXYZ(p0_rotation)[2])
      p1_position, p1_rotation = pose1[0], pose1[1]
      p1 = utils.xyz_to_pix(p1_position, self.bounds, self.pixel_size)
      p1_theta = -np.float32(
          utils.quatXYZW_to_eulerXYZ(p1_rotation)[2])
      p1_theta = p1_theta - p0_theta
      p0_theta = 0

      # Concatenate color with depth images.
      input_image = np.concatenate((colormap, heightmap[Ellipsis, None],
                                    heightmap[Ellipsis, None], heightmap[Ellipsis, None]),
                                   axis=2)

      # Do data augmentation (perturb rotation and translation).
      input_image, _, roundedpixels, _ = utils.perturb(input_image, [p0, p1])
      p0, p1 = roundedpixels

      # Compute training loss.
      loss0 = self.pick_model.train(input_image, p0, theta=0)
      loss1 = self.place_model.train(input_image, p1, theta=0)
      loss2 = self.match_model.train(input_image, p0, p1, theta=p1_theta)
      with writer.as_default():
        tf.summary.scalar(
            'pick_loss',
            self.pick_model.metric.result(),
            step=self.total_iter + i)
        tf.summary.scalar(
            'place_loss',
            self.place_model.metric.result(),
            step=self.total_iter + i)
        tf.summary.scalar(
            'match_loss',
            self.match_model.metric.result(),
            step=self.total_iter + i)
      print(
          f'Train Iter: {self.total_iter + i} Loss: {loss0:.4f} {loss1:.4f} {loss2:.4f}'
      )

    self.total_iter += num_iter
    self.save()

  def act(self, obs, info):
    """Run inference and return best action given visual observations."""
    del info

    act = {'camera_config': self.camera_config, 'primitive': None}
    if not obs:
      return act

    # Get heightmap from RGB-D images.
    colormap, heightmap = self.get_heightmap(obs, self.camera_config)

    # Concatenate color with depth images.
    input_image = np.concatenate(
        (colormap, heightmap[Ellipsis, None], heightmap[Ellipsis, None], heightmap[Ellipsis,
                                                                         None]),
        axis=2)

    # Get top-k pixels from pick and place heatmaps.
    k = 100
    pick_heatmap = self.pick_model.forward(
        input_image, apply_softmax=True).squeeze()
    place_heatmap = self.place_model.forward(
        input_image, apply_softmax=True).squeeze()
    descriptors = np.float32(self.match_model.forward(input_image))

    # V4
    pick_heatmap = cv2.GaussianBlur(pick_heatmap, (49, 49), 0)
    place_heatmap = cv2.GaussianBlur(place_heatmap, (49, 49), 0)
    pick_topk = np.int32(
        np.unravel_index(
            np.argsort(pick_heatmap.reshape(-1))[-k:], pick_heatmap.shape)).T
    pick_pixel = pick_topk[-1, :]
    from skimage.feature import peak_local_max  # pylint: disable=g-import-not-at-top
    place_peaks = peak_local_max(place_heatmap, num_peaks=1)
    distances = np.ones((place_peaks.shape[0], self.num_rotations)) * 10
    pick_descriptor = descriptors[0, pick_pixel[0],
                                  pick_pixel[1], :].reshape(1, -1)
    for i in range(place_peaks.shape[0]):
      peak = place_peaks[i, :]
      place_descriptors = descriptors[:, peak[0], peak[1], :]
      distances[i, :] = np.linalg.norm(
          place_descriptors - pick_descriptor, axis=1)
    ibest = np.unravel_index(np.argmin(distances), shape=distances.shape)
    p0_pixel = pick_pixel
    p0_theta = 0
    p1_pixel = place_peaks[ibest[0], :]
    p1_theta = ibest[1] * (2 * np.pi / self.num_rotations)

    # # V3
    # pick_heatmap = cv2.GaussianBlur(pick_heatmap, (49, 49), 0)
    # place_heatmap = cv2.GaussianBlur(place_heatmap, (49, 49), 0)
    # pick_topk = np.int32(
    #     np.unravel_index(
    #         np.argsort(pick_heatmap.reshape(-1))[-k:], pick_heatmap.shape)).T
    # place_topk = np.int32(
    #     np.unravel_index(
    #         np.argsort(place_heatmap.reshape(-1))[-k:],
    #         place_heatmap.shape)).T
    # pick_pixel = pick_topk[-1, :]
    # place_pixel = place_topk[-1, :]
    # pick_descriptor = descriptors[0, pick_pixel[0],
    #                               pick_pixel[1], :].reshape(1, -1)
    # place_descriptor = descriptors[:, place_pixel[0], place_pixel[1], :]
    # distances = np.linalg.norm(place_descriptor - pick_descriptor, axis=1)
    # irotation = np.argmin(distances)
    # p0_pixel = pick_pixel
    # p0_theta = 0
    # p1_pixel = place_pixel
    # p1_theta = irotation * (2 * np.pi / self.num_rotations)

    # # V2
    # pick_topk = np.int32(
    #     np.unravel_index(
    #         np.argsort(pick_heatmap.reshape(-1))[-k:], pick_heatmap.shape)).T
    # place_topk = np.int32(
    #     np.unravel_index(
    #         np.argsort(place_heatmap.reshape(-1))[-k:],
    #         place_heatmap.shape)).T
    # pick_pixel = pick_topk[-1, :]
    # pick_descriptor = descriptors[0, pick_pixel[0],
    #                               pick_pixel[1], :].reshape(1, 1, 1, -1)
    # distances = np.linalg.norm(descriptors - pick_descriptor, axis=3)
    # distances = np.transpose(distances, [1, 2, 0])
    # max_distance = int(np.round(np.max(distances)))
    # for i in range(self.num_rotations):
    #   distances[:, :, i] = cv2.circle(distances[:, :, i],
    #                                   (pick_pixel[1], pick_pixel[0]), 50,
    #                                   max_distance, -1)
    # ibest = np.unravel_index(np.argmin(distances), shape=distances.shape)
    # p0_pixel = pick_pixel
    # p0_theta = 0
    # p1_pixel = ibest[:2]
    # p1_theta = ibest[2] * (2 * np.pi / self.num_rotations)

    # # V1
    # pick_topk = np.int32(
    #     np.unravel_index(
    #         np.argsort(pick_heatmap.reshape(-1))[-k:], pick_heatmap.shape)).T
    # place_topk = np.int32(
    #     np.unravel_index(
    #         np.argsort(place_heatmap.reshape(-1))[-k:],
    #         place_heatmap.shape)).T
    # distances = np.zeros((k, k, self.num_rotations))
    # for ipick in range(k):
    #   pick_descriptor = descriptors[0, pick_topk[ipick, 0],
    #                                 pick_topk[ipick, 1], :].reshape(1, -1)
    #   for iplace in range(k):
    #     place_descriptors = descriptors[:, place_topk[iplace, 0],
    #                                     place_topk[iplace, 1], :]
    #     distances[ipick, iplace, :] = np.linalg.norm(
    #         place_descriptors - pick_descriptor, axis=1)
    # ibest = np.unravel_index(np.argmin(distances), shape=distances.shape)
    # p0_pixel = pick_topk[ibest[0], :]
    # p0_theta = 0
    # p1_pixel = place_topk[ibest[1], :]
    # p1_theta = ibest[2] * (2 * np.pi / self.num_rotations)

    # Pixels to end effector poses.
    p0_position = utils.pix_to_xyz(p0_pixel, heightmap, self.bounds,
                                   self.pixel_size)
    p1_position = utils.pix_to_xyz(p1_pixel, heightmap, self.bounds,
                                   self.pixel_size)
    p0_rotation = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
    p1_rotation = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))

    act['primitive'] = 'pick_place'
    if self.task == 'sweeping':
      act['primitive'] = 'sweep'
    elif self.task == 'pushing':
      act['primitive'] = 'push'
    params = {
        'pose0': (p0_position, p0_rotation),
        'pose1': (p1_position, p1_rotation)
    }
    act['params'] = params
    return act

  #-------------------------------------------------------------------------
  # Helper Functions
  #-------------------------------------------------------------------------

  def preprocess(self, image):
    """Pre-process images (subtract mean, divide by std)."""
    color_mean = 0.18877631
    depth_mean = 0.00509261
    color_std = 0.07276466
    depth_std = 0.00903967
    image[:, :, :3] = (image[:, :, :3] / 255 - color_mean) / color_std
    image[:, :, 3:] = (image[:, :, 3:] - depth_mean) / depth_std
    return image

  def get_heightmap(self, obs, configs):
    """Reconstruct orthographic heightmaps with segmentation masks."""
    heightmaps, colormaps = utils.reconstruct_heightmaps(
        obs['color'], obs['depth'], configs, self.bounds, self.pixel_size)
    colormaps = np.float32(colormaps)
    heightmaps = np.float32(heightmaps)

    # Fuse maps from different views.
    valid = np.sum(colormaps, axis=3) > 0
    repeat = np.sum(valid, axis=0)
    repeat[repeat == 0] = 1
    colormap = np.sum(colormaps, axis=0) / repeat[Ellipsis, None]
    colormap = np.uint8(np.round(colormap))
    heightmap = np.sum(heightmaps, axis=0) / repeat
    return colormap, heightmap

  def load(self, num_iter):
    """Load pre-trained models."""
    pick_fname = 'pick-ckpt-%d.h5' % num_iter
    place_fname = 'place-ckpt-%d.h5' % num_iter
    match_fname = 'match-ckpt-%d.h5' % num_iter
    pick_fname = os.path.join(self.models_dir, pick_fname)
    place_fname = os.path.join(self.models_dir, place_fname)
    match_fname = os.path.join(self.models_dir, match_fname)
    self.pick_model.load(pick_fname)
    self.place_model.load(place_fname)
    self.match_model.load(match_fname)
    self.total_iter = num_iter

  def save(self):
    """Save models."""
    if not os.path.exists(self.models_dir):
      os.makedirs(self.models_dir)
    pick_fname = 'pick-ckpt-%d.h5' % self.total_iter
    place_fname = 'place-ckpt-%d.h5' % self.total_iter
    match_fname = 'match-ckpt-%d.h5' % self.total_iter
    pick_fname = os.path.join(self.models_dir, pick_fname)
    place_fname = os.path.join(self.models_dir, place_fname)
    match_fname = os.path.join(self.models_dir, match_fname)
    self.pick_model.save(pick_fname)
    self.place_model.save(place_fname)
    self.match_model.save(match_fname)