Beispiel #1
0
    def forward(self, input_dict, target):
        initial_position = input_dict['initial_position']
        initial_rotation = input_dict['initial_rotation']
        rgb = input_dict['rgb']
        batch_size, seq_len, c, w, h = rgb.shape
        object_name = input_dict['object_name']
        assert len(object_name) == 1  # only support one object
        object_name = object_name[0]

        image_features = self.resnet_features(rgb)

        # Contact point prediction tower
        image_features_contact_point = self.contact_point_image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)
        initial_object_features_contact_point = \
            self.contact_point_input_object_embed(torch.cat([initial_position, initial_rotation], dim=-1))
        object_features_contact_point = initial_object_features_contact_point.unsqueeze(1).\
            repeat(1, self.sequence_length, 1)  # add a dimension for sequence length and then repeat that

        # Predict contact point
        input_embedded_contact_point = torch.cat(
            [image_features_contact_point, object_features_contact_point],
            dim=-1)
        embedded_sequence_contact_point, (
            _, _) = self.contact_point_encoder(input_embedded_contact_point)
        contact_points_prediction = self.contact_point_decoder(
            embedded_sequence_contact_point).view(batch_size, seq_len,
                                                  self.number_of_cp,
                                                  3)[:, -1, :, :]
        # Predict contact point for each image and get the last prediction

        # Force prediction tower
        image_features_force = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)
        initial_object_features_force = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features_force = initial_object_features_force.unsqueeze(
            1).repeat(1, self.sequence_length, 1)
        # add a dimension for sequence length and then repeat that

        input_embedded_force = torch.cat(
            [image_features_force, object_features_force], dim=-1)
        embedded_sequence_force, (
            hidden_force, cell_force) = self.lstm_encoder(input_embedded_force)

        last_hidden = hidden_force.view(self.num_layers, 1, 1,
                                        self.hidden_size)[-1, -1, :, :]
        # num_layers, num direction, batchsize, hidden and then take the last hidden layer
        last_cell = cell_force.view(self.num_layers, 1, 1,
                                    self.hidden_size)[-1, -1, :, :]
        # num_layers, num direction, batchsize, cell and then take the last hidden layer

        hn = last_hidden
        cn = last_cell

        resulting_force_success = []
        forces_directions = []
        all_force_applied = []

        env_state = EnvState(object_name=object_name,
                             rotation=initial_rotation[0],
                             position=initial_position[0],
                             velocity=None,
                             omega=None)
        resulting_position = []
        resulting_rotation = []

        for seq_ind in range(self.sequence_length - 1):
            prev_location = env_state.toTensorCoverName().unsqueeze(0)
            contact_point_as_input = contact_points_prediction.view(1, 3 * 5)

            prev_state_and_cp = torch.cat(
                [prev_location, contact_point_as_input], dim=-1)
            prev_state_embedded = self.state_embed(prev_state_and_cp)
            next_state_embedded = embedded_sequence_force[:, seq_ind + 1]
            input_lstm_cell = torch.cat(
                [prev_state_embedded, next_state_embedded], dim=-1)

            (hn, cn) = self.lstm_decoder(input_lstm_cell, (hn, cn))
            force = self.forces_directions_decoder(hn)
            assert force.shape[0] == 1
            force = force.squeeze(0)
            assert force.shape[0] == (self.number_of_cp * 3)

            # initial initial_velocity is whatever it was the last frame,
            # note that the gradients are not backproped here
            env_state, force_success, force_applied = \
                self.environment_layer.apply(self.environment, env_state.toTensor(), force,
                                             contact_point_as_input.squeeze(0))

            env_state = EnvState.fromTensor(env_state)

            resulting_position.append(env_state.position)
            resulting_rotation.append(env_state.rotation)
            resulting_force_success.append(force_success)
            forces_directions.append(force.view(self.number_of_cp, 3))
            all_force_applied.append(force_applied)

        resulting_position = torch.stack(resulting_position, dim=0)
        resulting_rotation = torch.stack(resulting_rotation, dim=0)
        resulting_force_success = torch.stack(resulting_force_success, dim=0)
        forces_directions = torch.stack(forces_directions, dim=0)
        all_force_applied = torch.stack(all_force_applied, dim=0)

        resulting_position = resulting_position.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        resulting_rotation = resulting_rotation.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        forces_directions = forces_directions.unsqueeze(0)
        resulting_force_success = resulting_force_success.unsqueeze(0)
        all_force_applied = all_force_applied.unsqueeze(0)

        all_keypoints = get_keypoint_projection(
            object_name, resulting_position, resulting_rotation,
            self.all_objects_keypoint_tensor[object_name])
        all_keypoints = all_keypoints.unsqueeze(
            0)  # adding batchsize back because we need it in the loss

        contact_points_prediction = contact_points_prediction.unsqueeze(
            1).repeat(1, seq_len, 1, 1)

        output_dict = {
            'keypoints': all_keypoints,
            'rotation': resulting_rotation,
            'position': resulting_position,
            'force_success_flag': resulting_force_success,
            'force_applied': all_force_applied,
            'force_direction':
            forces_directions,  # batch size x seq len -1 x number of cp x 3
            'contact_points': contact_points_prediction,
        }

        target['object_name'] = input_dict['object_name']

        return output_dict, target
    def forward(self, input, target):
        object_name = input['object_name']
        assert len(object_name) == 1
        object_name = object_name[0]

        initial_keypoint = input['initial_keypoint']
        initial_keypoint = initial_keypoint / DEFAULT_IMAGE_SIZE
        initial_keypoint = initial_keypoint.view(1, 2 * 10)

        object_points = self.all_objects_keypoint_tensor[object_name].view(
            1, 3 * 10)

        rgb = input['rgb']
        contact_points = input['contact_points']
        assert contact_points.shape[0] == 1
        contact_points = contact_points.squeeze(0)

        contact_point_as_input = input['contact_points'].view(
            1, 3 * 5)  # Batchsize , 3 * number of contact points
        batch_size, seq_len, c, w, h = rgb.shape

        image_features = self.resnet_features(rgb)
        image_features = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)

        initial_state = self.predict_initial_pose(
            torch.cat([initial_keypoint, object_points],
                      dim=-1)).view(1, 3 + 4)
        initial_position = initial_state[:, :3]
        initial_rotation = initial_state[:, 3:]
        initial_rotation = F.normalize(initial_rotation, dim=-1)

        # Just for visualization purposes
        input['initial_rotation'] = initial_rotation
        input['initial_position'] = initial_position

        initial_object_features = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features = initial_object_features.unsqueeze(1).repeat(
            1, self.sequence_length,
            1)  # add a dimension for sequence length and then repeat that

        input_embedded = torch.cat([image_features, object_features], dim=-1)
        embedded_sequence, (hidden, cell) = self.lstm_encoder(input_embedded)

        last_hidden = hidden.view(
            self.num_layers, 1, 1, self.hidden_size
        )[-1,
          -1, :, :]  # num_layers, num direction, batchsize, hidden and then take the last hidden layer
        last_cell = cell.view(
            self.num_layers, 1, 1, self.hidden_size
        )[-1,
          -1, :, :]  # num_layers, num direction, batchsize, cell and then take the last hidden layer

        hn = last_hidden
        cn = last_cell

        resulting_force_success = []
        forces_directions = []
        all_force_applied = []

        env_state = EnvState(object_name=object_name,
                             rotation=initial_rotation[0],
                             position=initial_position[0],
                             velocity=None,
                             omega=None)
        resulting_position = []
        resulting_rotation = []

        for seq_ind in range(self.sequence_length - 1):
            prev_location = env_state.toTensorCoverName().unsqueeze(0)

            prev_state_and_cp = torch.cat(
                [prev_location, contact_point_as_input], dim=-1)
            prev_state_embedded = self.state_embed(prev_state_and_cp)
            next_state_embedded = embedded_sequence[:, seq_ind + 1]
            input_lstm_cell = torch.cat(
                [prev_state_embedded, next_state_embedded], dim=-1)

            (hn, cn) = self.lstm_decoder(input_lstm_cell, (hn, cn))
            force = self.forces_directions_decoder(hn)
            assert force.shape[0] == 1
            force = force.squeeze(0)
            assert force.shape[0] == (self.number_of_cp * 3)

            # Cap forces
            # think about this more
            force = force.clamp(-1., 1.)

            # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here
            env_state, force_success, force_applied = self.environment_layer.apply(
                self.environment, env_state.toTensor(), force, contact_points)
            env_state = EnvState.fromTensor(env_state)

            resulting_position.append(env_state.position)
            resulting_rotation.append(env_state.rotation)
            resulting_force_success.append(force_success)
            forces_directions.append(force.view(self.number_of_cp, 3))
            all_force_applied.append(force_applied)

        resulting_position = torch.stack(resulting_position, dim=0)
        resulting_rotation = torch.stack(resulting_rotation, dim=0)
        resulting_force_success = torch.stack(resulting_force_success, dim=0)
        forces_directions = torch.stack(forces_directions, dim=0)
        all_force_applied = torch.stack(all_force_applied, dim=0)

        resulting_position = resulting_position.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        resulting_rotation = resulting_rotation.unsqueeze(
            0)  # adding batchsize back because we need it in the loss
        forces_directions = forces_directions.unsqueeze(0)
        resulting_force_success = resulting_force_success.unsqueeze(0)
        all_force_applied = all_force_applied.unsqueeze(0)

        all_keypoints = get_keypoint_projection(
            object_name, resulting_position, resulting_rotation,
            self.all_objects_keypoint_tensor[object_name])
        all_keypoints = all_keypoints.unsqueeze(
            0)  # adding batchsize back because we need it in the loss

        output = {
            'keypoints': all_keypoints,
            'rotation': resulting_rotation,
            'position': resulting_position,
            'force_success_flag': resulting_force_success,
            'force_applied': all_force_applied,
            'force_direction':
            forces_directions,  # batch size x seq len -1 x number of cp x 3
        }

        target['object_name'] = input['object_name']

        return output, target
    def forward(self, input, target):

        initial_position = input['initial_position']
        initial_rotation = input['initial_rotation']
        contact_points = input['contact_points']
        assert contact_points.shape[0] == 1
        contact_points = contact_points.squeeze(0)

        object_name = input['object_name']
        assert len(object_name) == 1
        object_name = object_name[0]
        target['object_name'] = input['object_name']

        predefined_force = target['forces'].squeeze(0).view(self.sequence_length - 1, 5 * 3)
        all_forces = torch.nn.Parameter(predefined_force.detach())

        loss_function = self.this_loss_func

        if self.gpu_ids != -1:
            all_forces = all_forces.cuda().detach()
            loss_function = loss_function.cuda()
        optimizer = torch.optim.SGD([all_forces], lr=self.base_lr)

        number_of_states = 20

        for t in range(number_of_states):

            # all_forces = all_forces.clamp(-1.5, 1.5)

            if t <= self.step_size:
                lr = self.base_lr
            elif t <= self.step_size * 2:
                lr = self.base_lr * 0.1
            elif t <= self.step_size * 3:
                lr = self.base_lr * 0.01

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            env_state = EnvState(object_name=object_name, rotation=initial_rotation[0], position=initial_position[0], velocity=None, omega=None)
            resulting_force_success = []
            forces_directions = []
            all_force_applied = []
            resulting_position = []
            resulting_rotation = []

            for seq_ind in range(self.sequence_length - 1):
                force = all_forces[seq_ind]

                assert force.shape[0] == (self.number_of_cp * 3)

                # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here
                env_state, force_success, force_applied = self.environment_layer.apply(self.environment, env_state.toTensor(), force, contact_points)
                env_state = EnvState.fromTensor(env_state)

                resulting_position.append(env_state.position)
                resulting_rotation.append(env_state.rotation)
                resulting_force_success.append(force_success)
                forces_directions.append(force.view(self.number_of_cp, 3))
                all_force_applied.append(force_applied)

            resulting_position = torch.stack(resulting_position, dim=0)
            resulting_rotation = torch.stack(resulting_rotation, dim=0)

            resulting_position = resulting_position.unsqueeze(0)  # adding batchsize back because we need it in the loss
            resulting_rotation = resulting_rotation.unsqueeze(0)  # adding batchsize back because we need it in the loss

            all_keypoints = get_keypoint_projection(object_name, resulting_position, resulting_rotation, self.all_objects_keypoint_tensor[object_name])
            all_keypoints = all_keypoints.unsqueeze(0)  # adding batchsize back because we need it in the loss

            output = {
                'keypoints': all_keypoints,
            }

            loss = loss_function(output, target)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        resulting_force_success = torch.stack(resulting_force_success, dim=0)
        forces_directions = torch.stack(forces_directions, dim=0)
        all_force_applied = torch.stack(all_force_applied, dim=0)

        forces_directions = forces_directions.unsqueeze(0)
        resulting_force_success = resulting_force_success.unsqueeze(0)
        all_force_applied = all_force_applied.unsqueeze(0)

        all_keypoints = torch.tensor(all_keypoints, requires_grad=True)

        output = {
            'keypoints': all_keypoints,
            'rotation': resulting_rotation,
            'position': resulting_position,
            'force_success_flag': resulting_force_success,
            'force_applied': all_force_applied,
            'force_direction': forces_directions,  # batch size x seq len -1 x number of cp x 3
        }

        return output, target
    def forward(self, input, target):

        initial_position = input['initial_position']
        initial_rotation = input['initial_rotation']
        rgb = input['rgb']
        batch_size, seq_len, c, w, h = rgb.shape
        contact_point_as_input = input['contact_points'].view(
            batch_size, 5 * 3)

        image_features = self.resnet_features(rgb)
        image_features = self.image_embed(
            image_features.view(batch_size * seq_len, 512, 7,
                                7)).view(batch_size, seq_len, 64 * 7 * 7)

        initial_object_features = self.input_object_embed(
            torch.cat([initial_position, initial_rotation], dim=-1))
        object_features = initial_object_features.unsqueeze(1).repeat(
            1, self.sequence_length,
            1)  # add a dimension for sequence length and then repeat that

        input_embedded = torch.cat([image_features, object_features], dim=-1)
        embedded_sequence, (hidden, cell) = self.lstm_encoder(input_embedded)

        contact_point_embedding = self.contact_point_embed(
            contact_point_as_input).unsqueeze(1).repeat(1, seq_len, 1)
        combined_w_cp = torch.cat([embedded_sequence, contact_point_embedding],
                                  dim=-1)
        forces_prediction = self.force_decoder(combined_w_cp).view(
            batch_size, seq_len, self.number_of_cp,
            3)  # Predict contact point for each image
        forces_prediction = forces_prediction[:, 1:, :, :]

        forces_prediction = forces_prediction.clamp(-1.5, 1.5)

        output = {
            'forces':
            forces_prediction,  # batchsize x seq len x number of cp x 3
        }
        target['object_name'] = input['object_name']

        # forces_prediction[:] = 0.
        # output['forces'][:] = 0.

        #  remove
        # forces_prediction[:] = target['forces']

        if not self.train_mode:
            object_name = input['object_name']
            assert len(object_name) == 1
            object_name = object_name[0]

            contact_points = input['contact_points']
            assert contact_points.shape[0] == 1
            contact_points = contact_points.squeeze(0)
            resulting_position = []
            resulting_rotation = []

            env_state = EnvState(object_name=object_name,
                                 rotation=initial_rotation[0],
                                 position=initial_position[0],
                                 velocity=None,
                                 omega=None)
            for seq_ind in range(self.sequence_length - 1):
                force = forces_prediction[0,
                                          seq_ind].view(self.number_of_cp * 3)
                # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here
                env_state, force_success, force_applied = self.environment_layer.apply(
                    self.environment, env_state.toTensor(), force,
                    contact_points)
                env_state = EnvState.fromTensor(env_state)

                resulting_position.append(env_state.position)
                resulting_rotation.append(env_state.rotation)
            resulting_position = torch.stack(resulting_position, dim=0)
            resulting_rotation = torch.stack(resulting_rotation, dim=0)
            resulting_position = resulting_position.unsqueeze(
                0)  # adding batchsize back because we need it in the loss
            resulting_rotation = resulting_rotation.unsqueeze(
                0)  # adding batchsize back because we need it in the loss

            all_keypoints = get_keypoint_projection(
                object_name, resulting_position, resulting_rotation,
                self.all_objects_keypoint_tensor[object_name])
            all_keypoints = all_keypoints.unsqueeze(
                0)  # adding batchsize back because we need it in the loss
            output['keypoints'] = all_keypoints
            output['rotation'] = resulting_rotation
            output['position'] = resulting_position

            output['force_applied'] = output['forces']
            output['force_direction'] = output['forces']

        return output, target