def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() meter = Meter(split=split) loader = self.val_loader if split == "val" else self.test_loader for i, batch in enumerate(loader): # Forward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) # Update meter. meter_update_dict = {"loss": loss.item()} meter_update_dict.update(metrics) meter.update(meter_update_dict) # Make plots. if self.logger is not None and epoch is not None: log_dict = meter.get_scalar_dict() log_dict.update({"epoch": epoch + 1}) self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(meter)
def load_extras(self): # learning rate scheduler. scheduler_lambda_fn = lambda x: warmup_lr_lambda( x, self.config["optim"]) self.scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=scheduler_lambda_fn) # metrics. self.meter = Meter(split="train")
def load_extras(self): # learning rate scheduler. self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=self.config["optim"]["lr_milestones"], gamma=self.config["optim"]["lr_gamma"], ) # metrics. self.meter = Meter(split="train")
def validate_relaxation(self, split="val", epoch=None): print("### Evaluating ML-relaxation") self.model.eval() metrics = {} meter = Meter(split=split) mae_energy, mae_structure = relax_eval( trainer=self, traj_dir=self.config["task"]["relaxation_dir"], metric=self.config["task"]["metric"], steps=self.config["task"].get("relaxation_steps", 300), fmax=self.config["task"].get("relaxation_fmax", 0.01), results_dir=self.config["cmd"]["results_dir"], ) metrics["relaxed_energy/{}".format( self.config["task"]["metric"])] = mae_energy metrics["relaxed_structure/{}".format( self.config["task"]["metric"])] = mae_structure meter.update(metrics) # Make plots. if self.logger is not None and epoch is not None: log_dict = meter.get_scalar_dict() log_dict.update({"epoch": epoch + 1}) self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(meter) return mae_energy, mae_structure
def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() meter = Meter(split=split) loader = self.val_loader if split == "val" else self.test_loader for i, batch in enumerate(loader): batch = batch.to(self.device) # Forward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) # Update meter. meter_update_dict = {"loss": loss.item()} meter_update_dict.update(metrics) meter.update(meter_update_dict) # Make plots. if self.logger is not None and epoch is not None: log_dict = meter.get_scalar_dict() log_dict.update({"epoch": epoch + 1}) self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(meter) return ( float(meter.loss.global_avg), float(meter.meters[self.config["task"]["labels"][0] + "/" + self.config["task"]["metric"]].global_avg), )
class BaseTrainer: def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, name="base_trainer", ): self.name = name if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if run_dir is None: run_dir = os.getcwd() run_dir = Path(run_dir) timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp(timestamp).strftime( "%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": str(run_dir / "checkpoints" / timestamp), "results_dir": str(run_dir / "results" / timestamp), "logs_dir": str(run_dir / "logs" / logger / timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name) def load(self): self.load_seed_from_config() self.load_logger() self.load_task() self.load_model() self.load_criterion() self.load_optimizer() self.load_extras() # Note: this function is now deprecated. We build config outside of trainer. # See build_config in ocpmodels.common.utils.py. def load_config_from_yaml_and_cmd(self, args): self.config = build_config(args) # AMP Scaler self.scaler = (torch.cuda.amp.GradScaler() if self.config["amp"] else None) # device self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # Are we just running sanity checks? self.is_debug = args.debug self.is_vis = args.vis # timestamps and directories args.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if args.identifier: args.timestamp += "-{}".format(args.identifier) args.checkpoint_dir = os.path.join("checkpoints", args.timestamp) args.results_dir = os.path.join("results", args.timestamp) args.logs_dir = os.path.join("logs", self.config["logger"], args.timestamp) print(yaml.dump(self.config, default_flow_style=False)) for arg in vars(args): print("{:<20}: {}".format(arg, getattr(args, arg))) # TODO(abhshkdz): Handle these parameters better. Maybe move to yaml. self.config["cmd"] = args.__dict__ del args if not self.is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) # Dump config parameters json.dump( self.config, open( os.path.join(self.config["cmd"]["checkpoint_dir"], "config.json"), "w", ), ) def load_seed_from_config(self): # https://pytorch.org/docs/stable/notes/randomness.html seed = self.config["cmd"]["seed"] if seed is None: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_logger(self): self.logger = None if not self.is_debug and distutils.is_master(): assert (self.config["logger"] is not None), "Specify logger in config" self.logger = registry.get_logger_class(self.config["logger"])( self.config) def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) dataset = registry.get_dataset_class(self.config["task"]["dataset"])( self.config["dataset"]) if self.config["task"]["dataset"] in ["qm9", "dogss"]: num_targets = dataset.data.y.shape[-1] if ("label_index" in self.config["task"] and self.config["task"]["label_index"] is not False): dataset.data.y = dataset.data.y[:, int(self.config["task"] ["label_index"])] num_targets = 1 else: num_targets = 1 self.num_targets = num_targets ( self.train_loader, self.val_loader, self.test_loader, ) = dataset.get_dataloaders( batch_size=int(self.config["optim"]["batch_size"])) # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): self.normalizers["target"] = Normalizer( self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): self.normalizers["grad_target"] = Normalizer( self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], self.device, ) self.normalizers["grad_target"].mean.fill_(0) if self.is_vis and self.config["task"]["dataset"] != "qm9": # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots) def load_model(self): # Build model if distutils.is_master(): print("### Loading model: {}".format(self.config["model"])) # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50) else: raise NotImplementedError self.model = registry.get_model_class(self.config["model"])( self.train_loader.dataset[0].x.shape[-1] if hasattr(self.train_loader.dataset[0], "x") and self.train_loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) if distutils.is_master(): print("### Loaded {} with {} parameters.".format( self.model.__class__.__name__, self.model.num_params)) if self.logger is not None: self.logger.watch(self.model) self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1, ) if distutils.initialized(): self.model = DistributedDataParallel(self.model, device_ids=[self.device]) def load_pretrained(self, checkpoint_path=None, ddp_to_dp=False): if checkpoint_path is None or os.path.isfile(checkpoint_path) is False: print(f"Checkpoint: {checkpoint_path} not found!") return False print("### Loading checkpoint from: {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) # Load model, optimizer, normalizer state dict. # if trained with ddp and want to load in non-ddp, modify keys from # module.module.. -> module.. if ddp_to_dp: new_dict = OrderedDict() for k, v in checkpoint["state_dict"].items(): name = k[7:] new_dict[name] = v self.model.load_state_dict(new_dict) else: self.model.load_state_dict(checkpoint["state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) for key in checkpoint["normalizers"]: if key in self.normalizers: self.normalizers[key].load_state_dict( checkpoint["normalizers"][key]) if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) return True # TODO(abhshkdz): Rename function to something nicer. # TODO(abhshkdz): Support multiple loss functions. def load_criterion(self): self.criterion = self.config["optim"].get("criterion", nn.L1Loss()) def load_optimizer(self): self.optimizer = optim.AdamW( self.model.parameters(), self.config["optim"]["lr_initial"], # weight_decay=3.0 ) def load_extras(self): # learning rate scheduler. scheduler_lambda_fn = lambda x: warmup_lr_lambda( x, self.config["optim"]) self.scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=scheduler_lambda_fn) # metrics. self.meter = Meter(split="train") def save(self, epoch, metrics): if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) def train(self, max_epochs=None, return_metrics=False): # TODO(abhshkdz): Timers for dataloading and forward pass. num_epochs = (max_epochs if max_epochs is not None else self.config["optim"]["max_epochs"]) for epoch in range(num_epochs): self.model.train() for i, batch in enumerate(self.train_loader): batch = batch.to(self.device) # Forward, loss, backward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) self._backward(loss) # Update meter. meter_update_dict = { "epoch": epoch + (i + 1) / len(self.train_loader), "loss": loss.item(), } meter_update_dict.update(metrics) self.meter.update(meter_update_dict) # Make plots. if self.logger is not None: self.logger.log( meter_update_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) # Print metrics. if i % self.config["cmd"]["print_every"] == 0: print(self.meter) self.scheduler.step() with torch.no_grad(): if self.val_loader is not None: v_loss, v_mae = self.validate(split="val", epoch=epoch) if self.test_loader is not None: test_loss, test_mae = self.validate(split="test", epoch=epoch) if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) if return_metrics: return { "training_loss": float(self.meter.loss.global_avg), "training_mae": float(self.meter.meters[ self.config["task"]["labels"][0] + "/" + self.config["task"]["metric"]].global_avg), "validation_loss": v_loss, "validation_mae": v_mae, "test_loss": test_loss, "test_mae": test_mae, } def validate(self, split="val", epoch=None): if distutils.is_master(): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def _forward(self, batch, compute_metrics=True): out = {} # enable gradient wrt input. if "grad_input" in self.config["task"]: inp_for_grad = batch.pos batch.pos = batch.pos.requires_grad_(True) # forward pass. if self.config["model_attributes"].get("regress_forces", False): output, output_forces = self.model(batch) else: output = self.model(batch) if output.shape[-1] == 1: output = output.view(-1) out["output"] = output force_output = None if self.config["model_attributes"].get("regress_forces", False): out["force_output"] = output_forces force_output = output_forces if ("grad_input" in self.config["task"] and self.config["model_attributes"].get( "regress_forces", False) is False): force_output = -1 * torch.autograd.grad( output, inp_for_grad, grad_outputs=torch.ones_like(output), create_graph=True, retain_graph=True, )[0] out["force_output"] = force_output if not compute_metrics: return out, None metrics = {} if self.config["dataset"].get("normalize_labels", True): errors = eval(self.config["task"]["metric"])( self.normalizers["target"].denorm(output).cpu(), batch.y.cpu()).view(-1) else: errors = eval(self.config["task"]["metric"])( output.cpu(), batch.y.cpu()).view(-1) if ("label_index" in self.config["task"] and self.config["task"]["label_index"] is not False): # TODO(abhshkdz): Get rid of this edge case for QM9. # This is only because QM9 has multiple targets and we can either # jointly predict all of them or one particular target. metrics["{}/{}".format( self.config["task"]["labels"][self.config["task"] ["label_index"]], self.config["task"]["metric"], )] = errors[0] else: for i, label in enumerate(self.config["task"]["labels"]): metrics["{}/{}".format( label, self.config["task"]["metric"])] = errors[i] if "grad_input" in self.config["task"]: force_pred = force_output force_target = batch.force if self.config["task"].get("eval_on_free_atoms", True): mask = batch.fixed == 0 force_pred = force_pred[mask] force_target = force_target[mask] if self.config["dataset"].get("normalize_labels", True): grad_input_errors = eval(self.config["task"]["metric"])( self.normalizers["grad_target"].denorm(force_pred).cpu(), force_target.cpu(), ) else: grad_input_errors = eval(self.config["task"]["metric"])( force_pred.cpu(), force_target.cpu()) metrics["force_x/{}".format( self.config["task"]["metric"])] = grad_input_errors[0] metrics["force_y/{}".format( self.config["task"]["metric"])] = grad_input_errors[1] metrics["force_z/{}".format( self.config["task"]["metric"])] = grad_input_errors[2] return out, metrics def _compute_loss(self, out, batch): loss = [] if self.config["dataset"].get("normalize_labels", True): target_normed = self.normalizers["target"].norm(batch.y) else: target_normed = batch.y loss.append(self.criterion(out["output"], target_normed)) # TODO(abhshkdz): Test support for gradients wrt input. # TODO(abhshkdz): Make this general; remove dependence on `.forces`. if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): grad_target_normed = self.normalizers["grad_target"].norm( batch.force) else: grad_target_normed = batch.force # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): mask = batch.fixed == 0 loss.append(force_mult * self.criterion( out["force_output"][mask], grad_target_normed[mask])) else: loss.append( force_mult * self.criterion(out["force_output"], grad_target_normed)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss def _backward(self, loss): self.optimizer.zero_grad() loss.backward() # TODO(abhshkdz): Add support for gradient clipping. if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() def save_results(self, predictions, results_file, keys): if results_file is None: return results_file_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{distutils.get_rank()}.npz", ) np.savez_compressed( results_file_path, ids=predictions["id"], **{key: predictions[key] for key in keys}, ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) for key in keys: gather_results[key].extend(rank_results[key]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] for k in keys: if k == "forces": gather_results[k] = np.array(gather_results[k], dtype=object)[idx] else: gather_results[k] = np.array(gather_results[k])[idx] print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results)
def load_extras(self): self.scheduler = LRScheduler(self.optimizer, self.config["optim"]) # metrics. self.meter = Meter(split="train")
class BaseTrainer: def __init__(self, args=None, local_rank=0): # defaults. self.device = "cpu" self.is_debug = True self.is_vis = True # load config. if args is not None: self.load_config_from_yaml_and_cmd(args) def load(self): self.load_seed_from_config() self.load_logger() self.load_task() self.load_model() self.load_criterion() self.load_optimizer() self.load_extras() # Note: this function is now deprecated. We build config outside of trainer. # See build_config in ocpmodels.common.utils.py. def load_config_from_yaml_and_cmd(self, args): self.config = build_config(args) # AMP Scaler self.scaler = torch.cuda.amp.GradScaler( ) if self.config["amp"] else None # device self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # Are we just running sanity checks? self.is_debug = args.debug self.is_vis = args.vis # timestamps and directories args.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if args.identifier: args.timestamp += "-{}".format(args.identifier) args.checkpoint_dir = os.path.join("checkpoints", args.timestamp) args.results_dir = os.path.join("results", args.timestamp) args.logs_dir = os.path.join("logs", self.config["logger"], args.timestamp) print(yaml.dump(self.config, default_flow_style=False)) for arg in vars(args): print("{:<20}: {}".format(arg, getattr(args, arg))) # TODO(abhshkdz): Handle these parameters better. Maybe move to yaml. self.config["cmd"] = args.__dict__ del args if not self.is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) # Dump config parameters json.dump( self.config, open( os.path.join(self.config["cmd"]["checkpoint_dir"], "config.json"), "w", ), ) def load_seed_from_config(self): # https://pytorch.org/docs/stable/notes/randomness.html seed = self.config["cmd"]["seed"] if seed is None: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_logger(self): self.logger = None if not self.is_debug and distutils.is_master(): assert (self.config["logger"] is not None), "Specify logger in config" self.logger = registry.get_logger_class(self.config["logger"])( self.config) def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) dataset = registry.get_dataset_class(self.config["task"]["dataset"])( self.config["dataset"]) if self.config["task"]["dataset"] in ["qm9", "dogss"]: num_targets = dataset.data.y.shape[-1] if ("label_index" in self.config["task"] and self.config["task"]["label_index"] is not False): dataset.data.y = dataset.data.y[:, int(self.config["task"] ["label_index"])] num_targets = 1 else: num_targets = 1 self.num_targets = num_targets ( self.train_loader, self.val_loader, self.test_loader, ) = dataset.get_dataloaders( batch_size=int(self.config["optim"]["batch_size"])) # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): self.normalizers["target"] = Normalizer( self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): self.normalizers["grad_target"] = Normalizer( self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], self.device, ) self.normalizers["grad_target"].mean.fill_(0) if self.is_vis and self.config["task"]["dataset"] != "qm9": # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots) def load_model(self): # Build model print("### Loading model: {}".format(self.config["model"])) # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "ulissigroup_co", "ulissigroup_h", "xie_grossman_mat_proj", ]: bond_feat_dim = self.train_loader.dataset[0].edge_attr.shape[-1] elif self.config["task"]["dataset"] in [ "gasdb", "trajectory", "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50) else: raise NotImplementedError self.model = registry.get_model_class(self.config["model"])( self.train_loader.dataset[0].x.shape[-1] if hasattr(self.train_loader.dataset[0], "x") and self.train_loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) print("### Loaded {} with {} parameters.".format( self.model.__class__.__name__, self.model.num_params)) if self.logger is not None: self.logger.watch(self.model) def load_pretrained(self, checkpoint_path=None, ddp_to_dp=False): if checkpoint_path is None or os.path.isfile(checkpoint_path) is False: return False print("### Loading checkpoint from: {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) # Load model, optimizer, normalizer state dict. # if trained with ddp and want to load in non-ddp, modify keys from # module.module.. -> module.. if ddp_to_dp: new_dict = OrderedDict() for k, v in checkpoint["state_dict"].items(): name = k[7:] new_dict[name] = v self.model.load_state_dict(new_dict) else: self.model.load_state_dict(checkpoint["state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) for key in checkpoint["normalizers"]: if key in self.normalizers: self.normalizers[key].load_state_dict( checkpoint["normalizers"][key]) if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) return True # TODO(abhshkdz): Rename function to something nicer. # TODO(abhshkdz): Support multiple loss functions. def load_criterion(self): self.criterion = self.config["optim"].get("criterion", nn.L1Loss()) def load_optimizer(self): self.optimizer = optim.AdamW( self.model.parameters(), self.config["optim"]["lr_initial"], # weight_decay=3.0 ) def load_extras(self): # learning rate scheduler. scheduler_lambda_fn = lambda x: warmup_lr_lambda( x, self.config["optim"]) self.scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=scheduler_lambda_fn) # metrics. self.meter = Meter(split="train") def train(self, max_epochs=None, return_metrics=False): # TODO(abhshkdz): Timers for dataloading and forward pass. num_epochs = (max_epochs if max_epochs is not None else self.config["optim"]["max_epochs"]) for epoch in range(num_epochs): self.model.train() for i, batch in enumerate(self.train_loader): batch = batch.to(self.device) # Forward, loss, backward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) self._backward(loss) # Update meter. meter_update_dict = { "epoch": epoch + (i + 1) / len(self.train_loader), "loss": loss.item(), } meter_update_dict.update(metrics) self.meter.update(meter_update_dict) # Make plots. if self.logger is not None: self.logger.log( meter_update_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) # Print metrics. if i % self.config["cmd"]["print_every"] == 0: print(self.meter) self.scheduler.step() with torch.no_grad(): if self.val_loader is not None: v_loss, v_mae = self.validate(split="val", epoch=epoch) if self.test_loader is not None: test_loss, test_mae = self.validate(split="test", epoch=epoch) if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) if return_metrics: return { "training_loss": float(self.meter.loss.global_avg), "training_mae": float(self.meter.meters[ self.config["task"]["labels"][0] + "/" + self.config["task"]["metric"]].global_avg), "validation_loss": v_loss, "validation_mae": v_mae, "test_loss": test_loss, "test_mae": test_mae, } def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() meter = Meter(split=split) loader = self.val_loader if split == "val" else self.test_loader for i, batch in enumerate(loader): batch = batch.to(self.device) # Forward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) # Update meter. meter_update_dict = {"loss": loss.item()} meter_update_dict.update(metrics) meter.update(meter_update_dict) # Make plots. if self.logger is not None and epoch is not None: log_dict = meter.get_scalar_dict() log_dict.update({"epoch": epoch + 1}) self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(meter) return ( float(meter.loss.global_avg), float(meter.meters[self.config["task"]["labels"][0] + "/" + self.config["task"]["metric"]].global_avg), ) def _forward(self, batch, compute_metrics=True): out = {} # enable gradient wrt input. if "grad_input" in self.config["task"]: inp_for_grad = batch.pos batch.pos = batch.pos.requires_grad_(True) # forward pass. if self.config["model_attributes"].get("regress_forces", False): output, output_forces = self.model(batch) else: output = self.model(batch) if output.shape[-1] == 1: output = output.view(-1) out["output"] = output force_output = None if self.config["model_attributes"].get("regress_forces", False): out["force_output"] = output_forces force_output = output_forces if ("grad_input" in self.config["task"] and self.config["model_attributes"].get( "regress_forces", False) is False): force_output = (-1 * torch.autograd.grad( output, inp_for_grad, grad_outputs=torch.ones_like(output), create_graph=True, retain_graph=True, )[0]) out["force_output"] = force_output if not compute_metrics: return out, None metrics = {} if self.config["dataset"].get("normalize_labels", True): errors = eval(self.config["task"]["metric"])( self.normalizers["target"].denorm(output).cpu(), batch.y.cpu()).view(-1) else: errors = eval(self.config["task"]["metric"])( output.cpu(), batch.y.cpu()).view(-1) if ("label_index" in self.config["task"] and self.config["task"]["label_index"] is not False): # TODO(abhshkdz): Get rid of this edge case for QM9. # This is only because QM9 has multiple targets and we can either # jointly predict all of them or one particular target. metrics["{}/{}".format( self.config["task"]["labels"][self.config["task"] ["label_index"]], self.config["task"]["metric"], )] = errors[0] else: for i, label in enumerate(self.config["task"]["labels"]): metrics["{}/{}".format( label, self.config["task"]["metric"])] = errors[i] if "grad_input" in self.config["task"]: force_pred = force_output force_target = batch.force if self.config["task"].get("eval_on_free_atoms", True): mask = batch.fixed == 0 force_pred = force_pred[mask] force_target = force_target[mask] if self.config["dataset"].get("normalize_labels", True): grad_input_errors = eval(self.config["task"]["metric"])( self.normalizers["grad_target"].denorm(force_pred).cpu(), force_target.cpu(), ) else: grad_input_errors = eval(self.config["task"]["metric"])( force_pred.cpu(), force_target.cpu()) metrics["force_x/{}".format( self.config["task"]["metric"])] = grad_input_errors[0] metrics["force_y/{}".format( self.config["task"]["metric"])] = grad_input_errors[1] metrics["force_z/{}".format( self.config["task"]["metric"])] = grad_input_errors[2] return out, metrics def _compute_loss(self, out, batch): loss = [] if self.config["dataset"].get("normalize_labels", True): target_normed = self.normalizers["target"].norm(batch.y) else: target_normed = batch.y loss.append(self.criterion(out["output"], target_normed)) # TODO(abhshkdz): Test support for gradients wrt input. # TODO(abhshkdz): Make this general; remove dependence on `.forces`. if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): grad_target_normed = self.normalizers["grad_target"].norm( batch.force) else: grad_target_normed = batch.force # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): mask = batch.fixed == 0 loss.append(force_mult * self.criterion( out["force_output"][mask], grad_target_normed[mask])) else: loss.append( force_mult * self.criterion(out["force_output"], grad_target_normed)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss def _backward(self, loss): self.optimizer.zero_grad() loss.backward() # TODO(abhshkdz): Add support for gradient clipping. if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step()
class DOGSSTrainer(BaseTrainer): def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="wandb", ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "dataset": dataset, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", timestamp), }, } os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.load() print(yaml.dump(self.config, default_flow_style=False)) initial_train_loss = self.get_initial_loss(self.train_loader) initial_val_loss = self.get_initial_loss(self.val_loader) initial_test_loss = self.get_initial_loss(self.test_loader) print( "### initial train loss: %f\n" % initial_train_loss, "### initial val loss: %f\n" % initial_val_loss, "### initial test loss: %f\n" % initial_test_loss, ) def load_criterion(self): self.criterion = mean_l2_distance def load_optimizer(self): self.optimizer = optim.AdamW( self.model.parameters(), self.config["optim"]["lr_initial"], weight_decay=self.config["optim"]["weight_decay"], ) def load_extras(self): # learning rate scheduler. self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=self.config["optim"]["lr_milestones"], gamma=self.config["optim"]["lr_gamma"], ) # metrics. self.meter = Meter(split="train") def train(self): for epoch in range(self.config["optim"]["max_epochs"]): self.model.train() for i, batch in enumerate(self.train_loader): batch = batch.to(self.device) # Forward, loss, backward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) self._backward(loss) # Update meter. meter_update_dict = { "epoch": epoch + (i + 1) / len(self.train_loader), "loss": loss.item(), } meter_update_dict.update(metrics) self.meter.update(meter_update_dict) # Make plots. if self.logger is not None: self.logger.log( meter_update_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) # Print metrics. if i % self.config["cmd"]["print_every"] == 0: print(self.meter) self.scheduler.step() if self.val_loader is not None: self.validate(split="val", epoch=epoch) if self.test_loader is not None: self.validate(split="test", epoch=epoch) if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, }, self.config["cmd"]["checkpoint_dir"], ) def get_initial_loss(self, dataset): distances = [] for data in dataset: free_atom_idx = np.where(data.fixed_base.cpu() == 0)[0] atom_pos = data.atom_pos[free_atom_idx] y = data.y dist = torch.sqrt(torch.sum((atom_pos - y)**2, dim=1)) distances.append(dist) mae = torch.mean(torch.cat(distances)) return mae