Ejemplo n.º 1
0
    def __init__(self,
                 channels=12,
                 local_steps=8,
                 Tf=1.0,
                 serial_nn=None,
                 open_nn=None,
                 close_nn=None):
        super(SerialNet, self).__init__()

        if open_nn is None:
            self.open_nn = OpenFlatLayer(channels)
        else:
            self.open_nn = open_nn

        if serial_nn is None:
            step_layer = lambda: StepLayer(channels)
            parallel_nn = torchbraid.LayerParallel(MPI.COMM_SELF,
                                                   step_layer,
                                                   local_steps,
                                                   Tf,
                                                   max_levels=1,
                                                   max_iters=1)
            parallel_nn.setPrintLevel(0)

            self.serial_nn = parallel_nn.buildSequentialOnRoot()
        else:
            self.serial_nn = serial_nn

        if close_nn is None:
            self.close_nn = CloseLayer(channels)
        else:
            self.close_nn = close_nn
Ejemplo n.º 2
0
    def __init__(self,
                 channels=4,
                 local_steps=2,
                 Tf=1.0,
                 max_levels=1,
                 max_iters=1,
                 print_level=0):
        super(ParallelNet, self).__init__()

        self.rank = MPI.COMM_WORLD.Get_rank()

        self.channels = channels
        step_layer = lambda: StepLayer(channels)

        self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD,
                                                    step_layer,
                                                    local_steps,
                                                    Tf,
                                                    max_levels=max_levels,
                                                    max_iters=max_iters)
        self.parallel_nn.setPrintLevel(print_level)
        self.parallel_nn.setCFactor(4)
        self.o = self.parallel_nn.comp_op(
        )  # get tool to build up composition neural networks

        # in this case, because OpenLayer/CloseLayer are classes, these return None on processors
        # away from rank==0...this might be too cute
        self.open_nn = self.o(OpenLayer, channels)
        self.close_nn = self.o(CloseLayer, channels)
Ejemplo n.º 3
0
    def __init__(self, Tstop=10.0, width=4, local_steps=10, max_levels=1, max_iters=1, fwd_max_iters=0, print_level=0, braid_print_level=0, cfactor=4, fine_fcf=False, skip_downcycle=True, fmg=False):
        super(ParallelNet, self).__init__()

        step_layer = lambda: StepLayer(width)

        # Create and store parallel net
        self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD, step_layer, local_steps, Tstop, max_levels=max_levels, max_iters=max_iters)

        # Set options
        if fwd_max_iters > 0:
            # print("FWD_max_iter = ", fwd_max_iters)
            self.parallel_nn.setFwdMaxIters(fwd_max_iters)
        self.parallel_nn.setPrintLevel(print_level,True)
        self.parallel_nn.setPrintLevel(braid_print_level,False)
        self.parallel_nn.setCFactor(cfactor)
        self.parallel_nn.setSkipDowncycle(skip_downcycle)

        if fmg:
            self.parallel_nn.setFMG()
        self.parallel_nn.setNumRelax(1)         # FCF elsewehre
        if not fine_fcf:
            self.parallel_nn.setNumRelax(0,level=0) # F-Relaxation on the fine grid
        else:
            self.parallel_nn.setNumRelax(1,level=0) # F-Relaxation on the fine grid

        # this object ensures that only the LayerParallel code runs on ranks!=0
        compose = self.compose = self.parallel_nn.comp_op()

        # by passing this through 'compose' (mean composition: e.g. OpenFlatLayer o channels)
        # on processors not equal to 0, these will be None (there are no parameters to train there)
        self.openlayer = compose(OpenLayer,width)
        self.closinglayer = compose(ClosingLayer)
Ejemplo n.º 4
0
def main():
  comm = MPI.COMM_WORLD

  local_steps = 3
  levels = 2

  step_layer = lambda: StepLayer(channels=2)

  parallel_nn = torchbraid.LayerParallel(comm,step_layer,local_steps,Tf=comm.Get_size()*local_steps,max_levels=levels,max_iters=1)
  parallel_nn.setPrintLevel(0)

  fwd_lower,fwd_upper = parallel_nn.fwd_app.getStepBounds()
  bwd_lower,bwd_upper = parallel_nn.bwd_app.getStepBounds()


  parallel_nn.train()

  x = torch.rand(10,2,9,9)
  root_print(comm.Get_rank(),'FORWARD')
  root_print(comm.Get_rank(),'======================')
  y = parallel_nn(x)

  print('  %d) fwd lower,upper = ' % comm.Get_rank(),fwd_lower,fwd_upper)
  sys.stdout.flush()
  comm.barrier()

  root_print(comm.Get_rank(),'\n\nBACKWARD')
  root_print(comm.Get_rank(),'======================')
  y.backward(torch.ones(y.shape))

  print('  %d) bwd lower,upper = ' % comm.Get_rank(),bwd_lower,bwd_upper)
  sys.stdout.flush()
  comm.barrier()
Ejemplo n.º 5
0
    def __init__(self,
                 out_channels=12,
                 local_steps=8,
                 Tf=1.0,
                 max_levels=1,
                 max_iters=1,
                 print_level=0):
        super(ParallelNet, self).__init__()

        step_layer = lambda: StepLayer(out_channels)

        # print("Rank in parallel net ", MPI.COMM_WORLD.Get_rank())
        self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD,
                                                    step_layer,
                                                    local_steps,
                                                    Tf,
                                                    max_levels=max_levels,
                                                    max_iters=max_iters)
        self.parallel_nn.setPrintLevel(print_level)
        self.parallel_nn.setCFactor(4)
        self.parallel_nn.setSkipDowncycle(True)
        self.parallel_nn.setNumRelax(1)  # FCF elsewehre
        self.parallel_nn.setNumRelax(0,
                                     level=0)  # F-Relaxation on the fine grid

        # this object ensures that only the LayerParallel code runs on ranks!=0
        compose = self.compose = self.parallel_nn.comp_op()

        # by passing this through 'compose' (mean composition: e.g. OpenLayer o channels)
        # on processors not equal to 0, these will be None (there are no parameters to train there)
        self.open_nn = compose(OpenLayer, out_channels)
        self.close_nn = compose(CloseLayer, out_channels)
Ejemplo n.º 6
0
    def __init__(self, channels=12, local_steps=8, Tf=1.0):
        super(SerialNet, self).__init__()

        step_layer = lambda: StepLayer(channels)

        self.open_nn = OpenLayer(channels)
        self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD,
                                                    step_layer,
                                                    local_steps,
                                                    Tf,
                                                    max_levels=1,
                                                    max_iters=1)
        self.parallel_nn.setPrintLevel(0)

        self.serial_nn = self.parallel_nn.buildSequentialOnRoot()
        self.close_nn = CloseLayer(channels)
Ejemplo n.º 7
0
    def __init__(self,
                 channels=12,
                 local_steps=8,
                 Tf=1.0,
                 max_levels=1,
                 max_iters=1,
                 print_level=0,
                 cfactor=4,
                 fine_fcf=False,
                 skip_downcycle=True,
                 fmg=False):
        super(ParallelNet, self).__init__()

        step_layer = lambda: StepLayer(channels)

        self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD,
                                                    step_layer,
                                                    local_steps,
                                                    Tf,
                                                    max_levels=max_levels,
                                                    max_iters=max_iters)
        self.parallel_nn.setPrintLevel(print_level)
        self.parallel_nn.setCFactor(cfactor)
        self.parallel_nn.setSkipDowncycle(skip_downcycle)

        if fmg:
            self.parallel_nn.setFMG()
        self.parallel_nn.setNumRelax(1)  # FCF elsewehre
        if not fine_fcf:
            self.parallel_nn.setNumRelax(
                0, level=0)  # F-Relaxation on the fine grid
        else:
            self.parallel_nn.setNumRelax(
                1, level=0)  # F-Relaxation on the fine grid

        # this object ensures that only the LayerParallel code runs on ranks!=0
        compose = self.compose = self.parallel_nn.comp_op()

        # by passing this through 'compose' (mean composition: e.g. OpenLayer o channels)
        # on processors not equal to 0, these will be None (there are no parameters to train there)
        self.open_nn = compose(OpenLayer, channels)
        self.close_nn = compose(CloseLayer, channels)
Ejemplo n.º 8
0
    def backForwardProp(self,
                        dim,
                        basic_block,
                        x0,
                        w0,
                        max_levels,
                        max_iters,
                        test_tol,
                        prefix,
                        ref_pair=None):
        Tf = 2.0
        num_steps = 4

        # this is the torchbraid class being tested
        #######################################
        m = torchbraid.LayerParallel(MPI.COMM_WORLD,
                                     basic_block,
                                     num_steps,
                                     Tf,
                                     max_levels=max_levels,
                                     max_iters=max_iters,
                                     spatial_ref_pair=ref_pair)
        m.setPrintLevel(0)

        w0 = m.copyVectorFromRoot(w0)

        # this is the reference torch "solution"
        #######################################
        dt = Tf / num_steps
        f = m.buildSequentialOnRoot()

        # run forward/backward propgation
        lr = 1e-3

        # propogation with torchbraid
        #######################################
        xm = x0.clone()
        xm.requires_grad = True

        wm = m(xm)
        wm.backward(w0)
        wm0 = m.getFinalOnRoot(wm)

        with torch.no_grad():
            for p in m.parameters():
                p -= p.grad * lr
        m.zero_grad()

        if xm.grad is not None:
            xm.grad.zero_()
        wm = m(xm)
        wm.backward(w0)

        m_param_grad = self.copyParameterGradToRoot(m)
        wm = m.getFinalOnRoot(wm)

        # print time results
        timer_str = m.getTimersString()

        # check some values
        if m.getMPIComm().Get_rank() == 0:

            # this is too much to print out every test run, but I'd like to make sure the
            # code is execueted
            self.assertTrue(len(timer_str) > 0)

            compute_grad = True

            # propogation with torch
            #######################################
            xf = x0.clone()
            xf.requires_grad = compute_grad

            wf = f(xf)
            wf.backward(w0)

            with torch.no_grad():
                for p in f.parameters():
                    p -= p.grad * lr
                f.zero_grad()

            xf.grad.zero_()
            wf = f(xf)
            wf.backward(w0)

            # compare the solutions
            #######################################

            self.assertTrue(torch.norm(wm) > 0.0)
            self.assertTrue(torch.norm(wf) > 0.0)

            print('\n')
            print('%s: fwd error = %.6e (%.6e, %.6e)' %
                  (prefix, torch.norm(wm - wf) / torch.norm(wf),
                   torch.norm(wf), torch.norm(wm)))
            print('%s: grad error = %.6e (%.6e, %.6e)' %
                  (prefix, torch.norm(xm.grad - xf.grad) / torch.norm(xf.grad),
                   torch.norm(xf.grad), torch.norm(xm.grad)))

            self.assertTrue(torch.norm(wm - wf) / torch.norm(wf) <= test_tol)
            self.assertTrue((torch.norm(xm.grad - xf.grad) /
                             torch.norm(xf.grad)) <= test_tol)

            param_errors = []
            for pf, pm_grad in zip(list(f.parameters()), m_param_grad):
                self.assertTrue(not pm_grad is None)

                # accumulate parameter errors for testing purposes
                param_errors += [(torch.norm(pf.grad - pm_grad) /
                                  (1e-15 + torch.norm(pf.grad))).item()]

                # check the error conditions
                #print('%s: p_grad error = %.6e (%.6e %.6e)' % (prefix,torch.norm(pf.grad-pm_grad),torch.norm(pf.grad),torch.norm(pm_grad)))
                #sys.stdout.flush()
                self.assertTrue(torch.norm(pf.grad - pm_grad) <= test_tol)

            if len(param_errors) > 0:
                print('%s: p grad error (mean,stddev) = %.6e, %.6e' %
                      (prefix, stats.mean(param_errors),
                       stats.stdev(param_errors)))

            print('\n')
Ejemplo n.º 9
0
    layers = [basic_block() for i in range(num_steps)]
    serial_nn = torch.nn.Sequential(*layers)

    t0_fwd_parallel = time.time()
    y_fwd_serial = serial_nn(x)
    tf_fwd_parallel = time.time()

    t0_bwd_parallel = time.time()
    y_fwd_serial.backward(w)
    tf_bwd_parallel = time.time()
else:
    root_print(my_rank, 'Running TorchBraid: %d' % comm.Get_size())
    # build the parallel neural network
    parallel_nn = torchbraid.LayerParallel(comm,
                                           basic_block,
                                           local_num_steps,
                                           Tf,
                                           max_levels=max_levels,
                                           max_iters=max_iters)
    parallel_nn.setPrintLevel(print_level)
    parallel_nn.setSkipDowncycle(True)
    parallel_nn.setCFactor(cfactor)
    parallel_nn.setNumRelax(nrelax)
    parallel_nn.setNumRelax(0, level=0)  # F-Relaxation on the fine grid

    w0 = parallel_nn.copyVectorFromRoot(w)
    x0 = x.clone()
    x0.requires_grad = True

    t0_fwd_parallel = time.time()
    y_fwd_parallel = parallel_nn(x0)
    comm.barrier()