def _check_and_init_distributed_model(self): if not self.options.use_data_parallel_distributed: return if not dist.is_initialized(): world_size = self.options.dist_world_size url = self.options.dist_url rank = self.options.dist_rank # This is for SLURM's special use case if rank == -1: rank = int(os.environ.get("SLURM_NODEID")) print("=> Distributed training: world size: {}, rank: {}, URL: {}". format(world_size, rank, url)) dist.init_process_group(backend="nccl", init_method=url, rank=rank, world_size=world_size) # Initialize the distributed data parallel model master_gpu = self.options.gpu if master_gpu is None or master_gpu < 0: raise RuntimeError("Distributed training requires " "to put the model on the GPU, but the GPU is " "not given in the argument") # This is needed for distributed model since the distributed model # initialization will require the model be on the GPU, even though # the later code will put the same model on the GPU again with # self.options.gpu, so this should be ok self.resnet.cuda(master_gpu) self.resnet = nn.parallel.DistributedDataParallel( self.resnet, output_device=master_gpu)
def barrier(): if dist.is_available() and dist.is_initialized(): dist.barrier()
def train( self, base_path: Union[Path, str], fix_len=20, min_freq=2, buckets=1000, batch_size=5000, lr=2e-3, mu=.9, nu=.9, epsilon=1e-12, clip=5.0, decay=.75, decay_steps=5000, patience=100, max_epochs=10, wandb=None ): r""" Train any class that implement model interface Args: base_path (object): Main path to which all output during training is logged and models are saved max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed. patience: decay_steps: decay: clip: epsilon: nu: mu: lr: proj: tree: batch_size: buckets: min_freq: fix_len: """ ################################################################################################################ # BUILD ################################################################################################################ feat = self.parser.feat embed = self.parser.embed os.makedirs(os.path.dirname(base_path), exist_ok=True) logger.info("Building the fields") WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True) if feat == 'char': FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=fix_len) elif feat == 'bert': from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.parser.bert) FEAT = SubwordField('bert', pad=tokenizer.pad_token, unk=tokenizer.unk_token, bos=tokenizer.bos_token or tokenizer.cls_token, fix_len=fix_len, tokenize=tokenizer.tokenize) FEAT.vocab = tokenizer.get_vocab() else: FEAT = Field('tags', bos=bos) ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs) REL = Field('rels', bos=bos) if feat in ('char', 'bert'): transform = CoNLL(FORM=(WORD, FEAT), HEAD=ARC, DEPREL=REL) else: transform = CoNLL(FORM=WORD, CPOS=FEAT, HEAD=ARC, DEPREL=REL) train = Dataset(transform, self.corpus.train) WORD.build(train, min_freq, (Embedding.load(embed, unk) if self.parser.embed else None)) FEAT.build(train) REL.build(train) n_words = WORD.vocab.n_init n_feats = len(FEAT.vocab) n_rels = len(REL.vocab) pad_index = WORD.pad_index unk_index = WORD.unk_index feat_pad_index = FEAT.pad_index parser = DependencyParser( n_words=n_words, n_feats=n_feats, n_rels=n_rels, pad_index=pad_index, unk_index=unk_index, feat_pad_index=feat_pad_index, transform=transform, feat=self.parser.feat, bert=self.parser.bert ) # word_field_embeddings = self.parser.embeddings[0] # word_field_embeddings.n_vocab = 100 parser.embeddings = self.parser.embeddings # parser.embeddings[0] = word_field_embeddings parser.load_pretrained(WORD.embed).to(device) ################################################################################################################ # TRAIN ################################################################################################################ if wandb: wandb.watch(parser) parser.transform.train() if dist.is_initialized(): batch_size = batch_size // dist.get_world_size() logger.info('Loading the data') train = Dataset(parser.transform, self.corpus.train) dev = Dataset(parser.transform, self.corpus.dev) test = Dataset(parser.transform, self.corpus.test) train.build(batch_size, buckets, True, dist.is_initialized()) dev.build(batch_size, buckets) test.build(batch_size, buckets) logger.info(f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") logger.info(f'{parser}') if dist.is_initialized(): parser = DDP(parser, device_ids=[dist.get_rank()], find_unused_parameters=True) optimizer = Adam(parser.parameters(), lr, (mu, nu), epsilon) scheduler = ExponentialLR(optimizer, decay ** (1 / decay_steps)) elapsed = timedelta() best_e, best_metric = 1, Metric() for epoch in range(1, max_epochs + 1): start = datetime.now() logger.info(f'Epoch {epoch} / {max_epochs}:') parser.train() bar = progress_bar(train.loader) metric = AttachmentMetric() for words, feats, arcs, rels in bar: optimizer.zero_grad() mask = words.ne(parser.WORD.pad_index) # ignore the first token of each sentence mask[:, 0] = 0 s_arc, s_rel = parser.forward(words, feats) loss = parser.forward_loss(s_arc, s_rel, arcs, rels, mask) loss.backward() nn.utils.clip_grad_norm_(parser.parameters(), clip) optimizer.step() scheduler.step() arc_preds, rel_preds = parser.decode(s_arc, s_rel, mask) # ignore all punctuation if not specified if not self.parser.args['punct']: mask &= words.unsqueeze(-1).ne(parser.puncts).all(-1) metric(arc_preds, rel_preds, arcs, rels, mask) bar.set_postfix_str(f'lr: {scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}') dev_loss, dev_metric = parser.evaluate(dev.loader) logger.info(f"{'dev:':6} - loss: {dev_loss:.4f} - {dev_metric}") test_loss, test_metric = parser.evaluate(test.loader) logger.info(f"{'test:':6} - loss: {test_loss:.4f} - {test_metric}") if wandb: wandb.log({"test_loss": test_loss}) wandb.log({"test_metric_uas": test_metric.uas}) wandb.log({"test_metric_las": test_metric.las}) t = datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric if is_master(): parser.save(base_path) logger.info(f'{t}s elapsed (saved)\n') else: logger.info(f'{t}s elapsed\n') elapsed += t if epoch - best_e >= patience: break loss, metric = parser.load(base_path).evaluate(test.loader) logger.info(f'Epoch {best_e} saved') logger.info(f"{'dev:':6} - {best_metric}") logger.info(f"{'test:':6} - {metric}") logger.info(f'{elapsed}s elapsed, {elapsed / epoch}s/epoch')
def __init__( self, datadir, crop_size=(512, 512), target_transform=None, common_transforms=None, transform=None, val=False, band_norm=True, ): super(ICVLDataset, self).__init__() datadir = Path(datadir) self.files = [datadir / f for f in os.listdir(datadir) if f.endswith(".npy")] if dist.is_initialized(): random.shuffle(self.files) # load all the data at the top self.loadfrom = [] # np.zeros(first, dtype=np.float32) self.band_norm = band_norm for c, f in enumerate(self.files): # the images are already in [bands, height, width] # loaded, _ = utils.normalize( # torch.tensor(np.load(f), dtype=torch.float32), by_band=band_norm, band_dim=0 # ) loaded = torch.tensor(np.load(f), dtype=torch.float32) self.loadfrom.append(loaded) self.loadfrom = tuple(self.loadfrom) if not val: self.base_transforms = transforms.Compose( [ # transforms.CenterCrop(crop_size), # transforms.RandomCrop(crop_size), transforms.RandomResizedCrop( crop_size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333) ), hyde_transforms.RandomBandPerm(10), hyde_transforms.RandChoice( [ hyde_transforms.RandRot90Transform(), transforms.RandomVerticalFlip(p=0.9), transforms.RandomAffine( degrees=180, # scale=(0.1, 10), # old (0.1, 3) shear=20, ), transforms.RandomHorizontalFlip(p=0.9), transforms.RandomPerspective(p=0.88), ], p=None, # 0.5, combos=True, ), ] ) else: self.base_transforms = transforms.CenterCrop(crop_size) # RandomCrop(crop_size) self.target_transform = target_transform self.common_transforms = common_transforms self.length = len(self.files) self.transform = transform
def reduce_mean(tensor): if not (dist.is_available() and dist.is_initialized()): return tensor tensor = tensor.clone() dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) return tensor
def get_rank() -> int: if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank()
def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier()
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): """Initialize and get a logger by name. If the logger has not been initialized, this method will initialize the logger by adding one or two handlers, otherwise the initialized logger will be directly returned. During initialization, a StreamHandler will always be added. If `log_file` is specified and the process rank is 0, a FileHandler will also be added. Args: name (str): Logger name. log_file (str | None): The log filename. If specified, a FileHandler will be added to the logger. log_level (int): The logger level. Note that only the process of rank 0 is affected, and other processes will set the level to "Error" thus be silent most of the time. file_mode (str): The file mode used in opening log file. Defaults to 'w'. Returns: logging.Logger: The expected logger. """ logger = logging.getLogger(name) if name in logger_initialized: return logger # handle hierarchical names # e.g., logger "a" is initialized, then logger "a.b" will skip the # initialization since it is a child of "a". for logger_name in logger_initialized: if name.startswith(logger_name): return logger # handle duplicate logs to the console # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET) # to the root logger. As logger.propagate is True by default, this root # level handler causes logging messages from rank>0 processes to # unexpectedly show up on the console, creating much unwanted clutter. # To fix this issue, we set the root logger's StreamHandler, if any, to log # at the ERROR level. for handler in logger.root.handlers: if type(handler) is logging.StreamHandler: handler.setLevel(logging.ERROR) stream_handler = logging.StreamHandler() handlers = [stream_handler] if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() else: rank = 0 # only rank 0 will add a FileHandler if rank == 0 and log_file is not None: # Here, the default behaviour of the official logger is 'a'. Thus, we # provide an interface to change the file mode to the default # behaviour. file_handler = logging.FileHandler(log_file, file_mode) handlers.append(file_handler) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') for handler in handlers: handler.setFormatter(formatter) handler.setLevel(log_level) logger.addHandler(handler) if rank == 0: logger.setLevel(log_level) else: logger.setLevel(logging.ERROR) logger_initialized[name] = True return logger
def print0(message): if dist.is_initialized(): if dist.get_rank() == 0: print(message, flush=True) else: print(message, flush=True)
def create_from_context() -> Optional["_NativeDistModel"]: if not (dist.is_available() and dist.is_initialized()): return None return _NativeDistModel()
def is_distributed(): """ Return if we are in distributed mode. """ return TORCH_AVAILABLE and dist.is_available() and dist.is_initialized()
def retrieve(self, combined_hidden_states: np.ndarray,current_hidden_states: np.ndarray, history_hidden_states: np.ndarray, n_docs: int, dialog_lengths: List[Tuple]=None) -> \ Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]: """ Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries from all the processes in the main training process group, performs the retrieval and scatters back the results. Args: question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): A batch of query vectors to retrieve with. n_docs (:obj:`int`): The number of docs retrieved per query. Output: retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` The retrieval embeddings of the retrieved docs per query. doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) The ids of the documents in the index doc_dicts (:obj:`List[dict]`): The retrieved_doc_embeds examples per query. """ # single GPU training if not dist.is_initialized(): # doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs) doc_ids, retrieved_doc_embeds, doc_scores = self._main_retrieve(combined_hidden_states, current_hidden_states, history_hidden_states, n_docs, dialog_lengths) # return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids) return retrieved_doc_embeds, doc_ids, doc_scores, self.index.get_doc_dicts(doc_ids) # distributed training world_size = dist.get_world_size(group=self.process_group) # gather logic gather_list_1 = None gather_list_2 = None gather_list_3 = None if self._is_main(): gather_list_1 = [torch.empty(combined_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)] gather_list_2 = [torch.empty(current_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)] gather_list_3 = [torch.empty(history_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)] dist.gather(torch.tensor(combined_hidden_states), dst=0, gather_list=gather_list_1, group=self.process_group) dist.gather(torch.tensor(current_hidden_states), dst=0, gather_list=gather_list_2, group=self.process_group) dist.gather(torch.tensor(history_hidden_states), dst=0, gather_list=gather_list_3, group=self.process_group) # scatter logic n_queries = combined_hidden_states.shape[0] scatter_ids = [] scatter_vectors = [] scatter_scores = [] if self._is_main(): assert len(gather_list_1) == len(gather_list_2) == len(gather_list_3) == world_size comb_h_s = torch.cat(gather_list_1).numpy() curr_h_s = torch.cat(gather_list_2).numpy() hist_h_s = torch.cat(gather_list_3).numpy() ids, vectors, scores = self._main_retrieve(comb_h_s, curr_h_s, hist_h_s, n_docs, dialog_lengths) ids, vectors, scores = torch.tensor(ids), torch.tensor(vectors), torch.tensor(scores) scatter_ids = self._chunk_tensor(ids, n_queries) scatter_vectors = self._chunk_tensor(vectors, n_queries) scatter_scores = self._chunk_tensor(scores, n_queries) doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64) retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, combined_hidden_states.shape[1]]) doc_scores = self._scattered(scatter_scores, [n_queries, n_docs], torch.float64) return retrieved_doc_embeds.numpy(), doc_ids.numpy(), doc_scores.numpy(), self.index.get_doc_dicts(doc_ids)
def tearDown(self): if dist.is_initialized(): dist.destroy_process_group()
def print_once(msg): if not dist.is_initialized() or dist.get_rank() == 0: print(msg)
def get_rank(): return dist.get_rank( ) if dist.is_available() and dist.is_initialized() else 0
def get_world_size(): return dist.get_world_size( ) if dist.is_available() and dist.is_initialized() else 1
def barrier(self, *args, **kwargs): if torch_distrib.is_initialized(): torch_distrib.barrier()
def __init__(self, args, model, optimizer=None, model_parameters=None, training_data=None, lr_scheduler=None, mpu=None, dist_init_required=None, collate_fn=None): super(DeepSpeedLight, self).__init__() logging.basicConfig(level=logging.INFO, format="[%(levelname)s %(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") self.client_optimizer = optimizer self.client_model_parameters = model_parameters self.client_lr_scheduler = lr_scheduler self.training_data = training_data self.collate_fn = collate_fn self.mpu = mpu self.data_parallel_group = None self.global_steps = 0 self.micro_steps = 0 self.skipped_steps = 0 self.gradient_predivide_factor = 1.0 self.gradient_average = True self.warn_unscaled_loss = True if dist_init_required is None: dist_init_required = not dist.is_initialized() self._mpi_check(args, dist_init_required) self.dist_backend = "nccl" if dist_init_required: if not dist.is_initialized(): logging.info( "Initializing torch distributed with backend: {}".format( self.dist_backend)) dist.init_process_group(backend=self.dist_backend) else: logging.warning( "Was given dist_init_required=True but detected that torch" "distributed was already initialized, cannot initialize twice." ) self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() self.sample_count = 0 if self.tensorboard_enabled(): self.summary_writer = self.get_summary_writer() self._init_distributed(dist_init_required) # Throughput timer self.tput_timer = ThroughputTimer( batch_size=self.train_micro_batch_size_per_gpu(), num_workers=self.world_size, monitor_memory=False) self.training_dataloader = self.deepspeed_io( training_data) if training_data else None # Configure distributed model self._configure_distributed_model(model) # Configure optimizer and scheduler self.optimizer = None self.lr_scheduler = None if model_parameters or optimizer: self._configure_optimizer(optimizer, model_parameters) self._configure_lr_scheduler(lr_scheduler) self._report_progress(0) # Configure wall clock timer self.timers = SynchronizedWallClockTimer() # Bookkeeping for csr support self.csr_tensor_module_names = set() if self.sparse_gradients_enabled(): for name, module in self.module.named_modules(): if isinstance(module, torch.nn.Embedding): self.csr_tensor_module_names.add(name) logging.info("Will convert {} to sparse (csr) " "tensor during training".format(name)) self.save_non_zero_checkpoint = False self.save_zero_checkpoint = False self._configure_checkpointing(dist_init_required) if self.global_rank == 0: self._config.print('DeepSpeedLight configuration') if self.dump_state(): print_configuration(self, 'DeepSpeedLight')
def save(self, model, ema_model, optimizer, epoch, step, best_wer, is_best=False): """Saves model checkpoint for inference/resuming training. Args: model: the model, optionally wrapped by DistributedDataParallel ema_model: model with averaged weights, can be None optimizer: optimizer epoch (int): epoch during which the model is saved step (int): number of steps since beginning of training best_wer (float): lowest recorded WER on the dev set is_best (bool, optional): set name of checkpoint to 'best' and overwrite the previous one """ rank = 0 if dist.is_initialized(): dist.barrier() rank = dist.get_rank() if rank != 0: return # Checkpoint already saved if not is_best and epoch in self.tracked: return unwrap_ddp = lambda model: getattr(model, 'module', model) state = { 'epoch': epoch, 'step': step, 'best_wer': best_wer, 'state_dict': unwrap_ddp(model).state_dict(), 'ema_state_dict': unwrap_ddp(ema_model).state_dict() if ema_model is not None else None, 'optimizer': optimizer.state_dict(), 'amp': amp.state_dict() if self.use_amp else None, } if is_best: fpath = os.path.join(self.save_dir, f"{self.model_name}_best_checkpoint.pt") else: fpath = os.path.join( self.save_dir, f"{self.model_name}_epoch{epoch}_checkpoint.pt") print_once(f"Saving {fpath}...") torch.save(state, fpath) if not is_best: # Remove old checkpoints; keep milestones and the last two self.tracked[epoch] = fpath for epoch in set(list(self.tracked)[:-2]) - set( self.keep_milestones): try: os.remove(self.tracked[epoch]) except: pass del self.tracked[epoch]
def get_world_size() -> int: if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size()
def losses(self, shifts, gt_instances, box_cls, box_delta, box_center): box_cls_flattened = [ permute_to_N_HWA_K(x, self.num_classes) for x in box_cls ] box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta] box_center_flattened = [permute_to_N_HWA_K(x, 1) for x in box_center] pred_class_logits = cat(box_cls_flattened, dim=1) pred_shift_deltas = cat(box_delta_flattened, dim=1) pred_obj_logits = cat(box_center_flattened, dim=1) pred_class_probs = pred_class_logits.sigmoid() pred_obj_probs = pred_obj_logits.sigmoid() pred_box_probs = [] num_foreground = pred_class_logits.new_zeros(1) num_background = pred_class_logits.new_zeros(1) positive_losses = [] gaussian_norm_losses = [] for shifts_per_image, gt_instances_per_image, \ pred_class_probs_per_image, pred_shift_deltas_per_image, \ pred_obj_probs_per_image in zip( shifts, gt_instances, pred_class_probs, pred_shift_deltas, pred_obj_probs): locations = torch.cat(shifts_per_image, dim=0) labels = gt_instances_per_image.gt_classes gt_boxes = gt_instances_per_image.gt_boxes target_shift_deltas = self.shift2box_transform.get_deltas( locations, gt_boxes.tensor.unsqueeze(1)) is_in_boxes = target_shift_deltas.min(dim=-1).values > 0 foreground_idxs = torch.nonzero(is_in_boxes, as_tuple=True) with torch.no_grad(): # predicted_boxes_per_image: a_{j}^{loc}, shape: [j, 4] predicted_boxes_per_image = self.shift2box_transform.apply_deltas( pred_shift_deltas_per_image, locations) # gt_pred_iou: IoU_{ij}^{loc}, shape: [i, j] gt_pred_iou = pairwise_iou( gt_boxes, Boxes(predicted_boxes_per_image)).max( dim=0, keepdim=True).values.repeat( len(gt_instances_per_image), 1) # pred_box_prob_per_image: P{a_{j} \in A_{+}}, shape: [j, c] pred_box_prob_per_image = torch.zeros_like( pred_class_probs_per_image) box_prob = 1 / (1 - gt_pred_iou[foreground_idxs]).clamp_(1e-12) for i in range(len(gt_instances_per_image)): idxs = foreground_idxs[0] == i if idxs.sum() > 0: box_prob[idxs] = normalize(box_prob[idxs]) pred_box_prob_per_image[foreground_idxs[1], labels[foreground_idxs[0]]] = box_prob pred_box_probs.append(pred_box_prob_per_image) normal_probs = [] for stride, shifts_i in zip(self.fpn_strides, shifts_per_image): gt_shift_deltas = self.shift2box_transform.get_deltas( shifts_i, gt_boxes.tensor.unsqueeze(1)) distances = (gt_shift_deltas[..., :2] - gt_shift_deltas[..., 2:]) / 2 normal_probs.append( normal_distribution(distances / stride, self.mu[labels].unsqueeze(1), self.sigma[labels].unsqueeze(1))) normal_probs = torch.cat(normal_probs, dim=1).prod(dim=-1) composed_cls_prob = pred_class_probs_per_image[:, labels] * pred_obj_probs_per_image # matched_gt_shift_deltas: P_{ij}^{loc} loss_box_reg = iou_loss(pred_shift_deltas_per_image.unsqueeze(0), target_shift_deltas, box_mode="ltrb", loss_type=self.iou_loss_type, reduction="none") * self.reg_weight pred_reg_probs = (-loss_box_reg).exp() # positive_losses: { -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } positive_losses.append( positive_bag_loss( composed_cls_prob.transpose(1, 0) * pred_reg_probs, is_in_boxes.float(), normal_probs)) num_foreground += len(gt_instances_per_image) num_background += normal_probs[foreground_idxs].sum().item() gaussian_norm_losses.append( len(gt_instances_per_image) / normal_probs[foreground_idxs].sum().clamp_(1e-12)) if dist.is_initialized(): dist.all_reduce(num_foreground) num_foreground /= dist.get_world_size() dist.all_reduce(num_background) num_background /= dist.get_world_size() # positive_loss: \sum_{i}{ -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } / ||B|| positive_loss = torch.cat(positive_losses).sum() / max( 1, num_foreground) # pred_box_probs: P{a_{j} \in A_{+}} pred_box_probs = torch.stack(pred_box_probs, dim=0) # negative_loss: \sum_{j}{ FL( (1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}) ) } / n||B|| negative_loss = negative_bag_loss( pred_class_probs * pred_obj_probs * (1 - pred_box_probs), self.focal_loss_gamma).sum() / max(1, num_background) loss_pos = positive_loss * self.focal_loss_alpha loss_neg = negative_loss * (1 - self.focal_loss_alpha) loss_norm = torch.stack(gaussian_norm_losses).mean() * ( 1 - self.focal_loss_alpha) return { "loss_pos": loss_pos, "loss_neg": loss_neg, "loss_norm": loss_norm, }
def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True
def train(self, train, dev, test, buckets=32, batch_size=5000, lr=2e-3, mu=.9, nu=.9, epsilon=1e-12, clip=5.0, decay=.75, decay_steps=5000, epochs=5000, patience=100, verbose=True, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.train() if dist.is_initialized(): args.batch_size = args.batch_size // dist.get_world_size() logger.info("Loading the data") train = Dataset(self.transform, args.train, **args) dev = Dataset(self.transform, args.dev) test = Dataset(self.transform, args.test) logger.info("Building the datasets") train.build(args.batch_size, args.buckets, True, dist.is_initialized()) logger.info("train built") dev.build(args.batch_size, args.buckets) logger.info("dev built") test.build(args.batch_size, args.buckets) logger.info( f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") logger.info(f"{self.model}\n") if dist.is_initialized(): self.model = DDP(self.model, device_ids=[args.local_rank], find_unused_parameters=True) self.optimizer = Adam(self.model.parameters(), args.lr, (args.mu, args.nu), args.epsilon) self.scheduler = ExponentialLR(self.optimizer, args.decay**(1 / args.decay_steps)) elapsed = timedelta() best_e, best_metric = 1, Metric() for epoch in range(1, args.epochs + 1): start = datetime.now() logger.info(f"Epoch {epoch} / {args.epochs}:") self._train(train.loader) loss, dev_metric = self._evaluate(dev.loader) logger.info(f"{'dev:':6} - loss: {loss:.4f} - {dev_metric}") loss, test_metric = self._evaluate(test.loader) logger.info(f"{'test:':6} - loss: {loss:.4f} - {test_metric}") t = datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric if is_master(): self.save(args.path) logger.info(f"{t}s elapsed (saved)\n") else: logger.info(f"{t}s elapsed\n") elapsed += t if epoch - best_e >= args.patience: break loss, metric = self.load(**args)._evaluate(test.loader) logger.info(f"Epoch {best_e} saved") logger.info(f"{'dev:':6} - {best_metric}") logger.info(f"{'test:':6} - {metric}") logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch")
def setup(self, config): self.args = args = config["args"] start = time.time() self.tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None, ) logger.info(f"tokenizer instantiation time: {time.time() - start}") # Load data. train_dataset = load_and_cache_examples( args, args.task_name, self.tokenizer, evaluate=False ) train_sampler = ( RandomSampler(train_dataset) if not dist.is_initialized() else None ) train_loader = DataLoader( train_dataset, sampler=train_sampler, batch_size=args.per_device_train_batch_size, ) # Create model. with FileLock(os.path.expanduser("~/.download.lock")): processor = processors[args.task_name]() label_list = processor.get_labels() num_labels = len(label_list) model_config = AutoConfig.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name, cache_dir=args.cache_dir if args.cache_dir else None, ) model = AutoModelForSequenceClassification.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=model_config, cache_dir=args.cache_dir if args.cache_dir else None, ) # Create optimizer. no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = AdamW( optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon ) # Register components. self.model, self.optimizer = self.register( models=model, optimizers=optimizer, apex_args={"opt_level": args.fp16_opt_level}, ) self.register_data(train_loader=train_loader, validation_loader=None) self.train_data_len = len(self.train_loader) self._warmup_scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=self.calculate_t_total(), ) self._global_step = 0 announce_training(args, self.train_data_len, self.calculate_t_total())
def forward(ctx, x): if (dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1)): x = x.contiguous() / dist.get_world_size() dist.all_reduce(x) return x
def barrier(self, name: str = None): if torch_distrib.is_initialized(): torch_distrib.barrier()
def __init__( self, device: torch.device, max_epochs: int, data_loader: Union[Iterable, DataLoader], epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, post_transform: Optional[Callable] = None, key_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, handlers: Optional[Sequence] = None, amp: bool = False, event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, ) -> None: if iteration_update is not None: super().__init__(iteration_update) else: super().__init__(self._iteration) if not isinstance(device, torch.device): raise TypeError(f"device must be a torch.device but is {type(device).__name__}.") if isinstance(data_loader, DataLoader): sampler = data_loader.__dict__["sampler"] if isinstance(sampler, DistributedSampler): @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine): sampler.set_epoch(engine.state.epoch) if epoch_length is None: epoch_length = len(data_loader) else: if epoch_length is None: raise ValueError("if data_loader is not PyTorch DataLoader, must specify the epoch_length.") # set all sharable data for the workflow based on Ignite engine.state self.state = State( rank=dist.get_rank() if dist.is_available() and dist.is_initialized() else 0, seed=0, iteration=0, epoch=0, max_epochs=max_epochs, epoch_length=epoch_length, output=None, batch=None, metrics={}, metric_details={}, dataloader=None, device=device, key_metric_name=None, # we can set many metrics, only use key_metric to compare and save the best model best_metric=-1, best_metric_epoch=-1, ) self.data_loader = data_loader self.non_blocking = non_blocking self.prepare_batch = prepare_batch self.amp = amp event_names = [IterationEvents] if event_names is None else event_names + [IterationEvents] for name in event_names: if isinstance(name, str): self.register_events(name, event_to_attr=event_to_attr) elif issubclass(name, EventEnum): self.register_events(*name, event_to_attr=event_to_attr) else: raise ValueError("event_names must be a list or string or EventEnum.") if post_transform is not None: self._register_post_transforms(post_transform) if key_metric is not None: self._register_metrics(key_metric, additional_metrics) if handlers is not None: self._register_handlers(handlers)