Beispiel #1
0
def project_and_cal_kp_dis(pos1, rot1, pos2, rot2, obj_name, kp_tensors):
    kp1 = get_keypoint_projection(
        object_name=obj_name,
        resulting_positions=pos1.unsqueeze(0).unsqueeze(0),
        resulting_rotations=rot1.unsqueeze(0).unsqueeze(0),
        keypoints=kp_tensors)
    kp2 = get_keypoint_projection(
        object_name=obj_name,
        resulting_positions=pos2.unsqueeze(0).unsqueeze(0),
        resulting_rotations=rot2.unsqueeze(0).unsqueeze(0),
        keypoints=kp_tensors)
    kp_dis = cal_kp_dis(kp1, kp2)
    kp_loss_fn = torch.nn.SmoothL1Loss()
    kp_loss = cal_kp_loss(kp1, kp2, kp_loss_fn)

    return kp_dis, kp_loss
Beispiel #2
0
    def forward(self, input_dict, target):
        initial_position = input_dict['initial_position']
        initial_rotation = input_dict['initial_rotation']
        rgb = input_dict['rgb']
        batch_size, seq_len, c, w, h = rgb.shape
        object_name = input_dict['object_name']
        assert len(object_name) == 1  # only support one object
        object_name = object_name[0]

        image_features = self.resnet_features(rgb)

        # Contact point prediction tower
        image_features_contact_point = self.contact_point_image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)
        initial_object_features_contact_point = \
            self.contact_point_input_object_embed(torch.cat([initial_position, initial_rotation], dim=-1))
        object_features_contact_point = initial_object_features_contact_point.unsqueeze(1).\
            repeat(1, self.sequence_length, 1)  # add a dimension for sequence length and then repeat that

        # Predict contact point
        input_embedded_contact_point = torch.cat(
            [image_features_contact_point, object_features_contact_point],
            dim=-1)
        embedded_sequence_contact_point, (
            _, _) = self.contact_point_encoder(input_embedded_contact_point)
        contact_points_prediction = self.contact_point_decoder(
            embedded_sequence_contact_point).view(batch_size, seq_len,
                                                  self.number_of_cp,
                                                  3)[:, -1, :, :]
        # Predict contact point for each image and get the last prediction

        # Force prediction tower
        image_features_force = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)
        initial_object_features_force = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features_force = initial_object_features_force.unsqueeze(
            1).repeat(1, self.sequence_length, 1)
        # add a dimension for sequence length and then repeat that

        input_embedded_force = torch.cat(
            [image_features_force, object_features_force], dim=-1)
        embedded_sequence_force, (
            hidden_force, cell_force) = self.lstm_encoder(input_embedded_force)

        last_hidden = hidden_force.view(self.num_layers, 1, 1,
                                        self.hidden_size)[-1, -1, :, :]
        # num_layers, num direction, batchsize, hidden and then take the last hidden layer
        last_cell = cell_force.view(self.num_layers, 1, 1,
                                    self.hidden_size)[-1, -1, :, :]
        # num_layers, num direction, batchsize, cell and then take the last hidden layer

        hn = last_hidden
        cn = last_cell

        resulting_force_success = []
        forces_directions = []
        all_force_applied = []

        env_state = EnvState(object_name=object_name,
                             rotation=initial_rotation[0],
                             position=initial_position[0],
                             velocity=None,
                             omega=None)
        resulting_position = []
        resulting_rotation = []

        for seq_ind in range(self.sequence_length - 1):
            prev_location = env_state.toTensorCoverName().unsqueeze(0)
            contact_point_as_input = contact_points_prediction.view(1, 3 * 5)

            prev_state_and_cp = torch.cat(
                [prev_location, contact_point_as_input], dim=-1)
            prev_state_embedded = self.state_embed(prev_state_and_cp)
            next_state_embedded = embedded_sequence_force[:, seq_ind + 1]
            input_lstm_cell = torch.cat(
                [prev_state_embedded, next_state_embedded], dim=-1)

            (hn, cn) = self.lstm_decoder(input_lstm_cell, (hn, cn))
            force = self.forces_directions_decoder(hn)
            assert force.shape[0] == 1
            force = force.squeeze(0)
            assert force.shape[0] == (self.number_of_cp * 3)

            # initial initial_velocity is whatever it was the last frame,
            # note that the gradients are not backproped here
            env_state, force_success, force_applied = \
                self.environment_layer.apply(self.environment, env_state.toTensor(), force,
                                             contact_point_as_input.squeeze(0))

            env_state = EnvState.fromTensor(env_state)

            resulting_position.append(env_state.position)
            resulting_rotation.append(env_state.rotation)
            resulting_force_success.append(force_success)
            forces_directions.append(force.view(self.number_of_cp, 3))
            all_force_applied.append(force_applied)

        resulting_position = torch.stack(resulting_position, dim=0)
        resulting_rotation = torch.stack(resulting_rotation, dim=0)
        resulting_force_success = torch.stack(resulting_force_success, dim=0)
        forces_directions = torch.stack(forces_directions, dim=0)
        all_force_applied = torch.stack(all_force_applied, dim=0)

        resulting_position = resulting_position.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        resulting_rotation = resulting_rotation.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        forces_directions = forces_directions.unsqueeze(0)
        resulting_force_success = resulting_force_success.unsqueeze(0)
        all_force_applied = all_force_applied.unsqueeze(0)

        all_keypoints = get_keypoint_projection(
            object_name, resulting_position, resulting_rotation,
            self.all_objects_keypoint_tensor[object_name])
        all_keypoints = all_keypoints.unsqueeze(
            0)  # adding batchsize back because we need it in the loss

        contact_points_prediction = contact_points_prediction.unsqueeze(
            1).repeat(1, seq_len, 1, 1)

        output_dict = {
            'keypoints': all_keypoints,
            'rotation': resulting_rotation,
            'position': resulting_position,
            'force_success_flag': resulting_force_success,
            'force_applied': all_force_applied,
            'force_direction':
            forces_directions,  # batch size x seq len -1 x number of cp x 3
            'contact_points': contact_points_prediction,
        }

        target['object_name'] = input_dict['object_name']

        return output_dict, target
    def forward(self, input, target):

        initial_position = input['initial_position']
        initial_rotation = input['initial_rotation']
        contact_points = input['contact_points']
        assert contact_points.shape[0] == 1
        contact_points = contact_points.squeeze(0)

        object_name = input['object_name']
        assert len(object_name) == 1
        object_name = object_name[0]
        target['object_name'] = input['object_name']

        predefined_force = target['forces'].squeeze(0).view(self.sequence_length - 1, 5 * 3)
        all_forces = torch.nn.Parameter(predefined_force.detach())

        loss_function = self.this_loss_func

        if self.gpu_ids != -1:
            all_forces = all_forces.cuda().detach()
            loss_function = loss_function.cuda()
        optimizer = torch.optim.SGD([all_forces], lr=self.base_lr)

        number_of_states = 20

        for t in range(number_of_states):

            # all_forces = all_forces.clamp(-1.5, 1.5)

            if t <= self.step_size:
                lr = self.base_lr
            elif t <= self.step_size * 2:
                lr = self.base_lr * 0.1
            elif t <= self.step_size * 3:
                lr = self.base_lr * 0.01

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            env_state = EnvState(object_name=object_name, rotation=initial_rotation[0], position=initial_position[0], velocity=None, omega=None)
            resulting_force_success = []
            forces_directions = []
            all_force_applied = []
            resulting_position = []
            resulting_rotation = []

            for seq_ind in range(self.sequence_length - 1):
                force = all_forces[seq_ind]

                assert force.shape[0] == (self.number_of_cp * 3)

                # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here
                env_state, force_success, force_applied = self.environment_layer.apply(self.environment, env_state.toTensor(), force, contact_points)
                env_state = EnvState.fromTensor(env_state)

                resulting_position.append(env_state.position)
                resulting_rotation.append(env_state.rotation)
                resulting_force_success.append(force_success)
                forces_directions.append(force.view(self.number_of_cp, 3))
                all_force_applied.append(force_applied)

            resulting_position = torch.stack(resulting_position, dim=0)
            resulting_rotation = torch.stack(resulting_rotation, dim=0)

            resulting_position = resulting_position.unsqueeze(0)  # adding batchsize back because we need it in the loss
            resulting_rotation = resulting_rotation.unsqueeze(0)  # adding batchsize back because we need it in the loss

            all_keypoints = get_keypoint_projection(object_name, resulting_position, resulting_rotation, self.all_objects_keypoint_tensor[object_name])
            all_keypoints = all_keypoints.unsqueeze(0)  # adding batchsize back because we need it in the loss

            output = {
                'keypoints': all_keypoints,
            }

            loss = loss_function(output, target)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        resulting_force_success = torch.stack(resulting_force_success, dim=0)
        forces_directions = torch.stack(forces_directions, dim=0)
        all_force_applied = torch.stack(all_force_applied, dim=0)

        forces_directions = forces_directions.unsqueeze(0)
        resulting_force_success = resulting_force_success.unsqueeze(0)
        all_force_applied = all_force_applied.unsqueeze(0)

        all_keypoints = torch.tensor(all_keypoints, requires_grad=True)

        output = {
            'keypoints': all_keypoints,
            'rotation': resulting_rotation,
            'position': resulting_position,
            'force_success_flag': resulting_force_success,
            'force_applied': all_force_applied,
            'force_direction': forces_directions,  # batch size x seq len -1 x number of cp x 3
        }

        return output, target
    def forward(self, input, target):
        object_name = input['object_name']
        assert len(object_name) == 1
        object_name = object_name[0]

        initial_keypoint = input['initial_keypoint']
        initial_keypoint = initial_keypoint / DEFAULT_IMAGE_SIZE
        initial_keypoint = initial_keypoint.view(1, 2 * 10)

        object_points = self.all_objects_keypoint_tensor[object_name].view(
            1, 3 * 10)

        rgb = input['rgb']
        contact_points = input['contact_points']
        assert contact_points.shape[0] == 1
        contact_points = contact_points.squeeze(0)

        contact_point_as_input = input['contact_points'].view(
            1, 3 * 5)  # Batchsize , 3 * number of contact points
        batch_size, seq_len, c, w, h = rgb.shape

        image_features = self.resnet_features(rgb)
        image_features = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)

        initial_state = self.predict_initial_pose(
            torch.cat([initial_keypoint, object_points],
                      dim=-1)).view(1, 3 + 4)
        initial_position = initial_state[:, :3]
        initial_rotation = initial_state[:, 3:]
        initial_rotation = F.normalize(initial_rotation, dim=-1)

        # Just for visualization purposes
        input['initial_rotation'] = initial_rotation
        input['initial_position'] = initial_position

        initial_object_features = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features = initial_object_features.unsqueeze(1).repeat(
            1, self.sequence_length,
            1)  # add a dimension for sequence length and then repeat that

        input_embedded = torch.cat([image_features, object_features], dim=-1)
        embedded_sequence, (hidden, cell) = self.lstm_encoder(input_embedded)

        last_hidden = hidden.view(
            self.num_layers, 1, 1, self.hidden_size
        )[-1,
          -1, :, :]  # num_layers, num direction, batchsize, hidden and then take the last hidden layer
        last_cell = cell.view(
            self.num_layers, 1, 1, self.hidden_size
        )[-1,
          -1, :, :]  # num_layers, num direction, batchsize, cell and then take the last hidden layer

        hn = last_hidden
        cn = last_cell

        resulting_force_success = []
        forces_directions = []
        all_force_applied = []

        env_state = EnvState(object_name=object_name,
                             rotation=initial_rotation[0],
                             position=initial_position[0],
                             velocity=None,
                             omega=None)
        resulting_position = []
        resulting_rotation = []

        for seq_ind in range(self.sequence_length - 1):
            prev_location = env_state.toTensorCoverName().unsqueeze(0)

            prev_state_and_cp = torch.cat(
                [prev_location, contact_point_as_input], dim=-1)
            prev_state_embedded = self.state_embed(prev_state_and_cp)
            next_state_embedded = embedded_sequence[:, seq_ind + 1]
            input_lstm_cell = torch.cat(
                [prev_state_embedded, next_state_embedded], dim=-1)

            (hn, cn) = self.lstm_decoder(input_lstm_cell, (hn, cn))
            force = self.forces_directions_decoder(hn)
            assert force.shape[0] == 1
            force = force.squeeze(0)
            assert force.shape[0] == (self.number_of_cp * 3)

            # Cap forces
            # think about this more
            force = force.clamp(-1., 1.)

            # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here
            env_state, force_success, force_applied = self.environment_layer.apply(
                self.environment, env_state.toTensor(), force, contact_points)
            env_state = EnvState.fromTensor(env_state)

            resulting_position.append(env_state.position)
            resulting_rotation.append(env_state.rotation)
            resulting_force_success.append(force_success)
            forces_directions.append(force.view(self.number_of_cp, 3))
            all_force_applied.append(force_applied)

        resulting_position = torch.stack(resulting_position, dim=0)
        resulting_rotation = torch.stack(resulting_rotation, dim=0)
        resulting_force_success = torch.stack(resulting_force_success, dim=0)
        forces_directions = torch.stack(forces_directions, dim=0)
        all_force_applied = torch.stack(all_force_applied, dim=0)

        resulting_position = resulting_position.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        resulting_rotation = resulting_rotation.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        forces_directions = forces_directions.unsqueeze(0)
        resulting_force_success = resulting_force_success.unsqueeze(0)
        all_force_applied = all_force_applied.unsqueeze(0)

        all_keypoints = get_keypoint_projection(
            object_name, resulting_position, resulting_rotation,
            self.all_objects_keypoint_tensor[object_name])
        all_keypoints = all_keypoints.unsqueeze(
            0)  # adding batchsize back because we need it in the loss

        output = {
            'keypoints': all_keypoints,
            'rotation': resulting_rotation,
            'position': resulting_position,
            'force_success_flag': resulting_force_success,
            'force_applied': all_force_applied,
            'force_direction':
            forces_directions,  # batch size x seq len -1 x number of cp x 3
        }

        target['object_name'] = input['object_name']

        return output, target
    def forward(self, input_dict, target, contact_points_prediction,
                force_predictions):
        initial_position = input_dict['initial_position']
        initial_rotation = input_dict['initial_rotation']
        rgb = input_dict['rgb']
        dev = rgb.device
        batch_size, seq_len, c, w, h = rgb.shape

        object_name = input_dict['object_name']
        assert len(object_name) == 1  # only support one object
        object_name = object_name[0]
        image_features = forward_resnet_feature(
            x=rgb,
            feature_extractor=self.feature_extractor,
            train_res=self.train_res)

        # hooks to vis gradients
        if self.vis_grad:

            def grad_fn(grad):
                self.grad_value = grad.abs().mean()

            with torch.no_grad():
                if image_features.requires_grad:
                    handler1 = image_features.register_hook(grad_fn)

        frame_features = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)
        initial_object_features_force = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features_force = initial_object_features_force.unsqueeze(
            1).repeat(1, self.sequence_length, 1)
        # add a dimension for sequence length and then repeat that
        input_embedded = torch.cat([frame_features, object_features_force],
                                   dim=-1)
        ns_sequence_features, (hidden_force,
                               cell_force) = self.lstm_encoder(input_embedded)

        # state initialization
        initial_state = NoGradEnvState(object_name=object_name,
                                       rotation=initial_rotation[0],
                                       position=initial_position[0],
                                       velocity=None,
                                       omega=None)
        initial_state_tensor = initial_state.toTensorCoverName().unsqueeze(0)
        ns_state_tensor = initial_state_tensor.clone()
        phy_env_state = nograd_envstate_from_tensor(
            object_name=object_name,
            env_tensor=initial_state_tensor[0],
            clear_velocity=True)

        phy_positions, phy_rotations, ns_positions, ns_rotations, final_forces = [], [], [], [], []
        ns_hidden, ns_cell = None, None
        for seq_ind in range(self.sequence_length - 1):
            contact_point_as_input = contact_points_prediction.view(1, 3 * 5)
            if self.only_first_img_feature:
                current_frame_feature = ns_sequence_features[:, 0]
            else:
                current_frame_feature = ns_sequence_features[:, seq_ind + 1]

            # step physical simulator
            cur_force_pred = force_predictions[seq_ind]

            next_phy_env_state, succ_flags, force_locations, force_values = \
                self.environment.init_location_and_apply_force(forces=cur_force_pred[0].reshape(5, -1),
                                                               initial_state=phy_env_state,
                                                               list_of_contact_points=contact_point_as_input[0].reshape(5, -1),
                                                               no_grad=True, return_force_value=True)
            next_phy_state_tensor = next_phy_env_state.toTensorCoverName()
            phy_positions.append(next_phy_state_tensor[:3].to(dev))
            phy_rotations.append(next_phy_state_tensor[3:7].to(dev))

            embed()
            # step neural force simulator
            # cur_obj_ns_layer = self.ns_layer[object_name]
            # predicted_state_tensor = cur_obj_ns_layer(ns_state_tensor, force, contact_point_as_input)
            if self.clean_force:
                force_values = [ele.to(dev) for ele in force_values]
                force_locations = [ele.to(dev) for ele in force_locations]
                cleaned_force, cleaned_cp = torch.stack(force_values).unsqueeze(0), \
                                            torch.stack(force_locations).unsqueeze(0)
            else:
                cleaned_force, cleaned_cp = cur_force_pred, contact_point_as_input

            if self.use_lstm:
                if seq_ind == 0:
                    ns_hidden, ns_cell = None, None
                next_ns_state_tensor, ns_hidden, ns_cell = self.one_ns_layer(
                    ns_state_tensor[:, :7], cleaned_force, cleaned_cp,
                    current_frame_feature, ns_hidden, ns_cell)
            else:
                next_ns_state_tensor = self.one_ns_layer(
                    ns_state_tensor[:, :7], cleaned_force, cleaned_cp,
                    current_frame_feature)

            # collect ns results
            ns_positions.append(next_ns_state_tensor[0][:3])
            ns_rotations.append(next_ns_state_tensor[0][3:7])
            final_forces.append(cleaned_force[0])

            # update neural and physical simulator
            if seq_ind == self.sequence_length - 2:  # no next state
                break
            rand_num = random.random()
            if 0 <= rand_num < self.ns_ratio:  # update to ns states
                next_ns_env_state = nograd_envstate_from_tensor(
                    object_name=object_name,
                    env_tensor=next_ns_state_tensor[0],
                    clear_velocity=True)
                ns_state_tensor = next_ns_state_tensor
                phy_env_state = next_ns_env_state
            elif self.ns_ratio <= rand_num < self.ns_ratio + self.phy_ratio:  # update to phy states
                ns_state_tensor = next_phy_env_state.toTensorCoverName(
                ).unsqueeze(0).to(dev)
                phy_env_state = next_phy_env_state
            else:  # update to gt states
                gt_env_state = NoGradEnvState(
                    object_name=object_name,
                    rotation=target['rotation'][0][seq_ind + 1],
                    position=target['position'][0][seq_ind + 1],
                    velocity=None,
                    omega=None)
                gt_env_state_tensor = gt_env_state.toTensorCoverName(
                ).unsqueeze(0).to(dev)
                ns_state_tensor = gt_env_state_tensor
                phy_env_state = gt_env_state
        phy_positions = torch.stack(phy_positions).unsqueeze(0)
        phy_rotations = torch.stack(phy_rotations).unsqueeze(0)
        phy_kps = get_keypoint_projection(
            object_name, phy_positions, phy_rotations,
            self.all_objects_keypoint_tensor[object_name]).unsqueeze(0)
        ns_positions = torch.stack(ns_positions).unsqueeze(0)
        ns_rotations = torch.stack(ns_rotations).unsqueeze(0)
        final_forces = torch.stack(final_forces).unsqueeze(0)
        ns_kps = get_keypoint_projection(
            object_name, ns_positions, ns_rotations,
            self.all_objects_keypoint_tensor[object_name]).unsqueeze(0)

        contact_points_prediction = contact_points_prediction.unsqueeze(
            1).repeat(1, seq_len, 1, 1)

        output_dict = {
            'phy_position': phy_positions,
            'phy_rotation': phy_rotations,
            'phy_keypoints': phy_kps,
            'ns_position': ns_positions,
            'ns_rotation': ns_rotations,
            'ns_keypoints': ns_kps,
            'contact_points': contact_points_prediction,
            'force': final_forces,
        }
        return output_dict
Beispiel #6
0
    def forward(self, input_dict, target):
        initial_position = input_dict['initial_position']
        initial_rotation = input_dict['initial_rotation']
        rgb = input_dict['rgb']
        batch_size, seq_len, c, w, h = rgb.shape
        object_name = input_dict['object_name']
        assert len(object_name) == 1  # only support one object
        object_name = object_name[0]
        dev = rgb.device

        image_features = self.resnet_features(rgb)

        # hooks to vis gradients
        if self.vis_grad:

            def grad_fn(grad):
                self.grad_vis = grad.abs().mean()

            with torch.no_grad():
                if image_features.requires_grad:
                    handler1 = image_features.register_hook(grad_fn)

        if self.use_gt_cp:
            contact_points_prediction = input_dict['contact_points']
        else:
            # Contact point prediction tower
            image_features_contact_point = self.contact_point_image_embed(
                image_features.view(batch_size * seq_len, 512, 7,
                                    7)).view(batch_size, seq_len, 64 * 7 * 7)
            initial_object_features_contact_point = \
                self.contact_point_input_object_embed(torch.cat([initial_position, initial_rotation], dim=-1))
            object_features_contact_point = initial_object_features_contact_point.unsqueeze(1).\
                repeat(1, self.sequence_length, 1)  # add a dimension for sequence length and then repeat that

            # Predict contact point
            input_embedded_contact_point = torch.cat(
                [image_features_contact_point, object_features_contact_point],
                dim=-1)
            embedded_sequence_contact_point, (
                _,
                _) = self.contact_point_encoder(input_embedded_contact_point)
            contact_points_prediction = self.contact_point_decoder(
                embedded_sequence_contact_point).view(
                    batch_size, seq_len, self.number_of_cp,
                    3)[:, -1, :, :]  # Predict contact point for each image

        # Force prediction tower
        image_features_force = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)
        initial_object_features_force = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features_force = initial_object_features_force.unsqueeze(
            1).repeat(1, self.sequence_length, 1)
        # add a dimension for sequence length and then repeat that

        input_embedded_force = torch.cat(
            [image_features_force, object_features_force], dim=-1)
        embedded_sequence_force, (
            hidden_force, cell_force) = self.lstm_encoder(input_embedded_force)

        last_hidden = hidden_force.view(self.num_layers, 1, 1,
                                        self.hidden_size)[-1, -1, :, :]
        # num_layers, num direction, batchsize, hidden and then take the last hidden layer
        last_cell = cell_force.view(self.num_layers, 1, 1,
                                    self.hidden_size)[-1, -1, :, :]
        # num_layers, num direction, batchsize, cell and then take the last hidden layer

        hn = last_hidden
        cn = last_cell
        initial_state = NoGradEnvState(object_name=object_name,
                                       rotation=initial_rotation[0],
                                       position=initial_position[0],
                                       velocity=None,
                                       omega=None)
        initial_state_tensor = initial_state.toTensorCoverName().unsqueeze(0)
        ns_state_tensor = initial_state_tensor.clone()
        phy_env_state = nograd_envstate_from_tensor(
            object_name=object_name,
            env_tensor=initial_state_tensor[0],
            clear_velocity=True)
        phy_positions, ns_positions, phy_rotations, ns_rotations, final_forces = [], [], [], [], []

        for seq_ind in range(self.sequence_length - 1):
            contact_point_as_input = contact_points_prediction.view(1, 3 * 5)
            initial_state_and_cp = torch.cat(
                [initial_state_tensor, contact_point_as_input], dim=-1)
            initial_frame_feature = self.state_embed(initial_state_and_cp)
            current_frame_feature = embedded_sequence_force[:, seq_ind + 1]
            input_lstm_cell = torch.cat(
                [initial_frame_feature, current_frame_feature], dim=-1)
            (hn, cn) = self.lstm_decoder(input_lstm_cell, (hn, cn))
            force = self.forces_directions_decoder(hn)

            # step physical simulator
            assert len(ns_state_tensor) == 1, "only support bs = 1."
            phy_env_state, succ_flags, force_locations, force_values = \
                self.environment.init_location_and_apply_force(forces=force[0].reshape(5, -1),
                                                               initial_state=phy_env_state,
                                                               list_of_contact_points=contact_point_as_input[0].reshape(5, -1),
                                                               no_grad=True, return_force_value=True)
            phy_state_tensor = phy_env_state.toTensorCoverName()
            phy_positions.append(phy_state_tensor[:3])
            phy_rotations.append(phy_state_tensor[3:7])

            # step neural force simulator
            # cur_obj_ns_layer = self.ns_layer[object_name]
            # predicted_state_tensor = cur_obj_ns_layer(ns_state_tensor, force, contact_point_as_input)
            if self.clean_force:
                force_values = [ele.to(dev) for ele in force_values]
                force_locations = [ele.to(dev) for ele in force_locations]
                cleaned_force, cleaned_cp = torch.stack(force_values).unsqueeze(0), \
                                            torch.stack(force_locations).unsqueeze(0)
            else:
                cleaned_force, cleaned_cp = force, contact_point_as_input

            if self.use_image_feature:
                predicted_state_tensor = self.one_ns_layer(
                    ns_state_tensor[:, :7], cleaned_force, cleaned_cp,
                    current_frame_feature)
            else:
                predicted_state_tensor = self.one_ns_layer(
                    ns_state_tensor, cleaned_force, cleaned_cp)

            # collect ns results
            ns_positions.append(predicted_state_tensor[0][:3])
            ns_rotations.append(predicted_state_tensor[0][3:7])
            final_forces.append(cleaned_force[0])

            # update to next state
            ns_state_tensor = predicted_state_tensor

        phy_positions = torch.stack(phy_positions).unsqueeze(0)
        phy_rotations = torch.stack(phy_rotations).unsqueeze(0)
        phy_kps = get_keypoint_projection(
            object_name, phy_positions, phy_rotations,
            self.all_objects_keypoint_tensor[object_name]).unsqueeze(0)
        ns_positions = torch.stack(ns_positions).unsqueeze(0)
        ns_rotations = torch.stack(ns_rotations).unsqueeze(0)
        final_forces = torch.stack(final_forces).unsqueeze(0)
        ns_kps = get_keypoint_projection(
            object_name, ns_positions, ns_rotations,
            self.all_objects_keypoint_tensor[object_name]).unsqueeze(0)

        contact_points_prediction = contact_points_prediction.unsqueeze(
            1).repeat(1, seq_len, 1, 1)

        output = {
            'phy_position': phy_positions,
            'phy_rotation': phy_rotations,
            'phy_keypoints': phy_kps,
            'ns_position': ns_positions,
            'ns_rotation': ns_rotations,
            'ns_keypoints': ns_kps,
            'contact_points': contact_points_prediction,
            'force': final_forces,
        }

        target['object_name'] = input_dict['object_name']

        return output, target
    def forward(self, input, target):

        initial_position = input['initial_position']
        initial_rotation = input['initial_rotation']
        rgb = input['rgb']
        batch_size, seq_len, c, w, h = rgb.shape
        contact_point_as_input = input['contact_points'].view(
            batch_size, 5 * 3)

        image_features = self.resnet_features(rgb)
        image_features = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)

        initial_object_features = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features = initial_object_features.unsqueeze(1).repeat(
            1, self.sequence_length,
            1)  # add a dimension for sequence length and then repeat that

        input_embedded = torch.cat([image_features, object_features], dim=-1)
        embedded_sequence, (hidden, cell) = self.lstm_encoder(input_embedded)

        contact_point_embedding = self.contact_point_embed(
            contact_point_as_input).unsqueeze(1).repeat(1, seq_len, 1)
        combined_w_cp = torch.cat([embedded_sequence, contact_point_embedding],
                                  dim=-1)
        forces_prediction = self.force_decoder(combined_w_cp).view(
            batch_size, seq_len, self.number_of_cp,
            3)  # Predict contact point for each image
        forces_prediction = forces_prediction[:, 1:, :, :]

        forces_prediction = forces_prediction.clamp(-1.5, 1.5)

        output = {
            'forces':
            forces_prediction,  # batchsize x seq len x number of cp x 3
        }
        target['object_name'] = input['object_name']

        # forces_prediction[:] = 0.
        # output['forces'][:] = 0.

        #  remove
        # forces_prediction[:] = target['forces']

        if not self.train_mode:
            object_name = input['object_name']
            assert len(object_name) == 1
            object_name = object_name[0]

            contact_points = input['contact_points']
            assert contact_points.shape[0] == 1
            contact_points = contact_points.squeeze(0)
            resulting_position = []
            resulting_rotation = []

            env_state = EnvState(object_name=object_name,
                                 rotation=initial_rotation[0],
                                 position=initial_position[0],
                                 velocity=None,
                                 omega=None)
            for seq_ind in range(self.sequence_length - 1):
                force = forces_prediction[0,
                                          seq_ind].view(self.number_of_cp * 3)
                # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here
                env_state, force_success, force_applied = self.environment_layer.apply(
                    self.environment, env_state.toTensor(), force,
                    contact_points)
                env_state = EnvState.fromTensor(env_state)

                resulting_position.append(env_state.position)
                resulting_rotation.append(env_state.rotation)
            resulting_position = torch.stack(resulting_position, dim=0)
            resulting_rotation = torch.stack(resulting_rotation, dim=0)
            resulting_position = resulting_position.unsqueeze(
                0)  # adding batchsize back because we need it in the loss
            resulting_rotation = resulting_rotation.unsqueeze(
                0)  # adding batchsize back because we need it in the loss

            all_keypoints = get_keypoint_projection(
                object_name, resulting_position, resulting_rotation,
                self.all_objects_keypoint_tensor[object_name])
            all_keypoints = all_keypoints.unsqueeze(
                0)  # adding batchsize back because we need it in the loss
            output['keypoints'] = all_keypoints
            output['rotation'] = resulting_rotation
            output['position'] = resulting_position

            output['force_applied'] = output['forces']
            output['force_direction'] = output['forces']

        return output, target