Ejemplo n.º 1
0
    def eval_at_state(self, model_options, frame):
        model_input = ModelInput()
        #         if self.episode.current_frame is None:
        #             model_input.state = self.state()
        #         else:
        #             model_input.state = self.episode.current_frame
        #process_frame to shape [1,3,224,224], for input to resnet18
        processed_frame = self.preprocess_frame(
            resnet_input_transform(frame, 224).unsqueeze(0))
        resnet18_features = self.resnet18(processed_frame)

        model_input.state = resnet18_features
        model_input.hidden = self.hidden
        model_input.target_class_embedding = gpuify(torch.Tensor(
            self.target_glove_embedding),
                                                    gpu_id=self.gpu_id)
        model_input.action_probs = self.last_action_probs

        return model_input, self.model.forward(model_input, model_options)
Ejemplo n.º 2
0
 def preprocess_frame(self, frame):
     """ Preprocess the current frame for input into the model. """
     frame = resnet_input_transform(frame, 84)
     state = torch.Tensor(frame)
     return gpuify(state.unsqueeze(0), self.gpu_id)