コード例 #1
0
class Learner:
	"""Learner object encapsulating a local learner

		Parameters:
		algo_name (str): Algorithm Identifier
		state_dim (int): State size
		action_dim (int): Action size
		actor_lr (float): Actor learning rate
		critic_lr (float): Critic learning rate
		gamma (float): DIscount rate
		tau (float): Target network sync generate
		init_w (bool): Use kaimling normal to initialize?
		**td3args (**kwargs): arguments for TD3 algo


	"""

	def __init__(self, wwid, algo_name, state_dim, action_dim, actor_lr, critic_lr, gamma, tau, init_w = True, **td3args):
		self.td3args = td3args; self.id = id
		self.wwid = wwid
		self.algo = Off_Policy_Algo(wwid, algo_name, state_dim, action_dim, actor_lr, critic_lr, gamma, tau, init_w)
		self.args = td3args['cerl_args']

		#LEARNER STATISTICS
		self.fitnesses = []
		self.ep_lens = []
		self.value = None
		self.visit_count = 0
		self.private_replay_buffer = Buffer(1000000, self.args.buffer_gpu)   #

	def share_memory(self):
		self.algo.share_memory()

	def act(self, state, eps=None): #eps not used, to have common interface with dqn.act
		return self.algo.act(state)

	def step(self, state, action, reward, next_state, done):  #for training blue agent, add experience to reply buffer, and do one learning iteration
		self.private_replay_buffer.add(state, action, reward, next_state, done)
		self.update_parameters(self.private_replay_buffer, self.args.buffer_gpu, self.args.batch_size, iterations=1)

	def save_net(self, path):
		self.algo.save_net(path)

	def update_parameters(self, replay_buffer, buffer_gpu, batch_size, iterations):
		for _ in range(iterations):
			s, ns, a, r, done = replay_buffer.sample(batch_size)
			if not buffer_gpu and torch.cuda.is_available():
				s = s.cuda(); ns = ns.cuda(); a = a.cuda(); r = r.cuda(); done = done.cuda()
			self.algo.update_parameters(s, ns, a, r, done, 1, **self.td3args)


	def update_stats(self, fitness, ep_len, gamma=0.2):  #ms:fitness is the cum reward each whole episode
		self.visit_count += 1	#ms:visit_count is the number of workers per policy
		self.fitnesses.append(fitness)
		self.ep_lens.append(ep_len)

		if self.value == None: self.value = fitness	#ms: moving avg of fitness as value, value used in ucb for policy reallocation
		else: self.value = gamma * fitness + (1-gamma) * self.value
コード例 #2
0
ファイル: egrl_trainer.py プロジェクト: Anonymous991/EGRL
class EGRL_Trainer:
    """Main CERL class containing all methods for CERL

		Parameters:
		args (object): Parameter class with all the parameters

	"""
    def __init__(self, args, model_constructor, env_constructor,
                 observation_space, action_space, env, state_template,
                 test_envs, platform):
        self.args = args
        model_constructor.state_dim += 2
        self.platform = platform

        self.policy_string = self.compute_policy_type()
        self.device = torch.device("cuda" if torch.cuda.is_available(
        ) else "cpu") if self.args.gpu else torch.device('cpu')

        #Evolution
        dram_action = torch.ones((len(state_template.x), 2)) + 1
        state_template.x = torch.cat([state_template.x, dram_action], axis=1)
        self.evolver = MixedSSNE(
            self.args, state_template
        )  #GA(self.args) if args.boltzman else SSNE(self.args)
        self.env_constructor = env_constructor

        self.test_tracker = utils.Tracker(
            self.args.plot_folder,
            ['score_' + self.args.savetag, 'speedup_' + self.args.savetag],
            '.csv')  # Tracker class to log progress
        self.time_tracker = utils.Tracker(self.args.plot_folder, [
            'timed_score_' + self.args.savetag,
            'timed_speedup_' + self.args.savetag
        ], '.csv')
        self.champ_tracker = utils.Tracker(self.args.plot_folder, [
            'champ_score_' + self.args.savetag,
            'champ_speedup_' + self.args.savetag
        ], '.csv')
        self.pg_tracker = utils.Tracker(self.args.plot_folder, [
            'pg_noisy_speedup_' + self.args.savetag,
            'pg_clean_speedup_' + self.args.savetag
        ], '.csv')
        self.migration_tracker = utils.Tracker(self.args.plot_folder, [
            'selection_rate_' + self.args.savetag,
            'elite_rate_' + self.args.savetag
        ], '.csv')

        #Generalization Trackers
        self.r50_tracker = utils.Tracker(self.args.plot_folder, [
            'r50_score_' + self.args.savetag,
            'r50_speedup_' + self.args.savetag
        ], '.csv')
        self.r101_tracker = utils.Tracker(self.args.plot_folder, [
            'r101_score_' + self.args.savetag,
            'r101_speedup_' + self.args.savetag
        ], '.csv')
        self.bert_tracker = utils.Tracker(self.args.plot_folder, [
            'bert_score_' + self.args.savetag,
            'bert_speedup_' + self.args.savetag
        ], '.csv')

        self.r50_frames_tracker = utils.Tracker(self.args.plot_folder, [
            'r50_score_' + self.args.savetag,
            'r50_speedup_' + self.args.savetag
        ], '.csv')
        self.r101_frames_tracker = utils.Tracker(self.args.plot_folder, [
            'r101_score_' + self.args.savetag,
            'r101_speedup_' + self.args.savetag
        ], '.csv')
        self.bert_frames_tracker = utils.Tracker(self.args.plot_folder, [
            'bert_score_' + self.args.savetag,
            'bert_speedup_' + self.args.savetag
        ], '.csv')

        #Genealogy tool
        self.genealogy = Genealogy()

        self.env = env
        self.test_envs = test_envs

        if self.args.use_mp:
            #MP TOOLS
            self.manager = Manager()
            #Initialize Mixed Population
            self.population = self.manager.list()

        else:
            self.population = []

        boltzman_count = int(args.pop_size * args.ratio)
        rest = args.pop_size - boltzman_count
        for _ in range(boltzman_count):
            self.population.append(
                BoltzmannChromosome(model_constructor.num_nodes,
                                    model_constructor.action_dim))

        for _ in range(rest):
            self.population.append(
                model_constructor.make_model(self.policy_string))
            self.population[-1].eval()

        #Save best policy
        self.best_policy = model_constructor.make_model(self.policy_string)

        #Init BUFFER
        self.replay_buffer = Buffer(args.buffer_size, state_template,
                                    action_space,
                                    args.aux_folder + args.savetag)
        self.data_bucket = self.replay_buffer.tuples

        #Intialize portfolio of learners
        self.portfolio = []
        if args.rollout_size > 0:
            self.portfolio = initialize_portfolio(self.portfolio, self.args,
                                                  self.genealogy,
                                                  args.portfolio_id,
                                                  model_constructor)

        #Initialize Rollout Bucket
        self.rollout_bucket = self.manager.list() if self.args.use_mp else []
        for _ in range(len(self.portfolio)):
            self.rollout_bucket.append(
                model_constructor.make_model(self.policy_string))

        if self.args.use_mp:
            ############## MULTIPROCESSING TOOLS ###################
            #Evolutionary population Rollout workers
            data_bucket = self.data_bucket if args.rollout_size > 0 else None  #If Strictly Evo - don;t store data
            self.evo_task_pipes = [Pipe() for _ in range(args.pop_size)]
            self.evo_result_pipes = [Pipe() for _ in range(args.pop_size)]
            self.evo_workers = [
                Process(target=rollout_worker,
                        args=(id, 'evo', self.evo_task_pipes[id][1],
                              self.evo_result_pipes[id][0], data_bucket,
                              self.population, env_constructor))
                for id in range(args.pop_size)
            ]
            for worker in self.evo_workers:
                worker.start()

            #Learner rollout workers
            self.task_pipes = [Pipe() for _ in range(args.rollout_size)]
            self.result_pipes = [Pipe() for _ in range(args.rollout_size)]
            self.workers = [
                Process(target=rollout_worker,
                        args=(id, 'pg', self.task_pipes[id][1],
                              self.result_pipes[id][0], data_bucket,
                              self.rollout_bucket, env_constructor))
                for id in range(args.rollout_size)
            ]
            for worker in self.workers:
                worker.start()

        self.roll_flag = [True for _ in range(args.rollout_size)]
        self.evo_flag = [True for _ in range(args.pop_size)]

        #Meta-learning controller (Resource Distribution)
        self.allocation = [
        ]  #Allocation controls the resource allocation across learners
        for i in range(args.rollout_size):
            self.allocation.append(
                i % len(self.portfolio))  #Start uniformly (equal resources)

        #Trackers
        self.best_score = -float('inf')
        self.gen_frames = 0
        self.total_frames = 0
        self.best_speedup = -float('inf')
        self.champ_type = None

    def checkpoint(self):

        utils.pickle_obj(self.args.ckpt_folder + 'test_tracker',
                         self.test_tracker)
        utils.pickle_obj(self.args.ckpt_folder + 'time_tracker',
                         self.time_tracker)
        utils.pickle_obj(self.args.ckpt_folder + 'champ_tracker',
                         self.champ_tracker)
        for i in range(len(self.population)):
            net = self.population[i]

            if net.model_type == 'BoltzmanChromosome':
                utils.pickle_obj(self.args.ckpt_folder + 'Boltzman/' + str(i),
                                 net)

            else:
                torch.save(net.state_dict(),
                           self.args.ckpt_folder + 'Gumbel/' + str(i))

            self.population[i] = net

    def load_checkpoint(self):

        #Try to load trackers
        try:
            self.test_tracker = utils.unpickle_obj(self.args.ckpt_folder +
                                                   'test_tracker')
            self.time_tracker = utils.unpickle_obj(self.args.ckpt_folder +
                                                   'time_tracker')
            self.champ_tracker = utils.unpickle_obj(self.args.ckpt_folder +
                                                    'champ_tracker')
        except:
            None

        gumbel_template = False
        for i in range(len(self.population)):
            if self.population[i].model_type == 'GumbelPolicy':
                gumbel_template = self.population[i]
                break

        boltzman_nets = os.listdir(self.args.ckpt_folder + 'Boltzman/')
        gumbel_nets = os.listdir(self.args.ckpt_folder + 'Gumbel/')

        print('Boltzman seeds', boltzman_nets, 'Gumbel seeds', gumbel_nets)

        gumbel_models = []
        boltzman_models = []

        for fname in boltzman_nets:
            try:
                net = utils.unpickle_obj(self.args.ckpt_folder + 'Boltzman/' +
                                         fname)
                boltzman_models.append(net)
            except:
                print('Failed to load',
                      self.args.ckpt_folder + 'Boltzman/' + fname)

        for fname in gumbel_nets:
            try:
                model_template = copy.deepcopy(gumbel_template)
                model_template.load_state_dict(
                    torch.load(self.args.ckpt_folder + 'Gumbel/' + fname))
                model_template.eval()
                gumbel_models.append(model_template)
            except:
                print('Failed to load',
                      self.args.ckpt_folder + 'Gumbel/' + fname)

        for i in range(len(self.population)):
            net = self.population[i]

            if net.model_type == 'GumbelPolicy' and len(gumbel_models) >= 1:
                seed_model = gumbel_models.pop()
                utils.hard_update(net, seed_model)

            elif net.model_type == 'BoltzmanChromosome' and len(
                    boltzman_models) >= 1:
                seed_model = boltzman_models.pop()
                net = seed_model

            self.population[i] = net

        print()
        print()
        print()
        print()
        print('Checkpoint Loading Phase Completed')
        print()
        print()
        print()
        print()

    def forward_generation(self, gen, time_start):
        ################ START ROLLOUTS ##############

        #Start Evolution rollouts
        if self.args.pop_size >= 1 and self.args.use_mp:
            for id, actor in enumerate(self.population):
                if self.evo_flag[id]:
                    self.evo_task_pipes[id][0].send(id)
                    self.evo_flag[id] = False

        #If Policy Gradient
        if self.args.rollout_size > 0:
            #Sync all learners actor to cpu (rollout) actor
            for i, learner in enumerate(self.portfolio):
                learner.algo.actor.cpu()
                utils.hard_update(self.rollout_bucket[i], learner.algo.actor)
                learner.algo.actor.to(self.device)

            # Start Learner rollouts
            if self.args.use_mp:
                for rollout_id, learner_id in enumerate(self.allocation):
                    if self.roll_flag[rollout_id]:
                        self.task_pipes[rollout_id][0].send(learner_id)
                        self.roll_flag[rollout_id] = False

            ############# UPDATE PARAMS USING GRADIENT DESCENT ##########
            if self.replay_buffer.__len__(
            ) > self.args.learning_start and not self.args.random_baseline:  ###BURN IN PERIOD

                print('INSIDE GRAD DESCENT')

                for learner in self.portfolio:
                    learner.update_parameters(
                        self.replay_buffer, self.args.batch_size,
                        int(self.gen_frames * self.args.gradperstep))

                self.gen_frames = 0

            else:
                print('BURN IN PERIOD')

        gen_best = -float('inf')
        gen_best_speedup = -float("inf")
        gen_champ = None
        ########## SOFT -JOIN ROLLOUTS FOR EVO POPULATION ############
        if self.args.pop_size >= 1:
            for i in range(self.args.pop_size):

                if self.args.use_mp:
                    entry = self.evo_result_pipes[i][1].recv()
                else:
                    entry = rollout_function(
                        i,
                        'evo',
                        self.population[i],
                        self.env,
                        store_data=self.args.rollout_size > 0)

                self.gen_frames += entry[2]
                self.total_frames += entry[2]
                speedup = entry[3][0]
                score = entry[1]

                net = self.population[entry[0]]
                net.fitness_stats['speedup'] = speedup
                net.fitness_stats['score'] = score
                net.fitness_stats['shaped'][:] = entry[5]
                self.population[entry[0]] = net

                self.test_tracker.update([score, speedup], self.total_frames)
                self.time_tracker.update([score, speedup],
                                         time.time() - time_start)

                if speedup > self.best_speedup:
                    self.best_speedup = speedup

                if score > gen_best:
                    gen_best = score
                    gen_champ = self.population[i]

                if speedup > gen_best_speedup:
                    gen_best_speedup = speedup

                if score > self.best_score:
                    self.best_score = score
                    champ_index = i
                    self.champ_type = net.model_type
                    try:
                        torch.save(
                            self.population[champ_index].state_dict(),
                            self.args.models_folder + 'bestChamp_' +
                            self.args.savetag)
                    except:
                        None
                    # TODO
                    print("Best Evo Champ saved with score", '%.2f' % score)

                if self.args.rollout_size > 0:
                    self.replay_buffer.add(entry[4])

                self.evo_flag[i] = True

        try:
            torch.save(
                gen_champ.state_dict(),
                self.args.models_folder + 'genChamp_' + str(gen) +
                '_speedup_' + str(gen_best_speedup) + '_' + self.args.savetag)
        except:
            None

        ############################# GENERALIZATION EXPERIMENTS ########################
        _, resnet50_score, _, resnet50_speedup, _, _ = rollout_function(
            0, 'evo', gen_champ, self.test_envs[0], store_data=False)
        _, resnet101_score, _, resnet101_speedup, _, _ = rollout_function(
            0, 'evo', gen_champ, self.test_envs[1], store_data=False)
        resnet50_speedup = resnet50_speedup[0]
        resnet101_speedup = resnet101_speedup[0]
        self.r50_tracker.update([resnet50_score, resnet50_speedup], gen)
        self.r101_tracker.update([resnet101_score, resnet101_speedup], gen)
        self.r50_frames_tracker.update([resnet50_score, resnet50_speedup],
                                       self.total_frames)
        self.r101_frames_tracker.update([resnet101_score, resnet101_speedup],
                                        self.total_frames)
        bert_speedup, bert_score = None, None

        if self.platform != 'wpa':
            _, bert_score, _, bert_speedup, _, _ = rollout_function(
                0, 'evo', gen_champ, self.test_envs[2], store_data=False)
            bert_speedup = bert_speedup[0]
            self.bert_tracker.update([bert_score, bert_speedup], gen)
            self.bert_frames_tracker.update([bert_score, bert_speedup],
                                            self.total_frames)

        ############################# GENERALIZATION EXPERIMENTS ########################

        ########## HARD -JOIN ROLLOUTS FOR LEARNER ROLLOUTS ############
        if self.args.rollout_size > 0:
            for i in range(self.args.rollout_size):

                #NOISY PG
                if self.args.use_mp:
                    entry = self.result_pipes[i][1].recv()
                else:
                    entry = rollout_function(i,
                                             'pg',
                                             self.rollout_bucket[i],
                                             self.env,
                                             store_data=True)

                learner_id = entry[0]
                fitness = entry[1]
                num_frames = entry[2]
                speedup = entry[3][0]
                self.portfolio[learner_id].update_stats(fitness, num_frames)
                self.replay_buffer.add(entry[4])

                self.test_tracker.update([fitness, speedup], self.total_frames)
                self.time_tracker.update([fitness, speedup],
                                         time.time() - time_start)

                gen_best = max(fitness, gen_best)
                self.best_speedup = max(speedup, self.best_speedup)
                gen_best_speedup = max(speedup, gen_best_speedup)
                self.gen_frames += num_frames
                self.total_frames += num_frames
                if fitness > self.best_score:
                    self.best_score = fitness
                    torch.save(
                        self.rollout_bucket[i].state_dict(),
                        self.args.models_folder + 'noisy_bestPG_' +
                        str(speedup) + '_' + self.args.savetag)
                    print("Best Rollout Champ saved with score",
                          '%.2f' % fitness)
                noisy_speedup = speedup

                # Clean PG Measurement
                entry = rollout_function(i,
                                         'evo',
                                         self.rollout_bucket[i],
                                         self.env,
                                         store_data=True)
                learner_id = entry[0]
                fitness = entry[1]
                num_frames = entry[2]
                speedup = entry[3][0]
                self.portfolio[learner_id].update_stats(fitness, num_frames)
                self.replay_buffer.add(entry[4])

                self.test_tracker.update([fitness, speedup], self.total_frames)
                self.time_tracker.update([fitness, speedup],
                                         time.time() - time_start)

                gen_best = max(fitness, gen_best)
                self.best_speedup = max(speedup, self.best_speedup)
                gen_best_speedup = max(speedup, gen_best_speedup)
                self.gen_frames += num_frames
                self.total_frames += num_frames
                if fitness > self.best_score:
                    self.best_score = fitness
                    torch.save(
                        self.rollout_bucket[i].state_dict(),
                        self.args.models_folder + 'clean_bestPG_' +
                        str(speedup) + '_' + self.args.savetag)
                    print("Best Clean Evo Champ saved with score",
                          '%.2f' % fitness)

                self.pg_tracker.update([noisy_speedup, speedup],
                                       self.total_frames)
                self.roll_flag[i] = True

        self.champ_tracker.update([gen_best, gen_best_speedup],
                                  self.total_frames)

        #NeuroEvolution's probabilistic selection and recombination step
        if self.args.pop_size >= 1 and not self.args.random_baseline:

            if gen % 1 == 0:
                self.population = self.evolver.epoch(self.population,
                                                     self.rollout_bucket)
            else:
                self.population = self.evolver.epoch(self.population, [])

            if self.evolver.selection_stats['total'] > 0:
                selection_rate = (
                    1.0 * self.evolver.selection_stats['selected'] +
                    self.evolver.selection_stats['elite']
                ) / self.evolver.selection_stats['total']
                elite_rate = selection_rate = (
                    1.0 * self.evolver.selection_stats['elite']
                ) / self.evolver.selection_stats['total']
                self.migration_tracker.update([selection_rate, elite_rate],
                                              self.total_frames)

        if gen % 1 == 0:
            self.checkpoint()

        return gen_best

    def train(self, frame_limit):

        time_start = time.time()

        for gen in range(1, 1000000000):  # Infinite generations

            # Train one iteration
            gen_best = self.forward_generation(gen, time_start)

            print()
            print('Gen/Frames', gen, '/', self.total_frames, 'Gen_Score',
                  '%.2f' % gen_best, 'Best_Score', '%.2f' % self.best_score,
                  ' Speedup', '%.2f' % self.best_speedup, ' Frames/sec:',
                  '%.2f' % (self.total_frames / (time.time() - time_start)),
                  'Buffer', self.replay_buffer.__len__(), 'Savetag',
                  self.args.savetag)
            for net in self.population:

                print(net.model_type, net.fitness_stats)
                if net.model_type == 'BoltzmanChromosome':
                    print(net.temperature_stats)
                print()
            print()

            try:
                print('Initial Ratio', self.args.ratio, 'Current Ratio',
                      self.evolver.ratio, 'Chamption Type', self.champ_type)
            except:
                None

            if gen % 5 == 0:
                print('Learner Fitness', [
                    utils.pprint(learner.value) for learner in self.portfolio
                ])

            if self.total_frames > frame_limit:
                break

        ###Kill all processes
        try:
            for p in self.task_pipes:
                p[0].send('TERMINATE')
            for p in self.test_task_pipes:
                p[0].send('TERMINATE')
            for p in self.evo_task_pipes:
                p[0].send('TERMINATE')
        except:
            None

    def compute_policy_type(self):

        if self.args.algo == 'ddqn':
            return 'DDQN'

        elif self.args.algo == 'sac':
            return 'Gaussian_FF'

        elif self.args.algo == 'td3':
            return 'Deterministic_FF'

        elif self.args.algo == 'sac_discrete':
            return 'GumbelPolicy'
コード例 #3
0
class ERL_Trainer:
    def __init__(self, args, model_constructor, env_constructor):

        self.args = args
        self.policy_string = 'CategoricalPolicy' if env_constructor.is_discrete else 'Gaussian_FF'
        self.manager = Manager()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        #Evolution
        self.evolver = SSNE(self.args)

        #Initialize population
        self.population = self.manager.list()
        for _ in range(args.pop_size):
            self.population.append(
                model_constructor.make_model(self.policy_string))

        #Save best policy
        self.best_policy = model_constructor.make_model(self.policy_string)

        #PG Learner
        if env_constructor.is_discrete:
            from algos.ddqn import DDQN
            self.learner = DDQN(args, model_constructor)
        else:
            from algos.sac import SAC
            self.learner = SAC(args, model_constructor)

        #Replay Buffer
        self.replay_buffer = Buffer(args.buffer_size)

        #Initialize Rollout Bucket
        self.rollout_bucket = self.manager.list()
        for _ in range(args.rollout_size):
            self.rollout_bucket.append(
                model_constructor.make_model(self.policy_string))

        ############## MULTIPROCESSING TOOLS ###################
        #Evolutionary population Rollout workers
        self.evo_task_pipes = [Pipe() for _ in range(args.pop_size)]
        self.evo_result_pipes = [Pipe() for _ in range(args.pop_size)]
        self.evo_workers = [
            Process(target=rollout_worker,
                    args=(id, 'evo', self.evo_task_pipes[id][1],
                          self.evo_result_pipes[id][0], args.rollout_size > 0,
                          self.population, env_constructor))
            for id in range(args.pop_size)
        ]
        for worker in self.evo_workers:
            worker.start()
        self.evo_flag = [True for _ in range(args.pop_size)]

        #Learner rollout workers
        self.task_pipes = [Pipe() for _ in range(args.rollout_size)]
        self.result_pipes = [Pipe() for _ in range(args.rollout_size)]
        self.workers = [
            Process(target=rollout_worker,
                    args=(id, 'pg', self.task_pipes[id][1],
                          self.result_pipes[id][0], True, self.rollout_bucket,
                          env_constructor)) for id in range(args.rollout_size)
        ]
        for worker in self.workers:
            worker.start()
        self.roll_flag = [True for _ in range(args.rollout_size)]

        #Test bucket
        self.test_bucket = self.manager.list()
        self.test_bucket.append(
            model_constructor.make_model(self.policy_string))

        # Test workers
        self.test_task_pipes = [Pipe() for _ in range(args.num_test)]
        self.test_result_pipes = [Pipe() for _ in range(args.num_test)]
        self.test_workers = [
            Process(target=rollout_worker,
                    args=(id, 'test', self.test_task_pipes[id][1],
                          self.test_result_pipes[id][0], False,
                          self.test_bucket, env_constructor))
            for id in range(args.num_test)
        ]
        for worker in self.test_workers:
            worker.start()
        self.test_flag = False

        #Trackers
        self.best_score = -float('inf')
        self.gen_frames = 0
        self.total_frames = 0
        self.test_score = None
        self.test_std = None

    def forward_generation(self, gen, tracker):

        gen_max = -float('inf')

        #Start Evolution rollouts
        if self.args.pop_size > 1:
            for id, actor in enumerate(self.population):
                self.evo_task_pipes[id][0].send(id)

        #Sync all learners actor to cpu (rollout) actor and start their rollout
        self.learner.actor.cpu()
        for rollout_id in range(len(self.rollout_bucket)):
            utils.hard_update(self.rollout_bucket[rollout_id],
                              self.learner.actor)
            self.task_pipes[rollout_id][0].send(0)
        self.learner.actor.to(device=self.device)

        #Start Test rollouts
        if gen % self.args.test_frequency == 0:
            self.test_flag = True
            for pipe in self.test_task_pipes:
                pipe[0].send(0)

        ############# UPDATE PARAMS USING GRADIENT DESCENT ##########
        if self.replay_buffer.__len__(
        ) > self.args.learning_start:  ###BURN IN PERIOD
            for _ in range(int(self.gen_frames * self.args.gradperstep)):
                s, ns, a, r, done = self.replay_buffer.sample(
                    self.args.batch_size)
                self.learner.update_parameters(s, ns, a, r, done)

            self.gen_frames = 0

        ########## JOIN ROLLOUTS FOR EVO POPULATION ############
        all_fitness = []
        all_eplens = []
        if self.args.pop_size > 1:
            for i in range(self.args.pop_size):
                _, fitness, frames, trajectory = self.evo_result_pipes[i][
                    1].recv()

                all_fitness.append(fitness)
                all_eplens.append(frames)
                self.gen_frames += frames
                self.total_frames += frames
                self.replay_buffer.add(trajectory)
                self.best_score = max(self.best_score, fitness)
                gen_max = max(gen_max, fitness)

        ########## JOIN ROLLOUTS FOR LEARNER ROLLOUTS ############
        rollout_fitness = []
        rollout_eplens = []
        if self.args.rollout_size > 0:
            for i in range(self.args.rollout_size):
                _, fitness, pg_frames, trajectory = self.result_pipes[i][
                    1].recv()
                self.replay_buffer.add(trajectory)
                self.gen_frames += pg_frames
                self.total_frames += pg_frames
                self.best_score = max(self.best_score, fitness)
                gen_max = max(gen_max, fitness)
                rollout_fitness.append(fitness)
                rollout_eplens.append(pg_frames)

        ######################### END OF PARALLEL ROLLOUTS ################

        ############ FIGURE OUT THE CHAMP POLICY AND SYNC IT TO TEST #############
        if self.args.pop_size > 1:
            champ_index = all_fitness.index(max(all_fitness))
            utils.hard_update(self.test_bucket[0],
                              self.population[champ_index])
            if max(all_fitness) > self.best_score:
                self.best_score = max(all_fitness)
                utils.hard_update(self.best_policy,
                                  self.population[champ_index])
                torch.save(self.population[champ_index].state_dict(),
                           self.args.aux_folder + '_best' + self.args.savetag)
                print("Best policy saved with score",
                      '%.2f' % max(all_fitness))

        else:  #If there is no population, champion is just the actor from policy gradient learner
            utils.hard_update(self.test_bucket[0], self.rollout_bucket[0])

        ###### TEST SCORE ######
        if self.test_flag:
            self.test_flag = False
            test_scores = []
            for pipe in self.test_result_pipes:  #Collect all results
                _, fitness, _, _ = pipe[1].recv()
                self.best_score = max(self.best_score, fitness)
                gen_max = max(gen_max, fitness)
                test_scores.append(fitness)
            test_scores = np.array(test_scores)
            test_mean = np.mean(test_scores)
            test_std = (np.std(test_scores))
            tracker.update([test_mean], self.total_frames)

        else:
            test_mean, test_std = None, None

        #NeuroEvolution's probabilistic selection and recombination step
        if self.args.pop_size > 1:
            self.evolver.epoch(gen, self.population, all_fitness,
                               self.rollout_bucket)

        #Compute the champion's eplen
        champ_len = all_eplens[all_fitness.index(
            max(all_fitness))] if self.args.pop_size > 1 else rollout_eplens[
                rollout_fitness.index(max(rollout_fitness))]

        return gen_max, champ_len, all_eplens, test_mean, test_std, rollout_fitness, rollout_eplens

    def train(self, frame_limit):
        # Define Tracker class to track scores
        test_tracker = utils.Tracker(self.args.savefolder,
                                     ['score_' + self.args.savetag],
                                     '.csv')  # Tracker class to log progress
        time_start = time.time()

        for gen in range(1, 1000000000):  # Infinite generations

            # Train one iteration
            max_fitness, champ_len, all_eplens, test_mean, test_std, rollout_fitness, rollout_eplens = self.forward_generation(
                gen, test_tracker)
            if test_mean:
                self.args.writer.add_scalar('test_score', test_mean, gen)

            print(
                'Gen/Frames:', gen, '/', self.total_frames, ' Gen_max_score:',
                '%.2f' % max_fitness, ' Champ_len', '%.2f' % champ_len,
                ' Test_score u/std', utils.pprint(test_mean),
                utils.pprint(test_std), ' Rollout_u/std:',
                utils.pprint(np.mean(np.array(rollout_fitness))),
                utils.pprint(np.std(np.array(rollout_fitness))),
                ' Rollout_mean_eplen:',
                utils.pprint(sum(rollout_eplens) /
                             len(rollout_eplens)) if rollout_eplens else None)

            if gen % 5 == 0:
                print(
                    'Best_score_ever:'
                    '/', '%.2f' % self.best_score, ' FPS:',
                    '%.2f' % (self.total_frames / (time.time() - time_start)),
                    'savetag', self.args.savetag)
                print()

            if self.total_frames > frame_limit:
                break

        ###Kill all processes
        try:
            for p in self.task_pipes:
                p[0].send('TERMINATE')
            for p in self.test_task_pipes:
                p[0].send('TERMINATE')
            for p in self.evo_task_pipes:
                p[0].send('TERMINATE')
        except:
            None