def __init__( self, cardinality: int = 2, verbose: bool = True, device: str = "cpu", metric: str = "accuracy", tie_break_policy: str = "abstain", n_epochs: int = 100, lr: float = 0.01, l2: float = 0.0, optimizer: str = "sgd", optimizer_config: Optional[OptimizerConfig] = None, lr_scheduler: str = "constant", lr_scheduler_config: Optional[LRSchedulerConfig] = None, prec_init: float = 0.7, seed: int = np.random.randint(1e6), log_freq: int = 10, mu_eps: Optional[float] = None, class_balance: Optional[List[float]] = None, **kwargs: Any, ) -> None: self.cardinality = cardinality self.verbose = verbose self.device = device self.metric = metric self.tie_break_policy = tie_break_policy self.n_epochs = n_epochs self.lr = lr self.l2 = l2 self.optimizer = optimizer self.optimizer_config = ( optimizer_config if optimizer_config is not None else OptimizerConfig() # type: ignore ) self.lr_scheduler = lr_scheduler self.lr_scheduler_config = ( lr_scheduler_config if lr_scheduler_config is not None else LRSchedulerConfig() # type: ignore ) self.prec_init = prec_init self.seed = seed self.log_freq = log_freq self.mu_eps = mu_eps self.class_balance = class_balance self.label_model = LabelModel(cardinality=self.cardinality, verbose=self.verbose, device=self.device)
class TrainConfig(Config): """Settings for the fit() method of LabelModel. Parameters ---------- n_epochs The number of epochs to train (where each epoch is a single optimization step) lr Base learning rate (will also be affected by lr_scheduler choice and settings) l2 Centered L2 regularization strength optimizer Which optimizer to use (one of ["sgd", "adam", "adamax"]) optimizer_config Settings for the optimizer lr_scheduler Which lr_scheduler to use (one of ["constant", "linear", "exponential", "step"]) lr_scheduler_config Settings for the LRScheduler prec_init LF precision initializations / priors seed A random seed to initialize the random number generator with log_freq Report loss every this many epochs (steps) mu_eps Restrict the learned conditional probabilities to [mu_eps, 1-mu_eps] """ n_epochs: int = 100 lr: float = 0.01 l2: float = 0.0 optimizer: str = "sgd" optimizer_config: OptimizerConfig = OptimizerConfig() # type: ignore lr_scheduler: str = "constant" lr_scheduler_config: LRSchedulerConfig = LRSchedulerConfig( ) # type: ignore prec_init: Union[float, List[float], np.ndarray, torch.Tensor] = 0.7 seed: int = np.random.randint(1e6) log_freq: int = 10 mu_eps: Optional[float] = None
class TrainerConfig(Config): """Settings for the Trainer. Parameters ---------- seed A random seed to set before training; if None, no seed is set n_epochs The number of epochs to train lr Base learning rate (will also be affected by lr_scheduler choice and settings) l2 L2 regularization coefficient (weight decay) grad_clip The value that the gradient norm will be clipped to if it exceeds it train_split The name of the split to use as the training set valid_split The name of the split to use as the validation set test_split The name of the split to use as the test set progress_bar If True, print a tqdm progress bar during training model_config Settings for the MultitaskClassifier log_manager_config Settings for the LogManager checkpointing If True, use a Checkpointer to save the best model during training checkpointer_config Settings for the Checkpointer logging If True, log metrics (to file or Tensorboard) during training log_writer The type of LogWriter to use (one of ["json", "tensorboard"]) log_writer_config Settings for the LogWriter optimizer Which optimizer to use (one of ["sgd", "adam", "adamax"]) optimizer_config Settings for the optimizer lr_scheduler Which lr_scheduler to use (one of ["constant", "linear", "exponential", "step"]) lr_scheduler_config Settings for the LRScheduler batch_scheduler Which batch scheduler to use (in what order batches will be drawn from multiple tasks) """ seed: Optional[int] = None n_epochs: int = 1 lr: float = 0.01 l2: float = 0.0 grad_clip: float = 1.0 train_split: str = "train" valid_split: str = "valid" test_split: str = "test" progress_bar: bool = True model_config: ClassifierConfig = ClassifierConfig() # type:ignore log_manager_config: LogManagerConfig = LogManagerConfig() # type:ignore checkpointing: bool = False checkpointer_config: CheckpointerConfig = CheckpointerConfig( ) # type:ignore logging: bool = False log_writer: str = "tensorboard" log_writer_config: LogWriterConfig = LogWriterConfig() # type:ignore optimizer: str = "adam" optimizer_config: OptimizerConfig = OptimizerConfig() # type:ignore lr_scheduler: str = "constant" lr_scheduler_config: LRSchedulerConfig = LRSchedulerConfig() # type:ignore batch_scheduler: str = "shuffled"