示例#1
0
def score_and_select_spatially_separated_keypoints(
        metadata,  # metadata.p,
        confidence_score_data,  # data.p file
        K,  # number of reference descriptors
        position_diff_threshold,  # threshold in pixels
        output_dir,
        visualize=False,
        multi_episode_dict=None,  # needed if you want to visualize
):
    """
    Scores keypoints according to their confidence.
    Selects the top keypoints that are "spatially separated"
    Saves out a file 'spatial_descriptors.p' that records this data
    """
    data = confidence_score_data
    heatmap_values = data['heatmap_values']
    scoring_func = create_scoring_function(gamma=3)
    score_data = score_heatmap_values(heatmap_values,
                                      scoring_func=scoring_func)
    sorted_idx = score_data['sorted_idx']

    keypoint_idx = select_spatially_separated_keypoints(
        sorted_idx,
        metadata['indices'],
        position_diff_threshold=position_diff_threshold,
        K=K,
        verbose=False)

    ref_descriptors = metadata['ref_descriptors'][keypoint_idx]  # [K, D]
    spatial_descriptors_data = score_data
    spatial_descriptors_data['spatial_descriptors'] = ref_descriptors
    spatial_descriptors_data['spatial_descriptors_idx'] = keypoint_idx
    save_pickle(spatial_descriptors_data,
                os.path.join(output_dir, 'spatial_descriptors.p'))
示例#2
0
    def save_data(
        self,
        save_dir=None,
    ):
        """
        Saves data from the

        - PlanContainer
        - OnlineEpisode
        """

        if save_dir is None:
            save_dir = os.path.join(
                utils.get_data_root(),
                'hardware_experiments/closed_loop_rollouts/sandbox',
                utils.get_current_YYYY_MM_DD_hh_mm_ss_ms())

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        print("saving MPC rollout data at: %s" % (save_dir))
        save_data = {
            'episode': self._state_dict['episode'].get_save_data(),
            'plan': self._state_dict['plan'].get_save_data(),
        }

        save_file = os.path.join(save_dir, "data.p")
        utils.save_pickle(save_data, save_file)
        print("done saving data")
示例#3
0
def select_spatially_separated_descriptors(K=5,  # number of reference descriptors
                                          output_dir=None,
                                          visualize=False):
    raise ValueError("deprecated")
    multi_episode_dict = exp_utils.load_episodes()['multi_episode_dict']
    model_file = exp_utils.get_DD_model_file()

    confidence_scores_folder = os.path.join(get_data_root(),
                                            "dev/experiments/09/descriptor_confidence_scores/2020-03-25-19-57-26-556093_constant_velocity_500/2020-03-30-14-21-13-371713")

    # folder = "dev/experiments/07/descriptor_confidence_scores/real_push_box/2020-03-10-15-57-43-867147"
    folder = confidence_scores_folder
    folder = os.path.join(get_data_root(), folder)
    data_file = os.path.join(folder, 'data.p')
    data = load_pickle(data_file)

    heatmap_values = data['heatmap_values']
    scoring_func = keypoint_selection.create_scoring_function(gamma=3)
    score_data = keypoint_selection.score_heatmap_values(heatmap_values,
                                                         scoring_func=scoring_func)
    sorted_idx = score_data['sorted_idx']

    metadata_file = os.path.join(folder, 'metadata.p')
    metadata = load_pickle(metadata_file)
    camera_name = metadata['camera_name']

    keypoint_idx = keypoint_selection.select_spatially_separated_keypoints(sorted_idx,
                                                                           metadata['indices'],
                                                                           position_diff_threshold=30,
                                                                           K=K,
                                                                           verbose=False)

    ref_descriptors = metadata['ref_descriptors'][keypoint_idx]  # [K, D]
    spatial_descriptors_data = score_data
    spatial_descriptors_data['spatial_descriptors'] = ref_descriptors
    spatial_descriptors_data['spatial_descriptors_idx'] = keypoint_idx
    save_pickle(spatial_descriptors_data, os.path.join(folder, 'spatial_descriptors.p'))
示例#4
0
def train_dynamics(
    config,
    train_dir,  # str: directory to save output
    multi_episode_dict=None,
    spatial_descriptors_idx=None,
    metadata=None,
    spatial_descriptors_data=None,
):

    assert multi_episode_dict is not None
    # assert spatial_descriptors_idx is not None

    # set random seed for reproduction
    set_seed(config['train']['random_seed'])

    st_epoch = config['train'][
        'resume_epoch'] if config['train']['resume_epoch'] > 0 else 0
    tee = Tee(os.path.join(train_dir, 'train_st_epoch_%d.log' % st_epoch), 'w')

    tensorboard_dir = os.path.join(train_dir, "tensorboard")
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)

    writer = SummaryWriter(log_dir=tensorboard_dir)

    # save the config
    save_yaml(config, os.path.join(train_dir, "config.yaml"))

    if metadata is not None:
        save_pickle(metadata, os.path.join(train_dir, 'metadata.p'))

    if spatial_descriptors_data is not None:
        save_pickle(spatial_descriptors_data,
                    os.path.join(train_dir, 'spatial_descriptors.p'))

    training_stats = dict()
    training_stats_file = os.path.join(train_dir, 'training_stats.yaml')

    # load the data

    action_function = ActionFunctionFactory.function_from_config(config)
    observation_function = ObservationFunctionFactory.function_from_config(
        config)

    datasets = {}
    dataloaders = {}
    data_n_batches = {}
    for phase in ['train', 'valid']:
        print("Loading data for %s" % phase)
        datasets[phase] = MultiEpisodeDataset(
            config,
            action_function=action_function,
            observation_function=observation_function,
            episodes=multi_episode_dict,
            phase=phase)

        dataloaders[phase] = DataLoader(
            datasets[phase],
            batch_size=config['train']['batch_size'],
            shuffle=True if phase == 'train' else False,
            num_workers=config['train']['num_workers'],
            drop_last=True)

        data_n_batches[phase] = len(dataloaders[phase])

    use_gpu = torch.cuda.is_available()

    # compute normalization parameters if not starting from pre-trained network . . .
    '''
    Build model for dynamics prediction
    '''
    model_dy = build_dynamics_model(config)
    camera_name = config['vision_net']['camera_name']

    # criterion
    criterionMSE = nn.MSELoss()
    l1Loss = nn.L1Loss()
    smoothL1 = nn.SmoothL1Loss()

    # optimizer
    params = model_dy.parameters()
    lr = float(config['train']['lr'])
    optimizer = optim.Adam(params,
                           lr=lr,
                           betas=(config['train']['adam_beta1'], 0.999))

    # setup scheduler
    sc = config['train']['lr_scheduler']
    scheduler = None

    if config['train']['lr_scheduler']['enabled']:
        if config['train']['lr_scheduler']['type'] == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=sc['factor'],
                                          patience=sc['patience'],
                                          threshold_mode=sc['threshold_mode'],
                                          cooldown=sc['cooldown'],
                                          verbose=True)
        elif config['train']['lr_scheduler']['type'] == "StepLR":
            step_size = config['train']['lr_scheduler']['step_size']
            gamma = config['train']['lr_scheduler']['gamma']
            scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
        else:
            raise ValueError("unknown scheduler type: %s" %
                             (config['train']['lr_scheduler']['type']))

    if use_gpu:
        print("using gpu")
        model_dy = model_dy.cuda()

    # print("model_dy.vision_net._ref_descriptors.device", model_dy.vision_net._ref_descriptors.device)
    # print("model_dy.vision_net #params: %d" %(count_trainable_parameters(model_dy.vision_net)))

    best_valid_loss = np.inf
    valid_loss_type = config['train']['valid_loss_type']
    global_iteration = 0
    counters = {'train': 0, 'valid': 0}
    epoch_counter_external = 0
    loss = 0

    index_map = get_object_and_robot_state_indices(config)
    object_state_indices = torch.LongTensor(index_map['object_indices'])
    robot_state_indices = torch.LongTensor(index_map['robot_indices'])

    object_state_shape = config['dataset']['object_state_shape']

    try:
        for epoch in range(st_epoch, config['train']['n_epoch']):
            phases = ['train', 'valid']
            epoch_counter_external = epoch

            writer.add_scalar("Training Params/epoch", epoch, global_iteration)
            for phase in phases:

                # only validate at a certain frequency
                if (phase == "valid") and (
                    (epoch % config['train']['valid_frequency']) != 0):
                    continue

                model_dy.train(phase == 'train')

                average_meter_container = dict()

                step_duration_meter = AverageMeter()

                # bar = ProgressBar(max_value=data_n_batches[phase])
                loader = dataloaders[phase]

                for i, data in enumerate(loader):

                    loss_container = dict()  # store the losses for this step

                    step_start_time = time.time()

                    global_iteration += 1
                    counters[phase] += 1

                    with torch.set_grad_enabled(phase == 'train'):
                        n_his, n_roll = config['train']['n_history'], config[
                            'train']['n_rollout']
                        n_samples = n_his + n_roll

                        if DEBUG:
                            print("global iteration: %d" % (global_iteration))
                            print("n_samples", n_samples)

                        # [B, n_samples, obs_dim]
                        observations = data['observations']
                        visual_observations_list = data[
                            'visual_observations_list']

                        # [B, n_samples, action_dim]
                        actions = data['actions']
                        B = actions.shape[0]

                        if use_gpu:
                            observations = observations.cuda()
                            actions = actions.cuda()

                        # compile the visual observations
                        # compute the output of the visual model for all timesteps
                        visual_model_output_list = []
                        for visual_obs in visual_observations_list:
                            # visual_obs is a dict containing observation for a single
                            # time step (of course across a batch however)
                            # visual_obs[<camera_name>]['rgb_tensor'] has shape [B, 3, H, W]

                            # probably need to cast input to cuda
                            # [B, -1, 3]
                            keypoints = visual_obs[camera_name][
                                'descriptor_keypoints_3d_world_frame']

                            # [B, K, 3] where K = len(spatial_descriptors_idx)
                            keypoints = keypoints[:, spatial_descriptors_idx]

                            B, K, _ = keypoints.shape

                            # [B, K*3]
                            keypoints_reshape = keypoints.reshape([B, K * 3])

                            if DEBUG:
                                print("keypoints.shape", keypoints.shape)
                                print("keypoints_reshape.shape",
                                      keypoints_reshape.shape)
                            visual_model_output_list.append(keypoints_reshape)

                        visual_model_output = None
                        if len(visual_model_output_list) > 0:
                            # concatenate this into a tensor
                            # [B, n_samples, vision_model_out_dim]
                            visual_model_output = torch.stack(
                                visual_model_output_list, dim=1)

                        else:
                            visual_model_output = torch.Tensor(
                            )  # empty tensor

                        # states, actions = data
                        assert actions.shape[1] == n_samples

                        # cast this to float so it can be concatenated below
                        visual_model_output = visual_model_output.type_as(
                            observations)

                        # we don't have any visual observations, so states are observations
                        # states is gotten by concatenating visual_observations and observations
                        # [B, n_samples, vision_model_out_dim + obs_dim]
                        states = torch.cat((visual_model_output, observations),
                                           dim=-1)

                        # state_cur: B x n_his x state_dim
                        # state_cur = states[:, :n_his]

                        # [B, n_his, state_dim]
                        state_init = states[:, :n_his]

                        # We want to rollout n_roll steps
                        # actions = [B, n_his + n_roll, -1]
                        # so we want action_seq.shape = [B, n_roll, -1]
                        action_start_idx = 0
                        action_end_idx = n_his + n_roll - 1
                        action_seq = actions[:, action_start_idx:
                                             action_end_idx, :]

                        if DEBUG:
                            print("states.shape", states.shape)
                            print("state_init.shape", state_init.shape)
                            print("actions.shape", actions.shape)
                            print("action_seq.shape", action_seq.shape)

                        # try using models_dy.rollout_model instead of doing this manually
                        rollout_data = rollout_model(state_init=state_init,
                                                     action_seq=action_seq,
                                                     dynamics_net=model_dy,
                                                     compute_debug_data=False)

                        # [B, n_roll, state_dim]
                        state_rollout_pred = rollout_data['state_pred']

                        # [B, n_roll, state_dim]
                        state_rollout_gt = states[:, n_his:]

                        if DEBUG:
                            print("state_rollout_gt.shape",
                                  state_rollout_gt.shape)
                            print("state_rollout_pred.shape",
                                  state_rollout_pred.shape)

                        # the loss function is between
                        # [B, n_roll, state_dim]
                        state_pred_err = state_rollout_pred - state_rollout_gt

                        # [B, n_roll, object_state_dim]
                        object_state_err = state_pred_err[:, :,
                                                          object_state_indices]
                        B, n_roll, object_state_dim = object_state_err.shape

                        # [B, n_roll, *object_state_shape]
                        object_state_err_reshape = object_state_err.reshape(
                            [B, n_roll, *object_state_shape])

                        # num weights
                        J = object_state_err_reshape.shape[2]
                        weights = model_dy.weight_matrix

                        assert len(
                            weights) == J, "len(weights) = %d, but J = %d" % (
                                len(weights), J)

                        # loss mse object, note the use of broadcasting semantics
                        # [B, n_roll]
                        object_state_loss_mse = weights * torch.pow(
                            object_state_err_reshape, 2).sum(dim=-1)
                        object_state_loss_mse = object_state_loss_mse.mean()

                        l2_object = (weights * torch.norm(
                            object_state_err_reshape, dim=-1)).mean()

                        l2_object_final_step = (weights * torch.norm(
                            object_state_err_reshape[:, -1], dim=-1)).mean()

                        # [B, n_roll, robot_state_dim]
                        robot_state_err = state_pred_err[:, :,
                                                         robot_state_indices]
                        robot_state_loss_mse = torch.pow(robot_state_err,
                                                         2).sum(dim=-1).mean()

                        loss_container[
                            'object_state_loss_mse'] = object_state_loss_mse
                        loss_container[
                            'robot_state_loss_mse'] = robot_state_loss_mse
                        loss_container['l2_object'] = l2_object
                        loss_container[
                            'l2_object_final_step'] = l2_object_final_step

                        # total loss
                        loss = object_state_loss_mse + robot_state_loss_mse
                        loss_container['loss'] = loss

                        for key, val in loss_container.items():
                            if not key in average_meter_container:
                                average_meter_container[key] = AverageMeter()

                            average_meter_container[key].update(val.item(), B)

                    step_duration_meter.update(time.time() - step_start_time)

                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    if (i % config['train']['log_per_iter']
                            == 0) or (global_iteration %
                                      config['train']['log_per_iter'] == 0):
                        log = '%s [%d/%d][%d/%d] LR: %.6f' % (
                            phase, epoch, config['train']['n_epoch'], i,
                            data_n_batches[phase], get_lr(optimizer))

                        # log += ', l2: %.6f' % (loss_container['l2'].item())
                        # log += ', l2_final_step: %.6f' %(loss_container['l2_final_step'].item())

                        log += ', step time %.6f' % (step_duration_meter.avg)
                        step_duration_meter.reset()

                        print(log)

                        # log data to tensorboard
                        # only do it once we have reached 100 iterations
                        if global_iteration > 100:
                            writer.add_scalar("Params/learning rate",
                                              get_lr(optimizer),
                                              global_iteration)
                            writer.add_scalar("Loss_train/%s" % (phase),
                                              loss.item(), global_iteration)

                            for loss_type, loss_obj in loss_container.items():
                                plot_name = "Loss/%s/%s" % (loss_type, phase)
                                writer.add_scalar(plot_name, loss_obj.item(),
                                                  counters[phase])

                            # only plot the weights if we are in the train phase . . . .
                            if phase == "train":
                                for i in range(len(weights)):
                                    plot_name = "Weights/%d" % (i)
                                    writer.add_scalar(plot_name,
                                                      weights[i].item(),
                                                      counters[phase])

                    if phase == 'train' and global_iteration % config['train'][
                            'ckp_per_iter'] == 0:
                        save_model(
                            model_dy, '%s/net_dy_epoch_%d_iter_%d' %
                            (train_dir, epoch, i))

                log = '%s [%d/%d] Loss: %.6f, Best valid: %.6f' % (
                    phase, epoch, config['train']['n_epoch'],
                    average_meter_container[valid_loss_type].avg,
                    best_valid_loss)
                print(log)

                # record all average_meter losses
                for key, meter in average_meter_container.items():
                    writer.add_scalar("AvgMeter/%s/%s" % (key, phase),
                                      meter.avg, epoch)

                if phase == "train":
                    if (scheduler is not None) and (
                            config['train']['lr_scheduler']['type']
                            == "StepLR"):
                        scheduler.step()

                if phase == 'valid':
                    if (scheduler is not None) and (
                            config['train']['lr_scheduler']['type']
                            == "ReduceLROnPlateau"):
                        scheduler.step(
                            average_meter_container[valid_loss_type].avg)

                    if average_meter_container[
                            valid_loss_type].avg < best_valid_loss:
                        best_valid_loss = average_meter_container[
                            valid_loss_type].avg
                        training_stats['epoch'] = epoch
                        training_stats['global_iteration'] = counters['valid']
                        save_yaml(training_stats, training_stats_file)
                        save_model(model_dy, '%s/net_best_dy' % (train_dir))

                writer.flush()  # flush SummaryWriter events to disk

    except KeyboardInterrupt:
        # save network if we have a keyboard interrupt
        save_model(
            model_dy, '%s/net_dy_epoch_%d_keyboard_interrupt' %
            (train_dir, epoch_counter_external))
        writer.flush()  # flush SummaryWriter events to disk
def train_dynamics(
    config,
    train_dir,  # str: directory to save output
    multi_episode_dict=None,
    visual_observation_function=None,
    metadata=None,
    spatial_descriptors_data=None,
):
    assert multi_episode_dict is not None
    # assert spatial_descriptors_idx is not None

    # set random seed for reproduction
    set_seed(config['train']['random_seed'])

    st_epoch = config['train'][
        'resume_epoch'] if config['train']['resume_epoch'] > 0 else 0
    tee = Tee(os.path.join(train_dir, 'train_st_epoch_%d.log' % st_epoch), 'w')

    tensorboard_dir = os.path.join(train_dir, "tensorboard")
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)

    writer = SummaryWriter(log_dir=tensorboard_dir)

    # save the config
    save_yaml(config, os.path.join(train_dir, "config.yaml"))

    if metadata is not None:
        save_pickle(metadata, os.path.join(train_dir, 'metadata.p'))

    if spatial_descriptors_data is not None:
        save_pickle(spatial_descriptors_data,
                    os.path.join(train_dir, 'spatial_descriptors.p'))

    training_stats = dict()
    training_stats_file = os.path.join(train_dir, 'training_stats.yaml')

    action_function = ActionFunctionFactory.function_from_config(config)
    observation_function = ObservationFunctionFactory.function_from_config(
        config)

    datasets = {}
    dataloaders = {}
    data_n_batches = {}
    for phase in ['train', 'valid']:
        print("Loading data for %s" % phase)
        datasets[phase] = MultiEpisodeDataset(
            config,
            action_function=action_function,
            observation_function=observation_function,
            episodes=multi_episode_dict,
            phase=phase,
            visual_observation_function=visual_observation_function)

        print("len(datasets[phase])", len(datasets[phase]))
        dataloaders[phase] = DataLoader(
            datasets[phase],
            batch_size=config['train']['batch_size'],
            shuffle=True if phase == 'train' else False,
            num_workers=config['train']['num_workers'],
            drop_last=True)

        data_n_batches[phase] = len(dataloaders[phase])

    use_gpu = torch.cuda.is_available()

    # compute normalization parameters if not starting from pre-trained network . . .

    if False:
        dataset = datasets["train"]
        data = dataset[0]
        print("data['observations_combined'].shape",
              data['observations_combined'].shape)
        print("data.keys()", data.keys())

        print("data['observations_combined']",
              data['observations_combined'][0])
        print("data['observations_combined'].shape",
              data['observations_combined'].shape)
        print("data['actions'].shape", data['actions'].shape)
        print("data['actions']\n", data['actions'])
        quit()
    '''
    Build model for dynamics prediction
    '''
    model_dy = build_dynamics_model(config)
    if config['dynamics_net'] == "mlp_weight_matrix":
        raise ValueError("can't use weight matrix with standard setup")

    # criterion
    criterionMSE = nn.MSELoss()
    l1Loss = nn.L1Loss()
    smoothL1 = nn.SmoothL1Loss()

    # optimizer
    params = model_dy.parameters()
    lr = float(config['train']['lr'])
    optimizer = optim.Adam(params,
                           lr=lr,
                           betas=(config['train']['adam_beta1'], 0.999))

    # setup scheduler
    sc = config['train']['lr_scheduler']
    scheduler = None

    if config['train']['lr_scheduler']['enabled']:
        if config['train']['lr_scheduler']['type'] == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=sc['factor'],
                                          patience=sc['patience'],
                                          threshold_mode=sc['threshold_mode'],
                                          cooldown=sc['cooldown'],
                                          verbose=True)
        elif config['train']['lr_scheduler']['type'] == "StepLR":
            step_size = config['train']['lr_scheduler']['step_size']
            gamma = config['train']['lr_scheduler']['gamma']
            scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
        else:
            raise ValueError("unknown scheduler type: %s" %
                             (config['train']['lr_scheduler']['type']))

    if use_gpu:
        print("using gpu")
        model_dy = model_dy.cuda()

    # print("model_dy.vision_net._ref_descriptors.device", model_dy.vision_net._ref_descriptors.device)
    # print("model_dy.vision_net #params: %d" %(count_trainable_parameters(model_dy.vision_net)))

    best_valid_loss = np.inf
    valid_loss_type = config['train']['valid_loss_type']
    global_iteration = 0
    counters = {'train': 0, 'valid': 0}
    epoch_counter_external = 0
    loss = 0

    try:
        for epoch in range(st_epoch, config['train']['n_epoch']):
            phases = ['train', 'valid']
            epoch_counter_external = epoch

            writer.add_scalar("Training Params/epoch", epoch, global_iteration)
            for phase in phases:

                # only validate at a certain frequency
                if (phase == "valid") and (
                    (epoch % config['train']['valid_frequency']) != 0):
                    continue

                model_dy.train(phase == 'train')

                average_meter_container = dict()

                step_duration_meter = AverageMeter()

                # bar = ProgressBar(max_value=data_n_batches[phase])
                loader = dataloaders[phase]

                for i, data in enumerate(loader):

                    loss_container = dict()  # store the losses for this step

                    step_start_time = time.time()

                    global_iteration += 1
                    counters[phase] += 1

                    with torch.set_grad_enabled(phase == 'train'):
                        n_his, n_roll = config['train']['n_history'], config[
                            'train']['n_rollout']
                        n_samples = n_his + n_roll

                        if DEBUG:
                            print("global iteration: %d" % (global_iteration))
                            print("n_samples", n_samples)

                        # [B, n_samples, obs_dim]
                        states = data['observations_combined']

                        # [B, n_samples, action_dim]
                        actions = data['actions']
                        B = actions.shape[0]

                        if use_gpu:
                            states = states.cuda()
                            actions = actions.cuda()

                        # state_cur: B x n_his x state_dim
                        # state_cur = states[:, :n_his]

                        # [B, n_his, state_dim]
                        state_init = states[:, :n_his]

                        # We want to rollout n_roll steps
                        # actions = [B, n_his + n_roll, -1]
                        # so we want action_seq.shape = [B, n_roll, -1]
                        action_start_idx = 0
                        action_end_idx = n_his + n_roll - 1
                        action_seq = actions[:, action_start_idx:
                                             action_end_idx, :]

                        if DEBUG:
                            print("states.shape", states.shape)
                            print("state_init.shape", state_init.shape)
                            print("actions.shape", actions.shape)
                            print("action_seq.shape", action_seq.shape)

                        # try using models_dy.rollout_model instead of doing this manually
                        rollout_data = rollout_model(state_init=state_init,
                                                     action_seq=action_seq,
                                                     dynamics_net=model_dy,
                                                     compute_debug_data=False)

                        # [B, n_roll, state_dim]
                        state_rollout_pred = rollout_data['state_pred']

                        # [B, n_roll, state_dim]
                        state_rollout_gt = states[:, n_his:]

                        if DEBUG:
                            print("state_rollout_gt.shape",
                                  state_rollout_gt.shape)
                            print("state_rollout_pred.shape",
                                  state_rollout_pred.shape)

                        # the loss function is between
                        # [B, n_roll, state_dim]
                        state_pred_err = state_rollout_pred - state_rollout_gt

                        # everything is in 3D space now so no need to do any scaling
                        # all the losses would be in meters . . . .
                        loss_mse = criterionMSE(state_rollout_pred,
                                                state_rollout_gt)
                        loss_l1 = l1Loss(state_rollout_pred, state_rollout_gt)
                        loss_l2 = torch.norm(state_pred_err, dim=-1).mean()
                        loss_smoothl1 = smoothL1(state_rollout_pred,
                                                 state_rollout_gt)
                        loss_smoothl1_final_step = smoothL1(
                            state_rollout_pred[:, -1], state_rollout_gt[:, -1])

                        # compute losses at final step of the rollout
                        mse_final_step = criterionMSE(
                            state_rollout_pred[:, -1], state_rollout_gt[:, -1])
                        l2_final_step = torch.norm(state_pred_err[:, -1],
                                                   dim=-1).mean()
                        l1_final_step = l1Loss(state_rollout_pred[:, -1],
                                               state_rollout_gt[:, -1])

                        loss_container['mse'] = loss_mse
                        loss_container['l1'] = loss_l1
                        loss_container['mse_final_step'] = mse_final_step
                        loss_container['l1_final_step'] = l1_final_step
                        loss_container['l2_final_step'] = l2_final_step
                        loss_container['l2'] = loss_l2
                        loss_container['smooth_l1'] = loss_smoothl1
                        loss_container[
                            'smooth_l1_final_step'] = loss_smoothl1_final_step

                        # compute the loss
                        loss = 0
                        for key, val in config['loss_function'].items():
                            if val['enabled']:
                                loss += loss_container[key] * val['weight']

                        loss_container['loss'] = loss

                        for key, val in loss_container.items():
                            if not key in average_meter_container:
                                average_meter_container[key] = AverageMeter()

                            average_meter_container[key].update(val.item(), B)

                    step_duration_meter.update(time.time() - step_start_time)

                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    if (i % config['train']['log_per_iter']
                            == 0) or (global_iteration %
                                      config['train']['log_per_iter'] == 0):
                        log = '%s [%d/%d][%d/%d] LR: %.6f' % (
                            phase, epoch, config['train']['n_epoch'], i,
                            data_n_batches[phase], get_lr(optimizer))

                        log += ', l2: %.6f' % (loss_container['l2'].item())
                        log += ', l2_final_step: %.6f' % (
                            loss_container['l2_final_step'].item())

                        log += ', step time %.6f' % (step_duration_meter.avg)
                        step_duration_meter.reset()

                        print(log)

                        # log data to tensorboard
                        # only do it once we have reached 100 iterations
                        if global_iteration > 100:
                            writer.add_scalar("Params/learning rate",
                                              get_lr(optimizer),
                                              global_iteration)
                            writer.add_scalar("Loss_train/%s" % (phase),
                                              loss.item(), global_iteration)

                            for loss_type, loss_obj in loss_container.items():
                                plot_name = "Loss/%s/%s" % (loss_type, phase)
                                writer.add_scalar(plot_name, loss_obj.item(),
                                                  counters[phase])

                    if phase == 'train' and global_iteration % config['train'][
                            'ckp_per_iter'] == 0:
                        save_model(
                            model_dy, '%s/net_dy_epoch_%d_iter_%d' %
                            (train_dir, epoch, i))

                log = '%s [%d/%d] Loss: %.6f, Best valid: %.6f' % (
                    phase, epoch, config['train']['n_epoch'],
                    average_meter_container[valid_loss_type].avg,
                    best_valid_loss)
                print(log)

                # record all average_meter losses
                for key, meter in average_meter_container.items():
                    writer.add_scalar("AvgMeter/%s/%s" % (key, phase),
                                      meter.avg, epoch)

                if phase == "train":
                    if (scheduler is not None) and (
                            config['train']['lr_scheduler']['type']
                            == "StepLR"):
                        scheduler.step()

                if phase == 'valid':
                    if (scheduler is not None) and (
                            config['train']['lr_scheduler']['type']
                            == "ReduceLROnPlateau"):
                        scheduler.step(
                            average_meter_container[valid_loss_type].avg)

                    if average_meter_container[
                            valid_loss_type].avg < best_valid_loss:
                        best_valid_loss = average_meter_container[
                            valid_loss_type].avg
                        training_stats['epoch'] = epoch
                        training_stats['global_iteration'] = counters['valid']
                        save_yaml(training_stats, training_stats_file)
                        save_model(model_dy, '%s/net_best_dy' % (train_dir))

                writer.flush()  # flush SummaryWriter events to disk

    except KeyboardInterrupt:
        # save network if we have a keyboard interrupt
        save_model(
            model_dy, '%s/net_dy_epoch_%d_keyboard_interrupt' %
            (train_dir, epoch_counter_external))
        writer.flush()  # flush SummaryWriter events to disk
def precompute_transporter_keypoints(
    multi_episode_dict,
    model_kp,
    output_dir,  # str
    batch_size=10,
    num_workers=10,
    camera_names=None,
    model_file=None,
):

    assert model_file is not None
    metadata = dict()
    metadata['model_file'] = model_file

    save_yaml(metadata, os.path.join(output_dir, 'metadata.yaml'))
    start_time = time.time()

    log_freq = 10

    device = next(model_kp.parameters()).device
    model_kp = model_kp.eval()  # make sure model is in eval mode

    image_data_config = {
        'rgb': True,
        'mask': True,
        'depth_int16': True,
    }

    # build all the dataset
    datasets = {}
    dataloaders = {}
    for episode_name, episode in multi_episode_dict.items():
        single_episode_dict = {episode_name: episode}
        config = model_kp.config

        # need to do this since transporter type data sampling only works
        # with tuple_size = 1
        dataset_config = copy.deepcopy(config)
        dataset_config['dataset']['use_transporter_type_data_sampling'] = False

        datasets[episode_name] = ImageTupleDataset(
            dataset_config,
            single_episode_dict,
            phase="all",
            image_data_config=image_data_config,
            tuple_size=1,
            compute_K_inv=True,
            camera_names=camera_names)

        dataloaders[episode_name] = DataLoader(datasets[episode_name],
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=False)

    episode_counter = 0
    num_episodes = len(multi_episode_dict)

    for episode_name, dataset in datasets.items():
        episode_counter += 1
        print("\n\n")

        episode = multi_episode_dict[episode_name]
        hdf5_file = None
        try:
            hdf5_file = os.path.basename(episode.image_data_file)
        except AttributeError:
            hdf5_file = "%s.h5" % (episode.name)

        hdf5_file_fullpath = os.path.join(output_dir, hdf5_file)

        str_split = hdf5_file_fullpath.split(".")
        assert len(str_split) == 2
        pickle_file_fullpath = str_split[0] + ".p"

        # print("episode_name", episode_name)
        # print("hdf5_file_fullpath", hdf5_file_fullpath)
        # print("pickle_file_fullpath", pickle_file_fullpath)

        if os.path.isfile(hdf5_file_fullpath):
            os.remove(hdf5_file_fullpath)

        if os.path.isfile(pickle_file_fullpath):
            os.remove(pickle_file_fullpath)

        episode_keypoint_data = dict()

        episode_start_time = time.time()
        with h5py.File(hdf5_file_fullpath, 'w') as hf:
            for i, data in enumerate(dataloaders[episode_name]):
                data = data[0]
                rgb_crop_tensor = data['rgb_crop_tensor'].to(device)
                crop_params = data['crop_param']
                depth_int16 = data['depth_int16']
                key_tree_joined = data['key_tree_joined']

                # print("\n\n i = %d, idx = %d, camera_name = %s" %(i, data['idx'], data['camera_name']))

                depth = depth_int16.float() * 1.0 / DEPTH_IM_SCALE

                if (i % log_freq) == 0:
                    log_msg = "computing [%d/%d][%d/%d]" % (
                        episode_counter, num_episodes, i + 1,
                        len(dataloaders[episode_name]))
                    print(log_msg)

                B = rgb_crop_tensor.shape[0]

                _, H, W, _ = data['rgb'].shape

                kp_pred = None
                kp_pred_full_pixels = None
                with torch.no_grad():
                    kp_pred = model_kp.predict_keypoint(rgb_crop_tensor)

                    # [B, n_kp, 2]
                    kp_pred_full_pixels = transporter_utils.map_cropped_pixels_to_full_pixels_torch(
                        kp_pred, crop_params)

                    xy = kp_pred_full_pixels.clone()
                    xy[:, :, 0] = (xy[:, :, 0]) * 2.0 / W - 1.0
                    xy[:, :, 1] = (xy[:, :, 1]) * 2.0 / H - 1.0

                    # debug
                    # print("xy[0,0]", xy[0,0])

                    # get depth values
                    kp_pred_full_pixels_int = kp_pred_full_pixels.type(
                        torch.LongTensor)

                    z = pdc_utils.index_into_batch_image_tensor(
                        depth.unsqueeze(1),
                        kp_pred_full_pixels_int.transpose(1, 2))

                    z = z.squeeze(1)
                    K_inv = data['K_inv']
                    pts_camera_frame = pdc_torch_utils.pinhole_unprojection(
                        kp_pred_full_pixels, z, K_inv)

                    # print("pts_camera_frame.shape", pts_camera_frame.shape)

                    pts_world_frame = pdc_torch_utils.transform_points_3D(
                        data['T_W_C'], pts_camera_frame)

                    # print("pts_world_frame.shape", pts_world_frame.shape)

                for j in range(B):

                    keypoint_data = {}

                    # this goes from [-1,1]
                    keypoint_data['xy'] = torch_utils.cast_to_numpy(xy[j])
                    keypoint_data['uv'] = torch_utils.cast_to_numpy(
                        kp_pred_full_pixels[j])
                    keypoint_data['uv_int'] = torch_utils.cast_to_numpy(
                        kp_pred_full_pixels_int[j])
                    keypoint_data['z'] = torch_utils.cast_to_numpy(z[j])
                    keypoint_data[
                        'pos_world_frame'] = torch_utils.cast_to_numpy(
                            pts_world_frame[j])
                    keypoint_data[
                        'pos_camera_frame'] = torch_utils.cast_to_numpy(
                            pts_camera_frame[j])

                    # save out some data in both hdf5 and pickle format
                    for key, val in keypoint_data.items():
                        save_key = key_tree_joined[
                            j] + "/transporter_keypoints/%s" % (key)
                        hf.create_dataset(save_key, data=val)
                        episode_keypoint_data[save_key] = val

            save_pickle(episode_keypoint_data, pickle_file_fullpath)
            print("duration: %.3f seconds" %
                  (time.time() - episode_start_time))
示例#7
0
def main():
    d = load_model_and_data()
    model_dy = d['model_dy']
    dataset = d['dataset']
    config = d['config']
    multi_episode_dict = d['multi_episode_dict']
    planner = d['planner']
    planner_config = planner.config

    idx_dict = get_object_and_robot_state_indices(config)
    object_indices = idx_dict['object_indices']
    robot_indices = idx_dict['robot_indices']

    n_his = config['train']['n_history']

    # save_dir = os.path.join(get_project_root(),  'sandbox/mpc/', get_current_YYYY_MM_DD_hh_mm_ss_ms())
    save_dir = os.path.join(get_project_root(),
                            'sandbox/mpc/push_right_box_horizontal')
    print("save_dir", save_dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # rotate
    # episode_names = dataset.get_episode_names()
    # print("len(episode_names)", len(episode_names))
    # episode_name = episode_names[0]
    # start_idx = 1
    # n_roll = 15

    # # straight + rotate
    # episode_name = "2020-06-29-21-04-16"
    # print('episode_name', episode_name)
    # start_idx = 1
    # n_roll = 15

    # this is a nice straight push . . .
    # push with box in horizontal position
    episode_name = "2020-06-29-22-03-45"
    start_idx = 2
    n_roll = 10

    # # validation set episodes
    # episode_names = dataset.get_episode_names()
    # print("len(episode_names)", len(episode_names))
    # episode_name = episode_names[1]
    # start_idx = 2
    # n_roll = 15

    camera_name = "d415_01"
    episode = multi_episode_dict[episode_name]
    print("episode_name", episode_name)

    vis = meshcat_utils.make_default_visualizer_object()
    vis.delete()

    idx_list = list(range(start_idx, start_idx + n_roll + 1))
    idx_list_GT = idx_list
    goal_idx = idx_list[-1]
    print("idx_list", idx_list)

    # visualize ground truth rollout
    if True:
        for display_idx, episode_idx in enumerate(idx_list):
            visualize_episode_data_single_timestep(
                vis=vis,
                dataset=dataset,
                episode=episode,
                camera_name=camera_name,
                episode_idx=episode_idx,
                display_idx=episode_idx,
            )

    data_goal = dataset._getitem(episode,
                                 goal_idx,
                                 rollout_length=1,
                                 n_history=1)
    states_goal = data_goal['observations_combined'][0]
    z_states_goal = model_dy.compute_z_state(states_goal)['z']

    print("states_goal.shape", states_goal.shape)
    print("z_states_goal.shape", z_states_goal.shape)

    ##### VISUALIZE PREDICTED ROLLOUT ##########
    data = dataset._getitem(episode, start_idx, rollout_length=n_roll)

    states = data['observations_combined'].unsqueeze(0)
    z = model_dy.compute_z_state(states)['z']
    actions = data['actions'].unsqueeze(0)
    idx_range_model_dy_input = data['idx_range']

    print("data.keys()", data.keys())
    print("data['idx_range']", data['idx_range'])

    # z_init
    z_init = z[:, :n_his]

    # actions_init
    action_start_idx = 0
    action_end_idx = n_his + n_roll - 1
    action_seq = actions[:, action_start_idx:action_end_idx]

    print("action_seq GT\n", action_seq)

    with torch.no_grad():
        rollout_data = rollout_model(state_init=z_init.cuda(),
                                     action_seq=action_seq.cuda(),
                                     dynamics_net=model_dy,
                                     compute_debug_data=False)

    # [B, n_roll, state_dim]
    # state_rollout_pred = rollout_data['state_pred']
    z_rollout_pred = rollout_data['state_pred'].squeeze()
    print("z_rollout_pred.shape", z_rollout_pred.shape)

    if True:
        for idx in range(len(z_rollout_pred)):
            display_idx = data['idx_range'][idx + n_his]
            visualize_model_prediction_single_timestep(
                vis,
                config,
                z_pred=z_rollout_pred[idx],
                display_idx=display_idx)

        print("z_rollout_pred.shape", z_rollout_pred.shape)

    # compute loss when rolled out using GT action sequence
    eval_indices = object_indices
    obs_goal = z_states_goal[object_indices].cuda()
    reward_data = planner_utils.evaluate_model_rollout(
        state_pred=rollout_data['state_pred'],
        obs_goal=obs_goal,
        eval_indices=eval_indices,
        terminal_cost_only=planner_config['mpc']['mppi']['terminal_cost_only'],
        p=planner_config['mpc']['mppi']['cost_norm'])

    print("reward_data using action_seq_GT\n", reward_data['reward'])

    ##### MPC ##########
    data = dataset._getitem(episode,
                            start_idx,
                            rollout_length=0,
                            n_history=config['train']['n_history'])

    state_cur = data['observations_combined'].cuda()
    z_state_cur = model_dy.compute_z_state(state_cur)['z']
    action_his = data['actions'][:(n_his - 1)].cuda()

    print("z_state_cur.shape", state_cur.shape)
    print("action_his.shape", action_his.shape)

    # don't seed with nominal actions just yet
    action_seq_rollout_init = None

    set_seed(SEED)
    mpc_out = planner.trajectory_optimization(
        state_cur=z_state_cur,
        action_his=action_his,
        obs_goal=obs_goal,
        model_dy=model_dy,
        action_seq_rollout_init=action_seq_rollout_init,
        n_look_ahead=n_roll,
        eval_indices=object_indices,
        rollout_best_action_sequence=True,
        verbose=True,
        add_grid_action_samples=True,
    )

    print("\n\n------MPC output-------\n\n")
    print("action_seq:\n", mpc_out['action_seq'])
    mpc_state_pred = mpc_out['state_pred']

    # current shape is [n_roll + 1, state_dim] but really should be
    # [n_roll, state_dim] . . . something is  up
    print("mpc_state_pred.shape", mpc_state_pred.shape)
    print("mpc_out['action_seq'].shape", mpc_out['action_seq'].shape)
    print("n_roll", n_roll)

    # visualize
    for idx in range(n_roll):
        episode_idx = start_idx + idx + 1
        visualize_model_prediction_single_timestep(vis,
                                                   config,
                                                   z_pred=mpc_state_pred[idx],
                                                   display_idx=episode_idx,
                                                   name_prefix="mpc",
                                                   color=[255, 0, 0])

    ######## MPC w/ dynamics model input builder #############
    print("\n\n-----DynamicsModelInputBuilder-----")

    # dynamics model input builder
    online_episode = OnlineEpisodeReader(no_copy=True)

    ref_descriptors = d['spatial_descriptor_data']['spatial_descriptors']
    ref_descriptors = torch_utils.cast_to_torch(ref_descriptors).cuda()
    K_matrix = episode.image_episode.camera_K_matrix(camera_name)
    T_world_camera = episode.image_episode.camera_pose(camera_name, 0)
    visual_observation_function = \
        VisualObservationFunctionFactory.descriptor_keypoints_3D(config=config,
                                                                 camera_name=camera_name,
                                                                 model_dd=d['model_dd'],
                                                                 ref_descriptors=ref_descriptors,
                                                                 K_matrix=K_matrix,
                                                                 T_world_camera=T_world_camera,
                                                                 )

    input_builder = DynamicsModelInputBuilder(
        observation_function=d['observation_function'],
        visual_observation_function=visual_observation_function,
        action_function=d['action_function'],
        episode=online_episode)

    compute_control_action_msg = dict()
    compute_control_action_msg['type'] = "COMPUTE_CONTROL_ACTION"

    for i in range(n_his):
        episode_idx = idx_range_model_dy_input[i]
        print("episode_idx", episode_idx)

        # add image information to
        data = add_images_to_episode_data(episode, episode_idx, camera_name)

        online_episode.add_data(copy.deepcopy(data))
        compute_control_action_msg['data'] = data

        # hack for seeing how much the history matters .. .
        # online_episode.add_data(copy.deepcopy(data))

    # save informatin for running zmq controller
    save_pickle(compute_control_action_msg,
                os.path.join(save_dir, 'compute_control_action_msg.p'))
    goal_idx = idx_list_GT[-1]
    goal_data = add_images_to_episode_data(episode, goal_idx, camera_name)
    goal_data['observations']['timestamp_system'] = time.time()
    plan_msg = {
        'type': "PLAN",
        'data': [goal_data],
        'n_roll': n_roll,
        'K_matrix': K_matrix,
        'T_world_camera': T_world_camera,
    }
    save_pickle(plan_msg, os.path.join(save_dir, "plan_msg.p"))

    print("len(online_episode)", len(online_episode))

    # use this to construct input
    # verify it's the same as what we got from using the dataset directly
    idx = online_episode.get_latest_idx()
    mpc_input_data = input_builder.get_dynamics_model_input(idx,
                                                            n_history=n_his)

    # print("mpc_input_data\n", mpc_input_data)
    state_cur_ib = mpc_input_data['states'].cuda()
    action_his_ib = mpc_input_data['actions'].cuda()

    z_state_cur_ib = model_dy.compute_z_state(state_cur_ib)['z']

    set_seed(SEED)
    mpc_out = planner.trajectory_optimization(
        state_cur=z_state_cur_ib,
        action_his=action_his_ib,
        obs_goal=obs_goal,
        model_dy=model_dy,
        action_seq_rollout_init=None,
        n_look_ahead=n_roll,
        eval_indices=object_indices,
        rollout_best_action_sequence=True,
        verbose=True,
        add_grid_action_samples=True,
    )

    # visualize
    for idx in range(n_roll):
        episode_idx = start_idx + idx + 1
        visualize_model_prediction_single_timestep(
            vis,
            config,
            z_pred=mpc_out['state_pred'][idx],
            display_idx=episode_idx,
            name_prefix="mpc_input_builder",
            color=[255, 255, 0])
示例#8
0
def compute_descriptor_confidences(multi_episode_dict,
                                   model,
                                   output_dir,  # str
                                   batch_size=10,
                                   num_workers=10,
                                   model_file=None,
                                   ref_descriptors=None,
                                   episode_name_arg=None,
                                   episode_idx=None,
                                   camera_name=None,
                                   num_ref_descriptors=None,
                                   localization_type="spatial_expectation",  # ['spatial_expectation', 'argmax']
                                   num_batches=None,
                                   ):
    """
    Computes confidence scores for different reference descriptors.
    Saves two files

    metadata.p: has information about reference descriptors, etc.
    data.p: descriptor confidence scores
    """

    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)

    start_time = time.time()

    log_freq = 10

    device = next(model.parameters()).device
    model.eval()  # make sure model is in eval mode

    # build all the dataset
    config = None
    dataset = ImageDataset(config, multi_episode_dict, phase="all")
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=True)

    # sample ref descriptors
    if ref_descriptors is None:
        if episode_name_arg is None:
            episode_list = list(multi_episode_dict.keys())
            episode_list.sort()
            episode_name_arg = episode_list[0]

        if episode_idx is None:
            episode_idx = 0

        # can't be left blank
        assert camera_name is not None
        episode = multi_episode_dict[episode_name_arg]
        data = dataset._getitem(episode,
                                episode_idx,
                                camera_name)

        rgb_tensor = data['rgb_tensor'].unsqueeze(0).to(device)

        # compute descriptor image from model
        with torch.no_grad():
            out = model.forward(rgb_tensor)
            des_img = out['descriptor_image']

            # get image mask and cast it to a tensor on the appropriate
            # device
            img_mask = episode.get_image(camera_name,
                                         episode_idx,
                                         type="mask")
            img_mask = torch.tensor(img_mask).to(des_img.device)

            ref_descriptors_dict = sample_descriptors(des_img.squeeze(), img_mask, num_ref_descriptors)

            ref_descriptors = ref_descriptors_dict['descriptors']
            ref_indices = ref_descriptors_dict['indices']

            print("ref_descriptors_dict\n", ref_descriptors_dict)

    # save metadata in dict
    metadata = {'model_file': model_file,
                'ref_descriptors': ref_descriptors.cpu().numpy(),  # [N, D]
                'indices': ref_indices.cpu().numpy(),
                'episode_name': episode_name_arg,
                'episode_idx': episode_idx,
                'camera_name': camera_name}

    metadata_file = os.path.join(output_dir, 'metadata.p')
    save_pickle(metadata, metadata_file)

    scores = dict()
    heatmap_value_list = []

    for i, data in enumerate(dataloader):
        if num_batches is not None and i > num_batches:
            break
        rgb_tensor = data['rgb_tensor'].to(device)
        key_tree_joined = data['key_tree_joined']

        if (i % log_freq) == 0:
            log_msg = "computing %d" % (i)
            print(log_msg)

        # don't use gradients
        with torch.no_grad():
            out = model.forward(rgb_tensor)

            # [B, D, H, W]
            des_img = out['descriptor_image']

            B, _, H, W = rgb_tensor.shape

            heatmap_values = None
            if localization_type == "spatial_expectation":
                sigma_descriptor_heatmap = 5  # default
                try:
                    sigma_descriptor_heatmap = model.config['network']['sigma_descriptor_heatmap']
                except:
                    pass

                # print("ref_descriptors.shape", ref_descriptors.shape)
                # print("des_img.shape", des_img.shape)
                d = get_spatial_expectation(ref_descriptors,
                                            des_img,
                                            sigma=sigma_descriptor_heatmap,
                                            type='exp',
                                            compute_heatmap_values=True,
                                            return_heatmap=True,
                                            )

                # [B, K]
                heatmap_values = d['heatmap_values']
            else:
                raise ValueError("unknown localization type: %s" % (localization_type))

            heatmap_value_list.append(heatmap_values)

    heatmap_values_tensor = torch.cat(heatmap_value_list)
    heatmap_values_np = heatmap_values_tensor.cpu().numpy()
    save_data = {'heatmap_values': heatmap_values_np}

    data_file = os.path.join(output_dir, 'data.p')
    save_pickle(save_data, data_file)
    print("total time to compute descriptors: %.3f seconds" % (time.time() - start_time))
示例#9
0
def precompute_descriptor_keypoints(multi_episode_dict,
                                    model,
                                    output_dir,  # str
                                    ref_descriptors_metadata,
                                    batch_size=10,
                                    num_workers=10,
                                    localization_type="spatial_expectation",  # ['spatial_expectation', 'argmax']
                                    compute_3D=True,  # in world frame
                                    camera_names=None,
                                    ):
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)

    start_time = time.time()

    log_freq = 10

    device = next(model.parameters()).device
    model = model.eval()  # make sure model is in eval mode

    # build all the dataset
    datasets = {}
    dataloaders = {}
    for episode_name, episode in iteritems(multi_episode_dict):
        single_episode_dict = {episode_name: episode}
        config = None
        datasets[episode_name] = ImageDataset(config,
                                              single_episode_dict,
                                              phase="all",
                                              camera_names=camera_names)
        dataloaders[episode_name] = DataLoader(datasets[episode_name],
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=False)

    # K = num_ref_descriptors
    metadata = ref_descriptors_metadata
    ref_descriptors = torch.Tensor(metadata['ref_descriptors'])
    ref_descriptors = ref_descriptors.cuda()
    K, _ = ref_descriptors.shape

    metadata_file = os.path.join(output_dir, 'metadata.p')
    save_pickle(metadata, metadata_file)

    episode_counter = 0
    num_episodes = len(multi_episode_dict)

    for episode_name, dataset in iteritems(datasets):
        episode_counter += 1
        print("\n\n")

        episode = multi_episode_dict[episode_name]
        hdf5_file = None
        try:
            hdf5_file = os.path.basename(episode.image_data_file)
        except AttributeError:
            hdf5_file = "%s.h5" % (episode.name)

        hdf5_file_fullpath = os.path.join(output_dir, hdf5_file)

        str_split = hdf5_file_fullpath.split(".")
        assert len(str_split) == 2
        pickle_file_fullpath = str_split[0] + ".p"

        # print("hdf5_file_fullpath", hdf5_file_fullpath)
        # print("pickle_file_fullpath", pickle_file_fullpath)

        if os.path.isfile(hdf5_file_fullpath):
            os.remove(hdf5_file_fullpath)

        if os.path.isfile(pickle_file_fullpath):
            os.remove(pickle_file_fullpath)

        dataloader = dataloaders[episode_name]


        episode_keypoint_data = dict()

        episode_start_time = time.time()
        with h5py.File(hdf5_file_fullpath, 'w') as hf:
            for i, data in enumerate(dataloaders[episode_name]):
                rgb_tensor = data['rgb_tensor'].to(device)
                key_tree_joined = data['key_tree_joined']

                if (i % log_freq) == 0:
                    log_msg = "computing [%d/%d][%d/%d]" % (episode_counter, num_episodes, i + 1, len(dataloader))
                    print(log_msg)

                # don't use gradients
                tmp_time = time.time()
                with torch.no_grad():
                    out = model.forward(rgb_tensor)

                    # [B, D, H, W]
                    des_img = out['descriptor_image']

                    B, _, H, W = rgb_tensor.shape

                    # [B, N, 2]
                    batch_indices = None
                    preds_3d = None
                    if localization_type == "spatial_expectation":
                        sigma_descriptor_heatmap = 5  # default
                        try:
                            sigma_descriptor_heatmap = model.config['network']['sigma_descriptor_heatmap']
                        except:
                            pass

                        # print("ref_descriptors.shape", ref_descriptors.shape)
                        # print("des_img.shape", des_img.shape)
                        d = get_spatial_expectation(ref_descriptors,
                                                    des_img,
                                                    sigma=sigma_descriptor_heatmap,
                                                    type='exp',
                                                    return_heatmap=True,
                                                    )

                        batch_indices = d['uv']

                        # [B, K, H, W]
                        if compute_3D:
                            # [B*K, H, W]
                            heatmaps_no_batch = d['heatmap_no_batch']

                            # [B, H, W]
                            depth = data['depth_int16'].to(device)

                            # expand depth images and convert to meters, instead of mm
                            # [B, K, H, W]
                            depth_expand = depth.unsqueeze(1).expand([B, K, H, W]).reshape([B * K, H, W])
                            depth_expand = depth_expand.type(torch.FloatTensor) / constants.DEPTH_IM_SCALE
                            depth_expand = depth_expand.to(heatmaps_no_batch.device)

                            pred_3d = get_integral_preds_3d(heatmaps_no_batch,
                                                            depth_images=depth_expand,
                                                            compute_uv=True)

                            pred_3d['uv'] = pred_3d['uv'].reshape([B, K, 2])
                            pred_3d['xy'] = pred_3d['xy'].reshape([B, K, 2])
                            pred_3d['z'] = pred_3d['z'].reshape([B, K])


                    elif localization_type == "argmax":
                        # localize descriptors
                        best_match_dict = get_argmax_l2(ref_descriptors,
                                                        des_img)

                        # [B, N, 2]
                        # where N is num_ref_descriptors
                        batch_indices = best_match_dict['indices']
                    else:
                        raise ValueError("unknown localization type: %s" % (localization_type))

                    print("computing keypoints took", time.time() - tmp_time)

                    tmp_time = time.time()
                    # iterate over elements in the batch
                    for j in range(B):
                        keypoint_data = {} # dict that stores information to save out

                        # [N,2]
                        # indices = batch_indices[j].cpu().numpy()
                        # key = key_tree_joined[j] + "/descriptor_keypoints"


                        # hf.create_dataset(key, data=indices)
                        # keypoint_indices_dict[key] = indices

                        # stored 3D keypoint locations (in both camera and world frame)
                        if pred_3d is not None:


                            # key_3d_W = key_tree_joined[j] + "/descriptor_keypoints_3d_world_frame"
                            # key_3d_C = key_tree_joined[j] + "/descriptor_keypoints_3d_camera_frame"

                            # T_W_C = data['T_world_camera'][j].cpu().numpy()
                            # K_matrix = data['K'][j].cpu().numpy()

                            T_W_C = torch_utils.cast_to_numpy(data['T_world_camera'][j])
                            K_matrix = torch_utils.cast_to_numpy(data['K'][j])

                            uv = torch_utils.cast_to_numpy(pred_3d['uv'][j])
                            xy = torch_utils.cast_to_numpy(pred_3d['xy'][j])
                            z = torch_utils.cast_to_numpy(pred_3d['z'][j])

                            # [K, 3]
                            # this is in camera frame
                            pts_3d_C = pdc_utils.pinhole_unprojection(uv, z, K_matrix)
                            # hf.create_dataset(key_3d_C, data=pts_3d_C)
                            # keypoint_indices_dict[key_3d_C] = pts_3d_C

                            # project into world frame
                            pts_3d_W = transform_utils.transform_points_3D(transform=T_W_C,
                                                                           points=pts_3d_C)

                            # hf.create_dataset(key_3d_W, data=pts_3d_W)
                            # keypoint_indices_dict[key_3d_W] = pts_3d_W

                            keypoint_data['xy'] = torch_utils.cast_to_numpy(xy)
                            keypoint_data['uv'] = torch_utils.cast_to_numpy(uv)
                            keypoint_data['z'] = torch_utils.cast_to_numpy(z)
                            keypoint_data['pos_world_frame'] = torch_utils.cast_to_numpy(pts_3d_W)
                            keypoint_data['pos_camera_frame'] = torch_utils.cast_to_numpy(pts_3d_C)

                        # save out some data in both hdf5 and pickle format
                        for key, val in keypoint_data.items():
                            save_key = key_tree_joined[j] + "/descriptor_keypoints/%s" % (key)
                            hf.create_dataset(save_key, data=val)
                            episode_keypoint_data[save_key] = val


                    print("saving to disk took", time.time() - tmp_time)

        # save_pickle(keypoint_indices_dict, pickle_file_fullpath)
        save_pickle(episode_keypoint_data, pickle_file_fullpath)
        print("duration: %.3f seconds" % (time.time() - episode_start_time))

    print("total time to compute descriptors: %.3f seconds" % (time.time() - start_time))