def backwardProp( self, max_levels=1, # for testing parallel rnn max_iters=1, # for testing parallel rnn sequence_length=6, # total number of time steps for each sequence input_size=28, # input size for each time step in a sequence hidden_size=20, num_layers=1, batch_size=1, tol=1e-6, applications=1): comm = MPI.COMM_WORLD num_procs = comm.Get_size() my_rank = comm.Get_rank() Tf = float(sequence_length) channels = 1 images = 10 image_size = 28 print_level = 0 nrelax = 1 cfactor = 2 num_batch = int(images / batch_size) # wait for serial processor comm.barrier() num_procs = comm.Get_size() # preprocess and distribute input data x_block = preprocess_distribute_input_data_parallel( my_rank, num_procs, num_batch, batch_size, channels, sequence_length, input_size, comm) num_steps = x_block[0].shape[1] basic_block_parallel = lambda: RNN_build_block_with_dim( input_size, hidden_size, num_layers) parallel_rnn = torchbraid.RNN_Parallel(comm, basic_block_parallel(), num_steps, hidden_size, num_layers, Tf, max_levels=max_levels, max_iters=max_iters) parallel_rnn.setPrintLevel(print_level) parallel_rnn.setSkipDowncycle(True) parallel_rnn.setCFactor(cfactor) parallel_rnn.setNumRelax(nrelax) torch.manual_seed(20) rand_w = torch.randn([1, x_block[0].size(0), hidden_size]) for i in range(applications): h_0 = torch.zeros(num_layers, x_block[i].size(0), hidden_size, requires_grad=True) c_0 = torch.zeros(num_layers, x_block[i].size(0), hidden_size, requires_grad=True) with torch.enable_grad(): y_parallel_hn, y_parallel_cn = parallel_rnn( x_block[i], (h_0, c_0)) comm.barrier() w_h = torch.zeros(y_parallel_hn.shape) w_c = torch.zeros(y_parallel_hn.shape) w_h[-1, :, :] = rand_w y_parallel_hn.backward(w_h) if i < applications - 1: with torch.no_grad(): for p in parallel_rnn.parameters(): p += p.grad parallel_rnn.zero_grad() # compute serial solution ############################################# if my_rank == 0: torch.manual_seed(20) serial_rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) image_all, x_block_all = preprocess_input_data_serial_test( num_procs, num_batch, batch_size, channels, sequence_length, input_size) for i in range(applications): y_serial_hn_0 = torch.zeros(num_layers, image_all[i].size(0), hidden_size, requires_grad=True) y_serial_cn_0 = torch.zeros(num_layers, image_all[i].size(0), hidden_size, requires_grad=True) with torch.enable_grad(): q, (y_serial_hn, y_serial_cn) = serial_rnn( image_all[i], (y_serial_hn_0, y_serial_cn_0)) w_q = torch.zeros(q.shape) w_q[:, -1, :] = rand_w.detach().clone() q.backward(w_q) if i < applications - 1: with torch.no_grad(): for p in serial_rnn.parameters(): p += p.grad serial_rnn.zero_grad() # end if my_rank # now check the answers ############################################# # send the final inference step to root if comm.Get_size() > 1 and my_rank == comm.Get_size() - 1: comm.send(y_parallel_hn, 0) comm.send(y_parallel_cn, 0) if my_rank == 0: if comm.Get_size() > 1: # recieve the final inference step parallel_hn = comm.recv(source=comm.Get_size() - 1) parallel_cn = comm.recv(source=comm.Get_size() - 1) else: parallel_hn = y_parallel_hn parallel_cn = y_parallel_cn print('\n\n') print( torch.norm(y_serial_cn - parallel_cn).item() / torch.norm(y_serial_cn).item(), 'forward cn') print( torch.norm(y_serial_hn - parallel_hn).item() / torch.norm(y_serial_hn).item(), 'forward hn') sys.stdout.flush() self.assertTrue( torch.norm(y_serial_cn - parallel_cn) / torch.norm(y_serial_cn) < tol, 'cn value') self.assertTrue( torch.norm(y_serial_hn - parallel_hn) / torch.norm(y_serial_hn) < tol, 'hn value') print( torch.norm(h_0.grad - y_serial_hn_0.grad).item(), 'back soln hn') print( torch.norm(c_0.grad - y_serial_cn_0.grad).item(), 'back soln cn') self.assertTrue( torch.norm(h_0.grad - y_serial_hn_0.grad).item() < tol) self.assertTrue( torch.norm(c_0.grad - y_serial_cn_0.grad).item() < tol) root_grads = [p.grad for p in serial_rnn.parameters()] else: root_grads = None ref_grads = comm.bcast(root_grads, root=0) for pa_grad, pb in zip(ref_grads, parallel_rnn.parameters()): if torch.norm(pa_grad).item() == 0.0: print(my_rank, torch.norm(pa_grad - pb.grad).item().item(), 'param grad') self.assertTrue( torch.norm(pa_grad - pb.grad).item() < 1e1 * tol, 'param grad') else: print( my_rank, torch.norm(pa_grad - pb.grad).item() / torch.norm(pa_grad).item(), 'param grad') self.assertTrue( torch.norm(pa_grad - pb.grad).item() / torch.norm(pa_grad).item() < 1e1 * tol, 'param grad')
# preprocess and distribute input data ########################################### # x_block = preprocess_distribute_synthetic_image_sequences_parallel(my_rank,num_procs,num_batch,batch_size,channels,sequence_length,input_size) x_block = preprocess_distribute_MNIST_image_sequences_parallel( train_loader, my_rank, num_procs, num_batch, batch_size, channels, sequence_length, input_size) max_levels = 1 # for testing parallel rnn max_iters = 1 # for testing parallel rnn num_steps = sequence_length parallel_nn = torchbraid.RNN_Parallel(comm, basic_block_parallel, num_steps, hidden_size, num_layers, Tf, max_levels=max_levels, max_iters=max_iters) parallel_nn.setPrintLevel(print_level) parallel_nn.setSkipDowncycle(True) parallel_nn.setCFactor(cfactor) parallel_nn.setNumRelax(nrelax) # parallel_nn.setNumRelax(nrelax,level=0) t0_parallel = time.time() for epoch in range(num_epochs): # x_block: (num_batch,batch_size,seq_len/num_procs,input_size) for i in range(len(x_block)):
def forwardProp( self, max_levels=1, # for testing parallel rnn max_iters=1, # for testing parallel rnn sequence_length=28, # total number of time steps for each sequence input_size=28, # input size for each time step in a sequence hidden_size=20, num_layers=2, batch_size=1): comm = MPI.COMM_WORLD num_procs = comm.Get_size() my_rank = comm.Get_rank() Tf = float(sequence_length) channels = 1 images = 10 image_size = 28 print_level = 0 nrelax = 1 cfactor = 2 num_batch = int(images / batch_size) if my_rank == 0: with torch.no_grad(): torch.manual_seed(20) serial_rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) num_blocks = 2 # equivalent to the num_procs variable used for parallel implementation image_all, x_block_all = preprocess_input_data_serial_test( num_blocks, num_batch, batch_size, channels, sequence_length, input_size) for i in range(1): y_serial_hn = torch.zeros(num_layers, image_all[i].size(0), hidden_size) y_serial_cn = torch.zeros(num_layers, image_all[i].size(0), hidden_size) _, (y_serial_hn, y_serial_cn) = serial_rnn(image_all[i], (y_serial_hn, y_serial_cn)) # compute serial solution # wait for serial processor comm.barrier() basic_block_parallel = lambda: RNN_build_block_with_dim( input_size, hidden_size, num_layers) num_procs = comm.Get_size() # preprocess and distribute input data ########################################### x_block = preprocess_distribute_input_data_parallel( my_rank, num_procs, num_batch, batch_size, channels, sequence_length, input_size, comm) num_steps = x_block[0].shape[1] # RNN_parallel.py -> RNN_Parallel() class parallel_rnn = torchbraid.RNN_Parallel(comm, basic_block_parallel(), num_steps, hidden_size, num_layers, Tf, max_levels=max_levels, max_iters=max_iters) parallel_rnn.setPrintLevel(print_level) parallel_rnn.setSkipDowncycle(True) parallel_rnn.setCFactor(cfactor) parallel_rnn.setNumRelax(nrelax) # for i in range(len(x_block)): for i in range(1): y_parallel_hn, y_parallel_cn = parallel_rnn(x_block[i]) comm.barrier() # send the final inference step to root if comm.Get_size() > 1 and my_rank == comm.Get_size() - 1: comm.send(y_parallel_hn, 0) comm.send(y_parallel_cn, 0) if my_rank == 0: # recieve the final inference step if comm.Get_size() > 1: parallel_hn = comm.recv(source=comm.Get_size() - 1) parallel_cn = comm.recv(source=comm.Get_size() - 1) else: parallel_hn = y_parallel_hn parallel_cn = y_parallel_cn print( 'cn values = ', torch.norm(y_serial_cn - parallel_cn).item() / torch.norm(y_serial_cn).item()) print( 'hn values = ', torch.norm(y_serial_hn - parallel_hn).item() / torch.norm(y_serial_hn).item()) self.assertTrue( torch.norm(y_serial_cn - parallel_cn).item() / torch.norm(y_serial_cn).item() < 1e-6, 'check cn') self.assertTrue( torch.norm(y_serial_hn - parallel_hn).item() / torch.norm(y_serial_hn).item() < 1e-6, 'check hn')
def backwardProp( self, max_levels=1, # for testing parallel rnn max_iters=1, # for testing parallel rnn sequence_length=128, # total number of time steps for each sequence input_size=1, # input size for each time step in a sequence hidden_size=20, num_layers=2, batch_size=100, tol=1e-6): comm = MPI.COMM_WORLD num_procs = comm.Get_size() my_rank = comm.Get_rank() Tf = 2.0 channels = 1 images = 1000 # image_size = 28 print_level = 1 nrelax = 3 cfactor = 4 num_batch = int(images / batch_size) # wait for serial processor comm.barrier() num_procs = comm.Get_size() # preprocess and distribute input data x_block = preprocess_distribute_synthetic_image_sequences_parallel( my_rank, num_procs, num_batch, batch_size, channels, sequence_length, input_size, comm) num_steps = x_block[0].shape[1] basic_block_parallel = lambda: RNN_build_block_with_dim( input_size, hidden_size, num_layers) parallel_rnn = torchbraid.RNN_Parallel(comm, basic_block_parallel, num_steps, hidden_size, num_layers, Tf, max_levels=max_levels, max_iters=max_iters) parallel_rnn.setPrintLevel(print_level) parallel_rnn.setSkipDowncycle(True) parallel_rnn.setCFactor(cfactor) parallel_rnn.setNumRelax(nrelax) h_0 = torch.zeros(num_layers, x_block[0].size(0), hidden_size, requires_grad=True) c_0 = torch.zeros(num_layers, x_block[0].size(0), hidden_size, requires_grad=True) with torch.enable_grad(): y_parallel_hn, y_parallel_cn = parallel_rnn(x_block[0], (h_0, c_0)) comm.barrier() w_h = torch.zeros(y_parallel_hn.shape) w_c = torch.zeros(y_parallel_hn.shape) torch.manual_seed(20) w_h[-1, :, :] = torch.randn(y_parallel_hn[-1, :, :].shape) y_parallel_hn.backward(w_h) # compute serial solution ############################################# if my_rank == 0: torch.manual_seed(20) serial_rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) image_all, x_block_all = preprocess_synthetic_image_sequences_serial( num_procs, num_batch, batch_size, channels, sequence_length, input_size) i = 0 y_serial_hn_0 = torch.zeros(num_layers, image_all[i].size(0), hidden_size, requires_grad=True) y_serial_cn_0 = torch.zeros(num_layers, image_all[i].size(0), hidden_size, requires_grad=True) with torch.enable_grad(): q, (y_serial_hn, y_serial_cn) = serial_rnn(image_all[i], (y_serial_hn_0, y_serial_cn_0)) print('\n\n') print('fore in: ', (torch.norm(y_serial_hn_0).item(), torch.norm(y_serial_cn_0).item()), 'out: ', (torch.norm(y_serial_hn).item(), torch.norm(y_serial_cn).item())) w_q = torch.zeros(q.shape) w_q[:, -1, :] = w_h[-1, :, :].detach().clone() q.backward(w_q) print('back in: ', torch.norm(w_q).item(), 'out: ', (torch.norm(y_serial_hn_0.grad).item(), torch.norm(y_serial_cn_0.grad).item())) print('') # end if my_rank # now check the answers ############################################# # send the final inference step to root if my_rank == comm.Get_size() - 1 and comm.Get_size() > 1: comm.send(y_parallel_hn, 0) comm.send(y_parallel_cn, 0) if my_rank == 0: # recieve the final inference step if comm.Get_size() > 1: parallel_hn = comm.recv(source=comm.Get_size() - 1) parallel_cn = comm.recv(source=comm.Get_size() - 1) else: parallel_hn = y_parallel_hn parallel_cn = y_parallel_cn print( torch.norm(y_serial_cn - parallel_cn).item() / torch.norm(y_serial_cn).item(), 'forward cn') print( torch.norm(y_serial_hn - parallel_hn).item() / torch.norm(y_serial_hn).item(), 'forward hn') #self.assertTrue(torch.norm(y_serial_cn-parallel_cn)/torch.norm(y_serial_cn)<tol,'cn value') #self.assertTrue(torch.norm(y_serial_hn-parallel_hn)/torch.norm(y_serial_hn)<tol,'rn value') print('back hn', torch.norm(h_0.grad).item(), torch.norm(y_serial_hn_0.grad).item()) print('back cn', torch.norm(c_0.grad).item(), torch.norm(y_serial_cn_0.grad).item()) #self.assertTrue(torch.norm(h_0.grad-y_serial_hn_0.grad).item()<tol) #self.assertTrue(torch.norm(c_0.grad-y_serial_cn_0.grad).item()<tol) root_grads = [p.grad for p in serial_rnn.parameters()] else: root_grads = None ref_grads = comm.bcast(root_grads, root=0) for pa_grad, pb in zip(ref_grads, parallel_rnn.parameters()): print( 'grad values = ', torch.norm(pa_grad - pb.grad).item() / torch.norm(pa_grad).item(), torch.norm(pa_grad).item())