def get_model(config: Dict[str, Any], args: argparse.Namespace, dataset: FairnessDataset) -> pl.LightningModule: """Selects and inits a model instance for training. Args: config: Dict with hyperparameters (learning rate, batch size, eta). args: Object from the argument parser that defines various settings of the model, dataset and training. dataset: Dataset instance that will be used for training. Returns: An instantiated model; one of the following: Model based on Adversarially Reweighted Learning (ARL). Model based on Distributionally Robust Optimization (DRO). Model based on Inverse Probability Weighting (IPW). Baseline model; simple fully-connected or convolutional (TODO) network. """ model: pl.LightningModule model = ARL( config=config, # for hparam tuning input_shape=dataset.dimensionality, pretrain_steps=args.pretrain_steps, prim_hidden=args.prim_hidden, adv_hidden=args.adv_hidden, optimizer=OPT_BY_NAME[args.opt], dataset_type=args.dataset_type, adv_input=set(args.adv_input), num_groups=len(dataset.protected_index2value), opt_kwargs={"initial_accumulator_value": 0.1} if args.tf_mode else {}) if args.tf_mode: def init_weights(layer): if type(layer) == torch.nn.Linear: torch.nn.init.xavier_uniform_(layer.weight) torch.nn.init.zeros_(layer.bias) model.apply(init_weights) return model
def get_model(config: Dict[str, Any], args: argparse.Namespace, dataset: FairnessDataset) -> pl.LightningModule: """Selects and inits a model instance for training. Args: config: Dict with hyperparameters (learning rate, batch size, eta). args: Object from the argument parser that defines various settings of the model, dataset and training. dataset: Dataset instance that will be used for training. Returns: An instantiated model; one of the following: Model based on Adversarially Reweighted Learning (ARL). Model based on Distributionally Robust Optimization (DRO). Model based on Inverse Probability Weighting (IPW). Baseline model; simple fully-connected or convolutional (TODO) network. """ model: pl.LightningModule if args.model == 'ARL': model = ARL( config=config, # for hparam tuning input_shape=dataset.dimensionality, pretrain_steps=args.pretrain_steps, prim_hidden=args.prim_hidden, adv_hidden=args.adv_hidden, optimizer=OPT_BY_NAME[args.opt], dataset_type=args.dataset_type, adv_input=set(args.adv_input), num_groups=len(dataset.protected_index2value), opt_kwargs={"initial_accumulator_value": 0.1} if args.tf_mode else {}) elif args.model == 'ARL_strong': model = ARL( config=config, # for hparam tuning input_shape=dataset.dimensionality, pretrain_steps=args.pretrain_steps, prim_hidden=args.prim_hidden, adv_hidden=args.adv_hidden, optimizer=OPT_BY_NAME[args.opt], dataset_type=args.dataset_type, adv_input=set(args.adv_input), num_groups=len(dataset.protected_index2value), adv_cnn_strength='strong', opt_kwargs={"initial_accumulator_value": 0.1} if args.tf_mode else {}) elif args.model == 'ARL_weak': model = ARL( config=config, # for hparam tuning input_shape=dataset.dimensionality, pretrain_steps=args.pretrain_steps, prim_hidden=args.prim_hidden, adv_hidden=args.adv_hidden, optimizer=OPT_BY_NAME[args.opt], dataset_type=args.dataset_type, adv_input=set(args.adv_input), num_groups=len(dataset.protected_index2value), adv_cnn_strength='weak', opt_kwargs={"initial_accumulator_value": 0.1} if args.tf_mode else {}) elif args.model == 'DRO': model = DRO( config=config, # for hparam tuning num_features=dataset.dimensionality, hidden_units=args.prim_hidden, pretrain_steps=args.pretrain_steps, k=args.k, optimizer=OPT_BY_NAME[args.opt], opt_kwargs={"initial_accumulator_value": 0.1} if args.tf_mode else {}) elif args.model == 'IPW': model = IPW( config=config, # for hparam tuning num_features=dataset.dimensionality, hidden_units=args.prim_hidden, optimizer=OPT_BY_NAME[args.opt], group_probs=dataset.group_probs, sensitive_label=args.sensitive_label, opt_kwargs={"initial_accumulator_value": 0.1} if args.tf_mode else {}) args.pretrain_steps = 0 # NO PRETRAINING elif args.model == 'baseline': model = BaselineModel( config=config, # for hparam tuning num_features=dataset.dimensionality, hidden_units=args.prim_hidden, optimizer=OPT_BY_NAME[args.opt], dataset_type=args.dataset_type, opt_kwargs={"initial_accumulator_value": 0.1} if args.tf_mode else {}) args.pretrain_steps = 0 # NO PRETRAINING # if Tensorflow mode is active, we use the TF default initialization, # which means Xavier/Glorot uniform (with gain 1) for the weights # and 0 bias if args.tf_mode: def init_weights(layer): if type(layer) == torch.nn.Linear: torch.nn.init.xavier_uniform_(layer.weight) torch.nn.init.zeros_(layer.bias) model.apply(init_weights) return model