def __init__( self, *, adv: nn.Module, enc: nn.Module, clf: nn.Module, lr: float = 3.0e-4, weight_decay: float = 0.0, lr_initial_restart: int = 10, lr_restart_mult: int = 2, lr_sched_interval: TrainingMode = TrainingMode.epoch, lr_sched_freq: int = 1, ) -> None: super().__init__( lr=lr, weight_decay=weight_decay, lr_initial_restart=lr_initial_restart, lr_restart_mult=lr_restart_mult, lr_sched_interval=lr_sched_interval, lr_sched_freq=lr_sched_freq, ) self.adv = adv self.enc = enc self.clf = clf self._loss_adv_fn = CrossEntropyLoss() self._loss_clf_fn = CrossEntropyLoss() self.automatic_optimization = False # Mark for manual optimization
def __init__( self, encoder: nn.Module, classifier: nn.Module, lr: float = 3.0e-4, weight_decay: float = 0.0, lr_initial_restart: int = 10, lr_restart_mult: int = 2, lr_sched_interval: TrainingMode = TrainingMode.epoch, lr_sched_freq: int = 1, loss_fn: Loss = CrossEntropyLoss(reduction="mean"), ) -> None: encoder = make_no_grad(encoder).eval() model = nn.Sequential(encoder, classifier) super().__init__( model=model, lr=lr, weight_decay=weight_decay, lr_initial_restart=lr_initial_restart, lr_restart_mult=lr_restart_mult, lr_sched_interval=lr_sched_interval, lr_sched_freq=lr_sched_freq, loss_fn=loss_fn, ) self.encoder = encoder self.classifier = classifier
def __init__( self, encoder: nn.Module, clf: nn.Module, lr: float, weight_decay: float, fairness: FairnessType, mixup_lambda: Optional[float] = None, alpha: float = 1.0, lr_initial_restart: int = 10, lr_restart_mult: int = 2, lr_sched_interval: TrainingMode = TrainingMode.epoch, lr_sched_freq: int = 1, ) -> None: super().__init__( lr=lr, weight_decay=weight_decay, lr_initial_restart=lr_initial_restart, lr_restart_mult=lr_restart_mult, lr_sched_interval=lr_sched_interval, lr_sched_freq=lr_sched_freq, ) self.encoder = encoder self.clf = clf self.net = nn.Sequential(self.encoder, self.clf) self.fairness = fairness self.mixup_lambda = mixup_lambda self.alpha = alpha self._loss_fn = CrossEntropyLoss(reduction=ReductionType.mean) self.test_acc = torchmetrics.Accuracy() self.train_acc = torchmetrics.Accuracy() self.val_acc = torchmetrics.Accuracy()
def test_xent(out_dim: int, dtype: str, reduction_type: ReductionType) -> None: target = torch.randint(0, max(out_dim, 2), (BATCH_SIZE, 1), dtype=getattr(torch, dtype)) pred = torch.randn(BATCH_SIZE, out_dim) iw = torch.randn(BATCH_SIZE) loss_fn = CrossEntropyLoss(reduction=reduction_type) loss_fn(input=pred, target=target, instance_weight=iw) loss_fn(input=pred, target=target.squeeze(), instance_weight=iw) if out_dim == 1: loss_fn(input=pred.squeeze(), target=target, instance_weight=iw)
def __init__( self, *, adv: nn.Module, enc: nn.Module, clf: nn.Module, lr: float = 3.0e-4, weight_decay: float = 0.0, grl_lambda: float = 1.0, lr_initial_restart: int = 10, lr_restart_mult: int = 2, lr_sched_interval: TrainingMode = TrainingMode.epoch, lr_sched_freq: int = 1, ) -> None: super().__init__( lr=lr, weight_decay=weight_decay, lr_initial_restart=lr_initial_restart, lr_restart_mult=lr_restart_mult, lr_sched_interval=lr_sched_interval, lr_sched_freq=lr_sched_freq, ) self.grl_lambda = grl_lambda self.learning_rate = lr self.weight_decay = weight_decay self.lr_initial_restart = lr_initial_restart self.lr_restart_mult = lr_restart_mult self.lr_sched_interval = lr_sched_interval self.lr_sched_freq = lr_sched_freq self.adv = adv self.enc = enc self.clf = clf self._loss_adv_fn = CrossEntropyLoss() self._loss_clf_fn = CrossEntropyLoss()
def __init__( self, *, lr: float, weight_decay: float, disc_steps: int, fairness: FairnessType, recon_weight: float, clf_weight: float, adv_weight: float, enc: nn.Module, dec: nn.Module, adv: nn.Module, clf: nn.Module, lr_initial_restart: int = 10, lr_restart_mult: int = 2, lr_sched_interval: TrainingMode = TrainingMode.epoch, lr_sched_freq: int = 1, ) -> None: super().__init__( lr=lr, weight_decay=weight_decay, lr_initial_restart=lr_initial_restart, lr_restart_mult=lr_restart_mult, lr_sched_interval=lr_sched_interval, lr_sched_freq=lr_sched_freq, ) self.enc = enc self.dec = dec self.adv = adv self.clf = clf self._clf_loss = CrossEntropyLoss(reduction=ReductionType.mean) self._recon_loss = nn.L1Loss(reduction="mean") self._adv_clf_loss = nn.L1Loss(reduction="none") self.disc_steps = disc_steps self.fairness = fairness self.clf_weight = clf_weight self.adv_weight = adv_weight self.recon_weight = recon_weight self.test_acc = torchmetrics.Accuracy() self.train_acc = torchmetrics.Accuracy() self.val_acc = torchmetrics.Accuracy()
def __init__( self, model: nn.Module, lr: float = 3.0e-4, weight_decay: float = 0.0, lr_initial_restart: int = 10, lr_restart_mult: int = 2, lr_sched_interval: TrainingMode = TrainingMode.epoch, lr_sched_freq: int = 1, loss_fn: Optional[Loss] = None, ) -> None: super().__init__( lr=lr, weight_decay=weight_decay, lr_initial_restart=lr_initial_restart, lr_restart_mult=lr_restart_mult, lr_sched_interval=lr_sched_interval, lr_sched_freq=lr_sched_freq, ) self.model = model self.loss_fn = (CrossEntropyLoss( reduction=ReductionType.mean) if loss_fn is None else loss_fn)
def __init__( self, *, encoder: nn.Module, clf: nn.Module, lr: float = 3.0e-4, weight_decay: float = 0.0, lr_initial_restart: int = 10, lr_restart_mult: int = 2, lr_sched_interval: TrainingMode = TrainingMode.epoch, lr_sched_freq: int = 1, ) -> None: super().__init__( encoder=encoder, clf=clf, lr=lr, weight_decay=weight_decay, lr_initial_restart=lr_initial_restart, lr_restart_mult=lr_restart_mult, lr_sched_interval=lr_sched_interval, lr_sched_freq=lr_sched_freq, loss_fn=CrossEntropyLoss(reduction=ReductionType.mean), )