def __init__(self,
                 rank: int,
                 num_callees: int = 1,
                 num_callers: int = 1,
                 threads_process: int = 1,
                 caller_class: object = None,
                 caller_args=None,
                 future_keys: list = None):

        # ASSERTIONS
        assert num_callees > 0
        assert num_callers > 0

        # caller_class must be given
        assert caller_class is not None

        # callee_rref is correct subclass
        # use import here to omit circular import
        # pylint: disable=import-outside-toplevel
        from ..agents.rpc_caller import RpcCaller
        assert issubclass(caller_class, RpcCaller)
        assert isinstance(future_keys, list)

        # ATTRIBUTES

        # RPC
        self.rank = rank
        # pylint: disable=invalid-name
        self.id = rpc.get_worker_info().id
        self.name = rpc.get_worker_info().name
        self.rref = RRef(self)

        self.shutdown = False
        self._shutdown_done = False

        # COUNTERS
        self._t_start = time.time()
        self._loop_iteration = 0

        # STORAGE
        self._caller_rrefs = []
        self._pending_rpcs = deque()
        self._future_answers = {k: Future() for k in future_keys}
        self._current_futures = deque(maxlen=len(future_keys))

        # THREADS
        self.lock_batching = mp.Lock()
        self._processing_threads = [
            Thread(target=self._process_batch,
                   daemon=True,
                   name='processing_thread_%d' % i)
            for i in range(threads_process)
        ]

        for thread in self._processing_threads:
            thread.start()

        # spawn actors
        self._spawn_callers(caller_class, num_callees, num_callers,
                            *caller_args)
Beispiel #2
0
    def __init__(self, batch_size, batch, state_size, nlayers, out_features):
        r"""
        Coordinator object to run on worker.  Only one coordinator exists.  Responsible
        for facilitating communication between agent and observers and recording benchmark
        throughput and latency data.
        Args:
            batch_size (int): Number of observer requests to process in a batch
            batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time
            state_size (list): List of ints dictating the dimensions of the state
            nlayers (int): Number of layers in the model
            out_features (int): Number of out features in the model
        """
        self.batch_size = batch_size
        self.batch = batch

        self.agent_rref = None  # Agent RRef
        self.ob_rrefs = []  # Observer RRef

        agent_info = rpc.get_worker_info(AGENT_NAME)
        self.agent_rref = rpc.remote(agent_info, AgentBase)

        for rank in range(batch_size):
            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2))
            ob_ref = rpc.remote(ob_info, ObserverBase)
            self.ob_rrefs.append(ob_ref)

            ob_ref.rpc_sync().set_state(state_size, batch)

        self.agent_rref.rpc_sync().set_world(batch_size, state_size, nlayers,
                                             out_features, self.batch)
Beispiel #3
0
    def test_worker_id(self):
        n = self.rank + 1
        peer_rank = n % self.world_size
        self_worker_info = rpc.get_worker_info()
        peer_worker_info = rpc.get_worker_info("worker{}".format(peer_rank))

        self.assertEqual(self_worker_info.name, "worker{}".format(self.rank))
        self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank))

        with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
            unknown_worker_id = rpc.get_worker_info("WorkerUnknown")
Beispiel #4
0
    def __init__(self, world_size, batch=True):
        self.ob_rrefs = []
        self.agent_rref = RRef(self)
        self.rewards = {}
        self.policy = Policy(batch).cuda()
        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
        self.running_reward = 0

        for ob_rank in range(1, world_size):
            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
            self.ob_rrefs.append(remote(ob_info, Observer, args=(batch, )))
            self.rewards[ob_info.id] = []

        self.states = torch.zeros(len(self.ob_rrefs), 1, 4)
        self.batch = batch
        # With batching, saved_log_probs contains a list of tensors, where each
        # tensor contains probs from all observers in one step.
        # Without batching, saved_log_probs is a dictionary where the key is the
        # observer id and the value is a list of probs for that observer.
        self.saved_log_probs = [] if self.batch else {
            k: []
            for k in range(len(self.ob_rrefs))
        }
        self.future_actions = torch.futures.Future()
        self.lock = threading.Lock()
        self.pending_states = len(self.ob_rrefs)
Beispiel #5
0
 def _test_self_remote_rref_as_rpc_arg(self, dst):
     self_worker_info = rpc.get_worker_info()
     rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
     fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
     ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, torch.ones(2, 2) + 1))
     self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1)
     self.assertEqual(fut.wait(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
    def load_new_data(self, loader, kwargs):
        print(rpc.get_worker_info().name + ": Reloading %d - %d" %
              (kwargs['start'], kwargs['end']))

        jobs = kwargs.pop('jobs')
        self.module.data = loader(jobs, **kwargs)
        return True
Beispiel #7
0
    def test_owner_equality(self):
        a = RRef(40)
        b = RRef(50)

        other_rank = (self.rank + 1) % self.world_size
        other_a = rpc.remote("worker{}".format(other_rank),
                             torch.add,
                             args=(torch.ones(1), 1))
        other_b = rpc.remote("worker{}".format(other_rank),
                             torch.add,
                             args=(torch.ones(1), 1))
        other_a.to_here()  # to ensure clean termination
        other_b.to_here()

        self.assertNotEqual(a.owner(), 23)
        self.assertEqual(other_a.owner(), other_b.owner())
        self.assertNotEqual(a.owner(), other_a.owner())
        self.assertEqual(other_a.owner(), other_a.owner())
        self.assertEqual(other_a.owner(), other_b.owner())
        self.assertEqual(a.owner(), a.owner())
        self.assertEqual(a.owner(), b.owner())
        self.assertEqual(a.owner(), rpc.get_worker_info())
        x = dict()
        x[a.owner()] = a
        x[other_a.owner()] = other_a
        self.assertEqual(x[a.owner()], a)
        self.assertEqual(x[b.owner()], a)
        self.assertEqual(x[other_a.owner()], other_a)
        self.assertEqual(x[other_b.owner()], other_a)
        self.assertEqual(len(x), 2)
    def _spawn_callers(self, caller_class: object, num_callees: int,
                       num_callers: int, *args):
        """Spawns instances of :py:attr:`caller_class` on RPC workers.

        Parameters
        ----------
        caller_class: Child class of :py:class:`~.RpcCaller`
            Class used to spawn callers.
        num_callees: `int`
            Number of total callees spawned by mother process.
        num_callers: `int`
            Number of total callers to spawn.
        *args:
            Arguments to pass to :py:attr:`caller_class`.
        """
        for i in range(num_callers):
            rank = i + num_callees
            callers_info = rpc.get_worker_info("actor%d" % (rank))

            # Store RRef of spawned caller
            self._caller_rrefs.append(
                rpc.remote(callers_info,
                           caller_class,
                           args=(rank, self.rref, *args)))
        print("{} callers spawned, awaiting start.".format(num_callers))
Beispiel #9
0
 def test_self_add(self):
     self_worker_info = rpc.get_worker_info()
     self_worker_name = "worker{}".format(self.rank)
     fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
     ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
     self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
     self.assertEqual(ret, torch.ones(2, 2) + 1)
Beispiel #10
0
    def train(self):

        name = rpc.get_worker_info().name
        self.m = self.ps_rref.rpc_sync().get_model().to(self.DEVICE)


        #now we compute the gradient based on the model m
        #we play one episode of the environment
        self.env = stock_env.trading_spy(C.max_simulation_length,C.min_history_length,C.max_position,C.init_cash_value)
        #self.env = gym.make('CartPole-v1')

        for loss, sum_of_rewards in self.get_next_batch(self.env):
            
            #utils.timed_log(f"reward is {sum_of_rewards}")

            loss.backward()
            #utils.timed_log(f"{name} reporting grads")

            self.m = rpc.rpc_sync(
                self.ps_rref.owner(),
                bups.BatchUpdateParameterServer.update_and_fetch_model,
                args=(self.ps_rref, [p.grad for p in self.m.cpu().parameters()],sum_of_rewards),
            ).to(self.DEVICE)

            #utils.timed_log(f"{name} got updated model")
Beispiel #11
0
    def set_world(self, batch_size, state_size, nlayers, out_features, batch=True):
        r"""
        Further initializes agent to be aware of rpc environment
        Args:
            batch_size (int): size of batches of observer requests to process
            state_size (list): List of ints dictating the dimensions of the state
            nlayers (int): Number of layers in the model
            out_features (int): Number of out features in the model
            batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time
        """
        self.batch = batch
        self.policy = Policy(reduce((lambda x, y: x * y), state_size), nlayers, out_features)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)

        self.batch_size = batch_size
        for rank in range(batch_size):
            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2))

            self.rewards[ob_info.id] = []

        self.saved_log_probs = [] if self.batch else {
            k: [] for k in range(self.batch_size)}

        self.pending_states = self.batch_size
        self.state_size = state_size
        self.states = torch.zeros(self.batch_size, *state_size)
Beispiel #12
0
    def calc_loss(self, zs, nratio, pred):
        # Uses masked val edges if module is set to eval()
        if self.module.training:
            partition = self.module.data.tr
        else:
            partition = self.module.data.va

        # Generate negative edges
        p, n, z = g.link_prediction(self.module.data,
                                    partition,
                                    zs,
                                    include_tr=not self.module.training,
                                    nratio=nratio)

        if pred:
            p, n, z = p[1:], n[1:], z[:-1]

        T = len(z)

        # Edge case for if each worker only has 1
        # Can only happen on the last worker if each worker has only
        # 1 delta. All others have overlap built in to prevent this
        if T == 0:
            print("%s returning null loss" % rpc.get_worker_info().name)
            if self.module.training:
                return torch.zeros(0)
            else:
                return torch.zeros(0), torch.zeros(0)

        p_scores = []
        n_scores = []

        for i in range(T):
            p_scores.append(self.decode(p[i], z[i]))
            n_scores.append(self.decode(n[i], z[i]))

        p_scores = torch.cat(p_scores, dim=0)
        n_scores = torch.cat(n_scores, dim=0)

        if self.module.training:
            loss = self.nll(p_scores, n_scores)
            print("%s returning loss" % rpc.get_worker_info().name)
            return loss

        else:
            return p_scores, n_scores
Beispiel #13
0
 def test_py_multi_async_call(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     dst_worker_info = rpc.get_worker_info("worker{}".format(dst_rank))
     fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
     fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
     self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
     self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
Beispiel #14
0
    def _init_rpc(self):
        # Validate PG and RPC ranks match.
        pg_rank = dist.get_rank()
        rpc_rank = rpc.get_worker_info().id
        if pg_rank != rpc_rank:
            raise ValueError(
                f'Default ProcessGroup and RPC ranks must be '
                f'the same for ShardedTensor, found process group rank: '
                f'{pg_rank} and RPC rank: {rpc_rank}')

        self._remote_shards = {}

        # Gather all the sharded tensor ids.
        world_size = dist.get_world_size(self._process_group)
        worker_infos = rpc._get_current_rpc_agent().get_worker_infos()
        rank_to_name = {}
        name_to_rank = {}

        for worker_info in worker_infos:
            rank_to_name[worker_info.id] = worker_info.name
            name_to_rank[worker_info.name] = worker_info.id

        all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id)

        # Share the local shards to the entire world.
        futs = []
        rpc_rank = rpc.get_worker_info().id
        for rank in range(dist.get_world_size()):
            # Skip self.
            if rank == dist.get_rank():
                continue

            if len(self.local_shards()) != 0:
                rrefs: List[rpc.RRef[Shard]] = [
                    rpc.RRef(shard) for shard in self.local_shards()
                ]
                fut = rpc.rpc_async(rank,
                                    _register_remote_shards,
                                    args=(all_tensor_ids[rank_to_name[rank]],
                                          rrefs, rpc_rank))
                futs.append(fut)

        torch.futures.wait_all(futs)

        # Barrier for all RPCs to finish on all ranks.
        rpc.api._all_gather(None)
Beispiel #15
0
    def test_add_with_id(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        workder_info = rpc.get_worker_info("worker{}".format(dst_rank))

        ret = rpc.rpc_sync(workder_info,
                           torch.add,
                           args=(torch.ones(n, n), torch.ones(n, n)))
        self.assertEqual(ret, torch.ones(n, n) * 2)
Beispiel #16
0
 def _test_self_remote_rref_as_remote_arg(self, dst):
     self_worker_info = rpc.get_worker_info()
     rref = rpc.remote(self_worker_info,
                       my_function,
                       args=(torch.ones(2, 2), 1, 3))
     ret_rref = rpc.remote(dst,
                           add_rref_to_value,
                           args=(rref, torch.ones(2, 2)))
     self.assertEqual(ret_rref.to_here(),
                      torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
Beispiel #17
0
    def _init_rpc(self):
        self._rpc_initialized = True
        self._remote_shards = {}

        # Gather all the sharded tensor ids.
        world_size = dist.get_world_size(self._process_group)
        worker_infos = rpc._get_current_rpc_agent().get_worker_infos()
        rank_to_name = {}
        name_to_rank = {}

        for worker_info in worker_infos:
            rank_to_name[worker_info.id] = worker_info.name
            name_to_rank[worker_info.name] = worker_info.id

        rpc_workers = set()
        for rank in range(world_size):
            if self._process_group == distributed_c10d._get_default_group():
                global_rank = rank
            else:
                global_rank = distributed_c10d._get_global_rank(
                    self._process_group, rank)
            rpc_workers.add(rank_to_name[global_rank])

        all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id,
                                             rpc_workers)

        # Share the local shards to the entire world.
        futs = []
        rpc_rank = rpc.get_worker_info().id
        for rank in range(world_size):
            # Skip self.
            if rank == dist.get_rank(self._process_group):
                continue

            if self._process_group == distributed_c10d._get_default_group():
                global_rank = rank
            else:
                global_rank = distributed_c10d._get_global_rank(
                    self._process_group, rank)

            if len(self.local_shards()) != 0:
                rrefs: List[rpc.RRef[Shard]] = [
                    rpc.RRef(shard) for shard in self.local_shards()
                ]
                fut = rpc.rpc_async(
                    global_rank,
                    _register_remote_shards,
                    args=(all_tensor_ids[rank_to_name[global_rank]], rrefs,
                          rpc_rank))
                futs.append(fut)

        torch.futures.wait_all(futs)

        # Barrier for all RPCs to finish on all ranks.
        rpc.api._barrier(rpc_workers)
Beispiel #18
0
    def __init__(self, rank: int, callee_rref: rpc.RRef):
        # ASSERTIONS
        # check for RpcCallee being inherited by callee_rref
        # use import here to omit circular import
        from ..agents.rpc_callee import RpcCallee
        assert issubclass(callee_rref._get_type(), RpcCallee)

        # ATTRIBUTES

        # RPC
        self.callee_rref = callee_rref
        self.rank = rank
        # pylint: disable=invalid-name
        self.id = rpc.get_worker_info().id
        self.name = rpc.get_worker_info().name

        # COUNTER
        self._loop_iteration = 0

        self.shutdown = False
Beispiel #19
0
    def __init__(self, h_dim, loader, load_args, tail=False):
        print(rpc.get_worker_info().name + ": Loading ts from %d to %d" %
              (load_args['start'], load_args['end']))

        jobs = load_args.pop('jobs')
        data = loader(jobs, **load_args)
        super(R_GAE, self).__init__(data.x.size(1), h_dim, h_dim)

        self.data = data
        self.x_dim = data.x.size(1)
        self.tail = tail
Beispiel #20
0
    def decode_all(self, zs):
        '''
        Given node embeddings, return edge likelihoods for 
        all subgraphs held by this model

        For static model, it's very simple. Just return the embeddings
        for ei[n] given zs[n]
        '''
        assert not zs.size(0) < self.module.data.T, \
            "%s was given fewer embeddings than it has time slices"\
            % rpc.get_worker_info().name

        assert not zs.size(0) > self.module.data.T, \
            "%s was given more embeddings than it has time slices"\
            % rpc.get_worker_info().name

        preds = []
        for i in range(self.module.data.T):
            preds.append(self.decode(self.module.data.eis[i], zs[i]))

        return preds
Beispiel #21
0
    def forward(self, xs):
        x_futs = []
        for i in range(xs.size(0)):
            x_futs.append(
                _remote_method_async(DDP.forward,
                                     self.remote_embs[i % self.n_workers],
                                     xs[i]))

        xs = torch.stack([f.wait() for f in x_futs])
        print(rpc.get_worker_info().name + ' running RNN')
        h = self.rnn(xs)[1]
        return self.out(h).squeeze(0)
Beispiel #22
0
 def train(self):
     name = rpc.get_worker_info().name
     m = self.ps_rref.rpc_sync().get_model()
     for inputs, labels in self.get_next_batch():
         timed_log(f"{name} processing one batch")
         self.loss_fn(m(inputs), labels).backward()
         timed_log(f"{name} reporting grads")
         m = rpc.rpc_sync(
             self.ps_rref.owner(),
             BatchUpdateParameterServer.update_and_fetch_model,
             args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
         )
         timed_log(f"{name} got updated model")
Beispiel #23
0
    def test_self_add(self):
        self_worker_info = rpc.get_worker_info()
        self_worker_name = "worker{}".format(self.rank)

        with self.assertRaisesRegex(
            RuntimeError, "does not support making RPC calls to self"
        ):
            rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))

        with self.assertRaisesRegex(
            RuntimeError, "does not support making RPC calls to self"
        ):
            rpc.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
Beispiel #24
0
    def forward(self, *inputs: Tensor) -> rpc.RRef:  # type: ignore
        for i, input in enumerate(inputs):
            microbatch.check(input)

        # Divide a mini-batch into micro-batches.
        batches_list = [
            microbatch.scatter(input, self.chunks) for input in inputs
        ]

        # Create a DistributedPipelineRecord, one per partition, and make connections between them (i.e.
        # set list of consumers).
        pipeline_records: Dict[DistributedPipeline.Partition, rpc.RRef] = {}
        for partition in reversed(self.partitions):
            r_handler = partition.handler.remote()
            consumers = []
            # Identify consumers of the outputs of the partition
            for consumer in partition.nodes[-1].output_consumers:
                consumer_partition = next(p for p in self.partitions
                                          if p.nodes[0] is consumer.consumer)
                # Index of a consumer partition should be greater than index of the partition.
                assert consumer_partition in pipeline_records
                consumers.append(
                    DistributedPipelineRecord.DataConsumer(
                        pipeline_records[consumer_partition],
                        consumer.consumer_input_idx, consumer.output_idx))
            pipeline_records[partition] = r_handler.make_pipeline_record(
                consumers)
            # Let the pipeline-handler for the partition starts processing the pipeline-record for that partition.
            this_result = r_handler.run_pipeline(pipeline_records[partition])
            # If this is the last partition, we expect the result of the model be the output of this partition.
            if partition is self.partitions[-1]:
                result = this_result

        # Start feeding model input to the partitions that need them.
        for i, b in enumerate(zip(*batches_list)):
            for input_consumer in self.input_consumers:
                pipeline_record = pipeline_records[input_consumer.consumer]
                # TODO: Debug why we need this special handling
                if pipeline_record.owner().name == rpc.get_worker_info(
                ).name:  # type: ignore
                    pipeline_record.local_value().feed(
                        i, input_consumer.consumer_input_idx,
                        b[input_consumer.output_idx].value)
                else:
                    pipeline_record.rpc_async().feed(
                        i, input_consumer.consumer_input_idx,
                        b[input_consumer.output_idx].value)  # type: ignore

        return result
Beispiel #25
0
 def __init__(self, world_size):
     self.ob_rrefs = []
     self.agent_rref = RRef(self)
     self.rewards = {}
     self.saved_log_probs = {}
     self.policy = Policy()
     self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
     self.eps = numpy.finfo(numpy.float32).eps.item()
     self.running_reward = 0
     self.reward_threshold = gym.make(ENV).spec.reward_threshold
     for ob_rank in range(1, world_size):
         ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
         self.ob_rrefs.append(remote(ob_info, Observer))
         self.rewards[ob_info.id] = []
         self.saved_log_probs[ob_info.id] = []
 def __init__(self, world_size):
     self.ob_rrefs = []
     self.agent_rref = RRef(self)
     self.rewards = {}
     self.saved_log_probs = {}
     self.policy = Policy()
     self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
     self.eps = np.finfo(np.float32).eps.item()
     self.running_reward = 0
     self.reward_threshold = DummyEnv().reward_threshold
     for ob_rank in range(1, world_size):
         ob_info = rpc.get_worker_info(worker_name(ob_rank))
         self.ob_rrefs.append(remote(ob_info, Observer))
         self.rewards[ob_info.id] = []
         self.saved_log_probs[ob_info.id] = []
Beispiel #27
0
    def __init__(self):
        r"""
        Inits agent class
        """
        self.id = rpc.get_worker_info().id
        self.running_reward = 0
        self.eps = 1e-7

        self.rewards = {}

        self.future_actions = torch.futures.Future()
        self.lock = threading.Lock()

        self.agent_latency_start = None
        self.agent_latency_end = None
        self.agent_latency = []
        self.agent_throughput = []
Beispiel #28
0
 def __init__(self, config, world_size):
     self.e = 0
     self.config = config.config_NeuralPlayer
     self.preprocessor = None
     self._init_dataset(self.config.config_Datasets)
     self._init_agent(self.config.config_Agent)
     self.agent_rref = RRef(self.agent)
     self.world_size = world_size  #nb of remote agents
     self.worker_rrefs = []
     self.data_gatherer = ScoreDataGatherer()
     for worker_rank in range(1, self.world_size):
         worker_info = rpc.get_worker_info(f"worker{worker_rank}")
         self.worker_rrefs.append(
             remote(worker_info,
                    CentralAgentWorker,
                    args=(config, worker_rank),
                    timeout=600))
Beispiel #29
0
    def __init__(self, data_load, data_kws, h_dim, z_dim):
        super(GCN, self).__init__()

        # Load in the data before initing params
        # Note: passing None as the start or end data_kw skips the
        # actual loading part, and just pulls the x-dim
        print("%s loading %s-%s" %
              (rpc.get_worker_info().name, str(
                  data_kws['start']), str(data_kws['end'])))

        self.data = data_load(data_kws.pop("jobs"), **data_kws)

        # Params
        self.c1 = GCNConv(self.data.x_dim, h_dim, add_self_loops=True)
        self.relu = nn.ReLU()
        self.c2 = GCNConv(h_dim, z_dim, add_self_loops=True)
        self.drop = nn.Dropout(0.25)
        self.tanh = nn.Tanh()
Beispiel #30
0
    def __init__(self, world_size, log_interval, save_dir):
        env = create_env("SuperMarioBros-1-1-v0")

        self.logger = MetricLogger(save_dir)
        self.agent = MarioAgent(state_dim=(4, 84, 84),
                                action_dim=env.action_space.n,
                                save_dir=save_dir)
        self.learner_rref = RRef(self)
        self.actor_rrefs = []

        for actor_rank in range(1, world_size):
            actor_info = rpc.get_worker_info(ACTOR_NAME.format(actor_rank))
            self.actor_rrefs.append(
                remote(actor_info, Actor, args=(actor_rank, )))

        self.update_lock = threading.Lock()
        self.episode_lock = threading.Lock()
        self.episode = 0

        self.log_interval = log_interval