Esempio n. 1
0
def create_model_config(args, benchmark_config=None, model_specs=None):
    """Return a dict with the given model, dataset and optimizer."""

    # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    device = torch.device("cpu")

    if args.model_name == "lm":
        if args.use_synthetic_data:
            dataloader_fn = get_synthetic_dataloaders
        else:
            dataloader_fn = get_real_dataloaders

        data = dataloader_fn(args, device, benchmark_config, model_specs)
        model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
        return {
            "model": model,
            "optimizer": optimizer,
            "data": data,
        }
    elif args.model_name == "seq":

        data = get_synthetic_dataloaders(
            args, device, offload_seq.get_benchmark_config(), offload_seq.get_model_config()
        )
        model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
        return {
            "model": model,
            "optimizer": optimizer,
            "data": data,
        }
    else:
        raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}")
Esempio n. 2
0
def get_model_specs(model_name):
    """Return a dict with configurations required for configuring `model_name` model."""

    if model_name == "lm":
        return lm_wikitext2.get_model_config()
    elif model_name == "seq":
        return offload_seq.get_model_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)