コード例 #1
0
def worker_train(ps, replay_buffer, opt, learner_index):
    agent = Learner(opt, job="learner")
    keys = agent.get_weights()[0]
    weights = ray.get(ps.pull.remote(keys))
    agent.set_weights(keys, weights)

    cache = Cache(replay_buffer)

    cache.start()

    cnt = 1
    while True:

        # time1 = time.time()
        batch = cache.q1.get()

        # time2 = time.time()
        # print('cache get time:', time2-time1)
        if opt.model == "cnn":
            batch['obs'] = np.array([[unpack(o) for o in lno] for lno in batch['obs']])
        agent.train(batch, cnt)
        # time3 = time.time()
        # print('agent train time:', time3 - time2)
        # TODO cnt % 300 == 0 before
        if cnt % 100 == 0:
            cache.q2.put(agent.get_weights())
        cnt += 1
コード例 #2
0
ファイル: train.py プロジェクト: XrosLiang/Distributed-DRL
    def __init__(self, opt, weights_file, checkpoint_path, ps_index):
        # each node will have a Parameter Server

        self.opt = opt
        self.learner_step = 0
        net = Learner(opt, job="ps")
        keys, values = net.get_weights()

        # --- make dir for all nodes and save parameters ---
        try:
            os.makedirs(opt.save_dir)
            os.makedirs(opt.save_dir + '/checkpoint')
        except OSError:
            pass
        all_parameters = copy.deepcopy(vars(opt))
        all_parameters["obs_space"] = ""
        all_parameters["act_space"] = ""
        with open(opt.save_dir + "/" + 'All_Parameters.json', 'w') as fp:
            json.dump(all_parameters, fp, indent=4, sort_keys=True)
        # --- end ---

        self.weights = None

        if not checkpoint_path:
            checkpoint_path = opt.save_dir + "/checkpoint"

        if opt.recover:
            with open(checkpoint_path + "/checkpoint_weights.pickle",
                      "rb") as pickle_in:
                self.weights = pickle.load(pickle_in)
                print("****** weights restored! ******")

        if weights_file:
            try:
                with open(weights_file, "rb") as pickle_in:
                    self.weights = pickle.load(pickle_in)
                    print("****** weights restored! ******")
            except:
                print("------------------------------------------------")
                print(weights_file)
                print("------ error: weights file doesn't exist! ------")
                exit()

        if not opt.recover and not weights_file:
            values = [value.copy() for value in values]
            self.weights = dict(zip(keys, values))
コード例 #3
0
ファイル: train.py プロジェクト: XrosLiang/Distributed-DRL
def worker_train(ps, node_buffer, opt, learner_index):
    agent = Learner(opt, job="learner")
    keys = agent.get_weights()[0]
    weights = ray.get(ps.pull.remote(keys))
    agent.set_weights(keys, weights)

    cache = Cache(node_buffer)

    cache.start()

    cnt = 1
    while True:
        batch = cache.q1.get()
        agent.train(batch, cnt)

        if cnt % opt.push_freq == 0:
            cache.q2.put(agent.get_weights())
        cnt += 1
コード例 #4
0
def worker_train(ps, replay_buffer, opt, learner_index):

    agent = Learner(opt, job="learner")
    keys = agent.get_weights()[0]
    weights = ray.get(ps.pull.remote(keys))
    agent.set_weights(keys, weights)

    cache = Cache(replay_buffer)

    cache.start()

    cnt = 1
    while True:
        batch = cache.q1.get()
        agent.train(batch)
        if cnt % 300 == 0:
            # print('q1.qsize():', q1.qsize(), 'q2.qsize():', q2.qsize())
            cache.q2.put(agent.get_weights())
            # keys, values = agent.get_weights()
            # ps.push.remote(copy.deepcopy(keys), copy.deepcopy(values))
        cnt += 1
コード例 #5
0
def worker_train(ps, node_buffer, opt, model_type):
    agent = Learner(opt, model_type)
    weights = ray.get(ps.pull.remote(model_type))
    agent.set_weights(weights)

    cache = Cache(node_buffer)
    cache.start()

    cnt = 1
    while True:
        batch = cache.q1[model_type].start()
        agent.train(batch, cnt)

        if cnt % opt.push_freq == 0:
            cache.q2.put(agent.get_weights)
        cnt += 1
コード例 #6
0
    All_Parameters["obs_space"] = ""
    All_Parameters["act_space"] = ""

    try:
        os.makedirs(opt.save_dir)
    except OSError:
        pass
    with open(opt.save_dir + "/" + 'All_Parameters.json', 'w') as fp:
        json.dump(All_Parameters, fp, indent=4, sort_keys=True)

    # ------ end ------

    if FLAGS.weights_file:
        ps = ParameterServer.remote([], [], weights_file=FLAGS.weights_file)
    else:
        net = Learner(opt, job="main")
        all_keys, all_values = net.get_weights()
        ps = ParameterServer.remote(all_keys, all_values)

    # Experience buffer
    # Methods called on different actors can execute in parallel,
    # and methods called on the same actor are executed serially in the order that they are called.
    # we need more buffer for more workers to keep high store speed.
    replay_buffer = [ReplayBuffer.remote(opt) for i in range(opt.num_buffers)]

    # Start some training tasks.
    for i in range(FLAGS.num_workers):
        worker_rollout.remote(ps, replay_buffer, opt, i)
        time.sleep(0.05)
    # task_rollout = [worker_rollout.remote(ps, replay_buffer, opt, i) for i in range(FLAGS.num_workers)]