def dumpkvs(self): if self.comm is None: d = self.name2val else: from baselines.common import mpi_util d = mpi_util.mpi_weighted_mean(self.comm, {name : (val, self.name2cnt.get(name, 1)) for (name, val) in self.name2val.items()}) if self.comm.rank != 0: d['dummy'] = 1 # so we don't get a warning about empty dict out = d.copy() # Return the dict for unit testing purposes for fmt in self.output_formats: if isinstance(fmt, KVWriter): fmt.writekvs(d) self.name2val.clear() self.name2cnt.clear() return out
def test_mpi_weighted_mean(): comm = MPI.COMM_WORLD with logger.scoped_configure(comm=comm): if comm.rank == 0: name2valcount = {'a': (10, 2), 'b': (20, 3)} elif comm.rank == 1: name2valcount = {'a': (19, 1), 'c': (42, 3)} else: raise NotImplementedError d = mpi_util.mpi_weighted_mean(comm, name2valcount) correctval = {'a': (10 * 2 + 19) / 3.0, 'b': 20, 'c': 42} if comm.rank == 0: assert d == correctval, '{} != {}'.format(d, correctval) for name, (val, count) in name2valcount.items(): for _ in range(count): logger.logkv_mean(name, val) d2 = logger.dumpkvs() if comm.rank == 0: assert d2 == correctval
def test_mpi_weighted_mean(): from mpi4py import MPI comm = MPI.COMM_WORLD with logger.scoped_configure(comm=comm): if comm.rank == 0: name2valcount = {'a' : (10, 2), 'b' : (20,3)} elif comm.rank == 1: name2valcount = {'a' : (19, 1), 'c' : (42,3)} else: raise NotImplementedError d = mpi_util.mpi_weighted_mean(comm, name2valcount) correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42} if comm.rank == 0: assert d == correctval, '{} != {}'.format(d, correctval) for name, (val, count) in name2valcount.items(): for _ in range(count): logger.logkv_mean(name, val) d2 = logger.dumpkvs() if comm.rank == 0: assert d2 == correctval