Exemplo n.º 1
0
 def _test_zero_empty_partition(args):
     hidden_dim = 1
     model = SimpleModel(hidden_dim)
     # Ensure model has 2 parameters, to cause empty partition with DP=3
     assert len(list(model.parameters())) == 2
     model, _, _, _ = deepspeed.initialize(args=args,
                                           model=model,
                                           model_parameters=model.parameters())
     model.step()
Exemplo n.º 2
0
    def _test_adam_fp16_zero_onecycle_compatibility(args, zero_stage,
                                                    hidden_dim):
        model = SimpleModel(hidden_dim)

        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Exemplo n.º 3
0
    def _go(args):
        model = SimpleModel(hidden_dim)

        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=10,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for _, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Exemplo n.º 4
0
    def _test_zero_empty_partition(args):
        hidden_dim = 1
        model = SimpleModel(hidden_dim)
        # Ensure model has 2 parameters, to cause empty partition with DP=3
        assert len(list(model.parameters())) == 2
        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        # Now make sure things work..
        data_loader = random_dataloader(model=model,
                                        total_samples=1,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Exemplo n.º 5
0
    def _test_zero_static_scale(args):
        hidden_dim = 10
        model = SimpleModel(hidden_dim)
        model, optim, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        # Ensure the static scaler is configured.
        assert optim.dynamic_loss_scale == False
        assert optim.loss_scaler.loss_scale == 138.

        # Now make sure things work..
        data_loader = random_dataloader(model=model,
                                        total_samples=10,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Exemplo n.º 6
0
    def _helper():
        parser = argparse.ArgumentParser()
        args = parser.parse_args(args='')
        args.deepscale_config = config_path
        args.local_rank = 0

        hidden_dim = 10

        model = SimpleModel(hidden_dim=hidden_dim)

        model, _, _, _ = deepspeed.initialize(args=args, model=model)
        data_loader = random_dataloader(model=model,
                                        total_samples=5,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            with pytest.raises(AssertionError):
                model.backward(loss)
            with pytest.raises(AssertionError):
                model.step()