Exemplo n.º 1
0
def load_envs_and_config(model_file):
    save_dict = torch.load(model_file)

    config = save_dict['config']
    config['device'] = 'cpu'
    config['envs']['CartPole-v0'][
        'solved_reward'] = 100000  # something big enough to prevent early out triggering

    env_factory = EnvFactory(config=config)
    reward_env = env_factory.generate_reward_env()
    reward_env.load_state_dict(save_dict['model'])
    real_env = env_factory.generate_real_env()

    return reward_env, real_env, config
def load_envs_and_config(model_file):
    save_dict = torch.load(model_file)

    config = save_dict['config']
    if BREAK == 'solved':
        config['envs']['Cliff']['solved_reward'] = -20  # something big enough to prevent early out triggering
    else:
        config['envs']['Cliff']['solved_reward'] = 100000  # something big enough to prevent early out triggering

    env_factory = EnvFactory(config=config)
    reward_env = env_factory.generate_reward_env()
    reward_env.load_state_dict(save_dict['model'])
    real_env = env_factory.generate_real_env()

    return reward_env, real_env, config
Exemplo n.º 3
0
    def compute(self, working_dir, bohb_id, config_id, cso, budget, *args,
                **kwargs):
        with open(CONFIG_FILE, 'r') as stream:
            default_config = yaml.safe_load(stream)

        config = self.get_specific_config(cso, default_config, budget)
        print('----------------------------')
        print("START BOHB ITERATION")
        print('CONFIG: ' + str(config))
        print('CSO:    ' + str(cso))
        print('BUDGET: ' + str(budget))
        print('----------------------------')

        info = {}

        # generate environment
        env_fac = EnvFactory(config)

        real_env = env_fac.generate_real_env()
        reward_env = env_fac.generate_reward_env()
        save_dict = torch.load(SAVE_FILE)
        #config = save_dict['config']
        reward_env.load_state_dict(save_dict['model'])

        score = 0
        for i in range(NUM_EVALS):
            td3 = TD3(env=reward_env,
                      max_action=reward_env.get_max_action(),
                      config=config)
            reward_list_train, _, _ = td3.train(reward_env, test_env=real_env)
            reward_list_test, _, _ = td3.test(real_env)
            avg_reward_test = statistics.mean(reward_list_test)

            unsolved_weight = config["agents"]["gtn"]["unsolved_weight"]
            score += len(reward_list_train) + max(
                0, (real_env.get_solved_reward() -
                    avg_reward_test)) * unsolved_weight

        score = score / NUM_EVALS

        info['config'] = str(config)

        print('----------------------------')
        print('FINAL SCORE: ' + str(score))
        print("END BOHB ITERATION")
        print('----------------------------')

        return {"loss": score, "info": info}
    def compute(self, working_dir, bohb_id, config_id, cso, budget, *args,
                **kwargs):
        # with open("default_config_halfcheetah.yaml", 'r') as stream:
        #     default_config = yaml.safe_load(stream)
        with open("default_config_cmc.yaml", 'r') as stream:
            default_config = yaml.safe_load(stream)

        config = self.get_specific_config(cso, default_config, budget)
        print('----------------------------')
        print("START BOHB ITERATION")
        print('CONFIG: ' + str(config))
        print('CSO:    ' + str(cso))
        print('BUDGET: ' + str(budget))
        print('----------------------------')

        info = {}

        # generate environment
        env_fac = EnvFactory(config)
        env = env_fac.generate_real_env()

        td3 = TD3(env=env,
                  max_action=env.get_max_action(),
                  config=config,
                  icm=True)

        # score_list = []
        rewards_list = []
        for _ in range(5):
            rewards, _, _ = td3.train(env)
            rewards_list.append(sum(rewards))
            # score_i = len(rewards)
            # score_list.append(score_i)

        score = -np.mean(rewards_list)
        # score = np.mean(score_list)

        info['config'] = str(config)

        print('----------------------------')
        print('FINAL SCORE: ' + str(score))
        print("END BOHB ITERATION")
        print('----------------------------')

        return {"loss": score, "info": info}
def load_envs_and_config(file_name, model_dir, device):
    file_path = os.path.join(model_dir, file_name)
    save_dict = torch.load(file_path)
    config = save_dict['config']
    config['device'] = device
    env_factory = EnvFactory(config=config)
    virtual_env = env_factory.generate_virtual_env()
    virtual_env.load_state_dict(save_dict['model'])
    real_env = env_factory.generate_real_env()

    # load additional agent configs
    with open("../default_config_acrobot.yaml", "r") as stream:
        config_new = yaml.safe_load(stream)["agents"]

    config["agents"]["duelingddqn"] = config_new["duelingddqn"]
    config["agents"]["duelingddqn_vary"] = config_new["duelingddqn_vary"]

    return virtual_env, real_env, config
    def compute(self, working_dir, bohb_id, config_id, cso, budget, *args, **kwargs):
        with open("default_config_cartpole.yaml", 'r') as stream:
            default_config = yaml.safe_load(stream)

        config = self.get_specific_config(cso, default_config, budget)
        print('----------------------------')
        print("START BOHB ITERATION")
        print('CONFIG: ' + str(config))
        print('CSO:    ' + str(cso))
        print('BUDGET: ' + str(budget))
        print('----------------------------')

        info = {}

        # generate environment
        env_fac = EnvFactory(config)
        env = env_fac.generate_real_env()

        ddqn = DDQN(env=env,
                    config=config,
                    icm=True)

        score_list = []
        for _ in range(5):
            rewards, _, _ = ddqn.train(env)
            score_i = len(rewards)
            score_list.append(score_i)

        score = np.mean(score_list)

        info['config'] = str(config)

        print('----------------------------')
        print('FINAL SCORE: ' + str(score))
        print("END BOHB ITERATION")
        print('----------------------------')

        return {
                "loss": score,
                "info": info
                }
Exemplo n.º 7
0
    def compute(self, working_dir, bohb_id, config_id, cso, budget, *args, **kwargs):
        with open("default_config_pendulum.yaml", 'r') as stream:
            default_config = yaml.safe_load(stream)

        config = self.get_specific_config(cso, default_config, budget)
        print('----------------------------')
        print("START BOHB ITERATION")
        print('CONFIG: ' + str(config))
        print('CSO:    ' + str(cso))
        print('BUDGET: ' + str(budget))
        print('----------------------------')

        config["agents"]["ppo"]["train_episodes"] = int(budget)

        info = {}

        # generate environment
        env_fac = EnvFactory(config)
        real_env = env_fac.generate_real_env()

        score = 0
        for i in range(NUM_EVALS):
            ppo = PPO(env=real_env,
                      config=config)
            rewards, _, _ = ppo.train(real_env)
            score += len(rewards)

        score = score/NUM_EVALS

        info['config'] = str(config)

        print('----------------------------')
        print('FINAL SCORE: ' + str(score))
        print("END BOHB ITERATION")
        print('----------------------------')

        return {
            "loss": score,
            "info": info
        }
Exemplo n.º 8
0
        config_mod['agents'][
            self.agent_name]['hidden_size'] = config['hidden_size']
        config_mod['agents'][
            self.agent_name]['hidden_layer'] = config['hidden_layer']

        print("full config: ", config_mod['agents'][self.agent_name])

        return config_mod


if __name__ == "__main__":
    with open("../default_config_cartpole.yaml", "r") as stream:
        config = yaml.safe_load(stream)

    torch.set_num_threads(1)

    # generate environment
    env_fac = EnvFactory(config)
    virt_env = env_fac.generate_virtual_env()
    real_env = env_fac.generate_real_env()

    timing = []
    for i in range(10):
        ddqn = DDQN_vary(env=real_env, config=config, icm=True)
        # ddqn.train(env=virt_env, time_remaining=50)
        print('TRAIN')
        ddqn.train(env=real_env, time_remaining=500)
        # print('TEST')
        # ddqn.test(env=real_env, time_remaining=500)
    print('avg. ' + str(sum(timing) / len(timing)))
Exemplo n.º 9
0
class GTN_Worker(GTN_Base):
    def __init__(self, id, bohb_id=-1):
        """
        params:
          id: identifies the different workers
          bohb_id: identifies the different BOHB runs
        """
        super().__init__(bohb_id)

        torch.manual_seed(id + bohb_id * id * 1000 + int(time.time()))
        torch.cuda.manual_seed_all(id + bohb_id * id * 1000 + int(time.time()))

        # for identifying the different workers
        self.id = id

        # flag to stop worker
        self.quit_flag = False
        self.time_sleep_worker = 3
        self.timeout = None

        # delete corresponding sync files if existent
        for file in [
                self.get_input_file_name(self.id),
                self.get_input_check_file_name(self.id),
                self.get_result_file_name(self.id),
                self.get_result_check_file_name(self.id)
        ]:
            if os.path.isfile(file):
                os.remove(file)

        print('Starting GTN Worker with bohb_id {} and id {}'.format(
            bohb_id, id))

    def late_init(self, config):
        gtn_config = config["agents"]["gtn"]
        self.noise_std = gtn_config["noise_std"]
        self.num_grad_evals = gtn_config["num_grad_evals"]
        self.grad_eval_type = gtn_config["grad_eval_type"]
        self.mirrored_sampling = gtn_config["mirrored_sampling"]
        self.time_sleep_worker = gtn_config["time_sleep_worker"]
        self.agent_name = gtn_config["agent_name"]
        self.synthetic_env_type = gtn_config["synthetic_env_type"]
        self.unsolved_weight = gtn_config["unsolved_weight"]

        # make it faster on single PC
        if gtn_config["mode"] == 'single':
            self.time_sleep_worker /= 10

        self.env_factory = EnvFactory(config)
        if self.synthetic_env_type == 0:
            generate_synthetic_env_fn = self.env_factory.generate_virtual_env
        elif self.synthetic_env_type == 1:
            generate_synthetic_env_fn = self.env_factory.generate_reward_env
        else:
            raise NotImplementedError("Unknown synthetic_env_type value: " +
                                      str(self.synthetic_env_type))

        self.synthetic_env_orig = generate_synthetic_env_fn(
            print_str='GTN_Base: ')
        self.synthetic_env = generate_synthetic_env_fn(print_str='GTN_Worker' +
                                                       str(id) + ': ')
        self.eps = generate_synthetic_env_fn('GTN_Worker' + str(id) + ': ')

    def run(self):
        # read data from master
        while not self.quit_flag:
            self.read_worker_input()

            time_start = time.time()

            # get score for network of the last outer loop iteration
            score_orig = self.calc_score(env=self.synthetic_env_orig,
                                         time_remaining=self.timeout -
                                         (time.time() - time_start))

            self.get_random_noise()

            # first mirrored noise +N
            self.add_noise_to_synthetic_env()

            score_add = []
            for i in range(self.num_grad_evals):
                score = self.calc_score(env=self.synthetic_env,
                                        time_remaining=self.timeout -
                                        (time.time() - time_start))
                score_add.append(score)

            # # second mirrored noise -N
            self.subtract_noise_from_synthetic_env()

            score_sub = []
            for i in range(self.num_grad_evals):
                score = self.calc_score(env=self.synthetic_env,
                                        time_remaining=self.timeout -
                                        (time.time() - time_start))
                score_sub.append(score)

            score_best = self.calc_best_score(score_add=score_add,
                                              score_sub=score_sub)

            self.write_worker_result(score=score_best,
                                     score_orig=score_orig,
                                     time_elapsed=time.time() - time_start)

            if self.quit_flag:
                print('QUIT FLAG')
                break

        print('Worker ' + str(self.id) + ' quitting')

    def read_worker_input(self):
        file_name = self.get_input_file_name(id=self.id)
        check_file_name = self.get_input_check_file_name(id=self.id)

        while not os.path.isfile(check_file_name):
            time.sleep(self.time_sleep_worker)
        time.sleep(self.time_sleep_worker)

        data = torch.load(file_name)

        self.timeout = data['timeout']
        self.quit_flag = data['quit_flag']
        self.config = data['config']

        self.late_init(self.config)

        self.synthetic_env_orig.load_state_dict(data['synthetic_env_orig'])
        self.synthetic_env.load_state_dict(data['synthetic_env_orig'])

        os.remove(check_file_name)
        os.remove(file_name)

    def write_worker_result(self, score, score_orig, time_elapsed):
        file_name = self.get_result_file_name(id=self.id)
        check_file_name = self.get_result_check_file_name(id=self.id)

        # wait until master has deleted the file (i.e. acknowledged the previous result)
        while os.path.isfile(file_name):
            time.sleep(self.time_sleep_worker)

        data = {}
        data["eps"] = self.eps.state_dict()
        data["synthetic_env"] = self.synthetic_env.state_dict(
        )  # for debugging
        data["time_elapsed"] = time_elapsed
        data["score"] = score
        data["score_orig"] = score_orig
        torch.save(data, file_name)
        torch.save({}, check_file_name)

    def get_random_noise(self):
        for l_virt, l_eps in zip(self.synthetic_env.modules(),
                                 self.eps.modules()):
            if isinstance(l_virt, nn.Linear):
                l_eps.weight = torch.nn.Parameter(
                    torch.normal(mean=torch.zeros_like(l_virt.weight),
                                 std=torch.ones_like(l_virt.weight)) *
                    self.noise_std)
                if l_eps.bias != None:
                    l_eps.bias = torch.nn.Parameter(
                        torch.normal(mean=torch.zeros_like(l_virt.bias),
                                     std=torch.ones_like(l_virt.bias)) *
                        self.noise_std)

    def add_noise_to_synthetic_env(self, add=True):
        for l_orig, l_virt, l_eps in zip(self.synthetic_env_orig.modules(),
                                         self.synthetic_env.modules(),
                                         self.eps.modules()):
            if isinstance(l_virt, nn.Linear):
                if add:  # add eps
                    l_virt.weight = torch.nn.Parameter(l_orig.weight +
                                                       l_eps.weight)
                    if l_virt.bias != None:
                        l_virt.bias = torch.nn.Parameter(l_orig.bias +
                                                         l_eps.bias)
                else:  # subtract eps
                    l_virt.weight = torch.nn.Parameter(l_orig.weight -
                                                       l_eps.weight)
                    if l_virt.bias != None:
                        l_virt.bias = torch.nn.Parameter(l_orig.bias -
                                                         l_eps.bias)

    def subtract_noise_from_synthetic_env(self):
        self.add_noise_to_synthetic_env(add=False)

    def invert_eps(self):
        for l_eps in self.eps.modules():
            if isinstance(l_eps, nn.Linear):
                l_eps.weight = torch.nn.Parameter(-l_eps.weight)
                if l_eps.bias != None:
                    l_eps.bias = torch.nn.Parameter(-l_eps.bias)

    def calc_score(self, env, time_remaining):
        time_start = time.time()

        agent = select_agent(config=self.config, agent_name=self.agent_name)
        real_env = self.env_factory.generate_real_env()

        reward_list_train, episode_length_train, _ = agent.train(
            env=env,
            test_env=real_env,
            time_remaining=time_remaining - (time.time() - time_start))
        reward_list_test, _, _ = agent.test(env=real_env,
                                            time_remaining=time_remaining -
                                            (time.time() - time_start))
        avg_reward_test = statistics.mean(reward_list_test)

        if env.is_virtual_env():
            return avg_reward_test
        else:
            # # when timeout occurs, reward_list_train is padded (with min. reward values) and episode_length_train is not
            # if len(episode_length_train) < len(reward_list_train):
            #     print("due to timeout, reward_list_train has been padded")
            #     print(f"shape rewards: {np.shape(reward_list_train)}, shape episode lengths: {np.shape(episode_length_train)}")
            #     reward_list_train = reward_list_train[:len(episode_length_train)]

            print("AVG REWARD: ", avg_reward_test)
            return avg_reward_test

            # print("AUC: ", np.dot(reward_list_train, episode_length_train))
            # return np.dot(reward_list_train, episode_length_train)

            # if not real_env.can_be_solved():
            #     return avg_reward_test
            # else:
            #     print(str(sum(episode_length_train)) + ' ' + str(max(0, (real_env.get_solved_reward()-avg_reward_test))) + ' ' + str(self.unsolved_weight))
            #     # we maximize the objective
            #     # sum(episode_length_train) + max(0, (real_env.get_solved_reward()-avg_reward_test))*self.unsolved_weight
            #     return -sum(episode_length_train) - max(0, (real_env.get_solved_reward()-avg_reward_test))*self.unsolved_weight

    def calc_best_score(self, score_sub, score_add):
        if self.grad_eval_type == 'mean':
            score_sub = statistics.mean(score_sub)
            score_add = statistics.mean(score_add)
        elif self.grad_eval_type == 'minmax':
            score_sub = min(score_sub)
            score_add = min(score_add)
        else:
            raise NotImplementedError(
                'Unknown parameter for grad_eval_type: ' +
                str(self.grad_eval_type))

        if self.mirrored_sampling:
            score_best = max(score_add, score_sub)
            if score_sub > score_add:
                self.invert_eps()
            else:
                self.add_noise_to_synthetic_env()
        else:
            score_best = score_add
            self.add_noise_to_synthetic_env()

        return score_best
Exemplo n.º 10
0
class GTN_Master(GTN_Base):
    def __init__(self, config, bohb_id=-1, bohb_working_dir=None):
        super().__init__(bohb_id)
        self.config = config
        self.device = config["device"]
        self.env_name = config['env_name']

        gtn_config = config["agents"]["gtn"]
        self.max_iterations = gtn_config["max_iterations"]
        self.agent_name = gtn_config["agent_name"]
        self.num_workers = gtn_config["num_workers"]
        self.step_size = gtn_config["step_size"]
        self.nes_step_size = gtn_config["nes_step_size"]
        self.weight_decay = gtn_config["weight_decay"]
        self.score_transform_type = gtn_config["score_transform_type"]
        self.time_mult = gtn_config["time_mult"]
        self.time_max = gtn_config["time_max"]
        self.time_sleep_master = gtn_config["time_sleep_master"]
        self.quit_when_solved = gtn_config["quit_when_solved"]
        self.synthetic_env_type = gtn_config["synthetic_env_type"]
        self.unsolved_weight = gtn_config["unsolved_weight"]

        # make it faster on single PC
        if gtn_config["mode"] == 'single':
            self.time_sleep_master /= 10

        # to store results from workers
        self.time_elapsed_list = [None] * self.num_workers  # for debugging
        self.score_list = [None] * self.num_workers
        self.score_orig_list = [None] * self.num_workers  # for debugging
        self.score_transform_list = [None] * self.num_workers

        # to keep track of the reference virtual env
        self.env_factory = EnvFactory(config)
        if self.synthetic_env_type == 0:
            generate_synthetic_env_fn = self.env_factory.generate_virtual_env
        elif self.synthetic_env_type == 1:
            generate_synthetic_env_fn = self.env_factory.generate_reward_env
        else:
            raise NotImplementedError("Unknown synthetic_env_type value: " +
                                      str(self.synthetic_env_type))

        self.synthetic_env_orig = generate_synthetic_env_fn(
            print_str='GTN_Base: ')
        self.synthetic_env_list = [
            generate_synthetic_env_fn(print_str='GTN_Master: ')
            for _ in range(self.num_workers)
        ]
        self.eps_list = [
            generate_synthetic_env_fn(print_str='GTN_Master: ')
            for _ in range(self.num_workers)
        ]

        # for early out
        self.avg_runtime = None
        self.real_env = self.env_factory.generate_real_env()

        # to store models
        if bohb_working_dir:
            self.model_dir = str(
                os.path.join(bohb_working_dir, 'GTN_models_' + self.env_name))
        else:
            self.model_dir = str(
                os.path.join(os.getcwd(), "results",
                             'GTN_models_' + self.env_name))
        self.model_name = self.get_model_file_name(
            self.env_name + '_' +
            ''.join(random.choices(string.ascii_uppercase +
                                   string.digits, k=6)) + '.pt')
        self.best_score = -float('Inf')

        os.makedirs(self.model_dir, exist_ok=True)

        print('Starting GTN Master with bohb_id {}'.format(bohb_id))

    def get_model_file_name(self, file_name):
        return os.path.join(self.model_dir, file_name)

    def run(self):
        mean_score_orig_list = []

        for it in range(self.max_iterations):
            t1 = time.time()
            print('-- Master: Iteration ' + str(it) + ' ' +
                  str(time.time() - t1))
            print('-- Master: write worker inputs' + ' ' +
                  str(time.time() - t1))
            self.write_worker_inputs(it)
            print('-- Master: read worker results' + ' ' +
                  str(time.time() - t1))
            self.read_worker_results()

            mean_score = np.mean(self.score_orig_list)
            mean_score_orig_list.append(mean_score)
            solved_flag = self.save_good_model(mean_score)

            if solved_flag and self.quit_when_solved:
                print('ENV SOLVED')
                break

            print('-- Master: rank transform' + ' ' + str(time.time() - t1))
            self.score_transform()
            print('-- Master: update env' + ' ' + str(time.time() - t1))
            self.update_env()
            print('-- Master: print statistics' + ' ' + str(time.time() - t1))
            self.print_statistics(it=it, time_elapsed=time.time() - t1)

        print('Master quitting')

        self.print_statistics(it=-1, time_elapsed=-1)

        # error handling
        if len(mean_score_orig_list) > 0:
            return np.mean(
                self.score_orig_list), mean_score_orig_list, self.model_name
        else:
            return 1e9, mean_score_orig_list, self.model_name

    def save_good_model(self, mean_score):
        if self.synthetic_env_orig.is_virtual_env():
            if mean_score > self.real_env.get_solved_reward(
            ) and mean_score > self.best_score:
                self.save_model()
                self.best_score = mean_score
                return True
        else:
            # we save all models and select the best from the log
            # whether we can solve an environment is irrelevant for reward_env since we optimize for speed here
            if mean_score > self.best_score:
                self.save_model()
                self.best_score = mean_score

        return False

    def save_model(self):
        save_dict = {}
        save_dict['model'] = self.synthetic_env_orig.state_dict()
        save_dict['config'] = self.config
        save_path = os.path.join(self.model_dir, self.model_name)
        print('save model: ' + str(save_path))
        torch.save(save_dict, save_path)

    def calc_worker_timeout(self):
        if self.time_elapsed_list[0] is None:
            return self.time_max
        else:
            return statistics.mean(self.time_elapsed_list) * self.time_mult

    def write_worker_inputs(self, it):
        timeout = self.calc_worker_timeout()
        print('timeout: ' + str(timeout))

        for id in range(self.num_workers):

            file_name = self.get_input_file_name(id=id)
            check_file_name = self.get_input_check_file_name(id=id)

            # wait until worker has deleted the file (i.e. acknowledged the previous input)
            while os.path.isfile(file_name):
                time.sleep(self.time_sleep_master)

            time.sleep(self.time_sleep_master)

            # if we are not using bohb, shut everything down after last iteration
            if self.bohb_id < 0:
                quit_flag = it == self.max_iterations - 1
            else:
                quit_flag = False

            data = {}
            data['timeout'] = timeout
            data['quit_flag'] = quit_flag
            data['config'] = self.config
            data['synthetic_env_orig'] = self.synthetic_env_orig.state_dict()

            torch.save(data, file_name)
            torch.save({}, check_file_name)

    def read_worker_results(self):
        for id in range(self.num_workers):
            file_name = self.get_result_file_name(id)
            check_file_name = self.get_result_check_file_name(id)

            # wait until worker has finished calculations
            while not os.path.isfile(check_file_name):
                time.sleep(self.time_sleep_master)

            data = torch.load(file_name)
            self.time_elapsed_list[id] = data['time_elapsed']
            self.score_list[id] = data['score']
            self.eps_list[id].load_state_dict(data['eps'])
            self.score_orig_list[id] = data['score_orig']
            self.synthetic_env_list[id].load_state_dict(
                data['synthetic_env'])  # for debugging

            os.remove(check_file_name)
            os.remove(file_name)

    def score_transform(self):
        scores = np.asarray(self.score_list)
        scores_orig = np.asarray(self.score_orig_list)

        if self.score_transform_type == 0:
            # convert [1, 0, 5] to [0.2, 0, 1]
            scores = (scores - min(scores)) / (max(scores) - min(scores) +
                                               1e-9)

        elif self.score_transform_type == 1:
            # convert [1, 0, 5] to [0.5, 0, 1]
            s = np.argsort(scores)
            n = len(scores)
            for i in range(n):
                scores[s[i]] = i / (n - 1)

        elif self.score_transform_type == 2 or self.score_transform_type == 3:
            # fitness shaping from "Natural Evolution Strategies" (Wierstra 2014) paper, either with zero mean (2) or without (3)
            lmbda = len(scores)
            s = np.argsort(-scores)
            for i in range(lmbda):
                scores[s[i]] = i + 1
            scores = scores.astype(float)
            for i in range(lmbda):
                scores[i] = max(0, np.log(lmbda / 2 + 1) - np.log(scores[i]))

            scores = scores / sum(scores)

            if self.score_transform_type == 2:
                scores -= 1 / lmbda

            scores /= max(scores)

        elif self.score_transform_type == 4:
            # consider single best eps
            scores_tmp = np.zeros(scores.size)
            scores_tmp[np.argmax(scores)] = 1
            scores = scores_tmp

        elif self.score_transform_type == 5:
            # consider single best eps that is better than the average
            avg_score_orig = np.mean(scores_orig)

            scores_idx = np.where(scores > avg_score_orig + 1e-6, 1,
                                  0)  # 1e-6 to counter numerical errors
            if sum(scores_idx) > 0:
                scores_tmp = np.zeros(scores.size)
                scores_tmp[np.argmax(scores)] = 1
                scores = scores_tmp
            else:
                scores = scores_idx

        elif self.score_transform_type == 6 or self.score_transform_type == 7:
            # consider all eps that are better than the average, normalize weight sum to 1
            avg_score_orig = np.mean(scores_orig)

            scores_idx = np.where(scores > avg_score_orig + 1e-6, 1,
                                  0)  # 1e-6 to counter numerical errors
            if sum(scores_idx) > 0:
                #if sum(scores_idx) > 0:
                scores = scores_idx * (scores - avg_score_orig) / (
                    max(scores) - avg_score_orig + 1e-9)
                if self.score_transform_type == 6:
                    scores /= max(scores)
                else:
                    scores /= sum(scores)
            else:
                scores = scores_idx

        else:
            raise ValueError("Unknown rank transform type: " +
                             str(self.score_transform_type))

        self.score_transform_list = scores.tolist()

    def update_env(self):
        ss = self.step_size

        if self.nes_step_size:
            ss = ss / self.num_workers

        # print('-- update env --')
        print('score_orig_list      ' + str(self.score_orig_list))
        print('score_list           ' + str(self.score_list))
        print('score_transform_list ' + str(self.score_transform_list))
        print('venv weights         ' + str([
            calc_abs_param_sum(elem).item() for elem in self.synthetic_env_list
        ]))

        print('weights before: ' +
              str(calc_abs_param_sum(self.synthetic_env_orig).item()))

        # weight decay
        for l_orig in self.synthetic_env_orig.modules():
            if isinstance(l_orig, nn.Linear):
                l_orig.weight = torch.nn.Parameter(l_orig.weight *
                                                   (1 - self.weight_decay))
                if l_orig.bias != None:
                    l_orig.bias = torch.nn.Parameter(l_orig.bias *
                                                     (1 - self.weight_decay))

        print('weights after weight decay: ' +
              str(calc_abs_param_sum(self.synthetic_env_orig).item()))

        # weight update
        for eps, score_transform in zip(self.eps_list,
                                        self.score_transform_list):
            for l_orig, l_eps in zip(self.synthetic_env_orig.modules(),
                                     eps.modules()):
                if isinstance(l_orig, nn.Linear):
                    l_orig.weight = torch.nn.Parameter(l_orig.weight +
                                                       ss * score_transform *
                                                       l_eps.weight)
                    if l_orig.bias != None:
                        l_orig.bias = torch.nn.Parameter(l_orig.bias +
                                                         ss * score_transform *
                                                         l_eps.bias)

        print('weights after update: ' +
              str(calc_abs_param_sum(self.synthetic_env_orig).item()))

    def print_statistics(self, it, time_elapsed):
        orig_score = statistics.mean(self.score_orig_list)
        mean_time_elapsed = statistics.mean(self.time_elapsed_list)
        print('--------------')
        print('GTN iteration:    ' + str(it))
        print('GTN mstr t_elaps: ' + str(time_elapsed))
        print('GTN avg wo t_elaps: ' + str(mean_time_elapsed))
        print('GTN avg eval score:   ' + str(orig_score))
        print('--------------')