示例#1
0
    def train(self):
        valid_set = self.experiences[:self.valid_size]
        train_set = self.experiences[self.valid_size:]

        valid_batches = util.chunk(valid_set, self.batch_size)

        for epoch in range(self.epochs):
            print("Epoch", epoch)
            start_time = time.time()

            shuffle(train_set)
            batches = util.chunk(train_set, self.batch_size)

            for batch in batches:
                self.rl.train(batch, log=False)

            print(time.time() - start_time)

            for batch in valid_batches:
                self.rl.train(batch, train=False)

            self.rl.save()
示例#2
0
    def __init__(self, load=None, **kwargs):
        if load is None:
            args = {}
        else:
            args = util.load_params(load, 'train')

        kwargs.update(experience_length=6000, )
        util.update(args, mode=RL.Mode.TRAIN, **kwargs)
        util.pp.pprint(args)
        Default.__init__(self, **args)

        if self.init:
            self.rl.init()
            self.rl.save()
        else:
            self.rl.restore()

        if self.data is None:
            self.data = os.path.join(self.rl.path, 'experience')

        print("Loading experiences from", self.data)

        files = os.listdir(self.data)

        if self.file_limit:
            files = files[:self.file_limit]

        data_paths = [os.path.join(self.data, f) for f in files]

        print("Loading %d experiences." % len(files))

        self.experiences = []
        parallel = True

        if parallel:
            for paths in util.chunk(data_paths, 100):
                self.experiences.extend(
                    util.async_map(load_experience, paths)())
        else:
            for path in data_paths:
                with open(path, 'rb') as f:
                    self.experiences.append(pickle.load(f))

        self.valid_size = self.valid_batches * self.batch_size
示例#3
0
    def train(self):
        before = count_objects()

        sweeps = 0

        for _ in range(self.sweep_size):
            self.buffer.push(self.experience_socket.recv_pyobj())

        print("Buffer filled")

        while True:
            start_time = time.time()

            #print('Start: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)

            for _ in range(self.min_collect):
                self.buffer.push(self.experience_socket.recv_pyobj())

            collected = self.min_collect

            while True:
                try:
                    self.buffer.push(
                        self.experience_socket.recv_pyobj(zmq.NOBLOCK))
                    collected += 1
                except zmq.ZMQError as e:
                    break

            #print('After collect: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
            collect_time = time.time()

            experiences = self.buffer.as_list()

            for _ in range(self.sweeps):
                from random import shuffle
                shuffle(experiences)

                for batch in util.chunk(experiences, self.batch_size):
                    self.model.train(batch, self.batch_steps)

            print('After train: %s' %
                  resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
            train_time = time.time()

            #self.params_socket.send_string("", zmq.SNDMORE)
            self.params_socket.send_pyobj(self.model.blob())

            self.save()

            #print('After save: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
            save_time = time.time()

            sweeps += 1

            if False:
                after = count_objects()
                print(diff_objects(after, before))
                before = after

            save_time -= train_time
            train_time -= collect_time
            collect_time -= start_time

            print(sweeps, self.sweep_size, collected, collect_time, train_time,
                  save_time)
            print('Memory usage: %s (kb)' %
                  resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
示例#4
0
    def train(self):
        before = count_objects()

        sweeps = 0
        step = 0
        global_step = self.model.get_global_step()

        times = ['min_collect', 'extra_collect', 'train', 'save']
        averages = {name: util.MovingAverage(.9) for name in times}

        timer = util.Timer()

        def split(name):
            averages[name].append(timer.split())

        experiences = []

        while sweeps != self.sweep_limit:
            sweeps += 1
            timer.reset()

            #print('Start: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)

            old_len = len(experiences)
            if self.max_age is not None:
                # print("global_step", global_step)
                age_limit = global_step - self.max_age
                is_valid = lambda exp: exp['global_step'] >= age_limit
                experiences = list(filter(is_valid, experiences))
            else:
                is_valid = lambda _: True
            dropped = old_len - len(experiences)

            def pull_experience(block=True):
                exp = self.experience_socket.recv(
                    flags=0 if block else nnpy.DONTWAIT)
                return pickle.loads(exp)

            to_collect = max(self.sweep_size - len(experiences),
                             self.min_collect)
            new_experiences = []

            # print("Collecting experiences", len(experiences))
            doa = 0  # dead on arrival
            while len(new_experiences) < to_collect:
                #print("Waiting for experience")
                exp = pull_experience()
                if is_valid(exp):
                    new_experiences.append(exp)
                else:
                    #print("dead on arrival", doa)
                    doa += 1

            split('min_collect')
            #print('min_collected')

            # pull in all the extra experiences
            for _ in range(self.sweep_size):
                try:
                    exp = pull_experience(False)
                    if is_valid(exp):
                        new_experiences.append(exp)
                    else:
                        doa += 1
                except nnpy.NNError as e:
                    if e.error_no == nnpy.EAGAIN:
                        # nothing to receive
                        break
                    # a real error
                    raise e

            experiences += new_experiences

            ages = np.array(
                [global_step - exp['global_step'] for exp in experiences])
            print("Mean age:", ages.mean())

            #print('After collect: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
            split('extra_collect')

            #shuffle(experiences)

            batches = len(experiences) // self.batch_size
            batch_size = (len(experiences) + batches - 1) // batches

            kls = []

            try:
                for batch in util.chunk(experiences, batch_size):
                    train_out = self.model.train(batch,
                                                 self.batch_steps,
                                                 log=(step %
                                                      self.log_interval == 0),
                                                 kls=True)[-1]
                    global_step = train_out['global_step']
                    kls.extend(train_out['kls'].tolist())
                    step += 1
            except tf.errors.InvalidArgumentError:
                # always a NaN in histogram summary for entropy - what's up with that?
                experiences = []
                continue

            print("Mean KL", np.mean(kls))

            old_len = len(experiences)
            kl_exps = zip(kls, experiences)
            if self.max_buffer and old_len > self.max_buffer:
                kl_exps = list(kl_exps)[-self.max_buffer:]
            if self.max_kl:
                kl_exps = [ke for ke in kl_exps if ke[0] <= self.max_kl]
            kls, experiences = zip(*kl_exps)
            experiences = list(experiences)
            dropped += old_len - len(experiences)

            #print('After train: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
            split('train')

            if self.evolve and sweeps % self.evo_period == 0:
                if self.selection():
                    experiences = []
                self.model.mutation()

            if self.send:
                #self.params_socket.send_string("", zmq.SNDMORE)
                params = self.model.blob()
                blob = pickle.dumps(params)
                #print('After blob: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
                self.params_socket.send(blob)
                #print('After send: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)

            self.save()

            #print('After save: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
            split('save')

            if False:
                after = count_objects()
                print(diff_objects(after, before))
                before = after

            time_avgs = [averages[name].avg for name in times]
            total_time = sum(time_avgs)
            time_avgs = ['%.3f' % (t / total_time) for t in time_avgs]
            print(sweeps, len(experiences), len(new_experiences), dropped, doa,
                  total_time, *time_avgs)
            print('Memory usage: %s (kb)' %
                  resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)

            if self.objgraph:
                import objgraph
                #gc.collect()  # don't care about stuff that would be garbage collected properly
                objgraph.show_growth()
示例#5
0
    def train(self):
        before = count_objects()

        sweeps = 0
        step = 0

        experiences = []

        while sweeps != self.sweep_limit:
            start_time = time.time()

            print('Start: %s' %
                  resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)

            age_limit = self.model.get_global_step() - self.max_age
            is_valid = lambda exp: exp['global_step'] >= age_limit
            experiences = list(filter(is_valid, experiences))

            collected = 0
            while len(experiences) < self.sweep_size:
                exp = self.experience_socket.recv_pyobj()
                if is_valid(exp):
                    experiences.append(exp)
                    collected += 1

            # pull in all the extra experiences
            for _ in range(self.sweep_size):
                try:
                    exp = self.experience_socket.recv_pyobj(zmq.NOBLOCK)
                    if is_valid(exp):
                        experiences.append(exp)
                        collected += 1
                except zmq.ZMQError:
                    break

            print('After collect: %s' %
                  resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
            collect_time = time.time()

            for _ in range(self.sweeps):
                from random import shuffle
                shuffle(experiences)

                batch_size = len(experiences) // self.batches
                for batch in util.chunk(experiences, batch_size):
                    self.model.train(batch,
                                     self.batch_steps,
                                     log=(step % self.log_interval == 0))
                    step += 1

            print('After train: %s' %
                  resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
            train_time = time.time()

            if self.send:
                #self.params_socket.send_string("", zmq.SNDMORE)
                params = self.model.blob()
                blob = pickle.dumps(params)
                print('After blob: %s' %
                      resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
                self.params_socket.send(blob)
                print('After send: %s' %
                      resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)

            self.save()

            #print('After save: %s' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
            save_time = time.time()

            sweeps += 1

            if False:
                after = count_objects()
                print(diff_objects(after, before))
                before = after

            save_time -= train_time
            train_time -= collect_time
            collect_time -= start_time

            print(sweeps, len(experiences), collected, collect_time,
                  train_time, save_time)
            print('Memory usage: %s (kb)' %
                  resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)

            if self.objgraph:
                import objgraph
                #gc.collect()  # don't care about stuff that would be garbage collected properly
                objgraph.show_growth()