Exemplo n.º 1
0
    def initialize(self):
        self.shared_network = SharedNetwork().to(self.device)
        # 12 actions Roll pitch yaw surge sway heave
        self.scene_networks = {
            key: SceneSpecificNetwork(8).to(self.device)
            for key in TASK_LIST.keys()
        }

        self.shared_network.share_memory()
        for net in self.scene_networks.values():
            net.share_memory()

        parameters = list(self.shared_network.parameters())
        for net in self.scene_networks.values():
            parameters.extend(net.parameters())

        optimizer = SharedRMSprop(parameters,
                                  eps=self.rmsp_epsilon,
                                  alpha=self.rmsp_alpha,
                                  lr=self.learning_rate)
        optimizer.share_memory()

        scheduler = AnnealingLRScheduler(optimizer, self.total_epochs)

        optimizer_wrapper = TrainingOptimizer(self.grad_norm, optimizer,
                                              scheduler)
        self.optimizer = optimizer_wrapper
        optimizer_wrapper.share_memory()

        self.saver = TrainingSaver(self.shared_network, self.scene_networks,
                                   self.optimizer, self.config)
 def __init__(self, config):
     self.config = config
     self.device = config.get('device', torch.device('cuda:0'))
     self.shared_net = SharedNetwork().to(self.device)
     self.scene_nets = {
         key: SceneSpecificNetwork(ACTION_SPACE_SIZE).to(self.device)
         for key in TASK_LIST.keys()
     }
    def __init__(self, config):
        self.config = config
        self.method = config['method']
        gpu_id = get_first_free_gpu(2000)
        self.device = torch.device("cuda:" + str(gpu_id))
        if self.method != "random":
            self.shared_net = SharedNetwork(self.config['method'],
                                            self.config.get('mask_size',
                                                            5)).to(self.device)
            self.scene_net = SceneSpecificNetwork(
                self.config['action_size']).to(self.device)

        self.checkpoints = []
        self.checkpoint_id = 0
        self.saver = None
        self.chk_numbers = None
Exemplo n.º 4
0
    def initialize(self):
        # Shared network
        self.shared_network = SharedNetwork()
        self.scene_networks = {
            key: SceneSpecificNetwork(4)
            for key in TASK_LIST.keys()
        }

        # Share memory
        self.shared_network.share_memory()
        for net in self.scene_networks.values():
            net.share_memory()

        # Callect all parameters from all networks
        parameters = list(self.shared_network.parameters())
        for net in self.scene_networks.values():
            parameters.extend(net.parameters())

        # Create optimizer
        optimizer = SharedRMSprop(parameters,
                                  eps=self.rmsp_epsilon,
                                  alpha=self.rmsp_alpha,
                                  lr=self.learning_rate)
        optimizer.share_memory()

        # Create scheduler
        scheduler = AnnealingLRScheduler(optimizer, self.total_epochs)

        # Create optimizer wrapper
        optimizer_wrapper = TrainingOptimizer(self.grad_norm, optimizer,
                                              scheduler)
        self.optimizer = optimizer_wrapper
        optimizer_wrapper.share_memory()

        # Initialize saver
        self.saver = TrainingSaver(self.shared_network, self.scene_networks,
                                   self.optimizer, self.config)
    def __init__(self, id: int, network: torch.nn.Module, saver, optimizer,
                 scene: str, **kwargs):

        super(TrainingThread, self).__init__()

        # Initialize the environment
        self.env = None
        self.init_args = kwargs
        self.scene = scene
        self.saver = saver
        self.local_backbone_network = SharedNetwork()
        self.id = id

        self.master_network = network
        self.optimizer = optimizer
Exemplo n.º 6
0
    def _initialize_thread(self):
        h5_file_path = self.init_args.get('h5_file_path')
        # self.logger = logging.getLogger('agent')
        # self.logger.setLevel(logging.INFO)
        self.init_args['h5_file_path'] = lambda scene: h5_file_path.replace('{scene}', scene)
        self.env = THORDiscreteEnvironment(self.scene, **self.init_args)
        self.gamma : float = self.init_args.get('gamma', 0.99)
        self.grad_norm: float = self.init_args.get('grad_norm', 40.0)
        entropy_beta : float = self.init_args.get('entropy_beta', 0.01)
        self.max_t : int = self.init_args.get('max_t', 1)# TODO: 5)
        self.local_t = 0
        self.action_space_size = self.get_action_space_size()

        self.criterion = ActorCriticLoss(entropy_beta)
        self.policy_network = nn.Sequential(SharedNetwork(), SceneSpecificNetwork(self.get_action_space_size()))

        # Initialize the episode
        self._reset_episode()
        self._sync_network()
Exemplo n.º 7
0
    def _initialize_thread(self):
        # self.logger = logging.getLogger('agent')
        # self.logger.setLevel(logging.INFO)
        #self.init_args['h5_file_path'] = lambda scene: h5_file_path.replace('{scene}', scene)
        #self.env = THORDiscreteEnvironment(self.scene, **self.init_args)
        self.env = HabitatDiscreteEnvironment(
            self.scene_glb, terminal_image=self.terminal_image)
        self.gamma = 0.99
        self.grad_norm = 40.0
        entropy_beta = 0.01
        self.local_t = 0
        self.action_space_size = self.get_action_space_size()

        self.criterion = ActorCriticLoss(entropy_beta)
        self.policy_network = nn.Sequential(
            SharedNetwork(),
            SceneSpecificNetwork(self.get_action_space_size())).to(self.device)
        # Initialize the episode
        self._reset_episode()
        self._sync_network()
Exemplo n.º 8
0
    def __init__(self, id: int, optimizer, device, network: torch.nn.Module,
                 scene_glb: str, saver, max_t, terminal_image):

        super(TrainingThread, self).__init__()

        # Initialize the environment
        self.env = None
        self.scene_glb = scene_glb
        self.saver = saver
        self.max_t = max_t
        self.local_backbone_network = SharedNetwork().to(device)
        self.id = id
        self.terminal_image = terminal_image

        self.master_network = network
        self.optimizer = optimizer
        self.device = device

        self.fig = plt.figure()
        self.time_list = []
        self.reward_list = []
Exemplo n.º 9
0
class Training:
    def __init__(self, device, config):
        self.device = device
        self.config = config
        self.logger: logging.Logger = self._init_logger()
        self.learning_rate = config.get('learning_rate')
        self.rmsp_alpha = config.get('rmsp_alpha')
        self.rmsp_epsilon = config.get('rmsp_epsilon')
        self.grad_norm = config.get('grad_norm', 40.0)
        self.tasks = config.get('tasks', TASK_LIST)
        self.checkpoint_path = config.get('checkpoint_path',
                                          'model/checkpoint-{checkpoint}.pth')
        self.max_t = config.get('max_t', 5)
        self.total_epochs = TOTAL_PROCESSED_FRAMES // self.max_t
        self.initialize()

    @staticmethod
    def load_checkpoint(config, fail=True):
        device = torch.device('cpu')
        checkpoint_path = config.get('checkpoint_path',
                                     'model/checkpoint-{checkpoint}.pth')
        max_t = config.get('max_t', 5)
        total_epochs = TOTAL_PROCESSED_FRAMES // max_t
        files = os.listdir(os.path.dirname(checkpoint_path))
        base_name = os.path.basename(checkpoint_path)

        # Find latest checkpoint
        # TODO: improve speed
        restore_point = None
        if base_name.find('{checkpoint}') != -1:
            regex = re.escape(base_name).replace(re.escape('{checkpoint}'),
                                                 '(\d+)')
            points = [(fname, int(match.group(1))) for (fname, match) in ((
                fname,
                re.match(regex, fname),
            ) for fname in files) if not match is None]
            if len(points) == 0:
                if fail:
                    raise Exception('Restore point not found')
                else:
                    return None

            (base_name, restore_point) = max(points, key=lambda x: x[1])

        print(f'Restoring from checkpoint {restore_point}')
        state = torch.load(
            open(os.path.join(os.path.dirname(checkpoint_path), base_name),
                 'rb'))
        training = Training(device,
                            state['config'] if 'config' in state else config)
        training.saver.restore(state)
        return training

    def initialize(self):
        # Shared network
        self.shared_network = SharedNetwork()
        self.scene_networks = {
            key: SceneSpecificNetwork(4)
            for key in TASK_LIST.keys()
        }

        # Share memory
        self.shared_network.share_memory()
        for net in self.scene_networks.values():
            net.share_memory()

        # Callect all parameters from all networks
        parameters = list(self.shared_network.parameters())
        for net in self.scene_networks.values():
            parameters.extend(net.parameters())

        # Create optimizer
        optimizer = SharedRMSprop(parameters,
                                  eps=self.rmsp_epsilon,
                                  alpha=self.rmsp_alpha,
                                  lr=self.learning_rate)
        optimizer.share_memory()

        # Create scheduler
        scheduler = AnnealingLRScheduler(optimizer, self.total_epochs)

        # Create optimizer wrapper
        optimizer_wrapper = TrainingOptimizer(self.grad_norm, optimizer,
                                              scheduler)
        self.optimizer = optimizer_wrapper
        optimizer_wrapper.share_memory()

        # Initialize saver
        self.saver = TrainingSaver(self.shared_network, self.scene_networks,
                                   self.optimizer, self.config)

    def run(self):
        self.logger.info("Training started")
        self.print_parameters()

        # Prepare threads
        branches = [(scene, int(target)) for scene in TASK_LIST.keys()
                    for target in TASK_LIST.get(scene)]

        def _createThread(id, task):
            (scene, target) = task
            net = nn.Sequential(self.shared_network,
                                self.scene_networks[scene])
            net.share_memory()
            return TrainingThread(id=id,
                                  optimizer=self.optimizer,
                                  network=net,
                                  scene=scene,
                                  saver=self.saver,
                                  max_t=self.max_t,
                                  terminal_state_id=target,
                                  **self.config)

        self.threads = [
            _createThread(i, task) for i, task in enumerate(branches)
        ]

        for thread in self.threads:
            thread.start()

        for thread in self.threads:
            thread.join()

    def _init_logger(self):
        logger = logging.getLogger('agent')
        logger.setLevel(logging.INFO)
        logger.addHandler(logging.StreamHandler(sys.stdout))
        return logger

    def print_parameters(self):
        self.logger.info(f"- gamma: {self.config.get('gamma')}")
        self.logger.info(
            f"- learning rate: {self.config.get('learning_rate')}")
class Evaluation:
    def __init__(self, config):
        self.config = config
        self.device = config.get('device', torch.device('cuda:0'))
        self.shared_net = SharedNetwork().to(self.device)
        self.scene_nets = {
            key: SceneSpecificNetwork(ACTION_SPACE_SIZE).to(self.device)
            for key in TASK_LIST.keys()
        }

    @staticmethod
    def load_checkpoint(config, fail=True):
        checkpoint_path = config.get('checkpoint_path',
                                     'model/checkpoint-{checkpoint}.pth')

        (base_name, restore_point) = find_restore_point(checkpoint_path, fail)
        print(f'Restoring from checkpoint {restore_point}')
        state = torch.load(
            open(os.path.join(os.path.dirname(checkpoint_path), base_name),
                 'rb'))
        evaluation = Evaluation(config)
        saver = TrainingSaver(evaluation.shared_net, evaluation.scene_nets,
                              None, evaluation.config)
        print('Configuration')
        saver.restore(state)
        saver.print_config(offset=4)
        return evaluation

    def build_agent(self, scene_name):
        parent = self
        net = torch.nn.Sequential(parent.shared_net,
                                  parent.scene_nets[scene_name])

        class Agent:
            def __init__(self, initial_state, target):
                self.env = HabitatDiscreteEnvironment(self.scene_glb)
                self.env.reset()
                self.net = net

            @staticmethod
            def get_parameters():
                return net.parameters()

            def act(self):
                with torch.no_grad():
                    state = torch.Tensor(self.env.render()).to(parent.device)
                    target = torch.Tensor(self.env.render_target()).to(
                        parent.device)
                    (
                        policy,
                        value,
                    ) = net.forward((
                        state,
                        target,
                    ))
                    action = F.softmax(
                        policy, dim=0).multinomial(1).cpu().data.numpy()[0]

                self.env.step(action)
                return (self.env.is_terminal, self.env.collided,
                        self.env.reward)

        return Agent

    def run(self):
        scene_stats = dict()
        resultData = []

        for scene_scope, image in TASK_LIST.items():
            scene_net = self.scene_nets[scene_scope]
            scene_stats[scene_scope] = list()
            env = HabitatDiscreteEnvironment(scene_scope, terminal_image=image)
            ep_rewards = []
            ep_lengths = []
            ep_collisions = []
            ep_normalized_lengths = []

            env.reset()
            terminal = False
            ep_reward = 0
            ep_collision = 0
            ep_t = 0

            while not terminal:
                state = torch.Tensor(env.render()).permute(0, 3, 1,
                                                           2).to(self.device)

                target = torch.Tensor(env.render_target()).permute(
                    0, 3, 1, 2).to(self.device)

                (
                    policy,
                    value,
                ) = scene_net.forward(
                    self.shared_net.forward((
                        state,
                        target,
                    )))

                with torch.no_grad():
                    action = F.softmax(
                        policy, dim=0).multinomial(1).cpu().data.numpy()[0]
                print("Applied action: ", action)
                env.step(action)
                terminal = env.is_terminal

                ep_reward += env.reward
                ep_t += 1

            ep_lengths.append(ep_t)
            ep_rewards.append(ep_reward)
            ep_collisions.append(ep_collision)
            ep_normalized_lengths.append(ep_t)
            if VERBOSE:
                print("episode #{} ends after {} steps".format(
                    i_episode, ep_t))

            print('evaluation: %s' % (scene_scope))
            print('mean episode reward: %.2f' % np.mean(ep_rewards))
            print('mean episode length: %.2f' % np.mean(ep_lengths))
            print('mean episode collision: %.2f' % np.mean(ep_collisions))
            print('mean normalized episode length: %.2f' %
                  np.mean(ep_normalized_lengths))
            scene_stats[scene_scope].extend(ep_lengths)
            resultData.append((
                scene_scope,
                np.mean(ep_rewards),
                np.mean(ep_lengths),
                np.mean(ep_collisions),
                np.mean(ep_normalized_lengths),
            ))
            print('\nResults (average trajectory length):')

        for scene_scope in scene_stats:
            print('%s: %.2f steps' %
                  (scene_scope, np.mean(scene_stats[scene_scope])))

        if 'csv_file' in self.config and self.config['csv_file'] is not None:
            export_to_csv(resultData, self.config['csv_file'])
Exemplo n.º 11
0
class Evaluation:
    def __init__(self, config):
        self.config = config
        self.device = config.get('device', torch.device('cpu'))
        self.shared_net = SharedNetwork().to(self.device)
        self.scene_nets = { key:SceneSpecificNetwork(ACTION_SPACE_SIZE).to(self.device) for key in TASK_LIST.keys() }

    @staticmethod
    def load_checkpoint(config, fail = True):
        checkpoint_path = config.get('checkpoint_path', 'model/checkpoint-{checkpoint}.pth')
        
        (base_name, restore_point) = find_restore_point(checkpoint_path, fail)
        print(f'Restoring from checkpoint {restore_point}')
        state = torch.load(open(os.path.join(os.path.dirname(checkpoint_path), base_name), 'rb'))
        evaluation = Evaluation(config)
        saver = TrainingSaver(evaluation.shared_net, evaluation.scene_nets, None, evaluation.config)
        print('Configuration')
        saver.restore(state)
        saver.print_config(offset = 4)            
        return evaluation

    def build_agent(self, scene_name):
        parent = self
        net = torch.nn.Sequential(parent.shared_net, parent.scene_nets[scene_name])
        class Agent:
            def __init__(self, initial_state, target):
                self.env = THORDiscreteEnvironment(
                    scene_name=scene_name,
                    initial_state_id = initial_state,
                    terminal_state_id = target,
                    h5_file_path=(lambda scene: parent.config["h5_file_path"].replace("{scene}", scene_name))
                )

                self.env.reset()
                self.net = net

            @staticmethod
            def get_parameters():
                return net.parameters()

            def act(self):
                with torch.no_grad():
                    state = torch.Tensor(self.env.render(mode='resnet_features')).to(parent.device)
                    target = torch.Tensor(self.env.render_target(mode='resnet_features')).to(parent.device)
                    (policy, value,) = net.forward((state, target,))
                    action = F.softmax(policy, dim=0).multinomial(1).cpu().data.numpy()[0]

                self.env.step(action)
                return (self.env.is_terminal, self.env.collided, self.env.reward)
        return Agent
        
    
    def run(self):
        scene_stats = dict()
        resultData = []
        for scene_scope, items in TASK_LIST.items():
            if len(self.config['test_scenes']) != 0 and not scene_scope in self.config['test_scenes']:
                continue

            scene_net = self.scene_nets[scene_scope]
            scene_stats[scene_scope] = list()
            for task_scope in items:
                env = THORDiscreteEnvironment(
                    scene_name=scene_scope,
                    h5_file_path=(lambda scene: self.config.get("h5_file_path", "D:\\datasets\\visual_navigation_precomputed\\{scene}.h5").replace('{scene}', scene)),
                    terminal_state_id=int(task_scope),
                )

                graph = env._get_graph_handle()
                hitting_times = graph['hitting_times'][()]
                shortest_paths = graph['shortest_path_distance'][()]

                ep_rewards = []
                ep_lengths = []
                ep_collisions = []
                ep_normalized_lengths = []
                for (i_episode, start) in enumerate(env.get_initial_states(int(task_scope))):
                    env.reset(initial_state_id = start)
                    terminal = False
                    ep_reward = 0
                    ep_collision = 0
                    ep_t = 0
                    hitting_time = hitting_times[start, int(task_scope)]
                    shortest_path = shortest_paths[start, int(task_scope)]

                    while not terminal:
                        state = torch.Tensor(env.render(mode='resnet_features'))
                        target = torch.Tensor(env.render_target(mode='resnet_features'))
                        (policy, value,) = scene_net.forward(self.shared_net.forward((state, target,)))

                        with torch.no_grad():
                            action = F.softmax(policy, dim=0).multinomial(1).data.numpy()[0]
                        env.step(action)
                        terminal = env.is_terminal

                        if ep_t == hitting_time: break
                        if env.collided: ep_collision += 1
                        ep_reward += env.reward
                        ep_t += 1                   


                    ep_lengths.append(ep_t)
                    ep_rewards.append(ep_reward)
                    ep_collisions.append(ep_collision)
                    ep_normalized_lengths.append(min(ep_t, hitting_time) / shortest_path)
                    if VERBOSE: print("episode #{} ends after {} steps".format(i_episode, ep_t))

                    
                print('evaluation: %s %s' % (scene_scope, task_scope))
                print('mean episode reward: %.2f' % np.mean(ep_rewards))
                print('mean episode length: %.2f' % np.mean(ep_lengths))
                print('mean episode collision: %.2f' % np.mean(ep_collisions))
                print('mean normalized episode length: %.2f' % np.mean(ep_normalized_lengths))
                scene_stats[scene_scope].extend(ep_lengths)
                resultData.append((scene_scope, str(task_scope), np.mean(ep_rewards), np.mean(ep_lengths), np.mean(ep_collisions), np.mean(ep_normalized_lengths),))

        print('\nResults (average trajectory length):')
        for scene_scope in scene_stats:
            print('%s: %.2f steps'%(scene_scope, np.mean(scene_stats[scene_scope])))
        
        if 'csv_file' in self.config and self.config['csv_file'] is not None:
            export_to_csv(resultData, self.config['csv_file'])
 def __init__(self, config):
     self.config = config
     self.shared_net = SharedNetwork()
     self.scene_nets = { key:SceneSpecificNetwork(ACTION_SPACE_SIZE) for key in TASK_LIST.keys() }
class Evaluation:
    def __init__(self, config):
        self.config = config
        self.shared_net = SharedNetwork()
        self.scene_nets = { key:SceneSpecificNetwork(ACTION_SPACE_SIZE) for key in TASK_LIST.keys() }

    @staticmethod
    def load_checkpoint(config, fail = True):
        checkpoint_path = config.get('checkpoint_path', 'model/checkpoint-{checkpoint}.pth')
        
        (base_name, restore_point) = find_restore_point(checkpoint_path, fail)
        print(f'Restoring from checkpoint {restore_point}')
        state = torch.load(open(os.path.join(os.path.dirname(checkpoint_path), base_name), 'rb'))
        evaluation = Evaluation(config)
        saver = TrainingSaver(evaluation.shared_net, evaluation.scene_nets, None, evaluation.config)
        saver.restore(state)        
        return evaluation
        
    
    def run(self):
        scene_stats = dict()
        resultData = []
        for scene_scope, items in TASK_LIST.items():
            scene_net = self.scene_nets[scene_scope]
            scene_stats[scene_scope] = list()
            for task_scope in items:
                env = THORDiscreteEnvironment(
                    scene_name=scene_scope,
                    h5_file_path=(lambda scene: self.config.get("h5_file_path", "D:\\datasets\\visual_navigation_precomputed\\{scene}.h5").replace('{scene}', scene)),
                    terminal_state_id=int(task_scope)
                )

                ep_rewards = []
                ep_lengths = []
                ep_collisions = []
                for i_episode in range(NUM_EVAL_EPISODES):
                    env.reset()
                    terminal = False
                    ep_reward = 0
                    ep_collision = 0
                    ep_t = 0
                    while not terminal:
                        state = torch.Tensor(env.render(mode='resnet_features'))
                        target = torch.Tensor(env.render_target(mode='resnet_features'))
                        (policy, value,) = scene_net.forward(self.shared_net.forward((state, target,)))

                        with torch.no_grad():
                            action = F.softmax(policy, dim=0).multinomial(1).data.numpy()[0]
                        env.step(action)
                        terminal = env.is_terminal

                        if ep_t == 10000: break
                        if env.collided: ep_collision += 1
                        ep_reward += env.reward
                        ep_t += 1

                    ep_lengths.append(ep_t)
                    ep_rewards.append(ep_reward)
                    ep_collisions.append(ep_collision)
                    if VERBOSE: print("episode #{} ends after {} steps".format(i_episode, ep_t))

                print('evaluation: %s %s' % (scene_scope, task_scope))
                print('mean episode reward: %.2f' % np.mean(ep_rewards))
                print('mean episode length: %.2f' % np.mean(ep_lengths))
                print('mean episode collision: %.2f' % np.mean(ep_collisions))
                scene_stats[scene_scope].extend(ep_lengths)
                resultData.append((scene_scope, str(task_scope), np.mean(ep_rewards), np.mean(ep_lengths), np.mean(ep_collisions),))

        print('\nResults (average trajectory length):')
        for scene_scope in scene_stats:
            print('%s: %.2f steps'%(scene_scope, np.mean(scene_stats[scene_scope])))
        
        if 'csv_file' in self.config and self.config['csv_file'] is not None:
            export_to_csv(resultData, self.config['csv_file'])
class FeatureEvaluation:
    def __init__(self, config):
        self.config = config
        self.method = config['method']
        gpu_id = get_first_free_gpu(2000)
        self.device = torch.device("cuda:" + str(gpu_id))
        if self.method != "random":
            self.shared_net = SharedNetwork(self.config['method'],
                                            self.config.get('mask_size',
                                                            5)).to(self.device)
            self.scene_net = SceneSpecificNetwork(
                self.config['action_size']).to(self.device)

        self.checkpoints = []
        self.checkpoint_id = 0
        self.saver = None
        self.chk_numbers = None

    @staticmethod
    def load_checkpoints(config, fail=True):
        evaluation = FeatureEvaluation(config)
        checkpoint_path = config.get('checkpoint_path',
                                     'model/checkpoint-{checkpoint}.pth')

        checkpoints = []
        (base_name, chk_numbers) = find_restore_points(checkpoint_path, fail)
        if evaluation.method != "random":
            try:
                for chk_name in base_name:
                    state = torch.load(
                        open(
                            os.path.join(os.path.dirname(checkpoint_path),
                                         chk_name), 'rb'))
                    checkpoints.append(state)
            except Exception as e:
                print("Error loading", e)
                exit()
            evaluation.saver = TrainingSaver(evaluation.shared_net,
                                             evaluation.scene_net, None,
                                             evaluation.config)
        evaluation.chk_numbers = chk_numbers
        evaluation.checkpoints = checkpoints
        return evaluation

    def restore(self):
        print('Restoring from checkpoint',
              self.chk_numbers[self.checkpoint_id])
        self.saver.restore(self.checkpoints[self.checkpoint_id])

    def next_checkpoint(self):
        self.checkpoint_id = (self.checkpoint_id + 1) % len(self.checkpoints)

    def run(self):
        random.seed(200)
        num_episode_eval = 10

        self.method_class = None
        if self.method == 'word2vec' or self.method == 'word2vec_noconv' or self.method == 'word2vec_notarget' or self.method == 'word2vec_nosimi':
            self.method_class = SimilarityGrid(self.method)
        elif self.method == 'aop' or self.method == 'aop_we':
            self.method_class = AOP(self.method)
        elif self.method == 'target_driven':
            self.method_class = TargetDriven(self.method)
        elif self.method == 'gcn':
            self.method_class = GCN(self.method)

        for chk_id in self.chk_numbers:
            scene_stats = dict()
            for scene_scope, items in self.config['task_list'].items():
                self.restore()
                scene_stats[scene_scope] = dict()
                scene_stats[scene_scope]["length"] = list()
                scene_stats[scene_scope]["spl"] = list()
                scene_stats[scene_scope]["success"] = list()
                scene_stats[scene_scope]["spl_long"] = list()
                scene_stats[scene_scope]["success_long"] = list()

                for task_scope in items:
                    env = THORDiscreteEnvironmentFile(
                        scene_name=scene_scope,
                        method=self.method,
                        reward=self.config['reward'],
                        h5_file_path=(lambda scene: self.config.get(
                            "h5_file_path").replace('{scene}', scene)),
                        terminal_state=task_scope,
                        action_size=self.config['action_size'],
                        mask_size=self.config.get('mask_size', 5))
                    print("Current task:", env.terminal_state['object'])
                    for i_episode in range(num_episode_eval):
                        if not env.reset():
                            continue
                        ep_t = 0
                        terminal = False
                        while not terminal:

                            if self.method != "random":
                                policy, value, state = self.method_class.forward_policy(
                                    env, self.device, lambda x: self.scene_net(
                                        self.shared_net(x)))
                                policy_softmax = F.softmax(policy, dim=0)
                                action = policy_softmax.multinomial(
                                    1).data.cpu().numpy()[0]

                            env.step(action)
                            if ep_t == 500:
                                terminal = True
                                break
                            ep_t += 1
                            env.reward
                            terminal = env.terminal
                            # Compute CAM only for terminal state
                            if terminal and env.success:
                                # Retrieve the feature from the convolution layer (similarity grid)
                                conv_output = self.shared_net.net.conv_output

                                state, x_processed, object_mask = self.method_class.extract_input(
                                    env, self.device)

                                # Create one hot vector for outputted action
                                one_hot_vector = torch.zeros(
                                    (1, env.action_size), dtype=torch.float32)
                                one_hot_vector[0][action] = 500
                                one_hot_vector = one_hot_vector.to(self.device)

                                # Reset grad
                                self.shared_net.zero_grad()
                                self.scene_net.zero_grad()

                                # Backward pass with specified action
                                policy.backward(gradient=one_hot_vector,
                                                retain_graph=True)

                                # Get hooked gradients for CAM
                                guided_gradients = self.shared_net.net.gradient.cpu(
                                ).data.numpy()[0]

                                # Get hooked gradients for Vanilla
                                vanilla_grad = self.shared_net.net.gradient_vanilla.cpu(
                                )
                                vanilla_grad = vanilla_grad.data.numpy()[0]

                                # Get convolution outputs
                                target = conv_output.cpu().data.numpy()[0]

                                # Get weights from gradients
                                # Take averages for each gradient
                                weights = np.mean(guided_gradients,
                                                  axis=(1, 2))
                                # Create empty numpy array for cam
                                cam = np.ones(target.shape[1:],
                                              dtype=np.float32)

                                # Multiply each weight with its conv output and then, sum
                                for i, w in enumerate(weights):
                                    cam += w * target[i, :, :]
                                cam = np.maximum(cam, 0)
                                cam = (cam - np.min(cam)) / (
                                    np.max(cam) - np.min(cam)
                                )  # Normalize between 0-1
                                cam = np.uint8(
                                    cam *
                                    255)  # Scale between 0-255 to visualize
                                cam = np.uint8(
                                    Image.fromarray(cam).resize(
                                        (object_mask.shape[2],
                                         object_mask.shape[3]),
                                        Image.ANTIALIAS)) / 255

                                # Create vanilla saliency img
                                vanilla_grad = vanilla_grad - vanilla_grad.min(
                                )
                                vanilla_grad /= vanilla_grad.max()

                                fig = plt.figure(figsize=(7 * 1.5, 2 * 1.5))
                                obs_plt = fig.add_subplot(141)
                                simi_grid_plt = fig.add_subplot(142)
                                cam_plt = fig.add_subplot(143)
                                vanilla_plt = fig.add_subplot(144)

                                # Observation visualization
                                obs_plt.title.set_text(
                                    'Observation, Target:' +
                                    env.terminal_state['object'])
                                obs_plt.imshow(env.observation)

                                # Simliratity grid visualization
                                simi_grid_plt.title.set_text("Similarity grid")
                                ob_mask_viz = object_mask.cpu().squeeze()
                                ob_mask_viz = np.flip(np.rot90(ob_mask_viz),
                                                      axis=0)
                                simi_grid_plt.imshow(ob_mask_viz,
                                                     vmin=0,
                                                     vmax=1,
                                                     cmap='gray')

                                # CAM visualisation
                                cam_plt.title.set_text("CAM visualization")
                                cam_viz = np.flip(np.rot90(cam), axis=0)
                                cam_plt.imshow(cam_viz,
                                               vmin=0,
                                               vmax=1,
                                               cmap='plasma')

                                # Vanilla saliency visualization
                                vanilla_grad = vanilla_grad.squeeze(0)
                                van_viz = np.uint8(
                                    Image.fromarray(vanilla_grad).resize(
                                        (object_mask.shape[2],
                                         object_mask.shape[3]),
                                        Image.ANTIALIAS)) / 255
                                vanilla_plt.title.set_text(
                                    "Vanilla saliency visualization")
                                van_viz = np.flip(np.rot90(vanilla_grad),
                                                  axis=0)
                                vanilla_plt.imshow(van_viz,
                                                   vmin=0,
                                                   vmax=1,
                                                   cmap='gray')

                                plt.tight_layout()
                                plt.show()

            break
import pickle
import os
import numpy as np

TASK_LIST = {
    'bathroom_02': ['26', '37', '43', '53', '69'],
    'bedroom_04': ['134', '264', '320', '384', '387'],
    'kitchen_02': ['90', '136', '157', '207', '329'],
    'living_room_08': ['92', '135', '193', '228', '254']
}

ACTION_SPACE_SIZE = 4
NUM_EVAL_EPISODES = 100
VERBOSE = False

shared_net = SharedNetwork()
scene_nets = {
    key: SceneSpecificNetwork(ACTION_SPACE_SIZE)
    for key in TASK_LIST.keys()
}

# Load weights trained on tensorflow
data = pickle.load(open(os.path.join(__file__, '..\\..\\weights.p'), 'rb'),
                   encoding='latin1')


def convertToStateDict(data):
    return {key: torch.Tensor(v) for (key, v) in data.items()}


shared_net.load_state_dict(convertToStateDict(data['navigation']))