def train_tempering(model: nn.Module = None, dataset_name="MNIST"): if model is None: model = NetBinary(fc_sizes=linear_features[dataset_name], batch_norm=False) model = binarize_model(model) trainer = ParallelTempering(model, criterion=nn.CrossEntropyLoss(), dataset_name=dataset_name, trainer_cls=TrainerMCMCGibbs, n_chains=5, monitor_kwargs=dict(watch_parameters=False)) trainer.train(n_epoch=100, save=False, with_mutual_info=False, epoch_update_step=1) return model
def train_mcmc(model: nn.Module = None, dataset_name="MNIST"): if model is None: model = NetBinary(fc_sizes=linear_features[dataset_name], batch_norm=False) model = binarize_model(model) trainer = TrainerMCMC(model, criterion=nn.CrossEntropyLoss(), dataset_name=dataset_name, flip_ratio=0.1) trainer.train(n_epoch=500, save=False, with_mutual_info=False, epoch_update_step=3)
def train_gradient(model: nn.Module = None, is_binary=True, dataset_name="MNIST"): if model is None: model = NetBinary(fc_sizes=linear_features[dataset_name]) optimizer = AdamCustomDecay(model.parameters(), lr=1e-2, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, threshold=1e-3, min_lr=1e-4) if is_binary: model = binarize_model(model, keep_data=False) trainer_cls = TrainerGradBinary else: trainer_cls = TrainerGradFullPrecision trainer = trainer_cls(model, criterion=nn.CrossEntropyLoss(), dataset_name=dataset_name, optimizer=optimizer, scheduler=scheduler) trainer.train(n_epoch=500, save=False, with_mutual_info=True)