def __init__(self, channels=12, local_steps=8, Tf=1.0, serial_nn=None, open_nn=None, close_nn=None): super(SerialNet, self).__init__() if open_nn is None: self.open_nn = OpenFlatLayer(channels) else: self.open_nn = open_nn if serial_nn is None: step_layer = lambda: StepLayer(channels) parallel_nn = torchbraid.LayerParallel(MPI.COMM_SELF, step_layer, local_steps, Tf, max_levels=1, max_iters=1) parallel_nn.setPrintLevel(0) self.serial_nn = parallel_nn.buildSequentialOnRoot() else: self.serial_nn = serial_nn if close_nn is None: self.close_nn = CloseLayer(channels) else: self.close_nn = close_nn
def __init__(self, channels=4, local_steps=2, Tf=1.0, max_levels=1, max_iters=1, print_level=0): super(ParallelNet, self).__init__() self.rank = MPI.COMM_WORLD.Get_rank() self.channels = channels step_layer = lambda: StepLayer(channels) self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD, step_layer, local_steps, Tf, max_levels=max_levels, max_iters=max_iters) self.parallel_nn.setPrintLevel(print_level) self.parallel_nn.setCFactor(4) self.o = self.parallel_nn.comp_op( ) # get tool to build up composition neural networks # in this case, because OpenLayer/CloseLayer are classes, these return None on processors # away from rank==0...this might be too cute self.open_nn = self.o(OpenLayer, channels) self.close_nn = self.o(CloseLayer, channels)
def __init__(self, Tstop=10.0, width=4, local_steps=10, max_levels=1, max_iters=1, fwd_max_iters=0, print_level=0, braid_print_level=0, cfactor=4, fine_fcf=False, skip_downcycle=True, fmg=False): super(ParallelNet, self).__init__() step_layer = lambda: StepLayer(width) # Create and store parallel net self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD, step_layer, local_steps, Tstop, max_levels=max_levels, max_iters=max_iters) # Set options if fwd_max_iters > 0: # print("FWD_max_iter = ", fwd_max_iters) self.parallel_nn.setFwdMaxIters(fwd_max_iters) self.parallel_nn.setPrintLevel(print_level,True) self.parallel_nn.setPrintLevel(braid_print_level,False) self.parallel_nn.setCFactor(cfactor) self.parallel_nn.setSkipDowncycle(skip_downcycle) if fmg: self.parallel_nn.setFMG() self.parallel_nn.setNumRelax(1) # FCF elsewehre if not fine_fcf: self.parallel_nn.setNumRelax(0,level=0) # F-Relaxation on the fine grid else: self.parallel_nn.setNumRelax(1,level=0) # F-Relaxation on the fine grid # this object ensures that only the LayerParallel code runs on ranks!=0 compose = self.compose = self.parallel_nn.comp_op() # by passing this through 'compose' (mean composition: e.g. OpenFlatLayer o channels) # on processors not equal to 0, these will be None (there are no parameters to train there) self.openlayer = compose(OpenLayer,width) self.closinglayer = compose(ClosingLayer)
def main(): comm = MPI.COMM_WORLD local_steps = 3 levels = 2 step_layer = lambda: StepLayer(channels=2) parallel_nn = torchbraid.LayerParallel(comm,step_layer,local_steps,Tf=comm.Get_size()*local_steps,max_levels=levels,max_iters=1) parallel_nn.setPrintLevel(0) fwd_lower,fwd_upper = parallel_nn.fwd_app.getStepBounds() bwd_lower,bwd_upper = parallel_nn.bwd_app.getStepBounds() parallel_nn.train() x = torch.rand(10,2,9,9) root_print(comm.Get_rank(),'FORWARD') root_print(comm.Get_rank(),'======================') y = parallel_nn(x) print(' %d) fwd lower,upper = ' % comm.Get_rank(),fwd_lower,fwd_upper) sys.stdout.flush() comm.barrier() root_print(comm.Get_rank(),'\n\nBACKWARD') root_print(comm.Get_rank(),'======================') y.backward(torch.ones(y.shape)) print(' %d) bwd lower,upper = ' % comm.Get_rank(),bwd_lower,bwd_upper) sys.stdout.flush() comm.barrier()
def __init__(self, out_channels=12, local_steps=8, Tf=1.0, max_levels=1, max_iters=1, print_level=0): super(ParallelNet, self).__init__() step_layer = lambda: StepLayer(out_channels) # print("Rank in parallel net ", MPI.COMM_WORLD.Get_rank()) self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD, step_layer, local_steps, Tf, max_levels=max_levels, max_iters=max_iters) self.parallel_nn.setPrintLevel(print_level) self.parallel_nn.setCFactor(4) self.parallel_nn.setSkipDowncycle(True) self.parallel_nn.setNumRelax(1) # FCF elsewehre self.parallel_nn.setNumRelax(0, level=0) # F-Relaxation on the fine grid # this object ensures that only the LayerParallel code runs on ranks!=0 compose = self.compose = self.parallel_nn.comp_op() # by passing this through 'compose' (mean composition: e.g. OpenLayer o channels) # on processors not equal to 0, these will be None (there are no parameters to train there) self.open_nn = compose(OpenLayer, out_channels) self.close_nn = compose(CloseLayer, out_channels)
def __init__(self, channels=12, local_steps=8, Tf=1.0): super(SerialNet, self).__init__() step_layer = lambda: StepLayer(channels) self.open_nn = OpenLayer(channels) self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD, step_layer, local_steps, Tf, max_levels=1, max_iters=1) self.parallel_nn.setPrintLevel(0) self.serial_nn = self.parallel_nn.buildSequentialOnRoot() self.close_nn = CloseLayer(channels)
def __init__(self, channels=12, local_steps=8, Tf=1.0, max_levels=1, max_iters=1, print_level=0, cfactor=4, fine_fcf=False, skip_downcycle=True, fmg=False): super(ParallelNet, self).__init__() step_layer = lambda: StepLayer(channels) self.parallel_nn = torchbraid.LayerParallel(MPI.COMM_WORLD, step_layer, local_steps, Tf, max_levels=max_levels, max_iters=max_iters) self.parallel_nn.setPrintLevel(print_level) self.parallel_nn.setCFactor(cfactor) self.parallel_nn.setSkipDowncycle(skip_downcycle) if fmg: self.parallel_nn.setFMG() self.parallel_nn.setNumRelax(1) # FCF elsewehre if not fine_fcf: self.parallel_nn.setNumRelax( 0, level=0) # F-Relaxation on the fine grid else: self.parallel_nn.setNumRelax( 1, level=0) # F-Relaxation on the fine grid # this object ensures that only the LayerParallel code runs on ranks!=0 compose = self.compose = self.parallel_nn.comp_op() # by passing this through 'compose' (mean composition: e.g. OpenLayer o channels) # on processors not equal to 0, these will be None (there are no parameters to train there) self.open_nn = compose(OpenLayer, channels) self.close_nn = compose(CloseLayer, channels)
def backForwardProp(self, dim, basic_block, x0, w0, max_levels, max_iters, test_tol, prefix, ref_pair=None): Tf = 2.0 num_steps = 4 # this is the torchbraid class being tested ####################################### m = torchbraid.LayerParallel(MPI.COMM_WORLD, basic_block, num_steps, Tf, max_levels=max_levels, max_iters=max_iters, spatial_ref_pair=ref_pair) m.setPrintLevel(0) w0 = m.copyVectorFromRoot(w0) # this is the reference torch "solution" ####################################### dt = Tf / num_steps f = m.buildSequentialOnRoot() # run forward/backward propgation lr = 1e-3 # propogation with torchbraid ####################################### xm = x0.clone() xm.requires_grad = True wm = m(xm) wm.backward(w0) wm0 = m.getFinalOnRoot(wm) with torch.no_grad(): for p in m.parameters(): p -= p.grad * lr m.zero_grad() if xm.grad is not None: xm.grad.zero_() wm = m(xm) wm.backward(w0) m_param_grad = self.copyParameterGradToRoot(m) wm = m.getFinalOnRoot(wm) # print time results timer_str = m.getTimersString() # check some values if m.getMPIComm().Get_rank() == 0: # this is too much to print out every test run, but I'd like to make sure the # code is execueted self.assertTrue(len(timer_str) > 0) compute_grad = True # propogation with torch ####################################### xf = x0.clone() xf.requires_grad = compute_grad wf = f(xf) wf.backward(w0) with torch.no_grad(): for p in f.parameters(): p -= p.grad * lr f.zero_grad() xf.grad.zero_() wf = f(xf) wf.backward(w0) # compare the solutions ####################################### self.assertTrue(torch.norm(wm) > 0.0) self.assertTrue(torch.norm(wf) > 0.0) print('\n') print('%s: fwd error = %.6e (%.6e, %.6e)' % (prefix, torch.norm(wm - wf) / torch.norm(wf), torch.norm(wf), torch.norm(wm))) print('%s: grad error = %.6e (%.6e, %.6e)' % (prefix, torch.norm(xm.grad - xf.grad) / torch.norm(xf.grad), torch.norm(xf.grad), torch.norm(xm.grad))) self.assertTrue(torch.norm(wm - wf) / torch.norm(wf) <= test_tol) self.assertTrue((torch.norm(xm.grad - xf.grad) / torch.norm(xf.grad)) <= test_tol) param_errors = [] for pf, pm_grad in zip(list(f.parameters()), m_param_grad): self.assertTrue(not pm_grad is None) # accumulate parameter errors for testing purposes param_errors += [(torch.norm(pf.grad - pm_grad) / (1e-15 + torch.norm(pf.grad))).item()] # check the error conditions #print('%s: p_grad error = %.6e (%.6e %.6e)' % (prefix,torch.norm(pf.grad-pm_grad),torch.norm(pf.grad),torch.norm(pm_grad))) #sys.stdout.flush() self.assertTrue(torch.norm(pf.grad - pm_grad) <= test_tol) if len(param_errors) > 0: print('%s: p grad error (mean,stddev) = %.6e, %.6e' % (prefix, stats.mean(param_errors), stats.stdev(param_errors))) print('\n')
layers = [basic_block() for i in range(num_steps)] serial_nn = torch.nn.Sequential(*layers) t0_fwd_parallel = time.time() y_fwd_serial = serial_nn(x) tf_fwd_parallel = time.time() t0_bwd_parallel = time.time() y_fwd_serial.backward(w) tf_bwd_parallel = time.time() else: root_print(my_rank, 'Running TorchBraid: %d' % comm.Get_size()) # build the parallel neural network parallel_nn = torchbraid.LayerParallel(comm, basic_block, local_num_steps, Tf, max_levels=max_levels, max_iters=max_iters) parallel_nn.setPrintLevel(print_level) parallel_nn.setSkipDowncycle(True) parallel_nn.setCFactor(cfactor) parallel_nn.setNumRelax(nrelax) parallel_nn.setNumRelax(0, level=0) # F-Relaxation on the fine grid w0 = parallel_nn.copyVectorFromRoot(w) x0 = x.clone() x0.requires_grad = True t0_fwd_parallel = time.time() y_fwd_parallel = parallel_nn(x0) comm.barrier()