예제 #1
0
    def backwardProp(
            self,
            max_levels=1,  # for testing parallel rnn
            max_iters=1,  # for testing parallel rnn
            sequence_length=6,  # total number of time steps for each sequence
            input_size=28,  # input size for each time step in a sequence
            hidden_size=20,
            num_layers=1,
            batch_size=1,
            tol=1e-6,
            applications=1):

        comm = MPI.COMM_WORLD
        num_procs = comm.Get_size()
        my_rank = comm.Get_rank()

        Tf = float(sequence_length)
        channels = 1
        images = 10
        image_size = 28
        print_level = 0
        nrelax = 1
        cfactor = 2
        num_batch = int(images / batch_size)

        # wait for serial processor
        comm.barrier()

        num_procs = comm.Get_size()

        # preprocess and distribute input data
        x_block = preprocess_distribute_input_data_parallel(
            my_rank, num_procs, num_batch, batch_size, channels,
            sequence_length, input_size, comm)

        num_steps = x_block[0].shape[1]

        basic_block_parallel = lambda: RNN_build_block_with_dim(
            input_size, hidden_size, num_layers)
        parallel_rnn = torchbraid.RNN_Parallel(comm,
                                               basic_block_parallel(),
                                               num_steps,
                                               hidden_size,
                                               num_layers,
                                               Tf,
                                               max_levels=max_levels,
                                               max_iters=max_iters)

        parallel_rnn.setPrintLevel(print_level)
        parallel_rnn.setSkipDowncycle(True)
        parallel_rnn.setCFactor(cfactor)
        parallel_rnn.setNumRelax(nrelax)

        torch.manual_seed(20)
        rand_w = torch.randn([1, x_block[0].size(0), hidden_size])

        for i in range(applications):
            h_0 = torch.zeros(num_layers,
                              x_block[i].size(0),
                              hidden_size,
                              requires_grad=True)
            c_0 = torch.zeros(num_layers,
                              x_block[i].size(0),
                              hidden_size,
                              requires_grad=True)

            with torch.enable_grad():
                y_parallel_hn, y_parallel_cn = parallel_rnn(
                    x_block[i], (h_0, c_0))

            comm.barrier()

            w_h = torch.zeros(y_parallel_hn.shape)
            w_c = torch.zeros(y_parallel_hn.shape)

            w_h[-1, :, :] = rand_w

            y_parallel_hn.backward(w_h)

            if i < applications - 1:
                with torch.no_grad():
                    for p in parallel_rnn.parameters():
                        p += p.grad
                parallel_rnn.zero_grad()

        # compute serial solution
        #############################################

        if my_rank == 0:

            torch.manual_seed(20)
            serial_rnn = nn.LSTM(input_size,
                                 hidden_size,
                                 num_layers,
                                 batch_first=True)
            image_all, x_block_all = preprocess_input_data_serial_test(
                num_procs, num_batch, batch_size, channels, sequence_length,
                input_size)

            for i in range(applications):
                y_serial_hn_0 = torch.zeros(num_layers,
                                            image_all[i].size(0),
                                            hidden_size,
                                            requires_grad=True)
                y_serial_cn_0 = torch.zeros(num_layers,
                                            image_all[i].size(0),
                                            hidden_size,
                                            requires_grad=True)

                with torch.enable_grad():
                    q, (y_serial_hn, y_serial_cn) = serial_rnn(
                        image_all[i], (y_serial_hn_0, y_serial_cn_0))

                w_q = torch.zeros(q.shape)
                w_q[:, -1, :] = rand_w.detach().clone()

                q.backward(w_q)

                if i < applications - 1:
                    with torch.no_grad():
                        for p in serial_rnn.parameters():
                            p += p.grad
                    serial_rnn.zero_grad()
        # end if my_rank

        # now check the answers
        #############################################

        # send the final inference step to root
        if comm.Get_size() > 1 and my_rank == comm.Get_size() - 1:
            comm.send(y_parallel_hn, 0)
            comm.send(y_parallel_cn, 0)

        if my_rank == 0:
            if comm.Get_size() > 1:
                # recieve the final inference step
                parallel_hn = comm.recv(source=comm.Get_size() - 1)
                parallel_cn = comm.recv(source=comm.Get_size() - 1)
            else:
                parallel_hn = y_parallel_hn
                parallel_cn = y_parallel_cn

            print('\n\n')
            print(
                torch.norm(y_serial_cn - parallel_cn).item() /
                torch.norm(y_serial_cn).item(), 'forward cn')
            print(
                torch.norm(y_serial_hn - parallel_hn).item() /
                torch.norm(y_serial_hn).item(), 'forward hn')
            sys.stdout.flush()

            self.assertTrue(
                torch.norm(y_serial_cn - parallel_cn) / torch.norm(y_serial_cn)
                < tol, 'cn value')
            self.assertTrue(
                torch.norm(y_serial_hn - parallel_hn) / torch.norm(y_serial_hn)
                < tol, 'hn value')

            print(
                torch.norm(h_0.grad - y_serial_hn_0.grad).item(),
                'back soln hn')
            print(
                torch.norm(c_0.grad - y_serial_cn_0.grad).item(),
                'back soln cn')
            self.assertTrue(
                torch.norm(h_0.grad - y_serial_hn_0.grad).item() < tol)
            self.assertTrue(
                torch.norm(c_0.grad - y_serial_cn_0.grad).item() < tol)

            root_grads = [p.grad for p in serial_rnn.parameters()]
        else:
            root_grads = None

        ref_grads = comm.bcast(root_grads, root=0)
        for pa_grad, pb in zip(ref_grads, parallel_rnn.parameters()):
            if torch.norm(pa_grad).item() == 0.0:
                print(my_rank,
                      torch.norm(pa_grad - pb.grad).item().item(),
                      'param grad')
                self.assertTrue(
                    torch.norm(pa_grad - pb.grad).item() < 1e1 * tol,
                    'param grad')
            else:
                print(
                    my_rank,
                    torch.norm(pa_grad - pb.grad).item() /
                    torch.norm(pa_grad).item(), 'param grad')
                self.assertTrue(
                    torch.norm(pa_grad - pb.grad).item() /
                    torch.norm(pa_grad).item() < 1e1 * tol, 'param grad')
    # preprocess and distribute input data
    ###########################################
    # x_block = preprocess_distribute_synthetic_image_sequences_parallel(my_rank,num_procs,num_batch,batch_size,channels,sequence_length,input_size)
    x_block = preprocess_distribute_MNIST_image_sequences_parallel(
        train_loader, my_rank, num_procs, num_batch, batch_size, channels,
        sequence_length, input_size)

    max_levels = 1  # for testing parallel rnn
    max_iters = 1  # for testing parallel rnn

    num_steps = sequence_length

    parallel_nn = torchbraid.RNN_Parallel(comm,
                                          basic_block_parallel,
                                          num_steps,
                                          hidden_size,
                                          num_layers,
                                          Tf,
                                          max_levels=max_levels,
                                          max_iters=max_iters)

    parallel_nn.setPrintLevel(print_level)
    parallel_nn.setSkipDowncycle(True)
    parallel_nn.setCFactor(cfactor)
    parallel_nn.setNumRelax(nrelax)
    # parallel_nn.setNumRelax(nrelax,level=0)

    t0_parallel = time.time()

    for epoch in range(num_epochs):
        # x_block: (num_batch,batch_size,seq_len/num_procs,input_size)
        for i in range(len(x_block)):
예제 #3
0
    def forwardProp(
            self,
            max_levels=1,  # for testing parallel rnn
            max_iters=1,  # for testing parallel rnn
            sequence_length=28,  # total number of time steps for each sequence
            input_size=28,  # input size for each time step in a sequence
            hidden_size=20,
            num_layers=2,
            batch_size=1):

        comm = MPI.COMM_WORLD
        num_procs = comm.Get_size()
        my_rank = comm.Get_rank()

        Tf = float(sequence_length)
        channels = 1
        images = 10
        image_size = 28
        print_level = 0
        nrelax = 1
        cfactor = 2
        num_batch = int(images / batch_size)

        if my_rank == 0:
            with torch.no_grad():

                torch.manual_seed(20)
                serial_rnn = nn.LSTM(input_size,
                                     hidden_size,
                                     num_layers,
                                     batch_first=True)
                num_blocks = 2  # equivalent to the num_procs variable used for parallel implementation
                image_all, x_block_all = preprocess_input_data_serial_test(
                    num_blocks, num_batch, batch_size, channels,
                    sequence_length, input_size)

                for i in range(1):

                    y_serial_hn = torch.zeros(num_layers, image_all[i].size(0),
                                              hidden_size)
                    y_serial_cn = torch.zeros(num_layers, image_all[i].size(0),
                                              hidden_size)

                    _, (y_serial_hn,
                        y_serial_cn) = serial_rnn(image_all[i],
                                                  (y_serial_hn, y_serial_cn))

        # compute serial solution

        # wait for serial processor
        comm.barrier()

        basic_block_parallel = lambda: RNN_build_block_with_dim(
            input_size, hidden_size, num_layers)
        num_procs = comm.Get_size()

        # preprocess and distribute input data
        ###########################################
        x_block = preprocess_distribute_input_data_parallel(
            my_rank, num_procs, num_batch, batch_size, channels,
            sequence_length, input_size, comm)

        num_steps = x_block[0].shape[1]
        # RNN_parallel.py -> RNN_Parallel() class
        parallel_rnn = torchbraid.RNN_Parallel(comm,
                                               basic_block_parallel(),
                                               num_steps,
                                               hidden_size,
                                               num_layers,
                                               Tf,
                                               max_levels=max_levels,
                                               max_iters=max_iters)

        parallel_rnn.setPrintLevel(print_level)
        parallel_rnn.setSkipDowncycle(True)
        parallel_rnn.setCFactor(cfactor)
        parallel_rnn.setNumRelax(nrelax)

        # for i in range(len(x_block)):
        for i in range(1):

            y_parallel_hn, y_parallel_cn = parallel_rnn(x_block[i])

            comm.barrier()

            # send the final inference step to root
            if comm.Get_size() > 1 and my_rank == comm.Get_size() - 1:
                comm.send(y_parallel_hn, 0)
                comm.send(y_parallel_cn, 0)

            if my_rank == 0:
                # recieve the final inference step
                if comm.Get_size() > 1:
                    parallel_hn = comm.recv(source=comm.Get_size() - 1)
                    parallel_cn = comm.recv(source=comm.Get_size() - 1)
                else:
                    parallel_hn = y_parallel_hn
                    parallel_cn = y_parallel_cn

                print(
                    'cn values = ',
                    torch.norm(y_serial_cn - parallel_cn).item() /
                    torch.norm(y_serial_cn).item())
                print(
                    'hn values = ',
                    torch.norm(y_serial_hn - parallel_hn).item() /
                    torch.norm(y_serial_hn).item())
                self.assertTrue(
                    torch.norm(y_serial_cn - parallel_cn).item() /
                    torch.norm(y_serial_cn).item() < 1e-6, 'check cn')
                self.assertTrue(
                    torch.norm(y_serial_hn - parallel_hn).item() /
                    torch.norm(y_serial_hn).item() < 1e-6, 'check hn')
예제 #4
0
    def backwardProp(
            self,
            max_levels=1,  # for testing parallel rnn
            max_iters=1,  # for testing parallel rnn
            sequence_length=128,  # total number of time steps for each sequence
            input_size=1,  # input size for each time step in a sequence
            hidden_size=20,
            num_layers=2,
            batch_size=100,
            tol=1e-6):

        comm = MPI.COMM_WORLD
        num_procs = comm.Get_size()
        my_rank = comm.Get_rank()

        Tf = 2.0
        channels = 1
        images = 1000
        # image_size      = 28
        print_level = 1
        nrelax = 3
        cfactor = 4
        num_batch = int(images / batch_size)

        # wait for serial processor
        comm.barrier()

        num_procs = comm.Get_size()

        # preprocess and distribute input data
        x_block = preprocess_distribute_synthetic_image_sequences_parallel(
            my_rank, num_procs, num_batch, batch_size, channels,
            sequence_length, input_size, comm)

        num_steps = x_block[0].shape[1]

        basic_block_parallel = lambda: RNN_build_block_with_dim(
            input_size, hidden_size, num_layers)
        parallel_rnn = torchbraid.RNN_Parallel(comm,
                                               basic_block_parallel,
                                               num_steps,
                                               hidden_size,
                                               num_layers,
                                               Tf,
                                               max_levels=max_levels,
                                               max_iters=max_iters)

        parallel_rnn.setPrintLevel(print_level)
        parallel_rnn.setSkipDowncycle(True)
        parallel_rnn.setCFactor(cfactor)
        parallel_rnn.setNumRelax(nrelax)

        h_0 = torch.zeros(num_layers,
                          x_block[0].size(0),
                          hidden_size,
                          requires_grad=True)
        c_0 = torch.zeros(num_layers,
                          x_block[0].size(0),
                          hidden_size,
                          requires_grad=True)

        with torch.enable_grad():
            y_parallel_hn, y_parallel_cn = parallel_rnn(x_block[0], (h_0, c_0))

        comm.barrier()

        w_h = torch.zeros(y_parallel_hn.shape)
        w_c = torch.zeros(y_parallel_hn.shape)
        torch.manual_seed(20)
        w_h[-1, :, :] = torch.randn(y_parallel_hn[-1, :, :].shape)

        y_parallel_hn.backward(w_h)

        # compute serial solution
        #############################################

        if my_rank == 0:

            torch.manual_seed(20)
            serial_rnn = nn.LSTM(input_size,
                                 hidden_size,
                                 num_layers,
                                 batch_first=True)
            image_all, x_block_all = preprocess_synthetic_image_sequences_serial(
                num_procs, num_batch, batch_size, channels, sequence_length,
                input_size)

            i = 0
            y_serial_hn_0 = torch.zeros(num_layers,
                                        image_all[i].size(0),
                                        hidden_size,
                                        requires_grad=True)
            y_serial_cn_0 = torch.zeros(num_layers,
                                        image_all[i].size(0),
                                        hidden_size,
                                        requires_grad=True)

            with torch.enable_grad():
                q, (y_serial_hn,
                    y_serial_cn) = serial_rnn(image_all[i],
                                              (y_serial_hn_0, y_serial_cn_0))

            print('\n\n')
            print('fore in: ', (torch.norm(y_serial_hn_0).item(),
                                torch.norm(y_serial_cn_0).item()),
                  'out: ', (torch.norm(y_serial_hn).item(),
                            torch.norm(y_serial_cn).item()))

            w_q = torch.zeros(q.shape)
            w_q[:, -1, :] = w_h[-1, :, :].detach().clone()

            q.backward(w_q)

            print('back in: ',
                  torch.norm(w_q).item(), 'out: ',
                  (torch.norm(y_serial_hn_0.grad).item(),
                   torch.norm(y_serial_cn_0.grad).item()))
            print('')
        # end if my_rank

        # now check the answers
        #############################################

        # send the final inference step to root
        if my_rank == comm.Get_size() - 1 and comm.Get_size() > 1:
            comm.send(y_parallel_hn, 0)
            comm.send(y_parallel_cn, 0)

        if my_rank == 0:
            # recieve the final inference step
            if comm.Get_size() > 1:
                parallel_hn = comm.recv(source=comm.Get_size() - 1)
                parallel_cn = comm.recv(source=comm.Get_size() - 1)
            else:
                parallel_hn = y_parallel_hn
                parallel_cn = y_parallel_cn

            print(
                torch.norm(y_serial_cn - parallel_cn).item() /
                torch.norm(y_serial_cn).item(), 'forward cn')
            print(
                torch.norm(y_serial_hn - parallel_hn).item() /
                torch.norm(y_serial_hn).item(), 'forward hn')

            #self.assertTrue(torch.norm(y_serial_cn-parallel_cn)/torch.norm(y_serial_cn)<tol,'cn value')
            #self.assertTrue(torch.norm(y_serial_hn-parallel_hn)/torch.norm(y_serial_hn)<tol,'rn value')

            print('back hn',
                  torch.norm(h_0.grad).item(),
                  torch.norm(y_serial_hn_0.grad).item())
            print('back cn',
                  torch.norm(c_0.grad).item(),
                  torch.norm(y_serial_cn_0.grad).item())
            #self.assertTrue(torch.norm(h_0.grad-y_serial_hn_0.grad).item()<tol)
            #self.assertTrue(torch.norm(c_0.grad-y_serial_cn_0.grad).item()<tol)

            root_grads = [p.grad for p in serial_rnn.parameters()]
        else:
            root_grads = None

        ref_grads = comm.bcast(root_grads, root=0)
        for pa_grad, pb in zip(ref_grads, parallel_rnn.parameters()):
            print(
                'grad values = ',
                torch.norm(pa_grad - pb.grad).item() /
                torch.norm(pa_grad).item(),
                torch.norm(pa_grad).item())