def load_vaes_256(H, logprint): vae = VAE_256(H) if H.restore_path: logprint(f'Restoring vae from {H.restore_path}') restore_params(vae, H.restore_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size) ema_vae = VAE_256(H) if H.restore_ema_path: logprint(f'Restoring ema vae from {H.restore_ema_path}') restore_params(ema_vae, H.restore_ema_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size) else: ema_vae.load_state_dict(vae.state_dict()) ema_vae.requires_grad_(False) vae = vae.cuda(H.local_rank) ema_vae = ema_vae.cuda(H.local_rank) if H.image_size == 64: vae.decoder.requires_grad_(False) if H.image_size == 256: vae.encoder.requires_grad_(False) vae = DistributedDataParallel(vae, device_ids=[H.local_rank], output_device=H.local_rank, find_unused_parameters=True) if len(list(vae.named_parameters())) != len(list(vae.parameters())): raise ValueError('Some params are not named. Please name all params.') total_params = 0 for name, p in vae.named_parameters(): total_params += np.prod(p.shape) logprint(total_params=total_params, readable=f'{total_params:,}') return vae, ema_vae
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") # Other parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() if args.do_train: print("Loading Train Dataset", args.data_dir) bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) if args.local_rank == 0: dist.barrier() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel.distributed import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: #model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: #from apex.optimizers.fp16_optimizer import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers.fused_adam import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) model.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): if args.local_rank != -1: train_sampler.set_epoch(i_epoch) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', disable=args.local_rank not in (-1, 0)) for step, batch in enumerate(iter_bar): batch = [ t.to(device) if t is not None else None for t in batch ] if args.has_sentence_oracle: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch else: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch oracle_pos, oracle_weights, oracle_labels = None, None, None loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() loss = masked_lm_loss + next_sentence_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model if (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{0}.bin".format(i_epoch)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.output_dir, "optim.{0}.bin".format(i_epoch)) torch.save(optimizer.state_dict(), output_optim_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache()
class BaseTrainer(ABC): def __init__( self, task, model, dataset, optimizer, identifier, normalizer=None, timestamp_id=None, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, name="base_trainer", slurm={}, ): self.name = name self.cpu = cpu self.epoch = 0 self.step = 0 if torch.cuda.is_available() and not self.cpu: self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") self.cpu = True # handle case when `--cpu` isn't specified # but there are no gpu devices available if run_dir is None: run_dir = os.getcwd() if timestamp_id is None: timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device ) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp( timestamp.int() ).strftime("%Y-%m-%d-%H-%M-%S") if identifier: self.timestamp_id = f"{timestamp}-{identifier}" else: self.timestamp_id = timestamp else: self.timestamp_id = timestamp_id try: commit_hash = ( subprocess.check_output( [ "git", "-C", ocpmodels.__path__[0], "describe", "--always", ] ) .strip() .decode("ascii") ) # catch instances where code is not being run from a git repo except Exception: commit_hash = None self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "gpus": distutils.get_world_size() if not self.cpu else 0, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp_id": self.timestamp_id, "commit": commit_hash, "checkpoint_dir": os.path.join( run_dir, "checkpoints", self.timestamp_id ), "results_dir": os.path.join( run_dir, "results", self.timestamp_id ), "logs_dir": os.path.join( run_dir, "logs", logger, self.timestamp_id ), }, "slurm": slurm, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if "SLURM_JOB_ID" in os.environ and "folder" in self.config["slurm"]: self.config["slurm"]["job_id"] = os.environ["SLURM_JOB_ID"] self.config["slurm"]["folder"] = self.config["slurm"][ "folder" ].replace("%j", self.config["slurm"]["job_id"]) if isinstance(dataset, list): if len(dataset) > 0: self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] elif isinstance(dataset, dict): self.config["dataset"] = dataset.get("train", None) self.config["val_dataset"] = dataset.get("val", None) self.config["test_dataset"] = dataset.get("test", None) else: self.config["dataset"] = dataset self.normalizer = normalizer # This supports the legacy way of providing norm parameters in dataset if self.config.get("dataset", None) is not None and normalizer is None: self.normalizer = self.config["dataset"] if not is_debug and distutils.is_master() and not is_hpo: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis self.is_hpo = is_hpo if self.is_hpo: # conditional import is necessary for checkpointing from ray import tune from ocpmodels.common.hpo_utils import tune_reporter # sets the hpo checkpoint frequency # default is no checkpointing self.hpo_checkpoint_every = self.config["optim"].get( "checkpoint_every", -1 ) if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name) def load(self): self.load_seed_from_config() self.load_logger() self.load_task() self.load_model() self.load_loss() self.load_optimizer() self.load_extras() def load_seed_from_config(self): # https://pytorch.org/docs/stable/notes/randomness.html seed = self.config["cmd"]["seed"] if seed is None: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_logger(self): self.logger = None if not self.is_debug and distutils.is_master() and not self.is_hpo: assert ( self.config["logger"] is not None ), "Specify logger in config" self.logger = registry.get_logger_class(self.config["logger"])( self.config ) def get_sampler(self, dataset, batch_size, shuffle): if "load_balancing" in self.config["optim"]: balancing_mode = self.config["optim"]["load_balancing"] force_balancing = True else: balancing_mode = "atoms" force_balancing = False sampler = BalancedBatchSampler( dataset, batch_size=batch_size, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), device=self.device, mode=balancing_mode, shuffle=shuffle, force_balancing=force_balancing, ) return sampler def get_dataloader(self, dataset, sampler): loader = DataLoader( dataset, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, batch_sampler=sampler, ) return loader @abstractmethod def load_task(self): """Derived classes should implement this function.""" def load_model(self): # Build model if distutils.is_master(): logging.info(f"Loading model: {self.config['model']}") # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50 ) else: raise NotImplementedError loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( loader.dataset[0].x.shape[-1] if loader and hasattr(loader.dataset[0], "x") and loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) if distutils.is_master(): logging.info( f"Loaded {self.model.__class__.__name__} with " f"{self.model.num_params} parameters." ) if self.logger is not None: self.logger.watch(self.model) self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1 if not self.cpu else 0, ) if distutils.initialized(): self.model = DistributedDataParallel( self.model, device_ids=[self.device] ) def load_checkpoint(self, checkpoint_path): if not os.path.isfile(checkpoint_path): raise FileNotFoundError( errno.ENOENT, "Checkpoint file not found", checkpoint_path ) logging.info(f"Loading checkpoint from: {checkpoint_path}") map_location = torch.device("cpu") if self.cpu else self.device checkpoint = torch.load(checkpoint_path, map_location=map_location) self.epoch = checkpoint.get("epoch", 0) self.step = checkpoint.get("step", 0) # Load model, optimizer, normalizer state dict. # if trained with ddp and want to load in non-ddp, modify keys from # module.module.. -> module.. first_key = next(iter(checkpoint["state_dict"])) if not distutils.initialized() and first_key.split(".")[1] == "module": # No need for OrderedDict since dictionaries are technically ordered # since Python 3.6 and officially ordered since Python 3.7 new_dict = {k[7:]: v for k, v in checkpoint["state_dict"].items()} self.model.load_state_dict(new_dict) else: self.model.load_state_dict(checkpoint["state_dict"]) if "optimizer" in checkpoint: self.optimizer.load_state_dict(checkpoint["optimizer"]) if "scheduler" in checkpoint and checkpoint["scheduler"] is not None: self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"]) if "ema" in checkpoint and checkpoint["ema"] is not None: self.ema.load_state_dict(checkpoint["ema"]) for key in checkpoint["normalizers"]: if key in self.normalizers: self.normalizers[key].load_state_dict( checkpoint["normalizers"][key] ) if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) def load_loss(self): self.loss_fn = {} self.loss_fn["energy"] = self.config["optim"].get("loss_energy", "mae") self.loss_fn["force"] = self.config["optim"].get("loss_force", "mae") for loss, loss_name in self.loss_fn.items(): if loss_name in ["l1", "mae"]: self.loss_fn[loss] = nn.L1Loss() elif loss_name == "mse": self.loss_fn[loss] = nn.MSELoss() elif loss_name == "l2mae": self.loss_fn[loss] = L2MAELoss() else: raise NotImplementedError( f"Unknown loss function name: {loss_name}" ) if distutils.initialized(): self.loss_fn[loss] = DDPLoss(self.loss_fn[loss]) def load_optimizer(self): optimizer = self.config["optim"].get("optimizer", "AdamW") optimizer = getattr(optim, optimizer) if self.config["optim"].get("weight_decay", 0) > 0: # Do not regularize bias etc. params_decay = [] params_no_decay = [] for name, param in self.model.named_parameters(): if param.requires_grad: if "embedding" in name: params_no_decay += [param] elif "frequencies" in name: params_no_decay += [param] elif "bias" in name: params_no_decay += [param] else: params_decay += [param] self.optimizer = optimizer( [ {"params": params_no_decay, "weight_decay": 0}, { "params": params_decay, "weight_decay": self.config["optim"]["weight_decay"], }, ], lr=self.config["optim"]["lr_initial"], **self.config["optim"].get("optimizer_params", {}), ) else: self.optimizer = optimizer( params=self.model.parameters(), lr=self.config["optim"]["lr_initial"], **self.config["optim"].get("optimizer_params", {}), ) def load_extras(self): self.scheduler = LRScheduler(self.optimizer, self.config["optim"]) self.clip_grad_norm = self.config["optim"].get("clip_grad_norm") self.ema_decay = self.config["optim"].get("ema_decay") if self.ema_decay: self.ema = ExponentialMovingAverage( self.model.parameters(), self.ema_decay, ) else: self.ema = None def save( self, metrics=None, checkpoint_file="checkpoint.pt", training_state=True, ): if not self.is_debug and distutils.is_master(): if training_state: save_checkpoint( { "epoch": self.epoch, "step": self.step, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.scheduler.state_dict() if self.scheduler.scheduler_type != "Null" else None, "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "ema": self.ema.state_dict() if self.ema else None, "amp": self.scaler.state_dict() if self.scaler else None, }, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) else: if self.ema: self.ema.store() self.ema.copy_to() save_checkpoint( { "state_dict": self.model.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) if self.ema: self.ema.restore() def save_hpo(self, epoch, step, metrics, checkpoint_every): # default is no checkpointing # checkpointing frequency can be adjusted by setting checkpoint_every in steps # to checkpoint every time results are communicated to Ray Tune set checkpoint_every=1 if checkpoint_every != -1 and step % checkpoint_every == 0: with tune.checkpoint_dir( # noqa: F821 step=step ) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") torch.save(self.save_state(epoch, step, metrics), path) def hpo_update( self, epoch, step, train_metrics, val_metrics, test_metrics=None ): progress = { "steps": step, "epochs": epoch, "act_lr": self.optimizer.param_groups[0]["lr"], } # checkpointing must occur before reporter # default is no checkpointing self.save_hpo( epoch, step, val_metrics, self.hpo_checkpoint_every, ) # report metrics to tune tune_reporter( # noqa: F821 iters=progress, train_metrics={ k: train_metrics[k]["metric"] for k in self.metrics }, val_metrics={k: val_metrics[k]["metric"] for k in val_metrics}, test_metrics=test_metrics, ) @abstractmethod def train(self): """Derived classes should implement this function.""" @torch.no_grad() def validate(self, split="val", disable_tqdm=False): if distutils.is_master(): logging.info(f"Evaluating on {split}.") if self.is_hpo: disable_tqdm = True self.model.eval() if self.ema: self.ema.store() self.ema.copy_to() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce( metrics[k]["total"], average=False, device=self.device ), "numel": distutils.all_reduce( metrics[k]["numel"], average=False, device=self.device ), } aggregated_metrics[k]["metric"] = ( aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"] ) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": self.epoch}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] logging.info(", ".join(log_str)) # Make plots. if self.logger is not None: self.logger.log( log_dict, step=self.step, split=split, ) if self.ema: self.ema.restore() return metrics @abstractmethod def _forward(self, batch_list): """Derived classes should implement this function.""" @abstractmethod def _compute_loss(self, out, batch_list): """Derived classes should implement this function.""" def _backward(self, loss): self.optimizer.zero_grad() loss.backward() # Scale down the gradients of shared parameters if hasattr(self.model, "shared_parameters"): for p, factor in self.model.shared_parameters: if p.grad is not None: p.grad.detach().div_(factor) if self.clip_grad_norm: if self.scaler: self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=self.clip_grad_norm, ) if self.logger is not None: self.logger.log( {"grad_norm": grad_norm}, step=self.step, split="train" ) if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() if self.ema: self.ema.update() def save_results(self, predictions, results_file, keys): if results_file is None: return results_file_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{distutils.get_rank()}.npz", ) np.savez_compressed( results_file_path, ids=predictions["id"], **{key: predictions[key] for key in keys}, ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) for key in keys: gather_results[key].extend(rank_results[key]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] for k in keys: if k == "forces": gather_results[k] = np.concatenate( np.array(gather_results[k])[idx] ) elif k == "chunk_idx": gather_results[k] = np.cumsum( np.array(gather_results[k])[idx] )[:-1] else: gather_results[k] = np.array(gather_results[k])[idx] logging.info(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results)