def __init__(self,
              make_env_fs,
              *args,
              gpus=get_available_gpus() * 4,
              **kwargs):
     tlogger.info("=== Calling MTConcurrentWorkers()")
     self.sess = None
     if not gpus:
         gpus = ['/cpu:0']
     print("GPUS: {}".format(gpus))
     with tf.Session() as sess:
         import gym_tensorflow
         self.workers = []
         for i in range(len(gpus)):
             # alternate between games for multi task learning
             if (i + 1) % 2 == 0:
                 game_index = 1  # second game
             else:
                 game_index = 0  # first game
             game_make_env = make_env_fs[game_index]
             ref_batch = gym_tensorflow.get_ref_batch(
                 game_make_env, sess, 128, game_max_action_space=4)
             ref_batch = ref_batch[:, ...]
             worker = RLEvalutionWorkerCappedActionSpace(
                 game_index,
                 game_make_env,
                 *args,
                 ref_batch=ref_batch,
                 **dict(kwargs, device=gpus[i]))
             self.workers.append(worker)
         self.model = self.workers[0].model
         self.steps_counter = sum([w.steps_counter for w in self.workers])
         self.async_hub = AsyncTaskHub()
         self.hub = WorkerHub(self.workers, self.async_hub.input_queue,
                              self.async_hub)
Ejemplo n.º 2
0
    def _loop(self):
        running = np.zeros((self.batch_size, ), dtype=np.bool)
        cumrews = np.zeros((self.batch_size, ), dtype=np.float32)
        cumlen = np.zeros((self.batch_size, ), dtype=np.int32)

        tlogger.info('RLEvalutionWorker._loop')

        while True:
            # nothing loaded, block
            if not any(running):
                idx = self.queue.get()
                if idx is None:
                    break
                running[idx] = True
            while not self.queue.empty():
                idx = self.queue.get()
                if idx is None:
                    break
                running[idx] = True

            indices = np.nonzero(running)[0]
            rews, is_done, _ = self.sess.run(
                [self.rew_op, self.done_op, self.incr_counter],
                {self.placeholder_indices: indices})
            cumrews[running] += rews
            cumlen[running] += 1
            if any(is_done):
                for idx in indices[is_done]:
                    self.sample_callback[idx](
                        self, idx,
                        (self.model.seeds[idx], cumrews[idx], cumlen[idx]))
                cumrews[indices[is_done]] = 0.
                cumlen[indices[is_done]] = 0.
                running[indices[is_done]] = False
    def __init__(self, config):
        self.num_frames = 0
        self.population = []
        self.timesteps_so_far = 0
        self.time_elapsed = 0
        self.validation_timesteps_so_far = 0
        self.elite = None
        self.it = 0
        self.mutation_power = make_schedule(config['mutation_power'])
        self.curr_solution = None
        self.curr_solution_val = float('-inf')
        self.curr_solution_test = float('-inf')

        if isinstance(config['episode_cutoff_mode'], int):
            self.tslimit = config['episode_cutoff_mode']
            self.incr_tslimit_threshold = None
            self.tslimit_incr_ratio = None
            self.adaptive_tslimit = False
        elif config['episode_cutoff_mode'].startswith('adaptive:'):
            _, args = config['episode_cutoff_mode'].split(':')
            arg0, arg1, arg2, arg3 = args.split(',')
            self.tslimit, self.incr_tslimit_threshold, self.tslimit_incr_ratio, self.tslimit_max = int(
                arg0), float(arg1), float(arg2), float(arg3)
            self.adaptive_tslimit = True
            tlogger.info(
                'Starting timestep limit set to {}. When {}% of rollouts hit the limit, it will be increased by {}'
                .format(self.tslimit, self.incr_tslimit_threshold * 100,
                        self.tslimit_incr_ratio))
        elif config['episode_cutoff_mode'] == 'env_default':
            self.tslimit, self.incr_tslimit_threshold, self.tslimit_incr_ratio = None, None, None
            self.adaptive_tslimit = False
        else:
            raise NotImplementedError(config['episode_cutoff_mode'])
Ejemplo n.º 4
0
def main():
    print('Number of mutations:', len(seeds))

    env = gym_tensorflow.make(game, 1)

    model = Model()
    obs_op = env.observation()
    reset_op = env.reset()

    action_op = model.make_net(tf.expand_dims(obs_op, axis=1), env.action_space, batch_size=1)
    if env.discrete_action:
        action_op = tf.argmax(action_op, axis=-1, output_type=tf.int32)
    rew_op, done_op = env.step(action_op)

    from gym.envs.classic_control import rendering
    viewer = rendering.SimpleImageViewer()
    if hasattr(env.unwrapped, 'render'):
        obs_op = env.unwrapped.render()
        def display_obs(im):
            im = im[0, 0, ...]

            viewer.imshow(im)
    else:
        def display_obs(im):
            im = im[0, :, :, -1]
            im = np.stack([im] * 3, axis=-1)
            im = (im * 255).astype(np.uint8)

            im = np.array(Image.fromarray(im).resize((256, 256), resample=Image.BILINEAR), dtype=np.uint8)
            viewer.imshow(im)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        model.initialize()
        tlogger.info(model.description)

        noise = SharedNoiseTable()

        weights = model.compute_weights_from_seeds(noise, seeds)
        model.load(sess, 0, weights, seeds)

        sess.run(reset_op)
        display_obs(sess.run(obs_op))

        total_rew = 0
        num_frames = 0
        while True:
            rew, done = sess.run([rew_op, done_op])
            num_frames += 1
            total_rew += rew[0]
            display_obs(sess.run(obs_op))
            time.sleep(4/60)
            if done[0]:
                print('Final reward: ', total_rew, 'after', num_frames, 'steps')
                break
Ejemplo n.º 5
0
 def _handle_output(self):
     try:
         while True:
             result = self.results_queue.get()
             if result is None:
                 tlogger.info('AsyncTaskHub._handle_output done')
                 break
             self.put(result)
     except:
         tlogger.exception('AsyncTaskHub._handle_output exception thrown')
         raise
Ejemplo n.º 6
0
 def _handle_output(self):
     try:
         while True:
             result = self.done_buffer.get()
             if result is None:
                 tlogger.info('WorkerHub._handle_output done')
                 break
             self.done_queue.put(result)
     except:
         tlogger.exception('WorkerHub._handle_output exception thrown')
         raise
Ejemplo n.º 7
0
    def __enter__(self, *args, **kwargs):
        self._sess = tf.Session(*args, **kwargs)
        self._sess.run(tf.global_variables_initializer())
        self._worker.initialize(self._sess)

        tlogger.info(self._worker.model.description)

        self.coord = tf.train.Coordinator()
        self.threads = tf.train.start_queue_runners(self._sess,
                                                    self.coord,
                                                    start=True)

        return self._sess
    def _handle_output(self):
        try:
            while True:
                result = self.done_buffer.get()
                if result is None:
                    tlogger.info('WorkerHub._handle_output done')
                    break


#                print("== putting result in done queue: result={}".format(result))
                self.done_queue.put(result)
        except:
            tlogger.exception('WorkerHub._handle_output exception thrown')
            raise
    def __initialize_handlers(self):
        self._input_handler = threading.Thread(
            target=WorkerHub._handle_input,
            args=(self,)
            )
        self._input_handler._state = 0
        tlogger.info('WorkerHub: _input_handler initialized')

        self._output_handler = threading.Thread(
            target=WorkerHub._handle_output,
            args=(self,)
            )
        self._output_handler._state = 0
        tlogger.info('WorkerHub: _output_handler initialized')
Ejemplo n.º 10
0
 def _handle_output(self):
     try:
         while True:
             result_0 = self.results_queue_0.get()
             result_1 = self.results_queue_1.get()
             if result_0 is None and result_1 is None:
                 tlogger.info('AsyncTaskHub._handle_output done')
                 break
             if result_0 is not None:
                 self.put(result_0)
             if result_1 is not None:
                 self.put(result_1)
     except:
         tlogger.exception('AsyncTaskHub._handle_output exception thrown')
         raise
Ejemplo n.º 11
0
def main():
    model = None
    if (learning == 'recurrent'): model = RecurrentLargeModel()
    if (learning == 'ga'): model = LargeModel()

    X_t = tf.placeholder(tf.float32, [None] + image_shape, name='X_t')
    action_op = model.make_net(tf.expand_dims(X_t, axis=1),
                               game_action_counts,
                               batch_size=1)

    all_saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        model.initialize()
        tlogger.info(model.description)
        noise = SharedNoiseTable()
        weights = model.compute_weights_from_seeds(noise, seeds)
        model.load(sess, 0, weights, seeds)

        # -----------------------------------------------------------------
        node_names = [
            node.name for node in tf.get_default_graph().as_graph_def().node
        ]
        new_node = []

        for i in node_names:
            if "ga" in i:
                new_node.append(i)
                pass

        graph_const = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), ['ga/Reshape_1'])

    outgraph = optimize_for_inference_lib.optimize_for_inference(
        graph_const,
        ['X_t'],  # an array of the input node(s)
        ['ga/Reshape_1'],  # an array of output nodes
        tf.float32.as_datatype_enum)

    # write frozen model to LOGDIR
    tf.train.write_graph(outgraph, LOGDIR, file_name + '.pb', as_text=False)

    # human readable format
    # tf.train.write_graph(outgraph, LOGDIR, file_name + '.pbtxt', as_text=True)

    print("Freezed graph for {} with {} game actions to {}{}.pb.".format(
        game, game_action_counts, LOGDIR, file_name))
Ejemplo n.º 12
0
    def make_net(self, model_constructor, device, ref_batch=None):
        self.model = model_constructor()
        tlogger.info('RLEvalutionWorker.make_net %s' % self.model)

        with tf.variable_scope(None, default_name='model'):
            with tf.device('/cpu:0'):
                self.env = self.make_env_f(self.batch_size)

                self.placeholder_indices = tf.placeholder(tf.int32,
                                                          shape=(None, ))
                self.placeholder_max_frames = tf.placeholder(tf.int32,
                                                             shape=(None, ))
                self.reset_op = self.env.reset(
                    indices=self.placeholder_indices,
                    max_frames=self.placeholder_max_frames)

                with tf.device(device):
                    self.obs_op = self.env.observation(
                        indices=self.placeholder_indices)
                    obs = tf.expand_dims(self.obs_op, axis=1)
                    self.action_op = self.model.make_net(
                        obs,
                        self.env.action_space,
                        indices=self.placeholder_indices,
                        batch_size=self.batch_size,
                        ref_batch=ref_batch)
                self.model.initialize()

                if self.env.discrete_action:
                    self.action_op = tf.argmax(
                        self.action_op[:tf.shape(self.placeholder_indices)[0]],
                        axis=-1,
                        output_type=tf.int32)
                with tf.device(device):
                    self.rew_op, self.done_op = self.env.step(
                        self.action_op, indices=self.placeholder_indices)

                self.steps_counter = tf.Variable(np.zeros((),
                                                          dtype=np.float32),
                                                 name="steps_counter",
                                                 dtype=tf.float32)
                self.incr_counter = tf.assign_add(
                    self.steps_counter,
                    tf.cast(tf.reduce_prod(tf.shape(self.placeholder_indices)),
                            dtype=tf.float32))
Ejemplo n.º 13
0
    def _handle_input(self):
        try:
            while True:
                worker_task = self.available_workers.get()
                if worker_task is None:
                    tlogger.info('WorkerHub._handle_input done')
                    break
                worker, subworker = worker_task

                task = self.input_queue.get()
                if task is None:
                    tlogger.info('WorkerHub._handle_input done')
                    break
                task_id, task = task
                self._cache[worker_task] = task_id
                worker.run_async(subworker, task, self.worker_callback)
        except:
            tlogger.exception('WorkerHub._handle_input exception thrown')
            raise
    def _handle_input(self):
        try:
            while True:
                worker_task = self.available_workers.get()
                if worker_task is None:
                    tlogger.info('WorkerHub._handle_input NO MORE WORKERS AWAILABLE')
                    break
                worker, subworker = worker_task

                task = self.input_queue.get()
                if task is None:
                    tlogger.info('WorkerHub._handle_input NO MORE INPUTS AWAILABLE')
                    break
                task_id, task = task
                self._cache[worker_task] = task_id
                # tlogger.info('WorkerHub: put task id: %s in cache keyed by worker task: %s' % (task_id, worker_task))

                worker.run_async(subworker, task, callback=self.worker_callback)
        except:
            tlogger.exception('WorkerHub._handle_input exception thrown')
            raise
Ejemplo n.º 15
0
    def monitor_eval(self, it, max_frames):
        logging_interval = 5
        last_timesteps = self.sess.run(self.steps_counter)
        tstart_all = time.time()
        tstart = time.time()

        tasks = []
        for t in it:
            tasks.append(self.eval_async(*t, max_frames=max_frames))
            if time.time() - tstart > logging_interval:
                cur_timesteps = self.sess.run(self.steps_counter)
                tlogger.info('Num timesteps:', cur_timesteps, 'per second:', (cur_timesteps-last_timesteps)//(time.time()-tstart), 'num episodes finished: {}/{}'.format(sum([1 if t.ready() else 0 for t in tasks]), len(tasks)))
                tstart = time.time()
                last_timesteps = cur_timesteps

        while not all([t.ready() for t in tasks]):
            if time.time() - tstart > logging_interval:
                cur_timesteps = self.sess.run(self.steps_counter)
                tlogger.info('Num timesteps:', cur_timesteps, 'per second:', (cur_timesteps-last_timesteps)//(time.time()-tstart), 'num episodes:', sum([1 if t.ready() else 0 for t in tasks]))
                tstart = time.time()
                last_timesteps = cur_timesteps
            time.sleep(0.1)
        tlogger.info('Done evaluating {} episodes in {:.2f} seconds'.format(len(tasks), time.time()-tstart_all))

        return [t.get() for t in tasks]
Ejemplo n.º 16
0
def maybe_allreduce_grads(model):
    if hvd.size() > 1:
        tstart_reduce = time.time()
        named_parameters = list(
            sorted(model.named_parameters(), key=lambda a: a[0]))
        grad_handles = []
        for name, p in named_parameters:
            if p.requires_grad:
                if p.grad is None:
                    p.grad = torch.zeros_like(p)
                with torch.no_grad():
                    grad_handles.append(hvd.allreduce_async_(p.grad,
                                                             name=name))
        for handle in grad_handles:
            hvd.synchronize(handle)
        tlogger.record_tabular("TimeElapsedAllReduce",
                               time.time() - tstart_reduce)
        if time.time() - tstart_reduce > 5:
            import socket
            tlogger.info(
                "Allreduce took more than 5 seconds for node {} (rank {})".
                format(socket.gethostname(), hvd.rank()))
    def monitor_eval_repeated(self, it, max_frames, num_episodes):
        print("== monitor_eval_repeated called from MTConcurrentWorkers()")
        logging_interval = 30
        last_timesteps = self.sess.run(self.steps_counter)
        tstart_all = time.time()
        tstart = time.time()

        tasks = []
        for t in it:
            for _ in range(num_episodes):
                tasks.append(self.eval_async(*t, max_frames=max_frames))
                if time.time() - tstart > logging_interval:
                    cur_timesteps = self.sess.run(self.steps_counter)
                    logstr = 'Num timesteps:', cur_timesteps, 'per second:', (
                        cur_timesteps - last_timesteps) // (
                            time.time() -
                            tstart), 'num episodes finished: {}/{}'.format(
                                sum([
                                    1 if task.ready() else 0 for task in tasks
                                ]), len(tasks))
                    tlogger.info(logstr)
                    print("=== " + logstr)
                    tstart = time.time()
                    last_timesteps = cur_timesteps
        print("== monitor_eval_repeated -> for loop ended")

        while not all([t.ready() for t in tasks]):
            for t in tasks:
                if t.ready():
                    pass  #print(t.get())
            if time.time() - tstart > 5:
                cur_timesteps = self.sess.run(self.steps_counter)
                tlogger.info('Num timesteps:', cur_timesteps, 'per second:',
                             (cur_timesteps - last_timesteps) //
                             (time.time() - tstart), 'num episodes:',
                             sum([1 if t.ready() else 0 for t in tasks]))
                tstart = time.time()
                last_timesteps = cur_timesteps
            time.sleep(0.1)
        tlogger.info('Done evaluating {} episodes in {:.2f} seconds'.format(
            len(tasks),
            time.time() - tstart_all))

        print("== monitor_eval_repeated -> while loop ended")

        results = [t.get() for t in tasks]
        print("== results = {}".format(results))
        # Group episodes
        results = zip(*[iter(results)] * num_episodes)

        l = []
        for evals in results:
            game_index, seeds, rews, length = zip(*evals)
            for s in seeds[1:]:
                assert s == seeds[0]
            l.append((game_index, seeds[0], np.array(rews), np.array(length)))
        print("=== monitor_eval_repetead returns l = {}".format(l))
        return l
Ejemplo n.º 18
0
    def __init__(self,
                 make_env_f,
                 *args,
                 gpus=get_available_gpus() * 4,
                 input_queue=None,
                 done_queue=None,
                 **kwargs):
        self.sess = None
        if not gpus:
            gpus = ['/cpu:0']
        with tf.Session() as sess:
            import gym_tensorflow
            ref_batch = gym_tensorflow.get_ref_batch(make_env_f, sess, 128)
            ref_batch = ref_batch[:, ...]
        if input_queue is None and done_queue is None:
            self.workers = [
                RLEvalutionWorker(make_env_f,
                                  *args,
                                  ref_batch=ref_batch,
                                  **dict(kwargs, device=gpus[i]))
                for i in range(len(gpus))
            ]
            self.model = self.workers[0].model
            self.steps_counter = sum([w.steps_counter for w in self.workers])
            self.async_hub = AsyncTaskHub()
            self.hub = WorkerHub(self.workers, self.async_hub.input_queue,
                                 self.async_hub)
        else:
            fake_worker = RLEvalutionWorker(*args,
                                            **dict(kwargs, device=gpus[0]))
            self.model = fake_worker.model
            self.workers = []
            self.hub = None
            self.steps_counter = tf.constant(0)
            self.async_hub = AsyncTaskHub(input_queue, done_queue)

        tlogger.info('CW: Steps counter %s' % self.steps_counter)
Ejemplo n.º 19
0
    def _handle_input(self):
        try:
            while True:
                #                worker_task = self.available_workers.get()
                #                if worker_task is None:
                #                    tlogger.info('WorkerHub._handle_input done')
                #                    break
                #                worker, subworker = worker_task
                if self.next_game_type_must_be == 0:
                    #                    print("Getting queue 0")
                    current_q = self.available_workers_0
                if self.next_game_type_must_be == 1:
                    #                    print("Getting queue 1")
                    current_q = self.available_workers_1

                worker_task = current_q.get()
                if worker_task is None:
                    tlogger.info('WorkerHub._handle_input done')
                    break
                worker, subworker = worker_task
                if self.next_game_type_must_be == 0:
                    self.next_game_type_must_be = 1
                else:
                    self.next_game_type_must_be = 0

                task = self.input_queue.get()
                if task is None:
                    tlogger.info('WorkerHub._handle_input done')
                    break
                task_id, task = task
                self._cache[worker_task] = task_id

                worker.run_async(subworker, task, self.worker_callback)
        except:
            tlogger.exception('WorkerHub._handle_input exception thrown')
            raise
    def monitor_eval_repeated(self, it, max_frames, num_episodes):
        logging_interval = 30
        last_timesteps = self.sess.run(self.steps_counter)
        tstart_all = time.time()
        tstart = time.time()

        tasks = []
        for t in it:
            for _ in range(num_episodes):
                tasks.append(self.eval_async(*t, max_frames=max_frames))
                if time.time() - tstart > logging_interval:
                    cur_timesteps = self.sess.run(self.steps_counter)
                    tlogger.info(
                        'Num timesteps:', cur_timesteps, 'per second:',
                        (cur_timesteps - last_timesteps) //
                        (time.time() - tstart),
                        'num episodes finished: {}/{}'.format(
                            sum([1 if task.ready() else 0 for task in tasks]),
                            len(tasks)))
                    tstart = time.time()
                    last_timesteps = cur_timesteps

        while not all([t.ready() for t in tasks]):
            if time.time() - tstart > 5:
                cur_timesteps = self.sess.run(self.steps_counter)
                tlogger.info('Num timesteps:', cur_timesteps, 'per second:',
                             (cur_timesteps - last_timesteps) //
                             (time.time() - tstart), 'num episodes:',
                             sum([1 if t.ready() else 0 for t in tasks]))
                tstart = time.time()
                last_timesteps = cur_timesteps
            time.sleep(0.1)
        tlogger.info('Done evaluating {} episodes in {:.2f} seconds'.format(
            len(tasks),
            time.time() - tstart_all))

        results = [t.get() for t in tasks]

        # Group episodes
        results = zip(*[iter(results)] * num_episodes)

        l = []
        for evals in results:
            seeds, rews, length = zip(*evals)
            for s in seeds[1:]:
                assert s == seeds[0]
            l.append((seeds[0], np.array(rews), np.array(length)))
        return l
def main(config, out_dir):
    if out_dir is not None:
        tlogger.set_log_dir(out_dir)

    log_dir = tlogger.log_dir()

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    tlogger.info(json.dumps(config, indent=4, sort_keys=True))
    tlogger.info('Logging to: {}'.format(log_dir))

    Model = neuroevolution.models.__dict__[config['model']]
    all_tstart = time.time()

    def make_env(b):
        tlogger.info('GA: Creating environment for game: %s' % config["game"])
        return gym_tensorflow.make(game=config["game"], batch_size=b)

    tlogger.info('GA: Creating Concurent Workers')
    worker = ConcurrentWorkers(make_env, Model, batch_size=64)
    tlogger.info('GA: Concurent Workers Created')
    with WorkerSession(worker) as sess:
        noise = SharedNoiseTable()
        rs = np.random.RandomState()

        cached_parents = []
        results = []

        def make_offspring():
            if len(cached_parents) == 0:
                return worker.model.randomize(rs, noise)
            else:
                assert len(cached_parents) == config['selection_threshold']
                parent = cached_parents[rs.randint(len(cached_parents))]
                theta, seeds = worker.model.mutate(parent,
                                                   rs,
                                                   noise,
                                                   mutation_power=state.sample(
                                                       state.mutation_power))
                #print("tetha len: %d, seeds len: %d" % (len(theta), len(seeds)))
                return theta, seeds

        tlogger.info('GA: Start timing')
        tstart = time.time()

        try:
            load_file = os.path.join(log_dir, 'snapshot.pkl')
            with open(load_file, 'rb+') as file:
                state = pickle.load(file)
            tlogger.info("Loaded iteration {} from {}".format(
                state.it, load_file))
        except FileNotFoundError:
            tlogger.info('Failed to load snapshot')
            state = TrainingState(config)

        if 'load_population' in config:
            tlogger.info('Loading population')
            state.copy_population(config['load_population'])

        # Cache first population if needed (on restart)
        if state.population and config['selection_threshold'] > 0:
            tlogger.info("Caching parents")
            cached_parents.clear()
            if state.elite in state.population[:config['selection_threshold']]:
                cached_parents.extend([
                    (worker.model.compute_weights_from_seeds(noise,
                                                             o.seeds), o.seeds)
                    for o in state.population[:config['selection_threshold']]
                ])
            else:
                cached_parents.append((worker.model.compute_weights_from_seeds(
                    noise, state.elite.seeds), state.elite.seeds))
                cached_parents.extend([
                    (worker.model.compute_weights_from_seeds(noise,
                                                             o.seeds), o.seeds)
                    for o in state.population[:config['selection_threshold'] -
                                              1]
                ])
            tlogger.info("Done caching parents")

        while True:
            tstart_iteration = time.time()
            if state.timesteps_so_far >= config['timesteps']:
                tlogger.info('Training terminated after {} timesteps'.format(
                    state.timesteps_so_far))
                break
            frames_computed_so_far = sess.run(worker.steps_counter)
            assert (len(cached_parents) == 0 and state.it == 0
                    ) or len(cached_parents) == config['selection_threshold']

            tasks = [
                make_offspring() for _ in range(config['population_size'])
            ]
            for seeds, episode_reward, episode_length in worker.monitor_eval(
                    tasks, max_frames=state.tslimit * 4):
                results.append(
                    Offspring(seeds, [episode_reward], [episode_length]))
            state.num_frames += sess.run(
                worker.steps_counter) - frames_computed_so_far

            state.it += 1
            tlogger.record_tabular('Iteration', state.it)
            tlogger.record_tabular('MutationPower',
                                   state.sample(state.mutation_power))

            # Trim unwanted results
            results = results[:config['population_size']]
            assert len(results) == config['population_size']
            rewards = np.array([a.fitness for a in results])
            population_timesteps = sum([a.training_steps for a in results])

            state.population = sorted(results,
                                      key=lambda x: x.fitness,
                                      reverse=True)
            tlogger.record_tabular('PopulationEpRewMax', np.max(rewards))
            tlogger.record_tabular('PopulationEpRewMean', np.mean(rewards))
            tlogger.record_tabular('PopulationEpCount', len(rewards))
            tlogger.record_tabular('PopulationTimesteps', population_timesteps)
            tlogger.record_tabular('NumSelectedIndividuals',
                                   config['selection_threshold'])

            tlogger.info('Evaluate population')
            validation_population = state.population[:config[
                'validation_threshold']]
            if state.elite is not None:
                validation_population = [state.elite
                                         ] + validation_population[:-1]

            validation_tasks = [(worker.model.compute_weights_from_seeds(
                noise, validation_population[x].seeds,
                cache=cached_parents), validation_population[x].seeds)
                                for x in range(config['validation_threshold'])]
            _, population_validation, population_validation_len = zip(
                *worker.monitor_eval_repeated(
                    validation_tasks,
                    max_frames=state.tslimit * 4,
                    num_episodes=config['num_validation_episodes']))
            population_validation = [np.mean(x) for x in population_validation]
            population_validation_len = [
                np.sum(x) for x in population_validation_len
            ]

            time_elapsed_this_iter = time.time() - tstart_iteration
            state.time_elapsed += time_elapsed_this_iter

            population_elite_idx = np.argmax(population_validation)
            state.elite = validation_population[population_elite_idx]
            elite_theta = worker.model.compute_weights_from_seeds(
                noise, state.elite.seeds, cache=cached_parents)
            _, population_elite_evals, population_elite_evals_timesteps = worker.monitor_eval_repeated(
                [(elite_theta, state.elite.seeds)],
                max_frames=None,
                num_episodes=config['num_test_episodes'])[0]

            # Log Results
            validation_timesteps = sum(population_validation_len)
            timesteps_this_iter = population_timesteps + validation_timesteps
            state.timesteps_so_far += timesteps_this_iter
            state.validation_timesteps_so_far += validation_timesteps

            # Log
            tlogger.record_tabular(
                'TruncatedPopulationRewMean',
                np.mean([a.fitness for a in validation_population]))
            tlogger.record_tabular('TruncatedPopulationValidationRewMean',
                                   np.mean(population_validation))
            tlogger.record_tabular('TruncatedPopulationEliteValidationRew',
                                   np.max(population_validation))
            tlogger.record_tabular("TruncatedPopulationEliteIndex",
                                   population_elite_idx)
            tlogger.record_tabular('TruncatedPopulationEliteSeeds',
                                   state.elite.seeds)
            tlogger.record_tabular('TruncatedPopulationEliteTestRewMean',
                                   np.mean(population_elite_evals))
            tlogger.record_tabular('TruncatedPopulationEliteTestEpCount',
                                   len(population_elite_evals))
            tlogger.record_tabular('TruncatedPopulationEliteTestEpLenSum',
                                   np.sum(population_elite_evals_timesteps))

            if np.mean(population_validation) > state.curr_solution_val:
                state.curr_solution = state.elite.seeds
                state.curr_solution_val = np.mean(population_validation)
                state.curr_solution_test = np.mean(population_elite_evals)

            tlogger.record_tabular('ValidationTimestepsThisIter',
                                   validation_timesteps)
            tlogger.record_tabular('ValidationTimestepsSoFar',
                                   state.validation_timesteps_so_far)
            tlogger.record_tabular('TimestepsThisIter', timesteps_this_iter)
            tlogger.record_tabular(
                'TimestepsPerSecondThisIter',
                timesteps_this_iter / (time.time() - tstart_iteration))
            tlogger.record_tabular('TimestepsComputed', state.num_frames)
            tlogger.record_tabular('TimestepsSoFar', state.timesteps_so_far)
            tlogger.record_tabular('TimeElapsedThisIter',
                                   time_elapsed_this_iter)
            tlogger.record_tabular('TimeElapsedThisIterTotal',
                                   time.time() - tstart_iteration)
            tlogger.record_tabular('TimeElapsed', state.time_elapsed)
            tlogger.record_tabular('TimeElapsedTotal',
                                   time.time() - all_tstart)

            tlogger.dump_tabular()
            tlogger.info('Current elite: {}'.format(state.elite.seeds))
            fps = state.timesteps_so_far / (time.time() - tstart)
            tlogger.info(
                'Timesteps Per Second: {:.0f}. Elapsed: {:.2f}h ETA {:.2f}h'.
                format(fps, (time.time() - all_tstart) / 3600,
                       (config['timesteps'] - state.timesteps_so_far) / fps /
                       3600))

            if state.adaptive_tslimit:
                if np.mean(
                    [a.training_steps >= state.tslimit
                     for a in results]) > state.incr_tslimit_threshold:
                    state.tslimit = min(
                        state.tslimit * state.tslimit_incr_ratio,
                        state.tslimit_max)
                    tlogger.info('Increased threshold to {}'.format(
                        state.tslimit))

            os.makedirs(log_dir, exist_ok=True)
            save_file = os.path.join(log_dir, 'snapshot.pkl')
            with open(save_file, 'wb+') as file:
                pickle.dump(state, file)
            #copyfile(save_file, os.path.join(log_dir, 'snapshot_gen{:04d}.pkl'.format(state.it)))
            tlogger.info("Saved iteration {} to {}".format(
                state.it, save_file))

            if state.timesteps_so_far >= config['timesteps']:
                tlogger.info('Training terminated after {} timesteps'.format(
                    state.timesteps_so_far))
                break
            results.clear()

            if config['selection_threshold'] > 0:
                tlogger.info("Caching parents")
                new_parents = []
                if state.elite in state.population[:config[
                        'selection_threshold']]:
                    new_parents.extend([
                        (worker.model.compute_weights_from_seeds(
                            noise, o.seeds, cache=cached_parents), o.seeds) for
                        o in state.population[:config['selection_threshold']]
                    ])
                else:
                    new_parents.append(
                        (worker.model.compute_weights_from_seeds(
                            noise, state.elite.seeds,
                            cache=cached_parents), state.elite.seeds))
                    new_parents.extend([
                        (worker.model.compute_weights_from_seeds(
                            noise, o.seeds, cache=cached_parents), o.seeds)
                        for o in
                        state.population[:config['selection_threshold'] - 1]
                    ])

                cached_parents.clear()
                cached_parents.extend(new_parents)
                tlogger.info("Done caching parents")

    return float(state.curr_solution_test), float(state.curr_solution_val)
Ejemplo n.º 22
0
def main(**exp):
    log_dir = tlogger.log_dir()

    tlogger.info(json.dumps(exp, indent=4, sort_keys=True))
    tlogger.info('Logging to: {}'.format(log_dir))
    Model = neuroevolution.models.__dict__[exp['model']]
    all_tstart = time.time()

    def make_env(b):
        return gym_tensorflow.make(game=exp["game"], batch_size=b)

    worker = ConcurrentWorkers(make_env, Model, batch_size=64)
    with WorkerSession(worker) as sess:
        noise = SharedNoiseTable()
        rs = np.random.RandomState()
        tlogger.info('Start timing')
        tstart = time.time()

        try:
            load_file = os.path.join(log_dir, 'snapshot.pkl')
            with open(load_file, 'rb+') as file:
                state = pickle.load(file)
            tlogger.info("Loaded iteration {} from {}".format(
                state.it, load_file))
        except FileNotFoundError:
            tlogger.info('Failed to load snapshot')
            state = TrainingState(exp)

            if 'load_from' in exp:
                dirname = os.path.join(os.path.dirname(__file__), '..',
                                       'neuroevolution', 'ga_legacy.py')
                load_from = exp['load_from'].format(**exp)
                os.system('python {} {} seeds.pkl'.format(dirname, load_from))
                with open('seeds.pkl', 'rb+') as file:
                    seeds = pickle.load(file)
                    state.set_theta(
                        worker.model.compute_weights_from_seeds(noise, seeds))
                tlogger.info('Loaded initial theta from {}'.format(load_from))
            else:
                state.initialize(rs, noise, worker.model)

        def make_offspring(state):
            for i in range(exp['population_size'] // 2):
                idx = noise.sample_index(rs, worker.model.num_params)
                mutation_power = state.sample(state.mutation_power)
                pos_theta = worker.model.compute_mutation(
                    noise, state.theta, idx, mutation_power)

                yield (pos_theta, idx)
                neg_theta = worker.model.compute_mutation(
                    noise, state.theta, idx, -mutation_power)
                diff = (np.max(
                    np.abs((pos_theta + neg_theta) / 2 - state.theta)))
                assert diff < 1e-5, 'Diff too large: {}'.format(diff)

                yield (neg_theta, idx)

        tlogger.info('Start training')
        _, initial_performance, _ = worker.monitor_eval_repeated(
            [(state.theta, 0)],
            max_frames=None,
            num_episodes=exp['num_test_episodes'])[0]
        while True:
            tstart_iteration = time.time()
            if state.timesteps_so_far >= exp['timesteps']:
                tlogger.info('Training terminated after {} timesteps'.format(
                    state.timesteps_so_far))
                break
            frames_computed_so_far = sess.run(worker.steps_counter)

            tlogger.info('Evaluating perturbations')
            iterator = iter(
                worker.monitor_eval(make_offspring(state),
                                    max_frames=state.tslimit * 4))
            results = []
            for pos_seeds, pos_reward, pos_length in iterator:
                neg_seeds, neg_reward, neg_length = next(iterator)
                assert pos_seeds == neg_seeds
                results.append(
                    Offspring(pos_seeds, [pos_reward, neg_reward],
                              [pos_length, neg_length]))
            state.num_frames += sess.run(
                worker.steps_counter) - frames_computed_so_far

            state.it += 1
            tlogger.record_tabular('Iteration', state.it)
            tlogger.record_tabular('MutationPower',
                                   state.sample(state.mutation_power))
            tlogger.record_tabular('TimestepLimitPerEpisode', state.tslimit)

            # Trim unwanted results
            results = results[:exp['population_size'] // 2]
            assert len(results) == exp['population_size'] // 2
            rewards = np.array([b for a in results for b in a.rewards])

            results_timesteps = np.array([a.training_steps for a in results])
            timesteps_this_iter = sum([a.training_steps for a in results])
            state.timesteps_so_far += timesteps_this_iter

            tlogger.record_tabular('PopulationEpRewMax', np.max(rewards))
            tlogger.record_tabular('PopulationEpRewMean', np.mean(rewards))
            tlogger.record_tabular('PopulationEpRewMedian', np.median(rewards))
            tlogger.record_tabular('PopulationEpCount', len(rewards))
            tlogger.record_tabular('PopulationTimesteps', timesteps_this_iter)

            # Update Theta
            returns_n2 = np.array([a.rewards for a in results])
            noise_inds_n = [a.seeds for a in results]

            if exp['return_proc_mode'] == 'centered_rank':
                proc_returns_n2 = compute_centered_ranks(returns_n2)
            else:
                raise NotImplementedError(exp['return_proc_mode'])
            # Compute and take step
            g, count = batched_weighted_sum(
                proc_returns_n2[:, 0] - proc_returns_n2[:, 1],
                (noise.get(idx, worker.model.num_params)
                 for idx in noise_inds_n),
                batch_size=500)
            # NOTE: gradients are scaled by \theta
            g /= returns_n2.size

            assert g.shape == (
                worker.model.num_params,
            ) and g.dtype == np.float32 and count == len(noise_inds_n)
            update_ratio, state.theta = state.optimizer.update(-g +
                                                               exp['l2coeff'] *
                                                               state.theta)

            time_elapsed_this_iter = time.time() - tstart_iteration
            state.time_elapsed += time_elapsed_this_iter
            tlogger.info('Evaluate elite')
            _, test_evals, test_timesteps = worker.monitor_eval_repeated(
                [(state.theta, 0)],
                max_frames=None,
                num_episodes=exp['num_test_episodes'])[0]
            test_timesteps = sum(test_timesteps)
            # Log Results
            tlogger.record_tabular('TestRewMean', np.mean(test_evals))
            tlogger.record_tabular('TestRewMedian', np.median(test_evals))
            tlogger.record_tabular('TestEpCount', len(test_evals))
            tlogger.record_tabular('TestEpLenSum', test_timesteps)
            tlogger.record_tabular('InitialRewMax',
                                   np.max(initial_performance))
            tlogger.record_tabular('InitialRewMean',
                                   np.mean(initial_performance))
            tlogger.record_tabular('InitialRewMedian',
                                   np.median(initial_performance))

            tlogger.record_tabular('TimestepsThisIter', timesteps_this_iter)
            tlogger.record_tabular(
                'TimestepsPerSecondThisIter',
                timesteps_this_iter / (time.time() - tstart_iteration))
            tlogger.record_tabular('TimestepsComputed', state.num_frames)
            tlogger.record_tabular('TimestepsSoFar', state.timesteps_so_far)
            tlogger.record_tabular('TimeElapsedThisIter',
                                   time_elapsed_this_iter)
            tlogger.record_tabular('TimeElapsedThisIterTotal',
                                   time.time() - tstart_iteration)
            tlogger.record_tabular('TimeElapsed', state.time_elapsed)
            tlogger.record_tabular('TimeElapsedTotal',
                                   time.time() - all_tstart)

            tlogger.dump_tabular()
            fps = state.timesteps_so_far / (time.time() - tstart)
            tlogger.info(
                'Timesteps Per Second: {:.0f}. Elapsed: {:.2f}h ETA {:.2f}h'.
                format(fps, (time.time() - all_tstart) / 3600,
                       (exp['timesteps'] - state.timesteps_so_far) / fps /
                       3600))

            if state.adaptive_tslimit:
                if np.mean(
                    [a.training_steps >= state.tslimit
                     for a in results]) > state.incr_tslimit_threshold:
                    state.tslimit = min(
                        state.tslimit * state.tslimit_incr_ratio,
                        state.tslimit_max)
                    tlogger.info('Increased threshold to {}'.format(
                        state.tslimit))

            os.makedirs(log_dir, exist_ok=True)
            save_file = os.path.join(log_dir, 'snapshot.pkl')
            with open(save_file, 'wb+') as file:
                pickle.dump(state, file)
            #copyfile(save_file, os.path.join(log_dir, 'snapshot_gen{:04d}.pkl'.format(state.it)))
            tlogger.info("Saved iteration {} to {}".format(
                state.it, save_file))

            if state.timesteps_so_far >= exp['timesteps']:
                tlogger.info('Training terminated after {} timesteps'.format(
                    state.timesteps_so_far))
                break
            results.clear()
Ejemplo n.º 23
0
def main(**exp):
    log_dir = tlogger.log_dir()

    tlogger.info(json.dumps(exp, indent=4, sort_keys=True))
    tlogger.info('Logging to: {}'.format(log_dir))
    Model = neuroevolution.models.__dict__[exp['model']]
    all_tstart = time.time()

    noise = SharedNoiseTable()
    rs = np.random.RandomState()

    def make_env0(b):
        return gym_tensorflow.make(game=exp["games"][0], batch_size=b)

    def make_env1(b):
        return gym_tensorflow.make(game=exp["games"][1], batch_size=b)

    workers = [
        ConcurrentWorkers(make_env0, Model, batch_size=64),
        ConcurrentWorkers(make_env1, Model, batch_size=64)
    ]

    saver = tf.train.Saver()

    tlogger.info('Start timing')
    tstart = time.time()
    tf_sess = tf.Session()
    tf_sess.run(tf.global_variables_initializer())
    state = TrainingState(exp)
    state.initialize(rs, noise, workers[0].model)

    workers[0].initialize(tf_sess)
    workers[1].initialize(tf_sess)

    for iteration in range(exp['iterations']):
        tlogger.info("BEGINNING ITERATION: {}".format(iteration))

        ##############
        ### GAME 0 ###
        ##############
        worker = workers[0]
        frames_computed_so_far = tf_sess.run(worker.steps_counter)
        game0_results = []
        game0_rewards = []
        game0_episode_lengths = []

        iterator = iter(
            worker.monitor_eval(make_offspring(exp, noise, rs, worker, state),
                                max_frames=state.tslimit * 4))

        for pos_seeds, pos_reward, pos_length in iterator:
            neg_seeds, neg_reward, neg_length = next(iterator)
            assert pos_seeds == neg_seeds
            result = Offspring(pos_seeds, [pos_reward, neg_reward],
                               [pos_length, neg_length])
            rewards = result.rewards
            game0_results.append(result)
            game0_rewards.append(rewards)
            game0_episode_lengths.append(result.ep_len)
        state.num_frames += tf_sess.run(
            worker.steps_counter) - frames_computed_so_far
        game0_returns_n2 = np.array([a.rewards for a in game0_results])
        game0_noise_inds_n = [a.seeds for a in game0_results]
        # tlogger.info("game0 rewards: {}".format(np.mean(game0_rewards)))
        # tlogger.info("game0 eplens: {}".format(game0_episode_lengths))
        save_pickle(iteration, log_dir, "game0_rewards", game0_rewards)
        save_pickle(iteration, log_dir, "game0_episode_lengths",
                    game0_episode_lengths)

        ##############
        ### GAME 1 ###
        ##############
        worker = workers[1]
        frames_computed_so_far = tf_sess.run(worker.steps_counter)
        game1_results = []
        game1_rewards = []
        game1_episode_lengths = []
        seeds_vector = np.array(game0_noise_inds_n)
        iterator = iter(
            worker.monitor_eval(make_offspring(exp, noise, rs, worker, state,
                                               seeds_vector),
                                max_frames=state.tslimit * 4))

        for pos_seeds, pos_reward, pos_length in iterator:
            neg_seeds, neg_reward, neg_length = next(iterator)
            assert pos_seeds == neg_seeds
            result = Offspring(pos_seeds, [pos_reward, neg_reward],
                               [pos_length, neg_length])
            rewards = result.rewards
            game1_results.append(result)
            game1_rewards.append(rewards)
            game1_episode_lengths.append(result.ep_len)
        state.num_frames += tf_sess.run(
            worker.steps_counter) - frames_computed_so_far
        game1_returns_n2 = np.array([a.rewards for a in game1_results])
        game1_noise_inds_n = [a.seeds for a in game1_results]
        # tlogger.info("game1 rewards: {}".format(np.mean(game1_rewards)))
        # tlogger.info("game1 eplens: {}".format(game0_episode_lengths))
        save_pickle(iteration, log_dir, "game1_rewards", game1_rewards)
        save_pickle(iteration, log_dir, "game1_episode_lengths",
                    game1_episode_lengths)

        tlogger.info("Saving offsprings seeds")
        save_pickle(iteration, log_dir, "offsprings_seeds", game1_noise_inds_n)

        ####################
        ### UPDATE THETA ###
        ####################
        game_returns = [game0_returns_n2, game1_returns_n2]
        proc_returns = obtain_proc_returns(exp['learn_option'], game_returns)

        assert game0_noise_inds_n == game1_noise_inds_n
        noise_inds_n = game0_noise_inds_n + game1_noise_inds_n  # concatenate the two lists

        # TOP 100 offspring
        #        dx = proc_returns[:, 0]
        #        dy = proc_returns[:, 1]
        #        dist_squared = (np.ones(dx.shape) - np.abs(dx))**2 + (np.ones(dy.shape) - np.abs(dy))**2
        #        top_n_rewards = dist_squared.argsort()[-100:][::-1]
        #        batched_weighted_indices = (noise.get(idx, worker.model.num_params) for idx in noise_inds_n)
        #        proc_returns = proc_returns[top_n_rewards, :]
        #        batched_weighted_args = {
        #            'deltas': proc_returns[:, 0] - proc_returns[:, 1],
        #            'indices': [myval for myidx, myval in enumerate(batched_weighted_indices) if myidx in top_n_rewards]
        #        }
        #        noise_inds_n = batched_weighted_args['indices']
        #        g, count = batched_weighted_sum(batched_weighted_args['deltas'], batched_weighted_args['indices'], batch_size=len(batched_weighted_args['deltas']))

        # ALL offspring
        g, count = batched_weighted_sum(
            proc_returns[:, 0] - proc_returns[:, 1],
            (noise.get(idx, worker.model.num_params) for idx in noise_inds_n),
            batch_size=500)

        # NOTE: gradients are scaled by \theta
        returns_n2 = np.array([a.rewards for a in game0_results] +
                              [a.rewards for a in game1_results])

        # Only if using top 100
        #        returns_n2 = returns_n2[top_n_rewards]

        g /= returns_n2.size

        assert g.shape == (
            worker.model.num_params,
        ) and g.dtype == np.float32 and count == len(noise_inds_n)
        update_ratio, state.theta = state.optimizer.update(-g +
                                                           exp['l2coeff'] *
                                                           state.theta)

        save_pickle(iteration, log_dir, "state", state)

        ######################
        ### EVALUATE ELITE ###
        ######################
        _, test_evals, test_timesteps = workers[0].monitor_eval_repeated(
            [(state.theta, 0)],
            max_frames=None,
            num_episodes=exp['num_test_episodes'] // 2)[0]
        tlogger.info("game0 elite: {}".format(np.mean(test_evals)))
        save_pickle(iteration, log_dir, 'game0_elite', test_evals)
        save_pickle(iteration, log_dir, 'game0_elite_timestemps',
                    test_timesteps)

        _, test_evals, test_timesteps = workers[1].monitor_eval_repeated(
            [(state.theta, 0)],
            max_frames=None,
            num_episodes=exp['num_test_episodes'] // 2)[0]
        tlogger.info("game1 elite: {}".format(np.mean(test_evals)))
        save_pickle(iteration, log_dir, "game1_elite", test_evals)
        save_pickle(iteration, log_dir, 'game1_elite_timestemps',
                    test_timesteps)

        state.num_frames += tf_sess.run(
            worker.steps_counter) - frames_computed_so_far

        saver.save(tf_sess, "{}/model-{}".format(log_dir, state.it))

        state.it += 1

    os.kill(os.getpid(), signal.SIGTERM)
Ejemplo n.º 24
0
def load_pickle(iteration, log_dir, pickle_filename):
    save_file = os.path.join(
        log_dir, "{:04d}-{}.pkl".format(iteration, pickle_filename))
    tlogger.info("Loading {}".format(save_file))
    with open(save_file, 'rb') as file:
        return pickle.load(file)
Ejemplo n.º 25
0
def save_pickle(iteration, log_dir, pickle_filename, dat):
    save_file = os.path.join(
        log_dir, "{:04d}-{}.pkl".format(iteration, pickle_filename))
    with open(save_file, 'wb+') as file:
        pickle.dump(dat, file, pickle.HIGHEST_PROTOCOL)
    tlogger.info("Saved {}".format(save_file))
Ejemplo n.º 26
0
def main(**exp):

    log_dir = tlogger.log_dir()

    tlogger.info(json.dumps(exp, indent=4, sort_keys=True))
    tlogger.info('Logging to: {}'.format(log_dir))
    Model = neuroevolution.models.__dict__[exp['model']]
    all_tstart = time.time()

    noise = SharedNoiseTable()
    rs = np.random.RandomState()

    def make_env0(b):
        return gym_tensorflow.make(game=exp["games"][0], batch_size=b)

    def make_env1(b):
        return gym_tensorflow.make(game=exp["games"][1], batch_size=b)

    workers = [
        ConcurrentWorkers(make_env0, Model, batch_size=64),
        ConcurrentWorkers(make_env1, Model, batch_size=64)
    ]

    tlogger.info('Start timing')
    tstart = time.time()
    tf_sess = tf.Session()
    tf_sess.run(tf.global_variables_initializer())
    state = TrainingState(exp)
    state.initialize(rs, noise, workers[0].model)

    workers[0].initialize(tf_sess)
    workers[1].initialize(tf_sess)

    for iteration in range(exp['iterations']):
        tlogger.info("BEGINNING ITERATION: {}".format(iteration))

        ##############
        ### GAME 0 ###
        ##############
        worker = workers[0]
        frames_computed_so_far = tf_sess.run(worker.steps_counter)
        game0_results = []
        game0_rewards = []
        game0_episode_lengths = []

        iterator = iter(
            worker.monitor_eval(make_offspring(exp, noise, rs, worker, state),
                                max_frames=state.tslimit * 4))

        for pos_seeds, pos_reward, pos_length in iterator:
            neg_seeds, neg_reward, neg_length = next(iterator)
            assert pos_seeds == neg_seeds
            result = Offspring(pos_seeds, [pos_reward, neg_reward],
                               [pos_length, neg_length])
            rewards = result.rewards
            game0_results.append(result)
            game0_rewards.append(rewards)
            game0_episode_lengths.append(result.ep_len)
        state.num_frames += tf_sess.run(
            worker.steps_counter) - frames_computed_so_far
        game0_returns_n2 = np.array([a.rewards for a in game0_results])
        game0_noise_inds_n = [a.seeds for a in game0_results]
        save_pickle(iteration, log_dir, "game0_rewards", game0_rewards)
        save_pickle(iteration, log_dir, "game0_episode_lengths",
                    game0_episode_lengths)

        ##############
        ### GAME 1 ###
        ##############
        if f_isSingleTask(exp):
            game1_results = []
            game1_rewards = []
            game1_episode_lengths = []
            game1_returns_n2 = game0_returns_n2
            game1_noise_inds_n = game0_noise_inds_n
        else:
            worker = workers[1]
            frames_computed_so_far = tf_sess.run(worker.steps_counter)
            game1_results = []
            game1_rewards = []
            game1_episode_lengths = []
            seeds_vector = np.array(game0_noise_inds_n)
            iterator = iter(
                worker.monitor_eval(make_offspring(exp, noise, rs, worker,
                                                   state, seeds_vector),
                                    max_frames=state.tslimit * 4))

            for pos_seeds, pos_reward, pos_length in iterator:
                neg_seeds, neg_reward, neg_length = next(iterator)
                assert pos_seeds == neg_seeds
                result = Offspring(pos_seeds, [pos_reward, neg_reward],
                                   [pos_length, neg_length])
                rewards = result.rewards
                game1_results.append(result)
                game1_rewards.append(rewards)
                game1_episode_lengths.append(result.ep_len)
            state.num_frames += tf_sess.run(
                worker.steps_counter) - frames_computed_so_far
            game1_returns_n2 = np.array([a.rewards for a in game1_results])
            game1_noise_inds_n = [a.seeds for a in game1_results]
        save_pickle(iteration, log_dir, "game1_rewards", game1_rewards)
        save_pickle(iteration, log_dir, "game1_episode_lengths",
                    game1_episode_lengths)

        tlogger.info("Saving offsprings seeds")
        save_pickle(iteration, log_dir, "offsprings_seeds", game1_noise_inds_n)

        ####################
        ### UPDATE THETA ###
        ####################

        if f_isSingleTask(exp):
            proc_frames = compute_centered_ranks(
                np.asarray(game0_episode_lengths))
            proc_returns = compute_centered_ranks(game0_returns_n2)
            noise_inds_n = game0_noise_inds_n
        else:
            game_returns = [game0_returns_n2, game1_returns_n2]
            proc_returns = obtain_proc_returns(exp['learn_option'],
                                               game_returns)

            assert game0_noise_inds_n == game1_noise_inds_n
            noise_inds_n = game0_noise_inds_n + game1_noise_inds_n  # concatenate the two lists

        g_returns, count_returns = batched_weighted_sum(
            proc_returns[:, 0] - proc_returns[:, 1],
            (noise.get(idx, worker.model.num_params) for idx in noise_inds_n),
            batch_size=500)

        g_frames, count_frames = batched_weighted_sum(
            proc_frames[:, 0] - proc_frames[:, 1],
            (noise.get(idx, worker.model.num_params) for idx in noise_inds_n),
            batch_size=500)

        assert count_frames == count_returns
        count = count_returns

        w = exp['w']
        g = w * g_returns + (1 - w) * g_frames

        returns_n2 = np.array([a.rewards for a in game0_results] +
                              [a.rewards for a in game1_results])
        g /= returns_n2.size

        assert g.shape == (
            worker.model.num_params,
        ) and g.dtype == np.float32 and count == len(noise_inds_n)
        update_ratio, state.theta = state.optimizer.update(-g +
                                                           exp['l2coeff'] *
                                                           state.theta)

        save_pickle(iteration, log_dir, "state", state)

        ######################
        ### EVALUATE ELITE ###
        ######################
        _, test_evals, test_timesteps = workers[0].monitor_eval_repeated(
            [(state.theta, 0)],
            max_frames=None,
            num_episodes=exp['num_test_episodes'] //
            (2**(1 - f_isSingleTask(exp))))[0]
        tlogger.info("game0 elite: {}".format(np.mean(test_evals)))
        tlogger.info("game0 elite frames max: {}".format(
            np.max(test_timesteps)))
        tlogger.info("game0 elite frames mean: {}".format(
            np.mean(test_timesteps)))
        tlogger.info("game0 elite frames min: {}".format(
            np.min(test_timesteps)))
        tlogger.info("game0 offspring frames max: {}".format(
            np.max(game0_episode_lengths)))
        tlogger.info("game0 offspring frames mean: {}".format(
            np.mean(game0_episode_lengths)))
        tlogger.info("game0 offspring frames min: {}".format(
            np.min(game0_episode_lengths)))
        save_pickle(iteration, log_dir, 'game0_elite', test_evals)
        save_pickle(iteration, log_dir, 'game0_elite_timestemps',
                    test_timesteps)

        if not (f_isSingleTask(exp)):
            _, test_evals, test_timesteps = workers[1].monitor_eval_repeated(
                [(state.theta, 0)],
                max_frames=None,
                num_episodes=exp['num_test_episodes'] // 2)[0]

        tlogger.info("game1 elite: {}".format(np.mean(test_evals)))
        save_pickle(iteration, log_dir, "game1_elite", test_evals)
        save_pickle(iteration, log_dir, 'game1_elite_timestemps',
                    test_timesteps)

        state.num_frames += tf_sess.run(
            worker.steps_counter) - frames_computed_so_far
        state.it += 1

    os.kill(os.getpid(), signal.SIGTERM)
Ejemplo n.º 27
0
def main(game,
         filename=None,
         outfile=None,
         model_name="LargeModel",
         no_video=False,
         add_text=False,
         num_runs=RUNS,
         graph=None):

    seeds = default_seeds
    outvid = None
    viewer = None
    iteration = None
    state = None

    if filename:
        with open(filename, 'rb+') as file:
            state = pickle.load(file)
            #if hasattr(state, 'best_score'):
            #    seeds = state.best_score.seeds
            #    iteration = len(seeds)
            #    print("Loading GA snapshot from best_score, iteration: ", len(seeds))
            if hasattr(state, 'elite'):
                seeds = state.elite.seeds
                iteration = state.it
                print("Loading GA snapshot from elite, iteration: {} / {}",
                      len(seeds), iteration)
            else:
                seeds = None
                iteration = state.it
                print("Loading ES snapshot, iteration: {}".format(state.it))

    if outfile:
        pass
        fourcc = cv.VideoWriter_fourcc(*'MJPG')
        outvid = cv.VideoWriter(outfile, fourcc, 16, (VIDEO_SIZE, VIDEO_SIZE))

    env = gym_tensorflow.make(game, 1)

    model = get_model(model_name)
    obs_op = env.observation()
    reset_op = env.reset()

    if model.requires_ref_batch:

        def make_env(b):
            return gym_tensorflow.make(game=game, batch_size=1)

        with tf.Session() as sess:
            ref_batch = gym_tensorflow.get_ref_batch(make_env, sess, 128)
            ref_batch = ref_batch[:, ...]
    else:
        ref_batch = None

    action_op = model.make_net(tf.expand_dims(obs_op, axis=1),
                               env.action_space,
                               batch_size=1,
                               ref_batch=ref_batch)
    if env.discrete_action:
        action_op = tf.argmax(action_op, axis=-1, output_type=tf.int32)
    rew_op, done_op = env.step(action_op)

    if not no_video:
        from gym.envs.classic_control import rendering
        viewer = rendering.SimpleImageViewer()

    if hasattr(env.unwrapped, 'render'):
        obs_op = env.unwrapped.render()

        def display_obs(im):
            # pdb.set_trace()
            if im.shape[1] > 1:
                im = np.bitwise_or(im[0, 0, ...], im[0, 1, ...])
            else:
                im = im[0, 0, ...]
            handle_frame(im, outvid, viewer, game, iteration, add_text)
    else:

        def display_obs(im):
            pdb.set_trace()
            im = im[0, :, :, -1]
            im = np.stack([im] * 3, axis=-1)
            im = (im * 255).astype(np.uint8)
            handle_frame(im, outvid, viewer, game, iteration, add_text)

    rewards = []

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        model.initialize()
        tlogger.info(model.description)

        import pdb
        pdb.set_trace()
        if seeds:
            noise = SharedNoiseTable()
            weights = model.compute_weights_from_seeds(noise, seeds)
            model.load(sess, 0, weights, seeds)
        else:
            weights = state.theta
            model.load(sess, 0, weights, (weights, 0))

        if graph:
            saver.save(sess, graph)

        for i in range(num_runs):
            sess.run(reset_op)
            sess.run(obs_op)
            #recorder.capture_frame()
            display_obs(sess.run(obs_op))

            total_rew = 0
            num_frames = 0
            while True:
                rew, done = sess.run([rew_op, done_op])
                num_frames += 1
                total_rew += rew[0]
                display_obs(sess.run(obs_op))
                time.sleep(4 / 60)
                if done[0]:
                    break

            rewards += [total_rew]
            print('Final reward: ', total_rew, 'after', num_frames, 'steps')

    print(rewards)
    print("Mean: ", np.mean(rewards))
    print("Std: ", np.std(rewards))

    if outvid:
        outvid.release()
 def make_env(b):
     tlogger.info('GA: Creating environment for game: %s' % config["game"])
     return gym_tensorflow.make(game=config["game"], batch_size=b)
Ejemplo n.º 29
0
def main(exp, log_dir):
    log_dir = tlogger.log_dir(log_dir)

    snap_idx = 0
    snapshots = []

    tlogger.info(json.dumps(exp, indent=4, sort_keys=True))
    tlogger.info('Logging to: {}'.format(log_dir))
    Model = neuroevolution.models.__dict__[exp['model']]
    all_tstart = time.time()

    def make_env(b):
        return gym_tensorflow.make(game=exp["game"], batch_size=b)

    worker = ConcurrentWorkers(make_env, Model, batch_size=64)
    with WorkerSession(worker) as sess:
        rs = np.random.RandomState()
        noise = None
        state = None
        cached_parents = []
        results = []

        def make_offspring():
            if len(cached_parents) == 0:
                return worker.model.randomize(rs, noise)
            else:
                assert len(cached_parents) == exp['selection_threshold']
                parent = cached_parents[rs.randint(len(cached_parents))]
                return worker.model.mutate(parent,
                                           rs,
                                           noise,
                                           mutation_power=state.sample(
                                               state.mutation_power))

        tlogger.info('Start timing')
        tstart = time.time()

        load_file = os.path.join(log_dir, 'snapshot.pkl')

        if 'load_from' in exp:
            filename = os.path.join(log_dir, exp['load_from'])
            with open(filename, 'rb+') as file:
                state = pickle.load(file)
                state.timesteps_so_far = 0  # Reset timesteps to 0
                state.it = 0
                state.max_reward = 0
                state.max_avg = 0
                state.max_sd = 0
            tlogger.info('Loaded initial policy from {}'.format(filename))
        elif os.path.exists(load_file):
            try:
                with open(load_file, 'rb+') as file:
                    state = pickle.load(file)
                tlogger.info("Loaded iteration {} from {}".format(
                    state.it, load_file))
            except FileNotFoundError:
                tlogger.info('Failed to load snapshot')

        if not noise:
            tlogger.info("Generating new noise table")
            noise = SharedNoiseTable()
        else:
            tlogger.info("Using noise table from snapshot")

        if not state:
            tlogger.info("Generation new TrainingState")
            state = TrainingState(exp)

        if 'load_population' in exp:
            state.copy_population(exp['load_population'])

        # Cache first population if needed (on restart)
        if state.population and exp['selection_threshold'] > 0:
            tlogger.info("Caching parents")
            cached_parents.clear()
            if state.elite in state.population[:exp['selection_threshold']]:
                cached_parents.extend([
                    (worker.model.compute_weights_from_seeds(noise,
                                                             o.seeds), o.seeds)
                    for o in state.population[:exp['selection_threshold']]
                ])
            else:
                cached_parents.append((worker.model.compute_weights_from_seeds(
                    noise, state.elite.seeds), state.elite.seeds))
                cached_parents.extend([
                    (worker.model.compute_weights_from_seeds(noise,
                                                             o.seeds), o.seeds)
                    for o in state.population[:exp['selection_threshold'] - 1]
                ])
            tlogger.info("Done caching parents")

        while True:
            tstart_iteration = time.time()
            if state.timesteps_so_far >= exp['timesteps']:
                tlogger.info('Training terminated after {} timesteps'.format(
                    state.timesteps_so_far))
                break
            frames_computed_so_far = sess.run(worker.steps_counter)
            assert (len(cached_parents) == 0 and state.it
                    == 0) or len(cached_parents) == exp['selection_threshold']

            tasks = [make_offspring() for _ in range(exp['population_size'])]
            for seeds, episode_reward, episode_length in worker.monitor_eval(
                    tasks, max_frames=state.tslimit * 4):
                results.append(
                    Offspring(seeds, [episode_reward], [episode_length]))
            state.num_frames += sess.run(
                worker.steps_counter) - frames_computed_so_far

            state.it += 1
            tlogger.record_tabular('Iteration', state.it)
            tlogger.record_tabular('MutationPower',
                                   state.sample(state.mutation_power))

            # Trim unwanted results
            results = results[:exp['population_size']]
            assert len(results) == exp['population_size']
            rewards = np.array([a.fitness for a in results])
            population_timesteps = sum([a.training_steps for a in results])
            state.population = sorted(results,
                                      key=lambda x: x.fitness,
                                      reverse=True)
            state.max_reward = save_best_pop_member(state.max_reward,
                                                    np.max(rewards), state,
                                                    state.population[0])
            tlogger.record_tabular('PopulationEpRewMax', np.max(rewards))
            tlogger.record_tabular('PopulationEpRewMean', np.mean(rewards))
            tlogger.record_tabular('PopulationEpCount', len(rewards))
            tlogger.record_tabular('PopulationTimesteps', population_timesteps)
            tlogger.record_tabular('NumSelectedIndividuals',
                                   exp['selection_threshold'])

            tlogger.info('Evaluate population')
            validation_population = state.population[:exp[
                'validation_threshold']]
            if state.elite is not None:
                validation_population = [state.elite
                                         ] + validation_population[:-1]

            validation_tasks = [(worker.model.compute_weights_from_seeds(
                noise, validation_population[x].seeds,
                cache=cached_parents), validation_population[x].seeds)
                                for x in range(exp['validation_threshold'])]
            _, population_validation, population_validation_len = zip(
                *worker.monitor_eval_repeated(
                    validation_tasks,
                    max_frames=state.tslimit * 4,
                    num_episodes=exp['num_validation_episodes']))

            it_max_avg = np.max([np.mean(x) for x in population_validation])
            it_max_sd = np.max([np.std(x) for x in population_validation])

            state.max_avg = np.max([state.max_avg, it_max_avg])
            state.max_sd = np.max([state.max_sd, it_max_sd])

            tlogger.info("Max Average: {}".format(state.max_avg))
            tlogger.info("Max Std: {}".format(state.max_sd))

            fitness_results = [(np.mean(x), np.std(x))
                               for x in population_validation]
            with open(os.path.join(log_dir, 'fitness.log'), 'a') as f:
                f.write("{},{},{}: {}\n".format(
                    state.it, state.max_avg, state.max_sd, ','.join([
                        "({},{})".format(x[0], x[1]) for x in fitness_results
                    ])))

            population_fitness = [
                fitness(x[0], x[1], state.max_avg, state.max_sd)
                for x in fitness_results
            ]
            tlogger.info("Fitness: {}".format(population_fitness))
            population_validation_len = [
                np.sum(x) for x in population_validation_len
            ]

            time_elapsed_this_iter = time.time() - tstart_iteration
            state.time_elapsed += time_elapsed_this_iter

            population_elite_idx = np.argmin(population_fitness)
            state.elite = validation_population[population_elite_idx]
            elite_theta = worker.model.compute_weights_from_seeds(
                noise, state.elite.seeds, cache=cached_parents)
            _, population_elite_evals, population_elite_evals_timesteps = worker.monitor_eval_repeated(
                [(elite_theta, state.elite.seeds)],
                max_frames=None,
                num_episodes=exp['num_test_episodes'])[0]

            # Log Results
            validation_timesteps = sum(population_validation_len)
            timesteps_this_iter = population_timesteps + validation_timesteps
            state.timesteps_so_far += timesteps_this_iter
            state.validation_timesteps_so_far += validation_timesteps

            # Log
            tlogger.record_tabular(
                'TruncatedPopulationRewMean',
                np.mean([a.fitness for a in validation_population]))
            tlogger.record_tabular('TruncatedPopulationValidationFitMean',
                                   np.mean(population_fitness))
            tlogger.record_tabular('TruncatedPopulationValidationFitMax',
                                   np.max(population_fitness))
            tlogger.record_tabular('TruncatedPopulationValidationFitMin',
                                   np.min(population_fitness))
            tlogger.record_tabular('TruncatedPopulationValidationMaxAvg',
                                   state.max_avg)
            tlogger.record_tabular('TruncatedPopulationValidationMaxStd',
                                   state.max_sd)
            tlogger.record_tabular('TruncatedPopulationEliteValidationFitMin',
                                   np.min(population_fitness))
            tlogger.record_tabular("TruncatedPopulationEliteIndex",
                                   population_elite_idx)
            tlogger.record_tabular('TruncatedPopulationEliteSeeds',
                                   state.elite.seeds)
            tlogger.record_tabular('TruncatedPopulationEliteTestRewMean',
                                   np.mean(population_elite_evals))
            tlogger.record_tabular('TruncatedPopulationEliteTestRewStd',
                                   np.std(population_elite_evals))
            tlogger.record_tabular('TruncatedPopulationEliteTestEpCount',
                                   len(population_elite_evals))
            tlogger.record_tabular('TruncatedPopulationEliteTestEpLenSum',
                                   np.sum(population_elite_evals_timesteps))

            if np.mean(population_validation) > state.curr_solution_val:
                state.curr_solution = state.elite.seeds
                state.curr_solution_val = np.mean(population_validation)
                state.curr_solution_test = np.mean(population_elite_evals)

            tlogger.record_tabular('ValidationTimestepsThisIter',
                                   validation_timesteps)
            tlogger.record_tabular('ValidationTimestepsSoFar',
                                   state.validation_timesteps_so_far)
            tlogger.record_tabular('TimestepsThisIter', timesteps_this_iter)
            tlogger.record_tabular(
                'TimestepsPerSecondThisIter',
                timesteps_this_iter / (time.time() - tstart_iteration))
            tlogger.record_tabular('TimestepsComputed', state.num_frames)
            tlogger.record_tabular('TimestepsSoFar', state.timesteps_so_far)
            tlogger.record_tabular('TimeElapsedThisIter',
                                   time_elapsed_this_iter)
            tlogger.record_tabular('TimeElapsedThisIterTotal',
                                   time.time() - tstart_iteration)
            tlogger.record_tabular('TimeElapsed', state.time_elapsed)
            tlogger.record_tabular('TimeElapsedTotal',
                                   time.time() - all_tstart)

            tlogger.dump_tabular()
            # tlogger.info('Current elite: {}'.format(state.elite.seeds))
            fps = state.timesteps_so_far / (time.time() - tstart)
            tlogger.info(
                'Timesteps Per Second: {:.0f}. Elapsed: {:.2f}h ETA {:.2f}h'.
                format(fps, (time.time() - all_tstart) / 3600,
                       (exp['timesteps'] - state.timesteps_so_far) / fps /
                       3600))

            if state.adaptive_tslimit:
                if np.mean(
                    [a.training_steps >= state.tslimit
                     for a in results]) > state.incr_tslimit_threshold:
                    state.tslimit = min(
                        state.tslimit * state.tslimit_incr_ratio,
                        state.tslimit_max)
                    tlogger.info('Increased threshold to {}'.format(
                        state.tslimit))

            snap_idx, snapshots = save_snapshot(state, log_dir, snap_idx,
                                                snapshots)
            # os.makedirs(log_dir, exist_ok=True)
            # copyfile(save_file, os.path.join(log_dir, 'snapshot_gen{:04d}.pkl'.format(state.it)))
            tlogger.info("Saved iteration {} to {}".format(
                state.it, snapshots[snap_idx - 1]))

            if state.timesteps_so_far >= exp['timesteps']:
                tlogger.info('Training terminated after {} timesteps'.format(
                    state.timesteps_so_far))
                break
            results.clear()

            if exp['selection_threshold'] > 0:
                tlogger.info("Caching parents")
                new_parents = []
                if state.elite in state.population[:
                                                   exp['selection_threshold']]:
                    new_parents.extend([
                        (worker.model.compute_weights_from_seeds(
                            noise, o.seeds, cache=cached_parents), o.seeds)
                        for o in state.population[:exp['selection_threshold']]
                    ])
                else:
                    new_parents.append(
                        (worker.model.compute_weights_from_seeds(
                            noise, state.elite.seeds,
                            cache=cached_parents), state.elite.seeds))
                    new_parents.extend([
                        (worker.model.compute_weights_from_seeds(
                            noise, o.seeds, cache=cached_parents), o.seeds)
                        for o in state.population[:exp['selection_threshold'] -
                                                  1]
                    ])

                cached_parents.clear()
                cached_parents.extend(new_parents)
                tlogger.info("Done caching parents")
    return float(state.curr_solution_test), {
        'val': float(state.curr_solution_val)
    }
Ejemplo n.º 30
0
def main(
        num_inner_iterations=64,
        noise_size=128,
        inner_loop_init_lr=0.2,
        inner_loop_init_momentum=0.5,
        training_iterations_schedule=5,
        min_training_iterations=4,
        lr=0.1,
        rms_momentum=0.9,
        final_relative_lr=1e-2,
        generator_batch_size=128,
        meta_batch_size=512,
        adam_epsilon=1e-8,
        adam_beta1=0.0,
        adam_beta2=0.999,
        num_meta_iterations=1000,
        starting_meta_iteration=1,
        max_elapsed_time=None,
        gradient_block_size=16,
        use_intermediate_losses=0,
        intermediate_losses_ratio=1.0,
        data_path='./data',
        meta_optimizer="adam",
        dataset='MNIST',
        logging_period=10,
        generator_type="cgtn",
        learner_type="base",
        validation_learner_type=None,
        warmup_iterations=None,
        warmup_learner="base",
        final_batch_norm_forward=False,
        # The following flag is used for architecture search (it maps iteration to a specific architecture)
        iteration_maps_seed=False,
        use_dataset_augmentation=False,
        training_schedule_backwards=True,
        evenly_distributed_labels=True,
        iterations_depth_schedule=100,
        use_encoder=True,
        decoder_loss_multiplier=1.0,
        load_from=None,
        virtual_batch_size=1,
        deterministic=False,
        seed=1,
        grad_bound=None,
        version=None,  # dummy variable
        enable_checkpointing=True,
        randomize_width=False,
        step_by_step_validation=True,
        semisupervised_classifier_loss=True,
        semisupervised_student_loss=True,
        automl_class=None,
        inner_loop_optimizer="SGD",
        meta_learn_labels=False,
        device='cuda'):
    validation_learner_type = validation_learner_type or learner_type
    hvd.init()
    assert hvd.mpi_threads_supported()
    from mpi4py import MPI
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    lr = lr * virtual_batch_size * hvd.size()
    torch.cuda.set_device(hvd.local_rank())
    # Load dataset
    img_shape, trainset, validationset, (testset_x, testset_y) = get_dataset(
        dataset,
        data_path,
        seed,
        device,
        with_augmentation=use_dataset_augmentation)
    validation_x, validation_y = zip(*validationset)
    validation_x = torch.stack(validation_x).to(device)
    validation_y = torch.as_tensor(validation_y).to(device)
    # Make each worker slightly different
    torch.manual_seed(seed + hvd.rank())
    np.random.seed(seed + hvd.rank())

    if generator_type == "semisupervised":
        unlabelled_trainset, trainset = torch.utils.data.random_split(
            trainset, [49500, 500])

    data_loader = torch.utils.data.DataLoader(trainset,
                                              batch_size=meta_batch_size,
                                              shuffle=True,
                                              drop_last=True,
                                              num_workers=1,
                                              pin_memory=True)
    data_loader = EndlessDataLoader(data_loader)

    if generator_type == "cgtn":
        generator = CGTN(
            generator=Generator(noise_size + 10, img_shape),
            num_inner_iterations=num_inner_iterations,
            generator_batch_size=generator_batch_size,
            noise_size=noise_size,
            evenly_distributed_labels=evenly_distributed_labels,
            meta_learn_labels=bool(meta_learn_labels),
        )
    elif generator_type == "cgtn_all_shuffled":
        generator = CGTNAllShuffled(
            generator=Generator(noise_size + 10, img_shape),
            num_inner_iterations=num_inner_iterations,
            generator_batch_size=generator_batch_size,
            noise_size=noise_size,
            evenly_distributed_labels=evenly_distributed_labels,
        )
    elif generator_type == "cgtn_batch_shuffled":
        generator = CGTNBatchShuffled(
            generator=Generator(noise_size + 10, img_shape),
            num_inner_iterations=num_inner_iterations,
            generator_batch_size=generator_batch_size,
            noise_size=noise_size,
            evenly_distributed_labels=evenly_distributed_labels,
        )
    elif generator_type == "gtn":
        generator = GTN(
            generator=Generator(noise_size + 10, img_shape),
            generator_batch_size=generator_batch_size,
            noise_size=noise_size,
        )
    elif generator_type == "gaussian_cgtn":
        generator = GaussianCGTN(
            generator=Generator(noise_size + 10, img_shape),
            num_inner_iterations=num_inner_iterations,
            generator_batch_size=generator_batch_size,
            noise_size=noise_size,
        )
    elif generator_type == "dataset":
        generator = UniformSamplingGenerator(
            torch.utils.data.DataLoader(trainset,
                                        batch_size=generator_batch_size,
                                        shuffle=True,
                                        drop_last=True),
            num_inner_iterations=num_inner_iterations,
            device=device,
        )
    elif generator_type == "distillation":
        generator = DatasetDistillation(
            img_shape=img_shape,
            num_inner_iterations=num_inner_iterations,
            generator_batch_size=generator_batch_size,
        )
    elif generator_type == "semisupervised":
        generator = SemisupervisedGenerator(
            torch.utils.data.DataLoader(unlabelled_trainset,
                                        batch_size=generator_batch_size,
                                        shuffle=True,
                                        drop_last=True),
            num_inner_iterations=num_inner_iterations,
            device=device,
            classifier=models.ClassifierLarger2(img_shape,
                                                batch_norm_momentum=0.9,
                                                randomize_width=False))
    else:
        raise NotImplementedError()

    # Create meta-objective models
    if inner_loop_optimizer == "SGD":
        optimizers = [
            inner_optimizers.SGD(inner_loop_init_lr, inner_loop_init_momentum,
                                 num_inner_iterations)
        ]
    elif inner_loop_optimizer == "RMSProp":
        optimizers = [
            inner_optimizers.RMSProp(inner_loop_init_lr,
                                     inner_loop_init_momentum,
                                     num_inner_iterations)
        ]
    elif inner_loop_optimizer == "Adam":
        optimizers = [
            inner_optimizers.Adam(inner_loop_init_lr, inner_loop_init_momentum,
                                  num_inner_iterations)
        ]
    else:
        raise ValueError(
            f"Inner loop optimizer '{inner_loop_optimizer}' not available")

    automl = (automl_class or AutoML)(
        generator=generator,
        optimizers=optimizers,
    )
    if use_encoder:
        automl.encoder = Encoder(img_shape, output_size=noise_size)

    automl = automl.to(device)

    if meta_optimizer == "adam":
        optimizer = torch.optim.Adam(automl.parameters(),
                                     lr=lr,
                                     betas=(adam_beta1, adam_beta2),
                                     eps=adam_epsilon)
    elif meta_optimizer == "sgd":
        optimizer = torch.optim.SGD(automl.parameters(),
                                    lr=lr,
                                    momentum=rms_momentum)
    elif meta_optimizer == "RMS":
        optimizer = torch.optim.RMSprop(automl.parameters(),
                                        lr=lr,
                                        alpha=adam_beta1,
                                        momentum=rms_momentum,
                                        eps=adam_epsilon)
    else:
        raise NotImplementedError()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, num_meta_iterations, lr * final_relative_lr)
    if hvd.rank() == 0:
        if load_from:
            state = torch.load(load_from)
            automl.load_state_dict(state["model"])
            if lr > 0:
                optimizer.load_state_dict(state["optimizer"])
            del state
            tlogger.info("loaded from:", load_from)
        total_num_parameters = 0
        for name, value in automl.named_parameters():
            tlogger.info("Optimizing parameter:", name, value.shape)
            total_num_parameters += np.prod(value.shape)
        tlogger.info("Total number of parameters:", int(total_num_parameters))

    def compute_learner(learner,
                        iterations=num_inner_iterations,
                        keep_grad=True,
                        callback=None):
        learner.model.train()
        names, params = list(zip(*learner.model.get_parameters()))
        buffers = list(zip(*learner.model.named_buffers()))
        if buffers:
            buffer_names, buffers = buffers
        else:
            buffer_names, buffers = None, ()
        param_shapes = [p.shape for p in params]
        param_sizes = [np.prod(shape) for shape in param_shapes]
        param_end_point = np.cumsum(param_sizes)

        buffer_shapes = [p.shape for p in buffers]
        buffer_sizes = [np.prod(shape) for shape in buffer_shapes]
        buffer_end_point = np.cumsum(buffer_sizes)

        def split_params(fused_params):
            # return fused_params
            assert len(fused_params) == 1
            return [
                fused_params[0][end - size:end].reshape(shape) for end, size,
                shape in zip(param_end_point, param_sizes, param_shapes)
            ]

        def split_buffer(fused_params):
            if fused_params:
                # return fused_params
                assert len(fused_params) == 1
                return [
                    fused_params[0][end - size:end].reshape(shape)
                    for end, size, shape in zip(buffer_end_point, buffer_sizes,
                                                buffer_shapes)
                ]
            return fused_params

        # test = split_params(torch.cat([p.flatten() for p in params]))
        # assert all([np.allclose(params[i].detach().cpu(), test[i].detach().cpu()) for i in range(len(test))])
        params = [torch.cat([p.flatten() for p in params])]
        buffers = [torch.cat([p.flatten()
                              for p in buffers])] if buffers else buffers
        optimizer_state = learner.optimizer.initial_state(params)

        params = params, buffers
        initial_params = nest.map_structure(lambda p: None, params)

        losses = {}
        accuracies = {}

        def intermediate_loss(params):
            params = nest.pack_sequence_as(initial_params, params[1:])
            params, buffers = params
            x, y = next(meta_generator)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            learner.model.set_parameters(list(zip(names,
                                                  split_params(params))))
            if buffer_names:
                learner.model.set_buffers(
                    list(zip(buffer_names, split_buffer(buffers))))
            learner.model.eval()
            pred = learner.model(x)
            if isinstance(pred, tuple):
                pred, aux_pred = pred
                loss = F.nll_loss(pred, y) + F.nll_loss(aux_pred, y)
            else:
                loss = F.nll_loss(pred, y)
            return loss * intermediate_losses_ratio

        if hasattr(automl.generator, "init"):
            generator_args = [automl.generator.init()]
        else:
            generator_args = []

        def body(args):
            it, params, optimizer_state = args
            if training_schedule_backwards:
                x, y_one_hot = automl.generator(iterations - it - 1,
                                                *generator_args)
            else:
                x, y_one_hot = automl.generator(it, *generator_args)
            with torch.enable_grad():
                if use_intermediate_losses > 0 and (
                        it >= use_intermediate_losses
                        and it % use_intermediate_losses == 0):
                    params = SurrogateLoss.apply(intermediate_loss, it,
                                                 *nest.flatten(params))
                    params = nest.pack_sequence_as(initial_params, params[1:])
                params, buffers = params
                for p in params:
                    if not p.requires_grad:
                        p.requires_grad = True

                learner.model.set_parameters(
                    list(zip(names, split_params(params))))
                if buffer_names:
                    learner.model.set_buffers(
                        list(zip(buffer_names, split_buffer(buffers))))
                learner.model.train()
                output = learner.model(x)
                if isinstance(output, tuple):
                    output1, output2 = output
                    loss = -(output1 * y_one_hot).sum() * (1 /
                                                           output1.shape[0])
                    loss = loss - (output2 *
                                   y_one_hot).sum() * (1 / output2.shape[0])
                    pred = output1
                else:
                    loss = -(output * y_one_hot).sum() * (1 / output.shape[0])
                    pred = output
                if it.item() not in losses:
                    losses[it.item()] = loss.detach().cpu().item()
                    accuracies[it.item()] = (
                        pred.max(-1).indices == y_one_hot.max(-1).indices).to(
                            torch.float).mean().item()

                grads = grad(loss,
                             params,
                             create_graph=x.requires_grad,
                             allow_unused=True)
            # assert len(grads) == len(names)
            new_params, optimizer_state = learner.optimizer(
                it, params, grads, optimizer_state)
            buffers = list(learner.model.buffers())
            buffers = [torch.cat([b.flatten()
                                  for b in buffers])] if buffers else buffers
            if callback is not None:
                learner.model.set_parameters(
                    list(zip(names, split_params(params))))
                if buffer_names:
                    learner.model.set_buffers(
                        list(zip(buffer_names, split_buffer(buffers))))
                callback(learner)

            return (it + 1, (
                new_params,
                buffers,
            ), optimizer_state)

        last_state, params, optimizer_state = gradient_checkpointing(
            (torch.as_tensor(0), params, optimizer_state),
            body,
            iterations,
            block_size=gradient_block_size)
        assert last_state.item() == iterations
        params, buffers = params
        learner.model.set_parameters(list(zip(names, split_params(params))))
        if buffer_names:
            learner.model.set_buffers(
                list(zip(buffer_names, split_buffer(buffers))))

        if final_batch_norm_forward:
            x, _ = automl.generator(torch.randint(iterations, size=()))
            learner.model.train()
            learner.model(x)

        return learner, losses, accuracies

    tstart = time.time()
    meta_generator = iter(data_loader)
    hvd.broadcast_parameters(automl.state_dict(), root_rank=0)
    best_optimizers = {}
    validation_accuracy = None
    total_inner_iterations_so_far = 0
    for iteration in range(starting_meta_iteration, num_meta_iterations + 1):
        last_iteration = time.time()
        # basic logging
        tlogger.record_tabular('Iteration', iteration)
        tlogger.record_tabular('lr', optimizer.param_groups[0]['lr'])

        # Train learner
        if training_iterations_schedule > 0:
            training_iterations = int(
                min(
                    num_inner_iterations, min_training_iterations +
                    (iteration - starting_meta_iteration) //
                    training_iterations_schedule))
        else:
            training_iterations = num_inner_iterations
        tlogger.record_tabular('training_iterations', training_iterations)
        total_inner_iterations_so_far += training_iterations
        tlogger.record_tabular('training_iterations_so_far',
                               total_inner_iterations_so_far * hvd.size())

        optimizer.zero_grad()

        for _ in range(virtual_batch_size):
            torch.cuda.empty_cache()
            meta_x, meta_y = next(meta_generator)
            meta_x = meta_x.to('cuda', non_blocking=True)
            meta_y = meta_y.to('cuda', non_blocking=True)

            tstart_forward = time.time()
            if generator_type != "semisupervised" or semisupervised_student_loss:
                # TODO: Learn batchnorm momentum and eps
                sample_learner_type = learner_type
                if warmup_iterations is not None and iteration < warmup_iterations:
                    sample_learner_type = warmup_learner
                learner, encoding = automl.sample_learner(
                    img_shape,
                    device,
                    allow_nas=False,
                    randomize_width=randomize_width,
                    learner_type=sample_learner_type,
                    iteration_maps_seed=iteration_maps_seed,
                    iteration=iteration,
                    deterministic=deterministic,
                    iterations_depth_schedule=iterations_depth_schedule)
                automl.train()

                if lr == 0.0:
                    with torch.no_grad():
                        learner, intermediate_losses, intermediate_accuracies = compute_learner(
                            learner,
                            iterations=training_iterations,
                            keep_grad=lr > 0.0)
                else:
                    learner, intermediate_losses, intermediate_accuracies = compute_learner(
                        learner,
                        iterations=training_iterations,
                        keep_grad=lr > 0.0)
                # TODO: remove this requirement
                params = list(learner.model.get_parameters())
                learner.model.eval()

                # Evaluate learner on training and back-prop
                torch.cuda.empty_cache()
                pred = learner.model(meta_x)
                if isinstance(pred, tuple):
                    pred, aux_pred = pred
                    loss = F.nll_loss(pred, meta_y) + F.nll_loss(
                        aux_pred, meta_y)
                else:
                    loss = F.nll_loss(pred, meta_y)
                accuracy = (pred.max(-1).indices == meta_y).to(
                    torch.float).mean()
                tlogger.record_tabular("TimeElapsedForward",
                                       time.time() - tstart_forward)
                num_parameters = sum([a[1].size().numel() for a in params])
                tlogger.record_tabular("TrainingLearnerParameters",
                                       num_parameters)
                tlogger.record_tabular("optimizer",
                                       type(learner.optimizer).__name__)
                tlogger.record_tabular('meta_training_loss', loss.item())
                tlogger.record_tabular('meta_training_accuracy',
                                       accuracy.item())
                tlogger.record_tabular('training_accuracies',
                                       intermediate_accuracies)
                tlogger.record_tabular('training_losses', intermediate_losses)
                tlogger.record_tabular("dag", encoding)
            else:
                loss = torch.as_tensor(0.0)

            if lr > 0.0:
                tstart_backward = time.time()
                if generator_type != "semisupervised" or semisupervised_student_loss:
                    loss.backward()

                if generator_type == "semisupervised" and semisupervised_classifier_loss:
                    automl.generator.classifier.train()
                    pred = automl.generator.classifier(meta_x)
                    accuracy = (pred.max(-1).indices == meta_y).to(
                        torch.float).mean()
                    loss2 = F.nll_loss(pred, meta_y)
                    loss2.backward()
                    tlogger.record_tabular('meta_training_generator_loss',
                                           loss2.item())
                    tlogger.record_tabular('meta_training_generator_accuracy',
                                           accuracy.item())
                    loss = loss + loss2
                    del loss2

                tlogger.record_tabular("TimeElapsedBackward",
                                       time.time() - tstart_backward)

                if use_encoder:
                    # TODO: add loss weight
                    meta_encoding = automl.encoder(meta_x)

                    meta_y_one_hot = torch.zeros(meta_x.shape[0],
                                                 10,
                                                 device=device)
                    meta_y_one_hot.scatter_(1, meta_y.unsqueeze(-1), 1)
                    meta_encoding = torch.cat([meta_encoding, meta_y_one_hot],
                                              -1)
                    reconstruct = automl.generator.generator(meta_encoding)
                    ae_loss = decoder_loss_multiplier * F.mse_loss(
                        reconstruct, meta_x)
                    ae_loss.backward()
                    tlogger.record_tabular("decoder_loss", ae_loss.item())

        if lr > 0.0:
            # If using distributed training aggregard gradients with Horovod
            maybe_allreduce_grads(automl)
            if grad_bound is not None:
                nn.utils.clip_grad_norm_(automl.parameters(), grad_bound)
            optimizer.step()
            if max_elapsed_time is not None:
                scheduler.step(
                    round((time.time() - tstart) / max_elapsed_time *
                          num_meta_iterations))
            else:
                scheduler.step(iteration - 1)

        is_last_iteration = iteration == num_meta_iterations or (
            max_elapsed_time is not None
            and time.time() - tstart > max_elapsed_time)
        if np.isnan(loss.item()):
            tlogger.info("NaN training loss, terminating")
            is_last_iteration = True
        is_last_iteration = MPI.COMM_WORLD.bcast(is_last_iteration, root=0)
        if iteration == 1 or iteration % logging_period == 0 or is_last_iteration:
            tstart_validation = time.time()

            val_loss, val_accuracy = [], []
            test_loss, test_accuracy = [], []

            if generator_type == "semisupervised":
                # Validation set
                evaluate_set(generator.classifier, validation_x, validation_y,
                             "generator_validation")
                # Test set
                evaluate_set(generator.classifier, testset_x, testset_y,
                             "generator_test")
            else:

                def compute_learner_callback(learner):
                    # Validation set
                    validation_loss, single_validation_accuracy, validation_accuracy = evaluate_set(
                        learner.model, validation_x, validation_y,
                        "validation")
                    val_loss.append(validation_loss)
                    val_accuracy.append(validation_accuracy)
                    best_optimizers[type(
                        learner.optimizer
                    ).__name__] = single_validation_accuracy.item()
                    # Test set
                    loss, _, accuracy = evaluate_set(learner.model, testset_x,
                                                     testset_y, "test")
                    test_loss.append(loss)
                    test_accuracy.append(accuracy)

                tlogger.info(
                    "sampling another learner_type ({}) for validation".format(
                        validation_learner_type))
                learner, _ = automl.sample_learner(
                    img_shape,
                    device,
                    allow_nas=False,
                    learner_type=validation_learner_type,
                    iteration_maps_seed=iteration_maps_seed,
                    iteration=iteration,
                    deterministic=deterministic,
                    iterations_depth_schedule=iterations_depth_schedule)
                if step_by_step_validation:
                    compute_learner_callback(learner)
                with torch.no_grad():
                    learner, _, _ = compute_learner(
                        learner,
                        iterations=training_iterations,
                        keep_grad=False,
                        callback=compute_learner_callback
                        if step_by_step_validation else None)
                if not step_by_step_validation:
                    compute_learner_callback(learner)

                tlogger.record_tabular("validation_losses", val_loss)
                tlogger.record_tabular("validation_accuracies", val_accuracy)
                validation_accuracy = val_accuracy[-1]
                tlogger.record_tabular("test_losses", test_loss)
                tlogger.record_tabular("test_accuracies", test_accuracy)

            # Extra logging
            tlogger.record_tabular('TimeElapsedIter',
                                   (tstart_validation - last_iteration) /
                                   virtual_batch_size)
            tlogger.record_tabular('TimeElapsedValidation',
                                   time.time() - tstart_validation)
            tlogger.record_tabular('TimeElapsed', time.time() - tstart)

            for k, v in best_optimizers.items():
                tlogger.record_tabular("{}_last_accuracy".format(k), v)

            if hvd.rank() == 0:
                tlogger.dump_tabular()

                if (iteration == 1 or iteration % 1000 == 0
                        or is_last_iteration):
                    with torch.no_grad():
                        if enable_checkpointing:
                            batches = []
                            for it in range(num_inner_iterations):
                                if training_schedule_backwards:
                                    x, y = automl.generator(
                                        num_inner_iterations - it - 1)
                                else:
                                    x, y = automl.generator(it)
                                batches.append(
                                    (x.cpu().numpy(), y.cpu().numpy()))
                            batches = list(reversed(batches))
                            with open(
                                    os.path.join(
                                        tlogger.get_dir(),
                                        'samples_{}.pkl'.format(iteration)),
                                    'wb') as file:
                                pickle.dump(batches, file)
                            del batches
                            tlogger.info(
                                "Saved:",
                                os.path.join(
                                    tlogger.get_dir(),
                                    'samples_{}.pkl'.format(iteration)))
                        ckpt = os.path.join(
                            tlogger.get_dir(),
                            'checkpoint_{}.pkl'.format(iteration))
                        torch.save(
                            {
                                "optimizer": optimizer.state_dict(),
                                "model": automl.state_dict(),
                            }, ckpt)
                        tlogger.info("Saved:", ckpt)

            if is_last_iteration:
                break
        elif hvd.rank() == 0:
            tlogger.info("training_loss:", loss.item())
    return validation_accuracy