Exemplo n.º 1
0
    def forward(self, input_dict, target):
        initial_position = input_dict['initial_position']
        initial_rotation = input_dict['initial_rotation']
        rgb = input_dict['rgb']
        batch_size, seq_len, c, w, h = rgb.shape
        object_name = input_dict['object_name']
        assert len(object_name) == 1  # only support one object
        object_name = object_name[0]

        image_features = self.resnet_features(rgb)

        # Contact point prediction tower
        image_features_contact_point = self.contact_point_image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)
        initial_object_features_contact_point = \
            self.contact_point_input_object_embed(torch.cat([initial_position, initial_rotation], dim=-1))
        object_features_contact_point = initial_object_features_contact_point.unsqueeze(1).\
            repeat(1, self.sequence_length, 1)  # add a dimension for sequence length and then repeat that

        # Predict contact point
        input_embedded_contact_point = torch.cat(
            [image_features_contact_point, object_features_contact_point],
            dim=-1)
        embedded_sequence_contact_point, (
            _, _) = self.contact_point_encoder(input_embedded_contact_point)
        contact_points_prediction = self.contact_point_decoder(
            embedded_sequence_contact_point).view(batch_size, seq_len,
                                                  self.number_of_cp,
                                                  3)[:, -1, :, :]
        # Predict contact point for each image and get the last prediction

        # Force prediction tower
        image_features_force = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)
        initial_object_features_force = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features_force = initial_object_features_force.unsqueeze(
            1).repeat(1, self.sequence_length, 1)
        # add a dimension for sequence length and then repeat that

        input_embedded_force = torch.cat(
            [image_features_force, object_features_force], dim=-1)
        embedded_sequence_force, (
            hidden_force, cell_force) = self.lstm_encoder(input_embedded_force)

        last_hidden = hidden_force.view(self.num_layers, 1, 1,
                                        self.hidden_size)[-1, -1, :, :]
        # num_layers, num direction, batchsize, hidden and then take the last hidden layer
        last_cell = cell_force.view(self.num_layers, 1, 1,
                                    self.hidden_size)[-1, -1, :, :]
        # num_layers, num direction, batchsize, cell and then take the last hidden layer

        hn = last_hidden
        cn = last_cell

        resulting_force_success = []
        forces_directions = []
        all_force_applied = []

        env_state = EnvState(object_name=object_name,
                             rotation=initial_rotation[0],
                             position=initial_position[0],
                             velocity=None,
                             omega=None)
        resulting_position = []
        resulting_rotation = []

        for seq_ind in range(self.sequence_length - 1):
            prev_location = env_state.toTensorCoverName().unsqueeze(0)
            contact_point_as_input = contact_points_prediction.view(1, 3 * 5)

            prev_state_and_cp = torch.cat(
                [prev_location, contact_point_as_input], dim=-1)
            prev_state_embedded = self.state_embed(prev_state_and_cp)
            next_state_embedded = embedded_sequence_force[:, seq_ind + 1]
            input_lstm_cell = torch.cat(
                [prev_state_embedded, next_state_embedded], dim=-1)

            (hn, cn) = self.lstm_decoder(input_lstm_cell, (hn, cn))
            force = self.forces_directions_decoder(hn)
            assert force.shape[0] == 1
            force = force.squeeze(0)
            assert force.shape[0] == (self.number_of_cp * 3)

            # initial initial_velocity is whatever it was the last frame,
            # note that the gradients are not backproped here
            env_state, force_success, force_applied = \
                self.environment_layer.apply(self.environment, env_state.toTensor(), force,
                                             contact_point_as_input.squeeze(0))

            env_state = EnvState.fromTensor(env_state)

            resulting_position.append(env_state.position)
            resulting_rotation.append(env_state.rotation)
            resulting_force_success.append(force_success)
            forces_directions.append(force.view(self.number_of_cp, 3))
            all_force_applied.append(force_applied)

        resulting_position = torch.stack(resulting_position, dim=0)
        resulting_rotation = torch.stack(resulting_rotation, dim=0)
        resulting_force_success = torch.stack(resulting_force_success, dim=0)
        forces_directions = torch.stack(forces_directions, dim=0)
        all_force_applied = torch.stack(all_force_applied, dim=0)

        resulting_position = resulting_position.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        resulting_rotation = resulting_rotation.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        forces_directions = forces_directions.unsqueeze(0)
        resulting_force_success = resulting_force_success.unsqueeze(0)
        all_force_applied = all_force_applied.unsqueeze(0)

        all_keypoints = get_keypoint_projection(
            object_name, resulting_position, resulting_rotation,
            self.all_objects_keypoint_tensor[object_name])
        all_keypoints = all_keypoints.unsqueeze(
            0)  # adding batchsize back because we need it in the loss

        contact_points_prediction = contact_points_prediction.unsqueeze(
            1).repeat(1, seq_len, 1, 1)

        output_dict = {
            'keypoints': all_keypoints,
            'rotation': resulting_rotation,
            'position': resulting_position,
            'force_success_flag': resulting_force_success,
            'force_applied': all_force_applied,
            'force_direction':
            forces_directions,  # batch size x seq len -1 x number of cp x 3
            'contact_points': contact_points_prediction,
        }

        target['object_name'] = input_dict['object_name']

        return output_dict, target
Exemplo n.º 2
0
    def forward(self, input, target):

        initial_position = input['initial_position']
        initial_rotation = input['initial_rotation']
        contact_points = input['contact_points']
        assert contact_points.shape[0] == 1
        contact_points = contact_points.squeeze(0)

        object_name = input['object_name']
        assert len(object_name) == 1
        object_name = object_name[0]
        target['object_name'] = input['object_name']

        predefined_force = target['forces'].squeeze(0).view(self.sequence_length - 1, 5 * 3)
        all_forces = torch.nn.Parameter(predefined_force.detach())

        loss_function = self.this_loss_func

        if self.gpu_ids != -1:
            all_forces = all_forces.cuda().detach()
            loss_function = loss_function.cuda()
        optimizer = torch.optim.SGD([all_forces], lr=self.base_lr)

        number_of_states = 20

        for t in range(number_of_states):

            # all_forces = all_forces.clamp(-1.5, 1.5)

            if t <= self.step_size:
                lr = self.base_lr
            elif t <= self.step_size * 2:
                lr = self.base_lr * 0.1
            elif t <= self.step_size * 3:
                lr = self.base_lr * 0.01

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            env_state = EnvState(object_name=object_name, rotation=initial_rotation[0], position=initial_position[0], velocity=None, omega=None)
            resulting_force_success = []
            forces_directions = []
            all_force_applied = []
            resulting_position = []
            resulting_rotation = []

            for seq_ind in range(self.sequence_length - 1):
                force = all_forces[seq_ind]

                assert force.shape[0] == (self.number_of_cp * 3)

                # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here
                env_state, force_success, force_applied = self.environment_layer.apply(self.environment, env_state.toTensor(), force, contact_points)
                env_state = EnvState.fromTensor(env_state)

                resulting_position.append(env_state.position)
                resulting_rotation.append(env_state.rotation)
                resulting_force_success.append(force_success)
                forces_directions.append(force.view(self.number_of_cp, 3))
                all_force_applied.append(force_applied)

            resulting_position = torch.stack(resulting_position, dim=0)
            resulting_rotation = torch.stack(resulting_rotation, dim=0)

            resulting_position = resulting_position.unsqueeze(0)  # adding batchsize back because we need it in the loss
            resulting_rotation = resulting_rotation.unsqueeze(0)  # adding batchsize back because we need it in the loss

            all_keypoints = get_keypoint_projection(object_name, resulting_position, resulting_rotation, self.all_objects_keypoint_tensor[object_name])
            all_keypoints = all_keypoints.unsqueeze(0)  # adding batchsize back because we need it in the loss

            output = {
                'keypoints': all_keypoints,
            }

            loss = loss_function(output, target)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        resulting_force_success = torch.stack(resulting_force_success, dim=0)
        forces_directions = torch.stack(forces_directions, dim=0)
        all_force_applied = torch.stack(all_force_applied, dim=0)

        forces_directions = forces_directions.unsqueeze(0)
        resulting_force_success = resulting_force_success.unsqueeze(0)
        all_force_applied = all_force_applied.unsqueeze(0)

        all_keypoints = torch.tensor(all_keypoints, requires_grad=True)

        output = {
            'keypoints': all_keypoints,
            'rotation': resulting_rotation,
            'position': resulting_position,
            'force_success_flag': resulting_force_success,
            'force_applied': all_force_applied,
            'force_direction': forces_directions,  # batch size x seq len -1 x number of cp x 3
        }

        return output, target
    def forward(self, input, target):
        object_name = input['object_name']
        assert len(object_name) == 1
        object_name = object_name[0]

        initial_keypoint = input['initial_keypoint']
        initial_keypoint = initial_keypoint / DEFAULT_IMAGE_SIZE
        initial_keypoint = initial_keypoint.view(1, 2 * 10)

        object_points = self.all_objects_keypoint_tensor[object_name].view(
            1, 3 * 10)

        rgb = input['rgb']
        contact_points = input['contact_points']
        assert contact_points.shape[0] == 1
        contact_points = contact_points.squeeze(0)

        contact_point_as_input = input['contact_points'].view(
            1, 3 * 5)  # Batchsize , 3 * number of contact points
        batch_size, seq_len, c, w, h = rgb.shape

        image_features = self.resnet_features(rgb)
        image_features = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)

        initial_state = self.predict_initial_pose(
            torch.cat([initial_keypoint, object_points],
                      dim=-1)).view(1, 3 + 4)
        initial_position = initial_state[:, :3]
        initial_rotation = initial_state[:, 3:]
        initial_rotation = F.normalize(initial_rotation, dim=-1)

        # Just for visualization purposes
        input['initial_rotation'] = initial_rotation
        input['initial_position'] = initial_position

        initial_object_features = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features = initial_object_features.unsqueeze(1).repeat(
            1, self.sequence_length,
            1)  # add a dimension for sequence length and then repeat that

        input_embedded = torch.cat([image_features, object_features], dim=-1)
        embedded_sequence, (hidden, cell) = self.lstm_encoder(input_embedded)

        last_hidden = hidden.view(
            self.num_layers, 1, 1, self.hidden_size
        )[-1,
          -1, :, :]  # num_layers, num direction, batchsize, hidden and then take the last hidden layer
        last_cell = cell.view(
            self.num_layers, 1, 1, self.hidden_size
        )[-1,
          -1, :, :]  # num_layers, num direction, batchsize, cell and then take the last hidden layer

        hn = last_hidden
        cn = last_cell

        resulting_force_success = []
        forces_directions = []
        all_force_applied = []

        env_state = EnvState(object_name=object_name,
                             rotation=initial_rotation[0],
                             position=initial_position[0],
                             velocity=None,
                             omega=None)
        resulting_position = []
        resulting_rotation = []

        for seq_ind in range(self.sequence_length - 1):
            prev_location = env_state.toTensorCoverName().unsqueeze(0)

            prev_state_and_cp = torch.cat(
                [prev_location, contact_point_as_input], dim=-1)
            prev_state_embedded = self.state_embed(prev_state_and_cp)
            next_state_embedded = embedded_sequence[:, seq_ind + 1]
            input_lstm_cell = torch.cat(
                [prev_state_embedded, next_state_embedded], dim=-1)

            (hn, cn) = self.lstm_decoder(input_lstm_cell, (hn, cn))
            force = self.forces_directions_decoder(hn)
            assert force.shape[0] == 1
            force = force.squeeze(0)
            assert force.shape[0] == (self.number_of_cp * 3)

            # Cap forces
            # think about this more
            force = force.clamp(-1., 1.)

            # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here
            env_state, force_success, force_applied = self.environment_layer.apply(
                self.environment, env_state.toTensor(), force, contact_points)
            env_state = EnvState.fromTensor(env_state)

            resulting_position.append(env_state.position)
            resulting_rotation.append(env_state.rotation)
            resulting_force_success.append(force_success)
            forces_directions.append(force.view(self.number_of_cp, 3))
            all_force_applied.append(force_applied)

        resulting_position = torch.stack(resulting_position, dim=0)
        resulting_rotation = torch.stack(resulting_rotation, dim=0)
        resulting_force_success = torch.stack(resulting_force_success, dim=0)
        forces_directions = torch.stack(forces_directions, dim=0)
        all_force_applied = torch.stack(all_force_applied, dim=0)

        resulting_position = resulting_position.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        resulting_rotation = resulting_rotation.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        forces_directions = forces_directions.unsqueeze(0)
        resulting_force_success = resulting_force_success.unsqueeze(0)
        all_force_applied = all_force_applied.unsqueeze(0)

        all_keypoints = get_keypoint_projection(
            object_name, resulting_position, resulting_rotation,
            self.all_objects_keypoint_tensor[object_name])
        all_keypoints = all_keypoints.unsqueeze(
            0)  # adding batchsize back because we need it in the loss

        output = {
            'keypoints': all_keypoints,
            'rotation': resulting_rotation,
            'position': resulting_position,
            'force_success_flag': resulting_force_success,
            'force_applied': all_force_applied,
            'force_direction':
            forces_directions,  # batch size x seq len -1 x number of cp x 3
        }

        target['object_name'] = input['object_name']

        return output, target
Exemplo n.º 4
0
    def forward(self, input, target):

        initial_position = input['initial_position']
        initial_rotation = input['initial_rotation']
        rgb = input['rgb']
        batch_size, seq_len, c, w, h = rgb.shape
        contact_point_as_input = input['contact_points'].view(
            batch_size, 5 * 3)

        image_features = self.resnet_features(rgb)
        image_features = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)

        initial_object_features = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features = initial_object_features.unsqueeze(1).repeat(
            1, self.sequence_length,
            1)  # add a dimension for sequence length and then repeat that

        input_embedded = torch.cat([image_features, object_features], dim=-1)
        embedded_sequence, (hidden, cell) = self.lstm_encoder(input_embedded)

        contact_point_embedding = self.contact_point_embed(
            contact_point_as_input).unsqueeze(1).repeat(1, seq_len, 1)
        combined_w_cp = torch.cat([embedded_sequence, contact_point_embedding],
                                  dim=-1)
        forces_prediction = self.force_decoder(combined_w_cp).view(
            batch_size, seq_len, self.number_of_cp,
            3)  # Predict contact point for each image
        forces_prediction = forces_prediction[:, 1:, :, :]

        forces_prediction = forces_prediction.clamp(-1.5, 1.5)

        output = {
            'forces':
            forces_prediction,  # batchsize x seq len x number of cp x 3
        }
        target['object_name'] = input['object_name']

        # forces_prediction[:] = 0.
        # output['forces'][:] = 0.

        #  remove
        # forces_prediction[:] = target['forces']

        if not self.train_mode:
            object_name = input['object_name']
            assert len(object_name) == 1
            object_name = object_name[0]

            contact_points = input['contact_points']
            assert contact_points.shape[0] == 1
            contact_points = contact_points.squeeze(0)
            resulting_position = []
            resulting_rotation = []

            env_state = EnvState(object_name=object_name,
                                 rotation=initial_rotation[0],
                                 position=initial_position[0],
                                 velocity=None,
                                 omega=None)
            for seq_ind in range(self.sequence_length - 1):
                force = forces_prediction[0,
                                          seq_ind].view(self.number_of_cp * 3)
                # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here
                env_state, force_success, force_applied = self.environment_layer.apply(
                    self.environment, env_state.toTensor(), force,
                    contact_points)
                env_state = EnvState.fromTensor(env_state)

                resulting_position.append(env_state.position)
                resulting_rotation.append(env_state.rotation)
            resulting_position = torch.stack(resulting_position, dim=0)
            resulting_rotation = torch.stack(resulting_rotation, dim=0)
            resulting_position = resulting_position.unsqueeze(
                0)  # adding batchsize back because we need it in the loss
            resulting_rotation = resulting_rotation.unsqueeze(
                0)  # adding batchsize back because we need it in the loss

            all_keypoints = get_keypoint_projection(
                object_name, resulting_position, resulting_rotation,
                self.all_objects_keypoint_tensor[object_name])
            all_keypoints = all_keypoints.unsqueeze(
                0)  # adding batchsize back because we need it in the loss
            output['keypoints'] = all_keypoints
            output['rotation'] = resulting_rotation
            output['position'] = resulting_position

            output['force_applied'] = output['forces']
            output['force_direction'] = output['forces']

        return output, target
Exemplo n.º 5
0
    def forward(ctx, environment, initial_state, forces_tensor,
                list_of_contact_points):
        assert len(forces_tensor.shape) == 1, 'The force should be flattened'
        initial_state = EnvState.fromTensor(initial_state)
        forces = ForceValOnly.fromForceArray(
            forces_tensor.view(environment.number_of_cp, 3))
        list_of_contact_points = list_of_contact_points.view(5, 3)

        finite_diff_success_flags = []

        with torch.no_grad():
            env_state, force_success_flag, force_that_applied = environment.init_location_and_apply_force(
                initial_state=initial_state,
                forces=forces,
                list_of_contact_points=list_of_contact_points)
            f_x_env_state_tensor = env_state.toTensor()

        no_grad_on = not (True in ctx.needs_input_grad)

        if not no_grad_on:  # at least one of them needs gradients
            with torch.no_grad():

                cp_h = 0.05
                # position and linear velocity should not depend, force similar, just angular
                manual_tweaks_cp = []
                contact_point_tensor = list_of_contact_points.view(-1)
                for arg_ind in range(len(contact_point_tensor)):

                    changed_contact_point = contact_point_tensor * 1  # Just copy
                    tweak_value_cp = cp_h
                    if random.random() > 0.5:
                        tweak_value_cp *= -1

                    changed_contact_point[arg_ind] += tweak_value_cp

                    changed_contact_point = changed_contact_point.view(5, 3)

                    env_state_h, _, _ = environment.init_location_and_apply_force(
                        initial_state=initial_state,
                        forces=forces,
                        list_of_contact_points=changed_contact_point)

                    manual_tweaks_cp.append(
                        (env_state_h.toTensor() - f_x_env_state_tensor) /
                        (tweak_value_cp))

                ctx.contact_pt_x_plus_h_diff = torch.stack(manual_tweaks_cp,
                                                           dim=-2)

                force_h = environment.force_h
                manual_tweaks_force = []
                for arg_ind in range(len(forces_tensor)):

                    changed_force = forces_tensor * 1  # Just copy
                    tweak_value_force = force_h
                    if random.random() > 0.5:
                        tweak_value_force *= -1

                    changed_force[arg_ind] += tweak_value_force

                    changed_force = ForceValOnly.fromForceArray(
                        changed_force.view(environment.number_of_cp, -1))

                    env_state_h, force_success_flag_h, force_that_applied_h = environment.init_location_and_apply_force(
                        initial_state=initial_state,
                        forces=changed_force,
                        list_of_contact_points=list_of_contact_points)

                    manual_tweaks_force.append(
                        (env_state_h.toTensor() - f_x_env_state_tensor) /
                        (tweak_value_force))

                    finite_diff_success_flags.append(force_success_flag_h)

                ctx.f_x_plus_h_diff = torch.stack(manual_tweaks_force, dim=-2)

                state = initial_state.toTensor()
                manual_tweaks_initial_state = []
                state_h = environment.state_h
                for arg_ind in range(len(state)):
                    if arg_ind == EnvState.OBJECT_TYPE_INDEX:
                        manual_tweaks_initial_state.append(
                            torch.zeros([len(state)]))
                        continue

                    changed_state = state * 1  # just to copy

                    tweak_value_state = state_h
                    if random.random() >= 0.5:
                        tweak_value_state *= -1

                    changed_state[arg_ind] += tweak_value_state

                    env_state_h, force_success_flag_h, force_that_applied_h = environment.init_location_and_apply_force(
                        initial_state=EnvState.fromTensor(changed_state),
                        forces=forces,
                        list_of_contact_points=list_of_contact_points)

                    manual_tweaks_initial_state.append(
                        (env_state_h.toTensor() - f_x_env_state_tensor) /
                        (tweak_value_state))

                    finite_diff_success_flags.append(force_success_flag_h)
                ctx.initial_state_plus_h_diff = torch.stack(
                    manual_tweaks_initial_state, dim=-2)

        finite_diff_success_flags += [force_success_flag]
        finite_diff_success_flags = torch.Tensor(finite_diff_success_flags)
        force_that_applied = torch.stack(force_that_applied, dim=0)

        env_state_tensor = env_state.toTensor()
        device = forces_tensor.device
        env_state_tensor = env_state_tensor.to(device=device)
        if not no_grad_on:
            ctx.f_x_plus_h_diff = ctx.f_x_plus_h_diff.to(device=device)
            ctx.initial_state_plus_h_diff = ctx.initial_state_plus_h_diff.to(
                device=device)
            ctx.contact_pt_x_plus_h_diff = ctx.contact_pt_x_plus_h_diff.to(
                device=device)

        return env_state_tensor, finite_diff_success_flags, force_that_applied
Exemplo n.º 6
0
    def forward(ctx, environment, initial_state, forces_tensor,
                list_of_contact_points):
        assert len(forces_tensor.shape) == 1, 'The force should be flattened'
        initial_state = EnvState.fromTensor(initial_state)
        forces = ForceValOnly.fromForceArray(
            forces_tensor.view(environment.number_of_cp, 3))
        list_of_contact_points = list_of_contact_points.view(5, 3)
        # print("With batch Input:", initial_state, forces_tensor, list_of_contact_points)

        batch_phy_input = [{
            'forces': [force.tolist() for force in forces],
            'initial_state':
            initial_state.to_dict(),
            'object_num':
            None,
            'list_of_contact_points':
            list_of_contact_points.cpu().tolist()
        }]
        batch_tweak_values = [None]
        no_grad_on = not (True in ctx.needs_input_grad)

        with torch.no_grad(
        ):  # manually compute the gradients via finite difference
            cp_h = 0.05
            contact_point_tensor = list_of_contact_points.view(-1)

            # to compute dS/dC
            for arg_ind in range(len(contact_point_tensor)):
                changed_contact_point = contact_point_tensor * 1  # Just copy
                tweak_value_cp = cp_h
                if random.random() > 0.5:
                    tweak_value_cp *= -1

                changed_contact_point[arg_ind] += tweak_value_cp

                changed_contact_point = changed_contact_point.view(5, 3)
                batch_tweak_values.append(tweak_value_cp)
                batch_phy_input.append({
                    'forces': [force.tolist() for force in forces],
                    'initial_state':
                    initial_state.to_dict(),
                    'object_num':
                    None,
                    'list_of_contact_points':
                    changed_contact_point.cpu().tolist()
                })

            # to compute dS/dF
            force_h = environment.force_h
            for arg_ind in range(len(forces_tensor)):

                changed_force = forces_tensor * 1  # Just copy
                tweak_value_force = force_h
                if random.random() > 0.5:
                    tweak_value_force *= -1

                changed_force[arg_ind] += tweak_value_force

                changed_force = ForceValOnly.fromForceArray(
                    changed_force.view(environment.number_of_cp, -1))
                batch_tweak_values.append(tweak_value_force)
                batch_phy_input.append({
                    'forces': [force.tolist() for force in changed_force],
                    'initial_state':
                    initial_state.to_dict(),
                    'object_num':
                    None,
                    'list_of_contact_points':
                    list_of_contact_points.cpu().tolist()
                })

            # to compute dS_next/dS_current
            state = initial_state.toTensor()
            state_h = environment.state_h
            assert EnvState.OBJECT_TYPE_INDEX == len(
                state) - 1, "we asssume the last idx is object type"
            for arg_ind in range(len(state) - 1):
                changed_state = state * 1  # just to copy

                tweak_value_state = state_h
                if random.random() >= 0.5:
                    tweak_value_state *= -1

                changed_state[arg_ind] += tweak_value_state
                batch_tweak_values.append(tweak_value_state)
                batch_phy_input.append({
                    'forces': [force.tolist() for force in forces],
                    'initial_state':
                    EnvState.fromTensor(changed_state).to_dict(),
                    'object_num':
                    None,
                    'list_of_contact_points':
                    list_of_contact_points.cpu().tolist()
                })

            batch_data = environment.batch_init_locations_and_apply_force(
                batch_data=batch_phy_input)
            batch_state_tensors = [
                build_env_state_from_dict(one_d['state']).toTensor()
                for one_d in batch_data
            ]
            batch_succ_flags = [one_d['succ'] for one_d in batch_data]
            batch_force_locations = [one_d['loc'] for one_d in batch_data]
            expected_length = 1 + len(contact_point_tensor) + len(
                forces_tensor) + len(state) - 1
            assert len(batch_state_tensors) == expected_length, \
                "Result tensor dimension is unexpected, expected %d, but get %d" % (expected_length,
                                                                                    len(batch_state_tensors))
            f_x_env_state_tensor = batch_state_tensors[0]
            indexes = [(1, 1 + len(contact_point_tensor)),
                       (1 + len(contact_point_tensor),
                        1 + len(contact_point_tensor) + len(forces_tensor)),
                       (1 + len(contact_point_tensor) + len(forces_tensor),
                        1 + len(contact_point_tensor) + len(forces_tensor) +
                        len(state) - 1)]
            cp_tweak_state_tensors, force_tweak_state_tensors, next_state_tweak_tensors = \
                [batch_state_tensors[a: b] for a, b in indexes]
            cp_tweak_values, force_tweak_values, next_state_tweak_values = \
                [batch_tweak_values[a: b] for a, b in indexes]
            if no_grad_on:
                finite_diff_success_flags = [batch_succ_flags[0]]
            else:
                finite_diff_success_flags = batch_succ_flags[
                    1 + len(contact_point_tensor):] + [batch_succ_flags[0]]
            cp_fd = [
                (fx_delta - f_x_env_state_tensor) / delta
                for (fx_delta,
                     delta) in zip(cp_tweak_state_tensors, cp_tweak_values)
            ]
            force_fd = [(fx_delta - f_x_env_state_tensor) / delta for (
                fx_delta,
                delta) in zip(force_tweak_state_tensors, force_tweak_values)]
            state_fd = [(fx_delta - f_x_env_state_tensor) / delta for (fx_delta, delta) in zip(next_state_tweak_tensors, next_state_tweak_values)] + \
                       [torch.zeros([len(state)])]
            if not no_grad_on:
                ctx.contact_pt_x_plus_h_diff = torch.stack(cp_fd, dim=-2)
                ctx.f_x_plus_h_diff = torch.stack(force_fd, dim=-2)
                ctx.initial_state_plus_h_diff = torch.stack(state_fd, dim=-2)

        # summarize results
        device = forces_tensor.device
        force_that_applied = torch.Tensor(
            batch_force_locations[0]).to(device=device)
        finite_diff_success_flags = torch.Tensor(finite_diff_success_flags).to(
            device=device)
        f_x_env_state_tensor = f_x_env_state_tensor.to(device=device)
        if not no_grad_on:  # if it requires grads.
            ctx.contact_pt_x_plus_h_diff = ctx.contact_pt_x_plus_h_diff.to(
                device=device)
            ctx.f_x_plus_h_diff = ctx.f_x_plus_h_diff.to(device=device)
            ctx.initial_state_plus_h_diff = ctx.initial_state_plus_h_diff.to(
                device=device)
        # print("With batch output:", f_x_env_state_tensor, force_that_applied)
        return f_x_env_state_tensor, finite_diff_success_flags, force_that_applied