def train(epochs: int, train_data_loader: DataLoader, valid_data_loader: DataLoader = None, rank=None): device = torch.device(f'cuda:{rank}') model = create_model(model_type).to(device) model = DistributedDataParallel(model, device_ids=[rank], output_device=rank) optimizer = AdamW(model.parameters(), lr=lr) tokenizer = BertTokenizer.from_pretrained(model_type) def update_weights(bi, di, num_batches, batch_loss): batch_loss.backward() optimizer.step() optimizer.zero_grad() if bi % 100 == 0: logger.info( f'training: device={di}; batch={bi+1}/{num_batches}; batch_error={batch_loss.item()};' ) def valid_loss_progress_log(bi, di, num_batches, batch_loss): if bi % 100 == 0: logger.info( f'validation: device={di}; batch={bi+1}/{num_batches}; val_batch_error={batch_loss.item()};' ) for i in range(epochs): model.train() train_data_loader.sampler.set_epoch(i) valid_data_loader.sampler.set_epoch(i) train_loss = run(model, train_data_loader, tokenizer, device, update_weights) if valid_data_loader is not None: with torch.no_grad(): model.eval() val_loss = run(model, valid_data_loader, tokenizer, device, valid_loss_progress_log) else: val_loss = 'N/A' logger.info( f'epoch={i}; device={rank}; train_error={train_loss}; valid_error={val_loss};' ) return model.module
class BaseTrainer: def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, name="base_trainer", ): self.name = name if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if run_dir is None: run_dir = os.getcwd() run_dir = Path(run_dir) timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp(timestamp).strftime( "%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": str(run_dir / "checkpoints" / timestamp), "results_dir": str(run_dir / "results" / timestamp), "logs_dir": str(run_dir / "logs" / logger / timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name) def load(self): self.load_seed_from_config() self.load_logger() self.load_task() self.load_model() self.load_criterion() self.load_optimizer() self.load_extras() # Note: this function is now deprecated. We build config outside of trainer. # See build_config in ocpmodels.common.utils.py. def load_config_from_yaml_and_cmd(self, args): self.config = build_config(args) # AMP Scaler self.scaler = (torch.cuda.amp.GradScaler() if self.config["amp"] else None) # device self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # Are we just running sanity checks? self.is_debug = args.debug self.is_vis = args.vis # timestamps and directories args.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if args.identifier: args.timestamp += "-{}".format(args.identifier) args.checkpoint_dir = os.path.join("checkpoints", args.timestamp) args.results_dir = os.path.join("results", args.timestamp) args.logs_dir = os.path.join("logs", self.config["logger"], args.timestamp) print(yaml.dump(self.config, default_flow_style=False)) for arg in vars(args): print("{:<20}: {}".format(arg, getattr(args, arg))) # TODO(abhshkdz): Handle these parameters better. Maybe move to yaml. self.config["cmd"] = args.__dict__ del args if not self.is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) # Dump config parameters json.dump( self.config, open( os.path.join(self.config["cmd"]["checkpoint_dir"], "config.json"), "w", ), ) def load_seed_from_config(self): # https://pytorch.org/docs/stable/notes/randomness.html seed = self.config["cmd"]["seed"] if seed is None: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_logger(self): self.logger = None if not self.is_debug and distutils.is_master(): assert (self.config["logger"] is not None), "Specify logger in config" self.logger = registry.get_logger_class(self.config["logger"])( self.config) def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) dataset = registry.get_dataset_class(self.config["task"]["dataset"])( self.config["dataset"]) if self.config["task"]["dataset"] in ["qm9", "dogss"]: num_targets = dataset.data.y.shape[-1] if ("label_index" in self.config["task"] and self.config["task"]["label_index"] is not False): dataset.data.y = dataset.data.y[:, int(self.config["task"] ["label_index"])] num_targets = 1 else: num_targets = 1 self.num_targets = num_targets ( self.train_loader, self.val_loader, self.test_loader, ) = dataset.get_dataloaders( batch_size=int(self.config["optim"]["batch_size"])) # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): self.normalizers["target"] = Normalizer( self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): self.normalizers["grad_target"] = Normalizer( self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], self.device, ) self.normalizers["grad_target"].mean.fill_(0) if self.is_vis and self.config["task"]["dataset"] != "qm9": # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots) def load_model(self): # Build model if distutils.is_master(): print("### Loading model: {}".format(self.config["model"])) # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50) else: raise NotImplementedError self.model = registry.get_model_class(self.config["model"])( self.train_loader.dataset[0].x.shape[-1] if hasattr(self.train_loader.dataset[0], "x") and self.train_loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) if distutils.is_master(): print("### Loaded {} with {} parameters.".format( self.model.__class__.__name__, self.model.num_params)) if self.logger is not None: self.logger.watch(self.model) self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1, ) if distutils.initialized(): self.model = DistributedDataParallel(self.model, device_ids=[self.device]) def load_pretrained(self, checkpoint_path=None, ddp_to_dp=False): if checkpoint_path is None or os.path.isfile(checkpoint_path) is False: print(f"Checkpoint: {checkpoint_path} not found!") return False print("### Loading checkpoint from: {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) # Load model, optimizer, normalizer state dict. # if trained with ddp and want to load in non-ddp, modify keys from # module.module.. -> module.. if ddp_to_dp: new_dict = OrderedDict() for k, v in checkpoint["state_dict"].items(): name = k[7:] new_dict[name] = v self.model.load_state_dict(new_dict) else: self.model.load_state_dict(checkpoint["state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) for key in checkpoint["normalizers"]: if key in self.normalizers: self.normalizers[key].load_state_dict( checkpoint["normalizers"][key]) if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) return True # TODO(abhshkdz): Rename function to something nicer. # TODO(abhshkdz): Support multiple loss functions. def load_criterion(self): self.criterion = self.config["optim"].get("criterion", nn.L1Loss()) def load_optimizer(self): self.optimizer = optim.AdamW( self.model.parameters(), self.config["optim"]["lr_initial"], # weight_decay=3.0 ) def load_extras(self): # learning rate scheduler. scheduler_lambda_fn = lambda x: warmup_lr_lambda( x, self.config["optim"]) self.scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=scheduler_lambda_fn) # metrics. self.meter = Meter(split="train") def save(self, epoch, metrics): if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) def train(self, max_epochs=None, return_metrics=False): # TODO(abhshkdz): Timers for dataloading and forward pass. num_epochs = (max_epochs if max_epochs is not None else self.config["optim"]["max_epochs"]) for epoch in range(num_epochs): self.model.train() for i, batch in enumerate(self.train_loader): batch = batch.to(self.device) # Forward, loss, backward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) self._backward(loss) # Update meter. meter_update_dict = { "epoch": epoch + (i + 1) / len(self.train_loader), "loss": loss.item(), } meter_update_dict.update(metrics) self.meter.update(meter_update_dict) # Make plots. if self.logger is not None: self.logger.log( meter_update_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) # Print metrics. if i % self.config["cmd"]["print_every"] == 0: print(self.meter) self.scheduler.step() with torch.no_grad(): if self.val_loader is not None: v_loss, v_mae = self.validate(split="val", epoch=epoch) if self.test_loader is not None: test_loss, test_mae = self.validate(split="test", epoch=epoch) if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) if return_metrics: return { "training_loss": float(self.meter.loss.global_avg), "training_mae": float(self.meter.meters[ self.config["task"]["labels"][0] + "/" + self.config["task"]["metric"]].global_avg), "validation_loss": v_loss, "validation_mae": v_mae, "test_loss": test_loss, "test_mae": test_mae, } def validate(self, split="val", epoch=None): if distutils.is_master(): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def _forward(self, batch, compute_metrics=True): out = {} # enable gradient wrt input. if "grad_input" in self.config["task"]: inp_for_grad = batch.pos batch.pos = batch.pos.requires_grad_(True) # forward pass. if self.config["model_attributes"].get("regress_forces", False): output, output_forces = self.model(batch) else: output = self.model(batch) if output.shape[-1] == 1: output = output.view(-1) out["output"] = output force_output = None if self.config["model_attributes"].get("regress_forces", False): out["force_output"] = output_forces force_output = output_forces if ("grad_input" in self.config["task"] and self.config["model_attributes"].get( "regress_forces", False) is False): force_output = -1 * torch.autograd.grad( output, inp_for_grad, grad_outputs=torch.ones_like(output), create_graph=True, retain_graph=True, )[0] out["force_output"] = force_output if not compute_metrics: return out, None metrics = {} if self.config["dataset"].get("normalize_labels", True): errors = eval(self.config["task"]["metric"])( self.normalizers["target"].denorm(output).cpu(), batch.y.cpu()).view(-1) else: errors = eval(self.config["task"]["metric"])( output.cpu(), batch.y.cpu()).view(-1) if ("label_index" in self.config["task"] and self.config["task"]["label_index"] is not False): # TODO(abhshkdz): Get rid of this edge case for QM9. # This is only because QM9 has multiple targets and we can either # jointly predict all of them or one particular target. metrics["{}/{}".format( self.config["task"]["labels"][self.config["task"] ["label_index"]], self.config["task"]["metric"], )] = errors[0] else: for i, label in enumerate(self.config["task"]["labels"]): metrics["{}/{}".format( label, self.config["task"]["metric"])] = errors[i] if "grad_input" in self.config["task"]: force_pred = force_output force_target = batch.force if self.config["task"].get("eval_on_free_atoms", True): mask = batch.fixed == 0 force_pred = force_pred[mask] force_target = force_target[mask] if self.config["dataset"].get("normalize_labels", True): grad_input_errors = eval(self.config["task"]["metric"])( self.normalizers["grad_target"].denorm(force_pred).cpu(), force_target.cpu(), ) else: grad_input_errors = eval(self.config["task"]["metric"])( force_pred.cpu(), force_target.cpu()) metrics["force_x/{}".format( self.config["task"]["metric"])] = grad_input_errors[0] metrics["force_y/{}".format( self.config["task"]["metric"])] = grad_input_errors[1] metrics["force_z/{}".format( self.config["task"]["metric"])] = grad_input_errors[2] return out, metrics def _compute_loss(self, out, batch): loss = [] if self.config["dataset"].get("normalize_labels", True): target_normed = self.normalizers["target"].norm(batch.y) else: target_normed = batch.y loss.append(self.criterion(out["output"], target_normed)) # TODO(abhshkdz): Test support for gradients wrt input. # TODO(abhshkdz): Make this general; remove dependence on `.forces`. if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): grad_target_normed = self.normalizers["grad_target"].norm( batch.force) else: grad_target_normed = batch.force # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): mask = batch.fixed == 0 loss.append(force_mult * self.criterion( out["force_output"][mask], grad_target_normed[mask])) else: loss.append( force_mult * self.criterion(out["force_output"], grad_target_normed)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss def _backward(self, loss): self.optimizer.zero_grad() loss.backward() # TODO(abhshkdz): Add support for gradient clipping. if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() def save_results(self, predictions, results_file, keys): if results_file is None: return results_file_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{distutils.get_rank()}.npz", ) np.savez_compressed( results_file_path, ids=predictions["id"], **{key: predictions[key] for key in keys}, ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) for key in keys: gather_results[key].extend(rank_results[key]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] for k in keys: if k == "forces": gather_results[k] = np.array(gather_results[k], dtype=object)[idx] else: gather_results[k] = np.array(gather_results[k])[idx] print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results)
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") # Other parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() if args.do_train: print("Loading Train Dataset", args.data_dir) bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) if args.local_rank == 0: dist.barrier() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel.distributed import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: #model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: #from apex.optimizers.fp16_optimizer import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers.fused_adam import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) model.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): if args.local_rank != -1: train_sampler.set_epoch(i_epoch) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', disable=args.local_rank not in (-1, 0)) for step, batch in enumerate(iter_bar): batch = [ t.to(device) if t is not None else None for t in batch ] if args.has_sentence_oracle: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch else: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch oracle_pos, oracle_weights, oracle_labels = None, None, None loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() loss = masked_lm_loss + next_sentence_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model if (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{0}.bin".format(i_epoch)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.output_dir, "optim.{0}.bin".format(i_epoch)) torch.save(optimizer.state_dict(), output_optim_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache()
class DistributedForcesTrainer(BaseTrainer): def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task="s2ef") def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater(1) if self.config["task"]["dataset"] == "trajectory_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], shuffle=True, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) self.val_loader = self.test_loader = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) else: self.dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) ( self.train_loader, self.val_loader, self.test_loader, ) = self.dataset.get_dataloaders( batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, ) self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: self.normalizers["target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if self.config["model_attributes"].get("regress_forces", True): if self.config["dataset"].get("normalize_labels", True): if "grad_target_mean" in self.config["dataset"]: self.normalizers["grad_target"] = Normalizer( mean=self.config["dataset"]["grad_target_mean"], std=self.config["dataset"]["grad_target_std"], device=self.device, ) else: self.normalizers["grad_target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) self.normalizers["grad_target"].mean.fill_(0) if (self.is_vis and self.config["task"]["dataset"] != "qm9" and distutils.is_master()): # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots) def load_model(self): super(DistributedForcesTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1, ) self.model = DistributedDataParallel(self.model, device_ids=[self.device], find_unused_parameters=True) # Takes in a new data source and generates predictions on it. def predict(self, dataset, batch_size=32): if isinstance(dataset, dict): if self.config["task"]["dataset"] == "trajectory_lmdb": print("### Generating predictions on {}.".format( dataset["src"])) else: print("### Generating predictions on {}.".format( dataset["src"] + dataset["traj"])) dataset = registry.get_dataset_class( self.config["task"]["dataset"])(dataset) data_loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=self.parallel_collater, ) elif isinstance(dataset, torch_geometric.data.Batch): data_loader = [[dataset]] else: raise NotImplementedError self.model.eval() predictions = {"energy": [], "forces": []} for i, batch_list in enumerate(data_loader): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch_list) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"]) atoms_sum = 0 predictions["energy"].extend(out["energy"].tolist()) batch_natoms = torch.cat([batch.natoms for batch in batch_list]) for natoms in batch_natoms: predictions["forces"].append( out["forces"][atoms_sum:natoms + atoms_sum].cpu().detach().numpy()) atoms_sum += natoms return predictions def train(self): self.best_val_mae = 1e9 eval_every = self.config["optim"].get("eval_every", -1) iters = 0 self.metrics = {} for epoch in range(self.config["optim"]["max_epochs"]): self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, self.metrics, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if i % self.config["cmd"]["print_every"] == 0: log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) iters += 1 # Evaluate on val set every `eval_every` iterations. if eval_every != -1 and iters % eval_every == 0: if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator. task_primary_metric["s2ef"]]["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator. task_primary_metric["s2ef"]]["metric"] if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + (i + 1) / len(self.train_loader), "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, }, self.config["cmd"]["checkpoint_dir"], ) self.scheduler.step() torch.cuda.empty_cache() if eval_every == -1: if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric["s2ef"]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator. task_primary_metric["s2ef"]]["metric"] if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, }, self.config["cmd"]["checkpoint_dir"], ) if self.test_loader is not None: self.validate(split="test", epoch=epoch) if ("relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "train"): self.validate_relaxation( split="val", epoch=epoch, ) if ("relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "end"): self.validate_relaxation( split="val", epoch=epoch, ) def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task="s2ef"), {} loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm(enumerate(loader), total=len(loader)): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def validate_relaxation(self, split="val", epoch=None): print("### Evaluating ML-relaxation") self.model.eval() mae_energy, mae_structure = relax_eval( trainer=self, traj_dir=self.config["task"]["relaxation_dir"], metric=self.config["task"]["metric"], steps=self.config["task"].get("relaxation_steps", 300), fmax=self.config["task"].get("relaxation_fmax", 0.01), results_dir=self.config["cmd"]["results_dir"], ) mae_energy = distutils.all_reduce(mae_energy, average=True, device=self.device) mae_structure = distutils.all_reduce(mae_structure, average=True, device=self.device) log_dict = { "relaxed_energy_mae": mae_energy, "relaxed_structure_mae": mae_structure, "epoch": epoch + 1, } # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(log_dict) return mae_energy, mae_structure def _forward(self, batch_list): # forward pass. if self.config["model_attributes"].get("regress_forces", True): out_energy, out_forces = self.model(batch_list) else: out_energy = self.model(batch_list) if out_energy.shape[-1] == 1: out_energy = out_energy.view(-1) out = { "energy": out_energy, } if self.config["model_attributes"].get("regress_forces", True): out["forces"] = out_forces return out def _compute_loss(self, out, batch_list): loss = [] # Energy loss. energy_target = torch.cat( [batch.y.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): energy_target = self.normalizers["target"].norm(energy_target) energy_mult = self.config["optim"].get("energy_coefficient", 1) loss.append(energy_mult * self.criterion(out["energy"], energy_target)) # Force loss. if self.config["model_attributes"].get("regress_forces", True): force_target = torch.cat( [batch.force.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): force_target = self.normalizers["grad_target"].norm( force_target) # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list]) mask = fixed == 0 loss.append( force_mult * self.criterion(out["forces"][mask], force_target[mask])) else: loss.append(force_mult * self.criterion(out["forces"], force_target)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): target = { "energy": torch.cat([batch.y.to(self.device) for batch in batch_list], dim=0), "forces": torch.cat([batch.force.to(self.device) for batch in batch_list], dim=0), } if self.config["task"].get("eval_on_free_atoms", True): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list]) mask = fixed == 0 out["forces"] = out["forces"][mask] target["forces"] = target["forces"][mask] if self.config["dataset"].get("normalize_labels", True): out["energy"] = self.normalizers["target"].denorm(out["energy"]) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"]) metrics = evaluator.eval(out, target, prev_metrics=metrics) return metrics
class EnergyTrainer(BaseTrainer): """ Trainer class for the Initial Structure to Relaxed Energy (IS2RE) task. .. note:: Examples of configurations for task, model, dataset and optimizer can be found in `configs/ocp_is2re <https://github.com/Open-Catalyst-Project/baselines/tree/master/configs/ocp_is2re/>`_. Args: task (dict): Task configuration. model (dict): Model configuration. dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset. optimizer (dict): Optimizer configuration. identifier (str): Experiment identifier that is appended to log directory. run_dir (str, optional): Path to the run directory where logs are to be saved. (default: :obj:`None`) is_debug (bool, optional): Run in debug mode. (default: :obj:`False`) is_vis (bool, optional): Run in debug mode. (default: :obj:`False`) print_every (int, optional): Frequency of printing logs. (default: :obj:`100`) seed (int, optional): Random number seed. (default: :obj:`None`) logger (str, optional): Type of logger to be used. (default: :obj:`tensorboard`) local_rank (int, optional): Local rank of the process, only applicable for distributed training. (default: :obj:`0`) amp (bool, optional): Run using automatic mixed precision. (default: :obj:`False`) """ def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, "amp": amp, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task="is2re") def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater( 1, self.config["model_attributes"].get("otf_graph", False)) if self.config["task"]["dataset"] == "single_point_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=True, ) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.train_sampler, ) self.val_loader = self.test_loader = None self.val_sampler = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_sampler = DistributedSampler( self.val_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.val_sampler, ) if "test_dataset" in self.config: self.test_dataset = registry.get_dataset_class( self.config["task"]["dataset"])( self.config["test_dataset"]) self.test_sampler = DistributedSampler( self.test_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.test_loader = DataLoader( self.test_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.test_sampler, ) else: raise NotImplementedError self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", False): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: raise NotImplementedError def load_model(self): super(EnergyTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=self.config["optim"].get("num_gpus", 1), ) if distutils.initialized(): self.model = DistributedDataParallel(self.model, device_ids=[self.device]) def train(self): self.best_val_mae = 1e9 for epoch in range(self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch) self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, metrics={}, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if i % self.config["cmd"]["print_every"] == 0: log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) self.scheduler.step() torch.cuda.empty_cache() if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric["is2re"]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric["is2re"]]["metric"] if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) else: if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "metrics": self.metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) if self.test_loader is not None: self.validate(split="test", epoch=epoch) def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task="is2re"), {} loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm(enumerate(loader), total=len(loader)): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def _forward(self, batch_list): output = self.model(batch_list) if output.shape[-1] == 1: output = output.view(-1) return { "energy": output, } def _compute_loss(self, out, batch_list): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): target_normed = self.normalizers["target"].norm(energy_target) else: target_normed = energy_target loss = self.criterion(out["energy"], target_normed) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): out["energy"] = self.normalizers["target"].denorm(out["energy"]) metrics = evaluator.eval( out, {"energy": energy_target}, prev_metrics=metrics, ) return metrics def predict(self, loader, results_file=None, disable_tqdm=False): assert isinstance(loader, torch.utils.data.dataloader.DataLoader) self.model.eval() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) predictions = [] for i, batch in tqdm(enumerate(loader), total=len(loader), disable=disable_tqdm): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) predictions.extend(out["energy"].tolist()) if results_file is not None: print(f"Writing results to {results_file}") # EvalAI expects a list of energies with open(results_file, "w") as resfile: json.dump(predictions, resfile) return predictions
class Trainer: def __init__(self, config: BaseConfig): self._config = config self._model = DeepLab(num_classes=9, output_stride=8, sync_bn=False).to(self._config.device) self._border_loss = TotalLoss(self._config) self._direction_loss = CrossEntropyLoss() self._loaders = get_data_loaders(config) self._writer = SummaryWriter() self._optimizer = torch.optim.SGD(self._model.parameters(), lr=self._config.lr, weight_decay=1e-4, nesterov=True, momentum=0.9) self._scheduler = torch.optim.lr_scheduler.ExponentialLR( self._optimizer, gamma=0.97) if self._config.parallel: self._model = DistributedDataParallel(self._model, device_ids=[ self._config.device, ]) # self._load() def train(self): for epoch in range(self._config.num_epochs): t = tqdm(self._loaders[DataMode.train]) self._model.train() for idx, data in enumerate(t): self._optimizer.zero_grad() imgs, borders, masks, crop_info, opened_img, opened_mask = data imgs, borders, masks = imgs.to( self._config.device), borders.to( self._config.device), masks.to(self._config.device) # borders = borders.unsqueeze(1) output = self._model(imgs) border_output = output[:, :1, :, :].squeeze() direction_output = output[:, 1:, :, :] seg_loss = self._direction_loss(direction_output, masks) loss = self._border_loss(border_output, borders) + seg_loss t.set_description(f"LOSS: {seg_loss.item()}") loss.backward() self._optimizer.step() self._writer.add_scalar( "Loss/training", loss.item(), global_step=epoch * len(self._loaders[DataMode.train]) + idx) if idx % self._config.frequency_visualization[ DataMode.train] == 0: self._tensorboard_visualization( loss=loss, epoch=epoch, idx=idx, imgs=imgs, border_gt=borders.unsqueeze(1), border=border_output.unsqueeze(1), direction_gt=masks, direction=direction_output) if self._config.live_visualization: self._live_visualization(imgs, borders, output) self._save() @staticmethod def _get_mask(masks): masks_new = torch.zeros(masks.shape[0], 2, masks.shape[1], masks.shape[2], device=masks.device) for idx in range(2): masks_new[:, idx, :, :][masks == idx] = 1 return masks_new def validate(self, epoch): t = tqdm(self._loaders[DataMode.eval]) self._model.eval() for idx, data in enumerate(t): imgs, masks, _, _, _ = data imgs, masks = imgs.to(self._config.device), masks.to( self._config.device) output = self._model(imgs) loss = self._border_loss(masks, output) def _tensorboard_visualization(self, loss, epoch, idx, imgs, border_gt, direction_gt, border, direction): self._writer.add_images(f"Images{idx}/training", imgs, epoch) self._writer.add_images(f"Border{idx}/training", border, epoch) self._writer.add_images(f"BorderGT{idx}/training", border_gt, epoch) self._writer.add_images(f"DirectionGT{idx}/training", direction_gt.unsqueeze(1) / 8., epoch) self._writer.add_images(f"Direction{idx}/training", direction.argmax(1).unsqueeze(1) / 8., epoch) def _save(self): path = os.path.join(self._config.checkpoint_path, self._config.EXPERIMENT_NAME) self.check_and_mkdir(path) torch.save(self._model.state_dict(), os.path.join(path, "corrector.pth")) def _load(self): path = os.path.join(self._config.checkpoint_path, self._config.EXPERIMENT_NAME) weights = glob(os.path.join(path, "*.pth")) if len(weights): state_dict = torch.load(weights[0]) try: self._model.load_state_dict(state_dict) except RuntimeError as e: print("ERROR while loading weights: {}".format(e)) @staticmethod def check_and_mkdir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) @staticmethod def _live_visualization(imgs, masks, output): out = np.expand_dims( (output[1, 0, :, :].detach().cpu().numpy()).astype(np.float32), axis=2) > 0.7 out2 = np.expand_dims(np.argmax( (output[1, 1:, :, :].detach().cpu().numpy()).astype(np.float32), axis=0), axis=2) / 7. out3 = np.expand_dims( (output[1, 2, :, :].detach().cpu().numpy()).astype(np.float32), axis=2) show = np.zeros((masks.shape[2], 5 * masks.shape[3], 3), dtype=np.float32) show[:, :masks.shape[3]] = cv2.applyColorMap( (masks[1, 0, :, :] * 255).detach().cpu().numpy().astype(np.uint8), cv2.COLORMAP_BONE) show[:, masks.shape[3]:2 * masks.shape[3], :] = np.concatenate( [out, out, out], axis=2) show[:, 2 * masks.shape[3]:3 * masks.shape[3]] = cv2.cvtColor( imgs[1, ...].permute(1, 2, 0).cpu().detach().numpy(), cv2.COLOR_BGR2RGB) show[:, 3 * masks.shape[3]:4 * masks.shape[3]] = np.concatenate( [out2, out2, out2], axis=2) show[:, 4 * masks.shape[3]:5 * masks.shape[3]] = np.concatenate( [out3, out3, out3], axis=2) cv2.imshow("training", show) cv2.waitKey(10)
def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") model, diffusion = create_classifier_and_diffusion( **args_to_dict(args, classifier_and_diffusion_defaults().keys())) model.to(dist_util.dev()) if args.noised: schedule_sampler = create_named_schedule_sampler( args.schedule_sampler, diffusion) resume_step = 0 if args.resume_checkpoint: resume_step = parse_resume_step_from_filename(args.resume_checkpoint) if dist.get_rank() == 0: logger.log( f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step" ) model.load_state_dict( dist_util.load_state_dict(args.resume_checkpoint, map_location=dist_util.dev())) # Needed for creating correct EMAs and fp16 parameters. dist_util.sync_params(model.parameters()) mp_trainer = MixedPrecisionTrainer(model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0) model = DDP( model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) logger.log("creating data loader...") data = load_data( data_dir=args.data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, random_crop=True, ) if args.val_data_dir: val_data = load_data( data_dir=args.val_data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, ) else: val_data = None logger.log(f"creating optimizer...") opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay) if args.resume_checkpoint: opt_checkpoint = bf.join(bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt") logger.log( f"loading optimizer state from checkpoint: {opt_checkpoint}") opt.load_state_dict( dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())) logger.log("training classifier model...") def forward_backward_log(data_loader, prefix="train"): batch, extra = next(data_loader) labels = extra["y"].to(dist_util.dev()) batch = batch.to(dist_util.dev()) # Noisy images if args.noised: t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev()) batch = diffusion.q_sample(batch, t) else: t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev()) for i, (sub_batch, sub_labels, sub_t) in enumerate( split_microbatches(args.microbatch, batch, labels, t)): logits = model(sub_batch, timesteps=sub_t) loss = F.cross_entropy(logits, sub_labels, reduction="none") losses = {} losses[f"{prefix}_loss"] = loss.detach() losses[f"{prefix}_acc@1"] = compute_top_k(logits, sub_labels, k=1, reduction="none") losses[f"{prefix}_acc@5"] = compute_top_k(logits, sub_labels, k=5, reduction="none") log_loss_dict(diffusion, sub_t, losses) del losses loss = loss.mean() if loss.requires_grad: if i == 0: mp_trainer.zero_grad() mp_trainer.backward(loss * len(sub_batch) / len(batch)) for step in range(args.iterations - resume_step): logger.logkv("step", step + resume_step) logger.logkv( "samples", (step + resume_step + 1) * args.batch_size * dist.get_world_size(), ) if args.anneal_lr: set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations) forward_backward_log(data) mp_trainer.optimize(opt) if val_data is not None and not step % args.eval_interval: with th.no_grad(): with model.no_sync(): model.eval() forward_backward_log(val_data, prefix="val") model.train() if not step % args.log_interval: logger.dumpkvs() if (step and dist.get_rank() == 0 and not (step + resume_step) % args.save_interval): logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) if dist.get_rank() == 0: logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) dist.barrier()
class DistributedEnergyTrainer(BaseTrainer): def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, "amp": amp, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task="is2re") def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater(1) if self.config["task"]["dataset"] == "single_point_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=True, ) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.train_sampler, ) self.val_loader = self.test_loader = None self.val_sampler = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_sampler = DistributedSampler( self.val_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.val_sampler, ) else: raise NotImplementedError self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: raise NotImplementedError def load_model(self): super(DistributedEnergyTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=self.config["optim"].get("num_gpus", 1), ) self.model = DistributedDataParallel(self.model, device_ids=[self.device], find_unused_parameters=True) def train(self): self.best_val_mae = 1e9 for epoch in range(self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch) self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, metrics={}, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if i % self.config["cmd"]["print_every"] == 0: log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) self.scheduler.step() torch.cuda.empty_cache() if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if self.test_loader is not None: self.validate(split="test", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric["is2re"]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric["is2re"]]["metric"] if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task="is2re"), {} loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm(enumerate(loader), total=len(loader)): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def _forward(self, batch_list): output = self.model(batch_list) if output.shape[-1] == 1: output = output.view(-1) return { "energy": output, } def _compute_loss(self, out, batch_list): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): target_normed = self.normalizers["target"].norm(energy_target) else: target_normed = energy_target loss = self.criterion(out["energy"], target_normed) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): out["energy"] = self.normalizers["target"].denorm(out["energy"]) metrics = evaluator.eval( out, {"energy": energy_target}, prev_metrics=metrics, ) return metrics