def __init__(self, _run, lr, loss_scale): # setup logging logging.basicConfig(level=logging.INFO) self.compute_metrics = False self.logger = logging.getLogger(__name__) # get output directory self.output_dir = get_output_dir() self.dataset = get_dataset() self.mass = torch.tensor(self.dataset.layer_mass.values).view( -1, 1, 1).float() self.z = torch.tensor(self.dataset.z.values).float() self.time_step = get_timestep(self.dataset) self.train_loader = get_data_loader(self.dataset, train=True) self.test_loader = get_data_loader(self.dataset, train=False) self.model = get_model(*get_pre_post(self.dataset, self.train_loader)) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.criterion = weighted_mean_squared_error(weights=self.mass / self.mass.mean(), dim=-3) self.plot_manager = TrainingPlotManager(ex, self.model, self.dataset) self.setup_validation_engine() self.setup_engine()
def __init__(self, _run, lr, loss_scale): # setup logging logging.basicConfig(level=logging.INFO) # db = MongoDBLogger() # experiment = Experiment(api_key="fEusCnWmzAtmrB0FbucyEggW2") self.logger = logging.getLogger(__name__) # get output directory self.output_dir = get_output_dir() self.dataset = get_dataset() self.mass = torch.tensor(self.dataset.layer_mass.values).view( -1, 1, 1).float() self.z = torch.tensor(self.dataset.z.values).float() self.time_step = get_timestep(self.dataset) self.train_loader = get_data_loader(self.dataset) self.model = get_model(*get_pre_post(self.dataset)) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.criterion = weighted_mean_squared_error(weights=self.mass / self.mass.mean(), dim=-3) self.plot_manager = TrainingPlotManager(ex, self.model, self.dataset) self.setup_engine()
def __init__(self, _run, lr, loss_scale, train_data, test_data, lr_decay_rate=None, lr_step_size=5): # setup logging logging.basicConfig(level=logging.INFO) self.compute_metrics = False self.logger = logging.getLogger(__name__) # get output directory self.output_dir = get_output_dir() train_dataset = get_dataset(train_data) test_dataset = get_dataset(test_data) self.mass = torch.tensor(train_dataset.layer_mass.values).view( -1, 1, 1).float() self.z = torch.tensor(train_dataset.z.values).float() self.time_step = get_timestep(train_dataset) self.train_loader = get_data_loader(train_dataset) self.test_loader = get_data_loader(test_dataset) self.model = get_model(*get_pre_post(train_dataset, self.train_loader)) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) if lr_decay_rate is None: self.lr_scheduler = None else: self.lr_scheduler = StepLR(self.optimizer, step_size=lr_step_size, gamma=lr_decay_rate) self.criterion = weighted_mean_squared_error(weights=self.mass / self.mass.mean(), dim=-3) self.plot_manager = get_plot_manager(self.model) self.setup_validation_engine() self.setup_engine()
def train_pre_post(prepost): """Train the pre and post processing modules""" dataset = get_dataset() logging.info(f"Saving Pre/Post module to {prepost['path']}") torch.save(get_pre_post(dataset), prepost['path'])