Esempio n. 1
0
 def dumpkvs(self):
     if self.comm is None:
         d = self.name2val
     else:
         from baselines.common import mpi_util
         d = mpi_util.mpi_weighted_mean(self.comm,
             {name : (val, self.name2cnt.get(name, 1))
                 for (name, val) in self.name2val.items()})
         if self.comm.rank != 0:
             d['dummy'] = 1 # so we don't get a warning about empty dict
     out = d.copy() # Return the dict for unit testing purposes
     for fmt in self.output_formats:
         if isinstance(fmt, KVWriter):
             fmt.writekvs(d)
     self.name2val.clear()
     self.name2cnt.clear()
     return out
Esempio n. 2
0
def test_mpi_weighted_mean():
    comm = MPI.COMM_WORLD
    with logger.scoped_configure(comm=comm):
        if comm.rank == 0:
            name2valcount = {'a': (10, 2), 'b': (20, 3)}
        elif comm.rank == 1:
            name2valcount = {'a': (19, 1), 'c': (42, 3)}
        else:
            raise NotImplementedError
        d = mpi_util.mpi_weighted_mean(comm, name2valcount)
        correctval = {'a': (10 * 2 + 19) / 3.0, 'b': 20, 'c': 42}
        if comm.rank == 0:
            assert d == correctval, '{} != {}'.format(d, correctval)

        for name, (val, count) in name2valcount.items():
            for _ in range(count):
                logger.logkv_mean(name, val)
        d2 = logger.dumpkvs()
        if comm.rank == 0:
            assert d2 == correctval
Esempio n. 3
0
def test_mpi_weighted_mean():
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    with logger.scoped_configure(comm=comm):
        if comm.rank == 0:
            name2valcount = {'a' : (10, 2), 'b' : (20,3)}
        elif comm.rank == 1:
            name2valcount = {'a' : (19, 1), 'c' : (42,3)}
        else:
            raise NotImplementedError

        d = mpi_util.mpi_weighted_mean(comm, name2valcount)
        correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
        if comm.rank == 0:
            assert d == correctval, '{} != {}'.format(d, correctval)

        for name, (val, count) in name2valcount.items():
            for _ in range(count):
                logger.logkv_mean(name, val)
        d2 = logger.dumpkvs()
        if comm.rank == 0:
            assert d2 == correctval