def forward(self, input_dict, target): initial_position = input_dict['initial_position'] initial_rotation = input_dict['initial_rotation'] rgb = input_dict['rgb'] batch_size, seq_len, c, w, h = rgb.shape object_name = input_dict['object_name'] assert len(object_name) == 1 # only support one object object_name = object_name[0] image_features = self.resnet_features(rgb) # Contact point prediction tower image_features_contact_point = self.contact_point_image_embed( image_features.view(batch_size * seq_len, 512, 7, 7)).view(batch_size, seq_len, 64 * 7 * 7) initial_object_features_contact_point = \ self.contact_point_input_object_embed(torch.cat([initial_position, initial_rotation], dim=-1)) object_features_contact_point = initial_object_features_contact_point.unsqueeze(1).\ repeat(1, self.sequence_length, 1) # add a dimension for sequence length and then repeat that # Predict contact point input_embedded_contact_point = torch.cat( [image_features_contact_point, object_features_contact_point], dim=-1) embedded_sequence_contact_point, ( _, _) = self.contact_point_encoder(input_embedded_contact_point) contact_points_prediction = self.contact_point_decoder( embedded_sequence_contact_point).view(batch_size, seq_len, self.number_of_cp, 3)[:, -1, :, :] # Predict contact point for each image and get the last prediction # Force prediction tower image_features_force = self.image_embed( image_features.view(batch_size * seq_len, 512, 7, 7)).view(batch_size, seq_len, 64 * 7 * 7) initial_object_features_force = self.input_object_embed( torch.cat([initial_position, initial_rotation], dim=-1)) object_features_force = initial_object_features_force.unsqueeze( 1).repeat(1, self.sequence_length, 1) # add a dimension for sequence length and then repeat that input_embedded_force = torch.cat( [image_features_force, object_features_force], dim=-1) embedded_sequence_force, ( hidden_force, cell_force) = self.lstm_encoder(input_embedded_force) last_hidden = hidden_force.view(self.num_layers, 1, 1, self.hidden_size)[-1, -1, :, :] # num_layers, num direction, batchsize, hidden and then take the last hidden layer last_cell = cell_force.view(self.num_layers, 1, 1, self.hidden_size)[-1, -1, :, :] # num_layers, num direction, batchsize, cell and then take the last hidden layer hn = last_hidden cn = last_cell resulting_force_success = [] forces_directions = [] all_force_applied = [] env_state = EnvState(object_name=object_name, rotation=initial_rotation[0], position=initial_position[0], velocity=None, omega=None) resulting_position = [] resulting_rotation = [] for seq_ind in range(self.sequence_length - 1): prev_location = env_state.toTensorCoverName().unsqueeze(0) contact_point_as_input = contact_points_prediction.view(1, 3 * 5) prev_state_and_cp = torch.cat( [prev_location, contact_point_as_input], dim=-1) prev_state_embedded = self.state_embed(prev_state_and_cp) next_state_embedded = embedded_sequence_force[:, seq_ind + 1] input_lstm_cell = torch.cat( [prev_state_embedded, next_state_embedded], dim=-1) (hn, cn) = self.lstm_decoder(input_lstm_cell, (hn, cn)) force = self.forces_directions_decoder(hn) assert force.shape[0] == 1 force = force.squeeze(0) assert force.shape[0] == (self.number_of_cp * 3) # initial initial_velocity is whatever it was the last frame, # note that the gradients are not backproped here env_state, force_success, force_applied = \ self.environment_layer.apply(self.environment, env_state.toTensor(), force, contact_point_as_input.squeeze(0)) env_state = EnvState.fromTensor(env_state) resulting_position.append(env_state.position) resulting_rotation.append(env_state.rotation) resulting_force_success.append(force_success) forces_directions.append(force.view(self.number_of_cp, 3)) all_force_applied.append(force_applied) resulting_position = torch.stack(resulting_position, dim=0) resulting_rotation = torch.stack(resulting_rotation, dim=0) resulting_force_success = torch.stack(resulting_force_success, dim=0) forces_directions = torch.stack(forces_directions, dim=0) all_force_applied = torch.stack(all_force_applied, dim=0) resulting_position = resulting_position.unsqueeze( 0) # adding batchsize back because we need it in the loss resulting_rotation = resulting_rotation.unsqueeze( 0) # adding batchsize back because we need it in the loss forces_directions = forces_directions.unsqueeze(0) resulting_force_success = resulting_force_success.unsqueeze(0) all_force_applied = all_force_applied.unsqueeze(0) all_keypoints = get_keypoint_projection( object_name, resulting_position, resulting_rotation, self.all_objects_keypoint_tensor[object_name]) all_keypoints = all_keypoints.unsqueeze( 0) # adding batchsize back because we need it in the loss contact_points_prediction = contact_points_prediction.unsqueeze( 1).repeat(1, seq_len, 1, 1) output_dict = { 'keypoints': all_keypoints, 'rotation': resulting_rotation, 'position': resulting_position, 'force_success_flag': resulting_force_success, 'force_applied': all_force_applied, 'force_direction': forces_directions, # batch size x seq len -1 x number of cp x 3 'contact_points': contact_points_prediction, } target['object_name'] = input_dict['object_name'] return output_dict, target
def forward(self, input, target): object_name = input['object_name'] assert len(object_name) == 1 object_name = object_name[0] initial_keypoint = input['initial_keypoint'] initial_keypoint = initial_keypoint / DEFAULT_IMAGE_SIZE initial_keypoint = initial_keypoint.view(1, 2 * 10) object_points = self.all_objects_keypoint_tensor[object_name].view( 1, 3 * 10) rgb = input['rgb'] contact_points = input['contact_points'] assert contact_points.shape[0] == 1 contact_points = contact_points.squeeze(0) contact_point_as_input = input['contact_points'].view( 1, 3 * 5) # Batchsize , 3 * number of contact points batch_size, seq_len, c, w, h = rgb.shape image_features = self.resnet_features(rgb) image_features = self.image_embed( image_features.view(batch_size * seq_len, 512, 7, 7)).view(batch_size, seq_len, 64 * 7 * 7) initial_state = self.predict_initial_pose( torch.cat([initial_keypoint, object_points], dim=-1)).view(1, 3 + 4) initial_position = initial_state[:, :3] initial_rotation = initial_state[:, 3:] initial_rotation = F.normalize(initial_rotation, dim=-1) # Just for visualization purposes input['initial_rotation'] = initial_rotation input['initial_position'] = initial_position initial_object_features = self.input_object_embed( torch.cat([initial_position, initial_rotation], dim=-1)) object_features = initial_object_features.unsqueeze(1).repeat( 1, self.sequence_length, 1) # add a dimension for sequence length and then repeat that input_embedded = torch.cat([image_features, object_features], dim=-1) embedded_sequence, (hidden, cell) = self.lstm_encoder(input_embedded) last_hidden = hidden.view( self.num_layers, 1, 1, self.hidden_size )[-1, -1, :, :] # num_layers, num direction, batchsize, hidden and then take the last hidden layer last_cell = cell.view( self.num_layers, 1, 1, self.hidden_size )[-1, -1, :, :] # num_layers, num direction, batchsize, cell and then take the last hidden layer hn = last_hidden cn = last_cell resulting_force_success = [] forces_directions = [] all_force_applied = [] env_state = EnvState(object_name=object_name, rotation=initial_rotation[0], position=initial_position[0], velocity=None, omega=None) resulting_position = [] resulting_rotation = [] for seq_ind in range(self.sequence_length - 1): prev_location = env_state.toTensorCoverName().unsqueeze(0) prev_state_and_cp = torch.cat( [prev_location, contact_point_as_input], dim=-1) prev_state_embedded = self.state_embed(prev_state_and_cp) next_state_embedded = embedded_sequence[:, seq_ind + 1] input_lstm_cell = torch.cat( [prev_state_embedded, next_state_embedded], dim=-1) (hn, cn) = self.lstm_decoder(input_lstm_cell, (hn, cn)) force = self.forces_directions_decoder(hn) assert force.shape[0] == 1 force = force.squeeze(0) assert force.shape[0] == (self.number_of_cp * 3) # Cap forces # think about this more force = force.clamp(-1., 1.) # initial initial_velocity is whatever it was the last frame, note that the gradients are not backproped here env_state, force_success, force_applied = self.environment_layer.apply( self.environment, env_state.toTensor(), force, contact_points) env_state = EnvState.fromTensor(env_state) resulting_position.append(env_state.position) resulting_rotation.append(env_state.rotation) resulting_force_success.append(force_success) forces_directions.append(force.view(self.number_of_cp, 3)) all_force_applied.append(force_applied) resulting_position = torch.stack(resulting_position, dim=0) resulting_rotation = torch.stack(resulting_rotation, dim=0) resulting_force_success = torch.stack(resulting_force_success, dim=0) forces_directions = torch.stack(forces_directions, dim=0) all_force_applied = torch.stack(all_force_applied, dim=0) resulting_position = resulting_position.unsqueeze( 0) # adding batchsize back because we need it in the loss resulting_rotation = resulting_rotation.unsqueeze( 0) # adding batchsize back because we need it in the loss forces_directions = forces_directions.unsqueeze(0) resulting_force_success = resulting_force_success.unsqueeze(0) all_force_applied = all_force_applied.unsqueeze(0) all_keypoints = get_keypoint_projection( object_name, resulting_position, resulting_rotation, self.all_objects_keypoint_tensor[object_name]) all_keypoints = all_keypoints.unsqueeze( 0) # adding batchsize back because we need it in the loss output = { 'keypoints': all_keypoints, 'rotation': resulting_rotation, 'position': resulting_position, 'force_success_flag': resulting_force_success, 'force_applied': all_force_applied, 'force_direction': forces_directions, # batch size x seq len -1 x number of cp x 3 } target['object_name'] = input['object_name'] return output, target