def _run():
            inputs = self.get_inputs()
            args_defaults = {
                'num_layers': 2,
                'hidden_size': 128,
                'num_attention_heads': 8,
                'max_position_embeddings': 128,
            }

            model = get_gpt2_model(args_defaults)
            model = self.get_deepspeed_model(model, tmpdir)

            model.eval()
            baseline = model(inputs[0].cuda(), inputs[1].cuda(),
                             inputs[2].cuda())

            tag = 'mp_1'
            state_dict = {}
            state_dict['checkpoint_version'] = get_megatron_version()
            model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)
            dist.barrier()
            model.load_checkpoint(tmpdir,
                                  tag=tag,
                                  load_optimizer_states=False,
                                  load_lr_scheduler_states=False)

            test = model(inputs[0], inputs[1], inputs[2])
            assert torch.allclose(
                baseline, test, atol=1e-07
            ), f"Baseline output {baseline} is not equal to save-then-load output {test}"
        def _run():
            args_defaults = {
                'num_layers': 8,
                'hidden_size': 128,
                'num_attention_heads': 8,
                'max_position_embeddings': 128,
            }

            topo = self.get_topology(mp_size, pp_size, world_size)
            gpt2_pipe_model = GPT2ModelPipe(num_layers=8,
                                            num_stages=pp_size,
                                            mp_size=mp_size,
                                            args_others=args_defaults,
                                            topo=topo)
            model = self.get_deepspeed_model(gpt2_pipe_model, tmpdir)

            tag = 'pp_basic'
            state_dict = {}
            state_dict['checkpoint_version'] = get_megatron_version()
            model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)

            if model.is_first_stage() or model.is_last_stage():
                inputs = self.get_inputs()
                loader = RepeatingLoader([(inputs[0], 0)])
                data_iter = iter(loader)
            else:
                data_iter = None

            baseline = model.eval_batch(data_iter=data_iter,
                                        compute_loss=False,
                                        reduce_output=None)

            dist.barrier()
            model.load_checkpoint(tmpdir,
                                  tag=tag,
                                  load_optimizer_states=False,
                                  load_lr_scheduler_states=False)
            dist.barrier()

            test = model.eval_batch(data_iter=data_iter,
                                    compute_loss=False,
                                    reduce_output=None)

            if test is not None:
                assert len(baseline) == len(test)
                # Compare outputs of each microbatch
                for mb in range(len(baseline)):
                    for b, t in zip(baseline[mb], test[mb]):
                        if b.is_floating_point():  # don't compare masks
                            assert torch.allclose(
                                b, t, atol=1e-07
                            ), f"Baseline output {baseline} is not equal to save-then-load output {test}"
        def _run_baseline(inputs, tag, output, quit_event):
            reset_random()
            args_defaults = {
                'num_layers': 8,
                'hidden_size': 128,
                'num_attention_heads': 8,
                'max_position_embeddings': 128,
            }

            topo = self.get_topology(mp_size, pp_size, mp_size * pp_size)
            gpt2_pipe_model = GPT2ModelPipe(num_layers=8,
                                            num_stages=pp_size,
                                            mp_size=mp_size,
                                            args_others=args_defaults,
                                            topo=topo)
            model = self.get_deepspeed_model(gpt2_pipe_model, tmpdir)

            with torch.no_grad():
                inputs = [x.cuda() for x in inputs]
                if model.is_first_stage() or model.is_last_stage():
                    loader = RepeatingLoader([((inputs[0], inputs[1]), 0)])
                    data_iter = iter(loader)
                else:
                    data_iter = None

                baseline = model.eval_batch(data_iter=data_iter,
                                            compute_loss=False,
                                            reduce_output=None)

                if baseline is not None:
                    # baseline should be [[hidden, True]]]
                    assert len(baseline) == 1
                    assert len(baseline[0]) == 2
                    assert torch.is_tensor(baseline[0][0])
                    assert baseline[0][1].numel() == 1
                    output.put(baseline[0][0].cpu())

                state_dict = {}
                state_dict['checkpoint_version'] = get_megatron_version()
                model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)
                quit_event.wait()
        def _run_baseline(inputs, tag, output, quit_event):
            reset_random()
            args_defaults = {
                'num_layers': 2,
                'hidden_size': 128,
                'num_attention_heads': 8,
                'max_position_embeddings': 128,
            }

            model = get_gpt2_model(args_defaults, mp_size=mp_size)
            model = self.get_deepspeed_model(model, tmpdir)

            model.eval()

            with torch.no_grad():
                baseline = model(inputs[0].cuda(), inputs[1].cuda(),
                                 inputs[2].cuda())
                if dist.get_rank() == 0:
                    output.put(baseline.cpu())

                state_dict = {}
                state_dict['checkpoint_version'] = get_megatron_version()
                model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)
                quit_event.wait()