class BaseTrainer: def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, name="base_trainer", ): self.name = name if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if run_dir is None: run_dir = os.getcwd() run_dir = Path(run_dir) timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp(timestamp).strftime( "%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": str(run_dir / "checkpoints" / timestamp), "results_dir": str(run_dir / "results" / timestamp), "logs_dir": str(run_dir / "logs" / logger / timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name) def load(self): self.load_seed_from_config() self.load_logger() self.load_task() self.load_model() self.load_criterion() self.load_optimizer() self.load_extras() # Note: this function is now deprecated. We build config outside of trainer. # See build_config in ocpmodels.common.utils.py. def load_config_from_yaml_and_cmd(self, args): self.config = build_config(args) # AMP Scaler self.scaler = (torch.cuda.amp.GradScaler() if self.config["amp"] else None) # device self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # Are we just running sanity checks? self.is_debug = args.debug self.is_vis = args.vis # timestamps and directories args.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if args.identifier: args.timestamp += "-{}".format(args.identifier) args.checkpoint_dir = os.path.join("checkpoints", args.timestamp) args.results_dir = os.path.join("results", args.timestamp) args.logs_dir = os.path.join("logs", self.config["logger"], args.timestamp) print(yaml.dump(self.config, default_flow_style=False)) for arg in vars(args): print("{:<20}: {}".format(arg, getattr(args, arg))) # TODO(abhshkdz): Handle these parameters better. Maybe move to yaml. self.config["cmd"] = args.__dict__ del args if not self.is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) # Dump config parameters json.dump( self.config, open( os.path.join(self.config["cmd"]["checkpoint_dir"], "config.json"), "w", ), ) def load_seed_from_config(self): # https://pytorch.org/docs/stable/notes/randomness.html seed = self.config["cmd"]["seed"] if seed is None: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_logger(self): self.logger = None if not self.is_debug and distutils.is_master(): assert (self.config["logger"] is not None), "Specify logger in config" self.logger = registry.get_logger_class(self.config["logger"])( self.config) def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) dataset = registry.get_dataset_class(self.config["task"]["dataset"])( self.config["dataset"]) if self.config["task"]["dataset"] in ["qm9", "dogss"]: num_targets = dataset.data.y.shape[-1] if ("label_index" in self.config["task"] and self.config["task"]["label_index"] is not False): dataset.data.y = dataset.data.y[:, int(self.config["task"] ["label_index"])] num_targets = 1 else: num_targets = 1 self.num_targets = num_targets ( self.train_loader, self.val_loader, self.test_loader, ) = dataset.get_dataloaders( batch_size=int(self.config["optim"]["batch_size"])) # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): self.normalizers["target"] = Normalizer( self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): self.normalizers["grad_target"] = Normalizer( self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], self.device, ) self.normalizers["grad_target"].mean.fill_(0) if self.is_vis and self.config["task"]["dataset"] != "qm9": # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots) def load_model(self): # Build model if distutils.is_master(): print("### Loading model: {}".format(self.config["model"])) # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50) else: raise NotImplementedError self.model = registry.get_model_class(self.config["model"])( self.train_loader.dataset[0].x.shape[-1] if hasattr(self.train_loader.dataset[0], "x") and self.train_loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) if distutils.is_master(): print("### Loaded {} with {} parameters.".format( self.model.__class__.__name__, self.model.num_params)) if self.logger is not None: self.logger.watch(self.model) self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1, ) if distutils.initialized(): self.model = DistributedDataParallel(self.model, device_ids=[self.device]) def load_pretrained(self, checkpoint_path=None, ddp_to_dp=False): if checkpoint_path is None or os.path.isfile(checkpoint_path) is False: print(f"Checkpoint: {checkpoint_path} not found!") return False print("### Loading checkpoint from: {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) # Load model, optimizer, normalizer state dict. # if trained with ddp and want to load in non-ddp, modify keys from # module.module.. -> module.. if ddp_to_dp: new_dict = OrderedDict() for k, v in checkpoint["state_dict"].items(): name = k[7:] new_dict[name] = v self.model.load_state_dict(new_dict) else: self.model.load_state_dict(checkpoint["state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) for key in checkpoint["normalizers"]: if key in self.normalizers: self.normalizers[key].load_state_dict( checkpoint["normalizers"][key]) if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) return True # TODO(abhshkdz): Rename function to something nicer. # TODO(abhshkdz): Support multiple loss functions. def load_criterion(self): self.criterion = self.config["optim"].get("criterion", nn.L1Loss()) def load_optimizer(self): self.optimizer = optim.AdamW( self.model.parameters(), self.config["optim"]["lr_initial"], # weight_decay=3.0 ) def load_extras(self): # learning rate scheduler. scheduler_lambda_fn = lambda x: warmup_lr_lambda( x, self.config["optim"]) self.scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=scheduler_lambda_fn) # metrics. self.meter = Meter(split="train") def save(self, epoch, metrics): if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) def train(self, max_epochs=None, return_metrics=False): # TODO(abhshkdz): Timers for dataloading and forward pass. num_epochs = (max_epochs if max_epochs is not None else self.config["optim"]["max_epochs"]) for epoch in range(num_epochs): self.model.train() for i, batch in enumerate(self.train_loader): batch = batch.to(self.device) # Forward, loss, backward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) self._backward(loss) # Update meter. meter_update_dict = { "epoch": epoch + (i + 1) / len(self.train_loader), "loss": loss.item(), } meter_update_dict.update(metrics) self.meter.update(meter_update_dict) # Make plots. if self.logger is not None: self.logger.log( meter_update_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) # Print metrics. if i % self.config["cmd"]["print_every"] == 0: print(self.meter) self.scheduler.step() with torch.no_grad(): if self.val_loader is not None: v_loss, v_mae = self.validate(split="val", epoch=epoch) if self.test_loader is not None: test_loss, test_mae = self.validate(split="test", epoch=epoch) if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) if return_metrics: return { "training_loss": float(self.meter.loss.global_avg), "training_mae": float(self.meter.meters[ self.config["task"]["labels"][0] + "/" + self.config["task"]["metric"]].global_avg), "validation_loss": v_loss, "validation_mae": v_mae, "test_loss": test_loss, "test_mae": test_mae, } def validate(self, split="val", epoch=None): if distutils.is_master(): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def _forward(self, batch, compute_metrics=True): out = {} # enable gradient wrt input. if "grad_input" in self.config["task"]: inp_for_grad = batch.pos batch.pos = batch.pos.requires_grad_(True) # forward pass. if self.config["model_attributes"].get("regress_forces", False): output, output_forces = self.model(batch) else: output = self.model(batch) if output.shape[-1] == 1: output = output.view(-1) out["output"] = output force_output = None if self.config["model_attributes"].get("regress_forces", False): out["force_output"] = output_forces force_output = output_forces if ("grad_input" in self.config["task"] and self.config["model_attributes"].get( "regress_forces", False) is False): force_output = -1 * torch.autograd.grad( output, inp_for_grad, grad_outputs=torch.ones_like(output), create_graph=True, retain_graph=True, )[0] out["force_output"] = force_output if not compute_metrics: return out, None metrics = {} if self.config["dataset"].get("normalize_labels", True): errors = eval(self.config["task"]["metric"])( self.normalizers["target"].denorm(output).cpu(), batch.y.cpu()).view(-1) else: errors = eval(self.config["task"]["metric"])( output.cpu(), batch.y.cpu()).view(-1) if ("label_index" in self.config["task"] and self.config["task"]["label_index"] is not False): # TODO(abhshkdz): Get rid of this edge case for QM9. # This is only because QM9 has multiple targets and we can either # jointly predict all of them or one particular target. metrics["{}/{}".format( self.config["task"]["labels"][self.config["task"] ["label_index"]], self.config["task"]["metric"], )] = errors[0] else: for i, label in enumerate(self.config["task"]["labels"]): metrics["{}/{}".format( label, self.config["task"]["metric"])] = errors[i] if "grad_input" in self.config["task"]: force_pred = force_output force_target = batch.force if self.config["task"].get("eval_on_free_atoms", True): mask = batch.fixed == 0 force_pred = force_pred[mask] force_target = force_target[mask] if self.config["dataset"].get("normalize_labels", True): grad_input_errors = eval(self.config["task"]["metric"])( self.normalizers["grad_target"].denorm(force_pred).cpu(), force_target.cpu(), ) else: grad_input_errors = eval(self.config["task"]["metric"])( force_pred.cpu(), force_target.cpu()) metrics["force_x/{}".format( self.config["task"]["metric"])] = grad_input_errors[0] metrics["force_y/{}".format( self.config["task"]["metric"])] = grad_input_errors[1] metrics["force_z/{}".format( self.config["task"]["metric"])] = grad_input_errors[2] return out, metrics def _compute_loss(self, out, batch): loss = [] if self.config["dataset"].get("normalize_labels", True): target_normed = self.normalizers["target"].norm(batch.y) else: target_normed = batch.y loss.append(self.criterion(out["output"], target_normed)) # TODO(abhshkdz): Test support for gradients wrt input. # TODO(abhshkdz): Make this general; remove dependence on `.forces`. if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): grad_target_normed = self.normalizers["grad_target"].norm( batch.force) else: grad_target_normed = batch.force # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): mask = batch.fixed == 0 loss.append(force_mult * self.criterion( out["force_output"][mask], grad_target_normed[mask])) else: loss.append( force_mult * self.criterion(out["force_output"], grad_target_normed)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss def _backward(self, loss): self.optimizer.zero_grad() loss.backward() # TODO(abhshkdz): Add support for gradient clipping. if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() def save_results(self, predictions, results_file, keys): if results_file is None: return results_file_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{distutils.get_rank()}.npz", ) np.savez_compressed( results_file_path, ids=predictions["id"], **{key: predictions[key] for key in keys}, ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) for key in keys: gather_results[key].extend(rank_results[key]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] for k in keys: if k == "forces": gather_results[k] = np.array(gather_results[k], dtype=object)[idx] else: gather_results[k] = np.array(gather_results[k])[idx] print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results)
class BaseTrainer(ABC): def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, name="base_trainer", ): self.name = name self.cpu = cpu self.start_step = 0 if torch.cuda.is_available() and not self.cpu: self.device = local_rank else: self.device = "cpu" self.cpu = True # handle case when `--cpu` isn't specified # but there are no gpu devices available if run_dir is None: run_dir = os.getcwd() timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp( timestamp.int()).strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) try: commit_hash = (subprocess.check_output([ "git", "-C", ocpmodels.__path__[0], "describe", "--always", ]).strip().decode("ascii")) # catch instances where code is not being run from a git repo except Exception: commit_hash = None self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "gpus": distutils.get_world_size() if not self.cpu else 0, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "commit": commit_hash, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master() and not is_hpo: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis self.is_hpo = is_hpo if self.is_hpo: # sets the hpo checkpoint frequency # default is no checkpointing self.hpo_checkpoint_every = self.config["optim"].get( "checkpoint_every", -1) if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name) def load(self): self.load_seed_from_config() self.load_logger() self.load_task() self.load_model() self.load_criterion() self.load_optimizer() self.load_extras() # Note: this function is now deprecated. We build config outside of trainer. # See build_config in ocpmodels.common.utils.py. def load_config_from_yaml_and_cmd(self, args): self.config = build_config(args) # AMP Scaler self.scaler = (torch.cuda.amp.GradScaler() if self.config["amp"] else None) # device self.device = torch.device("cuda" if ( torch.cuda.is_available() and not self.cpu) else "cpu") # Are we just running sanity checks? self.is_debug = args.debug self.is_vis = args.vis # timestamps and directories args.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if args.identifier: args.timestamp += "-{}".format(args.identifier) args.checkpoint_dir = os.path.join("checkpoints", args.timestamp) args.results_dir = os.path.join("results", args.timestamp) args.logs_dir = os.path.join("logs", self.config["logger"], args.timestamp) print(yaml.dump(self.config, default_flow_style=False)) for arg in vars(args): print("{:<20}: {}".format(arg, getattr(args, arg))) # TODO(abhshkdz): Handle these parameters better. Maybe move to yaml. self.config["cmd"] = args.__dict__ del args if not self.is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) # Dump config parameters json.dump( self.config, open( os.path.join(self.config["cmd"]["checkpoint_dir"], "config.json"), "w", ), ) def load_seed_from_config(self): # https://pytorch.org/docs/stable/notes/randomness.html seed = self.config["cmd"]["seed"] if seed is None: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_logger(self): self.logger = None if not self.is_debug and distutils.is_master() and not self.is_hpo: assert (self.config["logger"] is not None), "Specify logger in config" self.logger = registry.get_logger_class(self.config["logger"])( self.config) @abstractmethod def load_task(self): """Derived classes should implement this function.""" def load_model(self): # Build model if distutils.is_master(): print("### Loading model: {}".format(self.config["model"])) # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50) else: raise NotImplementedError self.model = registry.get_model_class(self.config["model"])( self.train_loader.dataset[0].x.shape[-1] if hasattr(self.train_loader.dataset[0], "x") and self.train_loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) if distutils.is_master(): print("### Loaded {} with {} parameters.".format( self.model.__class__.__name__, self.model.num_params)) if self.logger is not None: self.logger.watch(self.model) self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1 if not self.cpu else 0, ) if distutils.initialized(): self.model = DistributedDataParallel(self.model, device_ids=[self.device]) def load_pretrained(self, checkpoint_path=None, ddp_to_dp=False): if checkpoint_path is None or os.path.isfile(checkpoint_path) is False: print(f"Checkpoint: {checkpoint_path} not found!") return False print("### Loading checkpoint from: {}".format(checkpoint_path)) checkpoint = torch.load( checkpoint_path, map_location=(torch.device("cpu") if self.cpu else None), ) self.start_step = checkpoint.get("step", 0) # Load model, optimizer, normalizer state dict. # if trained with ddp and want to load in non-ddp, modify keys from # module.module.. -> module.. if ddp_to_dp: new_dict = OrderedDict() for k, v in checkpoint["state_dict"].items(): name = k[7:] new_dict[name] = v self.model.load_state_dict(new_dict) else: self.model.load_state_dict(checkpoint["state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) if "scheduler" in checkpoint and checkpoint["scheduler"] is not None: self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"]) for key in checkpoint["normalizers"]: if key in self.normalizers: self.normalizers[key].load_state_dict( checkpoint["normalizers"][key]) if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) return True # TODO(abhshkdz): Rename function to something nicer. # TODO(abhshkdz): Support multiple loss functions. def load_criterion(self): self.criterion = self.config["optim"].get("criterion", nn.L1Loss()) def load_optimizer(self): optimizer = self.config["optim"].get("optimizer", "AdamW") optimizer = getattr(optim, optimizer) self.optimizer = optimizer( params=self.model.parameters(), lr=self.config["optim"]["lr_initial"], **self.config["optim"].get("optimizer_params", {}), ) def load_extras(self): self.scheduler = LRScheduler(self.optimizer, self.config["optim"]) # metrics. self.meter = Meter(split="train") def save_state(self, epoch, step, metrics): state = { "epoch": epoch, "step": step, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.scheduler.state_dict() if self.scheduler.scheduler_type != "Null" else None, "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "amp": self.scaler.state_dict() if self.scaler else None, } return state def save(self, epoch, step, metrics): if not self.is_debug and distutils.is_master() and not self.is_hpo: save_checkpoint( self.save_state(epoch, step, metrics), self.config["cmd"]["checkpoint_dir"], ) def save_hpo(self, epoch, step, metrics, checkpoint_every): # default is no checkpointing # checkpointing frequency can be adjusted by setting checkpoint_every in steps # to checkpoint every time results are communicated to Ray Tune set checkpoint_every=1 if checkpoint_every != -1 and step % checkpoint_every == 0: with tune.checkpoint_dir(step=step) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") torch.save(self.save_state(epoch, step, metrics), path) def hpo_update(self, epoch, step, train_metrics, val_metrics, test_metrics=None): progress = { "steps": step, "epochs": epoch, "act_lr": self.optimizer.param_groups[0]["lr"], } # checkpointing must occur before reporter # default is no checkpointing self.save_hpo( epoch, step, val_metrics, self.hpo_checkpoint_every, ) # report metrics to tune tune_reporter( iters=progress, train_metrics={ k: train_metrics[k]["metric"] for k in self.metrics }, val_metrics={k: val_metrics[k]["metric"] for k in val_metrics}, test_metrics=test_metrics, ) @abstractmethod def train(self): """Derived classes should implement this function.""" @torch.no_grad() def validate(self, split="val", epoch=None, disable_tqdm=False): if distutils.is_master(): print("### Evaluating on {}.".format(split)) if self.is_hpo: disable_tqdm = True self.model.eval() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics @abstractmethod def _forward(self, batch_list): """Derived classes should implement this function.""" @abstractmethod def _compute_loss(self, out, batch_list): """Derived classes should implement this function.""" def _backward(self, loss): self.optimizer.zero_grad() loss.backward() # TODO(abhshkdz): Add support for gradient clipping. if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() def save_results(self, predictions, results_file, keys): if results_file is None: return results_file_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{distutils.get_rank()}.npz", ) np.savez_compressed( results_file_path, ids=predictions["id"], **{key: predictions[key] for key in keys}, ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) for key in keys: gather_results[key].extend(rank_results[key]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] for k in keys: if k == "forces": gather_results[k] = np.concatenate( np.array(gather_results[k])[idx]) elif k == "chunk_idx": gather_results[k] = np.cumsum( np.array(gather_results[k])[idx])[:-1] else: gather_results[k] = np.array(gather_results[k])[idx] print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results)
def main(args): """ Main function for training, evaluating, and checkpointing. Args: args: `argparse` object. """ # Print arguments. print('\nusing arguments:') _print_arguments(args) print() # Check if GPU is available. if not args.use_gpu and torch.cuda.is_available(): print('warning: GPU is available but args.use_gpu = False') print() local_rank = args.local_rank # world_size = torch.cuda.device_count() # assume all local GPUs # Set up distributed process group rank = setup_dist(local_rank) # Set up datasets. train_dataset = QADataset(args, args.train_path) dev_dataset = QADataset(args, args.dev_path) # Create vocabulary and tokenizer. vocabulary = Vocabulary(train_dataset.samples, args.vocab_size) tokenizer = Tokenizer(vocabulary) for dataset in (train_dataset, dev_dataset): dataset.register_tokenizer(tokenizer) args.vocab_size = len(vocabulary) args.pad_token_id = tokenizer.pad_token_id print(f'vocab words = {len(vocabulary)}') # Print number of samples. print(f'train samples = {len(train_dataset)}') print(f'dev samples = {len(dev_dataset)}') print() # Select model. model = _select_model(args) #model = model.to(rank) #model = DDP(model, device_ids=[rank], output_device=rank) num_pretrained = model.load_pretrained_embeddings( vocabulary, args.embedding_path ) pct_pretrained = round(num_pretrained / len(vocabulary) * 100., 2) print(f'using pre-trained embeddings from \'{args.embedding_path}\'') print( f'initialized {num_pretrained}/{len(vocabulary)} ' f'embeddings ({pct_pretrained}%)' ) print() # device = torch.device(f'cuda:{rank}') model = model.to(rank) model = DDP(model, device_ids=[rank], output_device=rank) # if args.use_gpu: # model = cuda(args, model) if args.resume and args.model_path: map_location = {"cuda:0": "cuda:{}".format(rank)} model.load_state_dict(torch.load(args.model_path, map_location=map_location)) params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'using model \'{args.model}\' ({params} params)') print(model) print() if args.do_train: # Track training statistics for checkpointing. eval_history = [] best_eval_loss = float('inf') # Begin training. for epoch in range(1, args.epochs + 1): # Perform training and evaluation steps. try: train_loss = train(args, epoch, model, train_dataset) except RuntimeError: print(f'NCCL Wait Timeout, rank: \'{args.local_rank}\' (exit)') exit(1) eval_loss = evaluate(args, epoch, model, dev_dataset) # If the model's evaluation loss yields a global improvement, # checkpoint the model. if rank == 0: eval_history.append(eval_loss < best_eval_loss) if eval_loss < best_eval_loss: best_eval_loss = eval_loss torch.save(model.state_dict(), args.model_path) print( f'epoch = {epoch} | ' f'train loss = {train_loss:.6f} | ' f'eval loss = {eval_loss:.6f} | ' f"{'saving model!' if eval_history[-1] else ''}" ) # If early stopping conditions are met, stop training. if _early_stop(args, eval_history): suffix = 's' if args.early_stop > 1 else '' print( f'no improvement after {args.early_stop} epoch{suffix}. ' 'early stopping...' ) print() cleanup_dist() break if args.do_test and rank == 0: # Write predictions to the output file. Use the printed command # below to obtain official EM/F1 metrics. write_predictions(args, model, dev_dataset) eval_cmd = ( 'python3 evaluate.py ' f'--dataset_path {args.dev_path} ' f'--output_path {args.output_path}' ) print() print(f'predictions written to \'{args.output_path}\'') print(f'compute EM/F1 with: \'{eval_cmd}\'') print()
output_device=args.local_rank) # load pre-trained parameters if args.load_from != 'none': with torch.cuda.device(args.local_rank): pretrained_dict = torch.load( os.path.join(args.workspace_prefix, 'models', args.load_from + '.pt'), map_location=lambda storage, loc: storage.cuda()) model_dict = model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(model_dict) decoding_path = os.path.join( args.workspace_prefix, 'decodes', args.load_from if args.mode == 'test' else (args.prefix + hp_str)) if (args.local_rank == 0) and (not os.path.exists(decoding_path)): os.mkdir(decoding_path) name_suffix = 'b={}_a={}.txt'.format(args.beam_size, args.alpha) names = [ '{}.src.{}'.format(args.test_set, name_suffix), '{}.trg.{}'.format(args.test_set, name_suffix), '{}.dec.{}'.format(args.test_set, name_suffix) ] # start running
class BaseTrainer(ABC): def __init__( self, task, model, dataset, optimizer, identifier, normalizer=None, timestamp_id=None, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, name="base_trainer", slurm={}, ): self.name = name self.cpu = cpu self.epoch = 0 self.step = 0 if torch.cuda.is_available() and not self.cpu: self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") self.cpu = True # handle case when `--cpu` isn't specified # but there are no gpu devices available if run_dir is None: run_dir = os.getcwd() if timestamp_id is None: timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device ) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp( timestamp.int() ).strftime("%Y-%m-%d-%H-%M-%S") if identifier: self.timestamp_id = f"{timestamp}-{identifier}" else: self.timestamp_id = timestamp else: self.timestamp_id = timestamp_id try: commit_hash = ( subprocess.check_output( [ "git", "-C", ocpmodels.__path__[0], "describe", "--always", ] ) .strip() .decode("ascii") ) # catch instances where code is not being run from a git repo except Exception: commit_hash = None self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "gpus": distutils.get_world_size() if not self.cpu else 0, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp_id": self.timestamp_id, "commit": commit_hash, "checkpoint_dir": os.path.join( run_dir, "checkpoints", self.timestamp_id ), "results_dir": os.path.join( run_dir, "results", self.timestamp_id ), "logs_dir": os.path.join( run_dir, "logs", logger, self.timestamp_id ), }, "slurm": slurm, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if "SLURM_JOB_ID" in os.environ and "folder" in self.config["slurm"]: self.config["slurm"]["job_id"] = os.environ["SLURM_JOB_ID"] self.config["slurm"]["folder"] = self.config["slurm"][ "folder" ].replace("%j", self.config["slurm"]["job_id"]) if isinstance(dataset, list): if len(dataset) > 0: self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] elif isinstance(dataset, dict): self.config["dataset"] = dataset.get("train", None) self.config["val_dataset"] = dataset.get("val", None) self.config["test_dataset"] = dataset.get("test", None) else: self.config["dataset"] = dataset self.normalizer = normalizer # This supports the legacy way of providing norm parameters in dataset if self.config.get("dataset", None) is not None and normalizer is None: self.normalizer = self.config["dataset"] if not is_debug and distutils.is_master() and not is_hpo: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis self.is_hpo = is_hpo if self.is_hpo: # conditional import is necessary for checkpointing from ray import tune from ocpmodels.common.hpo_utils import tune_reporter # sets the hpo checkpoint frequency # default is no checkpointing self.hpo_checkpoint_every = self.config["optim"].get( "checkpoint_every", -1 ) if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name) def load(self): self.load_seed_from_config() self.load_logger() self.load_task() self.load_model() self.load_loss() self.load_optimizer() self.load_extras() def load_seed_from_config(self): # https://pytorch.org/docs/stable/notes/randomness.html seed = self.config["cmd"]["seed"] if seed is None: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_logger(self): self.logger = None if not self.is_debug and distutils.is_master() and not self.is_hpo: assert ( self.config["logger"] is not None ), "Specify logger in config" self.logger = registry.get_logger_class(self.config["logger"])( self.config ) def get_sampler(self, dataset, batch_size, shuffle): if "load_balancing" in self.config["optim"]: balancing_mode = self.config["optim"]["load_balancing"] force_balancing = True else: balancing_mode = "atoms" force_balancing = False sampler = BalancedBatchSampler( dataset, batch_size=batch_size, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), device=self.device, mode=balancing_mode, shuffle=shuffle, force_balancing=force_balancing, ) return sampler def get_dataloader(self, dataset, sampler): loader = DataLoader( dataset, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, batch_sampler=sampler, ) return loader @abstractmethod def load_task(self): """Derived classes should implement this function.""" def load_model(self): # Build model if distutils.is_master(): logging.info(f"Loading model: {self.config['model']}") # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50 ) else: raise NotImplementedError loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( loader.dataset[0].x.shape[-1] if loader and hasattr(loader.dataset[0], "x") and loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) if distutils.is_master(): logging.info( f"Loaded {self.model.__class__.__name__} with " f"{self.model.num_params} parameters." ) if self.logger is not None: self.logger.watch(self.model) self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1 if not self.cpu else 0, ) if distutils.initialized(): self.model = DistributedDataParallel( self.model, device_ids=[self.device] ) def load_checkpoint(self, checkpoint_path): if not os.path.isfile(checkpoint_path): raise FileNotFoundError( errno.ENOENT, "Checkpoint file not found", checkpoint_path ) logging.info(f"Loading checkpoint from: {checkpoint_path}") map_location = torch.device("cpu") if self.cpu else self.device checkpoint = torch.load(checkpoint_path, map_location=map_location) self.epoch = checkpoint.get("epoch", 0) self.step = checkpoint.get("step", 0) # Load model, optimizer, normalizer state dict. # if trained with ddp and want to load in non-ddp, modify keys from # module.module.. -> module.. first_key = next(iter(checkpoint["state_dict"])) if not distutils.initialized() and first_key.split(".")[1] == "module": # No need for OrderedDict since dictionaries are technically ordered # since Python 3.6 and officially ordered since Python 3.7 new_dict = {k[7:]: v for k, v in checkpoint["state_dict"].items()} self.model.load_state_dict(new_dict) else: self.model.load_state_dict(checkpoint["state_dict"]) if "optimizer" in checkpoint: self.optimizer.load_state_dict(checkpoint["optimizer"]) if "scheduler" in checkpoint and checkpoint["scheduler"] is not None: self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"]) if "ema" in checkpoint and checkpoint["ema"] is not None: self.ema.load_state_dict(checkpoint["ema"]) for key in checkpoint["normalizers"]: if key in self.normalizers: self.normalizers[key].load_state_dict( checkpoint["normalizers"][key] ) if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) def load_loss(self): self.loss_fn = {} self.loss_fn["energy"] = self.config["optim"].get("loss_energy", "mae") self.loss_fn["force"] = self.config["optim"].get("loss_force", "mae") for loss, loss_name in self.loss_fn.items(): if loss_name in ["l1", "mae"]: self.loss_fn[loss] = nn.L1Loss() elif loss_name == "mse": self.loss_fn[loss] = nn.MSELoss() elif loss_name == "l2mae": self.loss_fn[loss] = L2MAELoss() else: raise NotImplementedError( f"Unknown loss function name: {loss_name}" ) if distutils.initialized(): self.loss_fn[loss] = DDPLoss(self.loss_fn[loss]) def load_optimizer(self): optimizer = self.config["optim"].get("optimizer", "AdamW") optimizer = getattr(optim, optimizer) if self.config["optim"].get("weight_decay", 0) > 0: # Do not regularize bias etc. params_decay = [] params_no_decay = [] for name, param in self.model.named_parameters(): if param.requires_grad: if "embedding" in name: params_no_decay += [param] elif "frequencies" in name: params_no_decay += [param] elif "bias" in name: params_no_decay += [param] else: params_decay += [param] self.optimizer = optimizer( [ {"params": params_no_decay, "weight_decay": 0}, { "params": params_decay, "weight_decay": self.config["optim"]["weight_decay"], }, ], lr=self.config["optim"]["lr_initial"], **self.config["optim"].get("optimizer_params", {}), ) else: self.optimizer = optimizer( params=self.model.parameters(), lr=self.config["optim"]["lr_initial"], **self.config["optim"].get("optimizer_params", {}), ) def load_extras(self): self.scheduler = LRScheduler(self.optimizer, self.config["optim"]) self.clip_grad_norm = self.config["optim"].get("clip_grad_norm") self.ema_decay = self.config["optim"].get("ema_decay") if self.ema_decay: self.ema = ExponentialMovingAverage( self.model.parameters(), self.ema_decay, ) else: self.ema = None def save( self, metrics=None, checkpoint_file="checkpoint.pt", training_state=True, ): if not self.is_debug and distutils.is_master(): if training_state: save_checkpoint( { "epoch": self.epoch, "step": self.step, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.scheduler.state_dict() if self.scheduler.scheduler_type != "Null" else None, "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "ema": self.ema.state_dict() if self.ema else None, "amp": self.scaler.state_dict() if self.scaler else None, }, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) else: if self.ema: self.ema.store() self.ema.copy_to() save_checkpoint( { "state_dict": self.model.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) if self.ema: self.ema.restore() def save_hpo(self, epoch, step, metrics, checkpoint_every): # default is no checkpointing # checkpointing frequency can be adjusted by setting checkpoint_every in steps # to checkpoint every time results are communicated to Ray Tune set checkpoint_every=1 if checkpoint_every != -1 and step % checkpoint_every == 0: with tune.checkpoint_dir( # noqa: F821 step=step ) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") torch.save(self.save_state(epoch, step, metrics), path) def hpo_update( self, epoch, step, train_metrics, val_metrics, test_metrics=None ): progress = { "steps": step, "epochs": epoch, "act_lr": self.optimizer.param_groups[0]["lr"], } # checkpointing must occur before reporter # default is no checkpointing self.save_hpo( epoch, step, val_metrics, self.hpo_checkpoint_every, ) # report metrics to tune tune_reporter( # noqa: F821 iters=progress, train_metrics={ k: train_metrics[k]["metric"] for k in self.metrics }, val_metrics={k: val_metrics[k]["metric"] for k in val_metrics}, test_metrics=test_metrics, ) @abstractmethod def train(self): """Derived classes should implement this function.""" @torch.no_grad() def validate(self, split="val", disable_tqdm=False): if distutils.is_master(): logging.info(f"Evaluating on {split}.") if self.is_hpo: disable_tqdm = True self.model.eval() if self.ema: self.ema.store() self.ema.copy_to() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce( metrics[k]["total"], average=False, device=self.device ), "numel": distutils.all_reduce( metrics[k]["numel"], average=False, device=self.device ), } aggregated_metrics[k]["metric"] = ( aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"] ) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": self.epoch}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] logging.info(", ".join(log_str)) # Make plots. if self.logger is not None: self.logger.log( log_dict, step=self.step, split=split, ) if self.ema: self.ema.restore() return metrics @abstractmethod def _forward(self, batch_list): """Derived classes should implement this function.""" @abstractmethod def _compute_loss(self, out, batch_list): """Derived classes should implement this function.""" def _backward(self, loss): self.optimizer.zero_grad() loss.backward() # Scale down the gradients of shared parameters if hasattr(self.model, "shared_parameters"): for p, factor in self.model.shared_parameters: if p.grad is not None: p.grad.detach().div_(factor) if self.clip_grad_norm: if self.scaler: self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=self.clip_grad_norm, ) if self.logger is not None: self.logger.log( {"grad_norm": grad_norm}, step=self.step, split="train" ) if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() if self.ema: self.ema.update() def save_results(self, predictions, results_file, keys): if results_file is None: return results_file_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{distutils.get_rank()}.npz", ) np.savez_compressed( results_file_path, ids=predictions["id"], **{key: predictions[key] for key in keys}, ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) for key in keys: gather_results[key].extend(rank_results[key]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] for k in keys: if k == "forces": gather_results[k] = np.concatenate( np.array(gather_results[k])[idx] ) elif k == "chunk_idx": gather_results[k] = np.cumsum( np.array(gather_results[k])[idx] )[:-1] else: gather_results[k] = np.array(gather_results[k])[idx] logging.info(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results)
class Trainer: def __init__(self, config: BaseConfig): self._config = config self._model = DeepLab(num_classes=9, output_stride=8, sync_bn=False).to(self._config.device) self._border_loss = TotalLoss(self._config) self._direction_loss = CrossEntropyLoss() self._loaders = get_data_loaders(config) self._writer = SummaryWriter() self._optimizer = torch.optim.SGD(self._model.parameters(), lr=self._config.lr, weight_decay=1e-4, nesterov=True, momentum=0.9) self._scheduler = torch.optim.lr_scheduler.ExponentialLR( self._optimizer, gamma=0.97) if self._config.parallel: self._model = DistributedDataParallel(self._model, device_ids=[ self._config.device, ]) # self._load() def train(self): for epoch in range(self._config.num_epochs): t = tqdm(self._loaders[DataMode.train]) self._model.train() for idx, data in enumerate(t): self._optimizer.zero_grad() imgs, borders, masks, crop_info, opened_img, opened_mask = data imgs, borders, masks = imgs.to( self._config.device), borders.to( self._config.device), masks.to(self._config.device) # borders = borders.unsqueeze(1) output = self._model(imgs) border_output = output[:, :1, :, :].squeeze() direction_output = output[:, 1:, :, :] seg_loss = self._direction_loss(direction_output, masks) loss = self._border_loss(border_output, borders) + seg_loss t.set_description(f"LOSS: {seg_loss.item()}") loss.backward() self._optimizer.step() self._writer.add_scalar( "Loss/training", loss.item(), global_step=epoch * len(self._loaders[DataMode.train]) + idx) if idx % self._config.frequency_visualization[ DataMode.train] == 0: self._tensorboard_visualization( loss=loss, epoch=epoch, idx=idx, imgs=imgs, border_gt=borders.unsqueeze(1), border=border_output.unsqueeze(1), direction_gt=masks, direction=direction_output) if self._config.live_visualization: self._live_visualization(imgs, borders, output) self._save() @staticmethod def _get_mask(masks): masks_new = torch.zeros(masks.shape[0], 2, masks.shape[1], masks.shape[2], device=masks.device) for idx in range(2): masks_new[:, idx, :, :][masks == idx] = 1 return masks_new def validate(self, epoch): t = tqdm(self._loaders[DataMode.eval]) self._model.eval() for idx, data in enumerate(t): imgs, masks, _, _, _ = data imgs, masks = imgs.to(self._config.device), masks.to( self._config.device) output = self._model(imgs) loss = self._border_loss(masks, output) def _tensorboard_visualization(self, loss, epoch, idx, imgs, border_gt, direction_gt, border, direction): self._writer.add_images(f"Images{idx}/training", imgs, epoch) self._writer.add_images(f"Border{idx}/training", border, epoch) self._writer.add_images(f"BorderGT{idx}/training", border_gt, epoch) self._writer.add_images(f"DirectionGT{idx}/training", direction_gt.unsqueeze(1) / 8., epoch) self._writer.add_images(f"Direction{idx}/training", direction.argmax(1).unsqueeze(1) / 8., epoch) def _save(self): path = os.path.join(self._config.checkpoint_path, self._config.EXPERIMENT_NAME) self.check_and_mkdir(path) torch.save(self._model.state_dict(), os.path.join(path, "corrector.pth")) def _load(self): path = os.path.join(self._config.checkpoint_path, self._config.EXPERIMENT_NAME) weights = glob(os.path.join(path, "*.pth")) if len(weights): state_dict = torch.load(weights[0]) try: self._model.load_state_dict(state_dict) except RuntimeError as e: print("ERROR while loading weights: {}".format(e)) @staticmethod def check_and_mkdir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) @staticmethod def _live_visualization(imgs, masks, output): out = np.expand_dims( (output[1, 0, :, :].detach().cpu().numpy()).astype(np.float32), axis=2) > 0.7 out2 = np.expand_dims(np.argmax( (output[1, 1:, :, :].detach().cpu().numpy()).astype(np.float32), axis=0), axis=2) / 7. out3 = np.expand_dims( (output[1, 2, :, :].detach().cpu().numpy()).astype(np.float32), axis=2) show = np.zeros((masks.shape[2], 5 * masks.shape[3], 3), dtype=np.float32) show[:, :masks.shape[3]] = cv2.applyColorMap( (masks[1, 0, :, :] * 255).detach().cpu().numpy().astype(np.uint8), cv2.COLORMAP_BONE) show[:, masks.shape[3]:2 * masks.shape[3], :] = np.concatenate( [out, out, out], axis=2) show[:, 2 * masks.shape[3]:3 * masks.shape[3]] = cv2.cvtColor( imgs[1, ...].permute(1, 2, 0).cpu().detach().numpy(), cv2.COLOR_BGR2RGB) show[:, 3 * masks.shape[3]:4 * masks.shape[3]] = np.concatenate( [out2, out2, out2], axis=2) show[:, 4 * masks.shape[3]:5 * masks.shape[3]] = np.concatenate( [out3, out3, out3], axis=2) cv2.imshow("training", show) cv2.waitKey(10)