def __init__(self, model_creator, optimizer_creator, loss_creator=None, scheduler_creator=None, training_operator_cls=None, config=None, use_tqdm=False, scheduler_step_freq=None): self.model_creator = model_creator self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator self.scheduler_creator = scheduler_creator self.training_operator_cls = training_operator_cls or TrainingOperator self.config = {} if config is None else config self.timers = utils.TimerCollection() self.epochs = 0 self.models = None self.optimizers = None self.criterion = None self.schedulers = None self.train_loader = None self.validation_loader = None self.training_operator = None self.use_tqdm = use_tqdm self.scheduler_step_freq = scheduler_step_freq self.backend = "torch-local" self.rank = 0 self.size = 0
def __init__(self, model_creator, data_creator, optimizer_creator, loss_creator=None, scheduler_creator=None, training_operator_cls=None, config=None, serialize_data_creation=True, use_tqdm=False, scheduler_step_freq=None): self.model_creator = model_creator self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator self.data_creator = data_creator self.scheduler_creator = scheduler_creator self.training_operator_cls = training_operator_cls or TrainingOperator self.config = {} if config is None else config self.timers = utils.TimerCollection() self.epochs = 0 self.models = None self.optimizers = None self.criterion = None self.schedulers = None self.train_loader = None self.validation_loader = None self.training_operator = None self.serialize_data_creation = serialize_data_creation self.use_tqdm = use_tqdm self.scheduler_step_freq = scheduler_step_freq