Beispiel #1
0
    def _check_and_init_distributed_model(self):
        if not self.options.use_data_parallel_distributed:
            return

        if not dist.is_initialized():
            world_size = self.options.dist_world_size
            url = self.options.dist_url
            rank = self.options.dist_rank
            # This is for SLURM's special use case
            if rank == -1:
                rank = int(os.environ.get("SLURM_NODEID"))

            print("=> Distributed training: world size: {}, rank: {}, URL: {}".
                  format(world_size, rank, url))

            dist.init_process_group(backend="nccl",
                                    init_method=url,
                                    rank=rank,
                                    world_size=world_size)

        # Initialize the distributed data parallel model
        master_gpu = self.options.gpu
        if master_gpu is None or master_gpu < 0:
            raise RuntimeError("Distributed training requires "
                               "to put the model on the GPU, but the GPU is "
                               "not given in the argument")
        # This is needed for distributed model since the distributed model
        # initialization will require the model be on the GPU, even though
        # the later code will put the same model on the GPU again with
        # self.options.gpu, so this should be ok
        self.resnet.cuda(master_gpu)
        self.resnet = nn.parallel.DistributedDataParallel(
            self.resnet,
            output_device=master_gpu)
Beispiel #2
0
def barrier():
    if dist.is_available() and dist.is_initialized():
        dist.barrier()
Beispiel #3
0
    def train(
        self, base_path: Union[Path, str],
        fix_len=20,
        min_freq=2,
        buckets=1000,
        batch_size=5000,
        lr=2e-3,
        mu=.9,
        nu=.9,
        epsilon=1e-12,
        clip=5.0,
        decay=.75,
        decay_steps=5000,
        patience=100,
        max_epochs=10,
        wandb=None
    ):
        r"""
        Train any class that implement model interface

        Args:
            base_path (object): Main path to which all output during training is logged and models are saved
            max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed.
            patience:
            decay_steps:
            decay:
            clip:
            epsilon:
            nu:
            mu:
            lr:
            proj:
            tree:
            batch_size:
            buckets:
            min_freq:
            fix_len:


        """
        ################################################################################################################
        # BUILD
        ################################################################################################################
        feat = self.parser.feat
        embed = self.parser.embed
        os.makedirs(os.path.dirname(base_path), exist_ok=True)
        logger.info("Building the fields")
        WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
        if feat == 'char':
            FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=fix_len)
        elif feat == 'bert':
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained(self.parser.bert)
            FEAT = SubwordField('bert',
                                pad=tokenizer.pad_token,
                                unk=tokenizer.unk_token,
                                bos=tokenizer.bos_token or tokenizer.cls_token,
                                fix_len=fix_len,
                                tokenize=tokenizer.tokenize)
            FEAT.vocab = tokenizer.get_vocab()
        else:
            FEAT = Field('tags', bos=bos)

        ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs)
        REL = Field('rels', bos=bos)
        if feat in ('char', 'bert'):
            transform = CoNLL(FORM=(WORD, FEAT), HEAD=ARC, DEPREL=REL)
        else:
            transform = CoNLL(FORM=WORD, CPOS=FEAT, HEAD=ARC, DEPREL=REL)

        train = Dataset(transform, self.corpus.train)
        WORD.build(train, min_freq, (Embedding.load(embed, unk) if self.parser.embed else None))
        FEAT.build(train)
        REL.build(train)
        n_words = WORD.vocab.n_init
        n_feats = len(FEAT.vocab)
        n_rels = len(REL.vocab)
        pad_index = WORD.pad_index
        unk_index = WORD.unk_index
        feat_pad_index = FEAT.pad_index
        parser = DependencyParser(
            n_words=n_words,
            n_feats=n_feats,
            n_rels=n_rels,
            pad_index=pad_index,
            unk_index=unk_index,
            feat_pad_index=feat_pad_index,
            transform=transform,
            feat=self.parser.feat,
            bert=self.parser.bert
        )
        # word_field_embeddings = self.parser.embeddings[0]
        # word_field_embeddings.n_vocab = 100
        parser.embeddings = self.parser.embeddings
        # parser.embeddings[0] = word_field_embeddings
        parser.load_pretrained(WORD.embed).to(device)

        ################################################################################################################
        # TRAIN
        ################################################################################################################
        if wandb:
            wandb.watch(parser)
        parser.transform.train()
        if dist.is_initialized():
            batch_size = batch_size // dist.get_world_size()
        logger.info('Loading the data')
        train = Dataset(parser.transform, self.corpus.train)
        dev = Dataset(parser.transform, self.corpus.dev)
        test = Dataset(parser.transform, self.corpus.test)
        train.build(batch_size, buckets, True, dist.is_initialized())
        dev.build(batch_size, buckets)
        test.build(batch_size, buckets)
        logger.info(f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n")
        logger.info(f'{parser}')
        if dist.is_initialized():
            parser = DDP(parser, device_ids=[dist.get_rank()], find_unused_parameters=True)

        optimizer = Adam(parser.parameters(), lr, (mu, nu), epsilon)
        scheduler = ExponentialLR(optimizer, decay ** (1 / decay_steps))

        elapsed = timedelta()
        best_e, best_metric = 1, Metric()

        for epoch in range(1, max_epochs + 1):
            start = datetime.now()
            logger.info(f'Epoch {epoch} / {max_epochs}:')

            parser.train()

            bar = progress_bar(train.loader)
            metric = AttachmentMetric()
            for words, feats, arcs, rels in bar:
                optimizer.zero_grad()

                mask = words.ne(parser.WORD.pad_index)
                # ignore the first token of each sentence
                mask[:, 0] = 0
                s_arc, s_rel = parser.forward(words, feats)
                loss = parser.forward_loss(s_arc, s_rel, arcs, rels, mask)
                loss.backward()
                nn.utils.clip_grad_norm_(parser.parameters(), clip)
                optimizer.step()
                scheduler.step()

                arc_preds, rel_preds = parser.decode(s_arc, s_rel, mask)
                # ignore all punctuation if not specified
                if not self.parser.args['punct']:
                    mask &= words.unsqueeze(-1).ne(parser.puncts).all(-1)
                metric(arc_preds, rel_preds, arcs, rels, mask)
                bar.set_postfix_str(f'lr: {scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}')

            dev_loss, dev_metric = parser.evaluate(dev.loader)
            logger.info(f"{'dev:':6} - loss: {dev_loss:.4f} - {dev_metric}")
            test_loss, test_metric = parser.evaluate(test.loader)
            logger.info(f"{'test:':6} - loss: {test_loss:.4f} - {test_metric}")
            if wandb:
                wandb.log({"test_loss": test_loss})
                wandb.log({"test_metric_uas": test_metric.uas})
                wandb.log({"test_metric_las": test_metric.las})

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric:
                best_e, best_metric = epoch, dev_metric
                if is_master():
                    parser.save(base_path)
                logger.info(f'{t}s elapsed (saved)\n')
            else:
                logger.info(f'{t}s elapsed\n')
            elapsed += t
            if epoch - best_e >= patience:
                break
        loss, metric = parser.load(base_path).evaluate(test.loader)

        logger.info(f'Epoch {best_e} saved')
        logger.info(f"{'dev:':6} - {best_metric}")
        logger.info(f"{'test:':6} - {metric}")
        logger.info(f'{elapsed}s elapsed, {elapsed / epoch}s/epoch')
    def __init__(
        self,
        datadir,
        crop_size=(512, 512),
        target_transform=None,
        common_transforms=None,
        transform=None,
        val=False,
        band_norm=True,
    ):
        super(ICVLDataset, self).__init__()
        datadir = Path(datadir)
        self.files = [datadir / f for f in os.listdir(datadir) if f.endswith(".npy")]
        if dist.is_initialized():
            random.shuffle(self.files)

        # load all the data at the top
        self.loadfrom = []  # np.zeros(first, dtype=np.float32)
        self.band_norm = band_norm
        for c, f in enumerate(self.files):
            # the images are already in [bands, height, width]
            # loaded, _ = utils.normalize(
            #     torch.tensor(np.load(f), dtype=torch.float32), by_band=band_norm, band_dim=0
            # )
            loaded = torch.tensor(np.load(f), dtype=torch.float32)
            self.loadfrom.append(loaded)

        self.loadfrom = tuple(self.loadfrom)

        if not val:
            self.base_transforms = transforms.Compose(
                [
                    # transforms.CenterCrop(crop_size),
                    # transforms.RandomCrop(crop_size),
                    transforms.RandomResizedCrop(
                        crop_size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333)
                    ),
                    hyde_transforms.RandomBandPerm(10),
                    hyde_transforms.RandChoice(
                        [
                            hyde_transforms.RandRot90Transform(),
                            transforms.RandomVerticalFlip(p=0.9),
                            transforms.RandomAffine(
                                degrees=180,
                                # scale=(0.1, 10), # old (0.1, 3)
                                shear=20,
                            ),
                            transforms.RandomHorizontalFlip(p=0.9),
                            transforms.RandomPerspective(p=0.88),
                        ],
                        p=None,  # 0.5,
                        combos=True,
                    ),
                ]
            )
        else:
            self.base_transforms = transforms.CenterCrop(crop_size)  # RandomCrop(crop_size)

        self.target_transform = target_transform
        self.common_transforms = common_transforms
        self.length = len(self.files)

        self.transform = transform
Beispiel #5
0
def reduce_mean(tensor):
    if not (dist.is_available() and dist.is_initialized()):
        return tensor
    tensor = tensor.clone()
    dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
    return tensor
Beispiel #6
0
def get_rank() -> int:
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank()
Beispiel #7
0
 def barrier(self, name: Optional[str] = None):
     if torch_distrib.is_initialized():
         torch_distrib.barrier()
Beispiel #8
0
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added. If `log_file` is specified and the process rank is 0, a FileHandler
    will also be added.
    Args:
        name (str): Logger name.
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the logger.
        log_level (int): The logger level. Note that only the process of
            rank 0 is affected, and other processes will set the level to
            "Error" thus be silent most of the time.
        file_mode (str): The file mode used in opening log file.
            Defaults to 'w'.
    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger
    # handle hierarchical names
    # e.g., logger "a" is initialized, then logger "a.b" will skip the
    # initialization since it is a child of "a".
    for logger_name in logger_initialized:
        if name.startswith(logger_name):
            return logger

    # handle duplicate logs to the console
    # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
    # to the root logger. As logger.propagate is True by default, this root
    # level handler causes logging messages from rank>0 processes to
    # unexpectedly show up on the console, creating much unwanted clutter.
    # To fix this issue, we set the root logger's StreamHandler, if any, to log
    # at the ERROR level.
    for handler in logger.root.handlers:
        if type(handler) is logging.StreamHandler:
            handler.setLevel(logging.ERROR)

    stream_handler = logging.StreamHandler()
    handlers = [stream_handler]

    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
    else:
        rank = 0

    # only rank 0 will add a FileHandler
    if rank == 0 and log_file is not None:
        # Here, the default behaviour of the official logger is 'a'. Thus, we
        # provide an interface to change the file mode to the default
        # behaviour.
        file_handler = logging.FileHandler(log_file, file_mode)
        handlers.append(file_handler)

    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)

    if rank == 0:
        logger.setLevel(log_level)
    else:
        logger.setLevel(logging.ERROR)

    logger_initialized[name] = True

    return logger
Beispiel #9
0
def print0(message):
    if dist.is_initialized():
        if dist.get_rank() == 0:
            print(message, flush=True)
    else:
        print(message, flush=True)
Beispiel #10
0
 def create_from_context() -> Optional["_NativeDistModel"]:
     if not (dist.is_available() and dist.is_initialized()):
         return None
     return _NativeDistModel()
def is_distributed():
    """
    Return if we are in distributed mode.
    """
    return TORCH_AVAILABLE and dist.is_available() and dist.is_initialized()
    def retrieve(self, combined_hidden_states: np.ndarray,current_hidden_states: np.ndarray,
                 history_hidden_states: np.ndarray, n_docs: int, dialog_lengths: List[Tuple]=None) -> \
            Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]:
        """
        Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries
        from all the processes in the main training process group, performs the retrieval and scatters back the results.

        Args:
            question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
                A batch of query vectors to retrieve with.
            n_docs (:obj:`int`):
                The number of docs retrieved per query.

        Output:
            retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
                The retrieval embeddings of the retrieved docs per query.
            doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
                The ids of the documents in the index
            doc_dicts (:obj:`List[dict]`):
                The retrieved_doc_embeds examples per query.
        """

        # single GPU training
        if not dist.is_initialized():
            # doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
            doc_ids, retrieved_doc_embeds, doc_scores = self._main_retrieve(combined_hidden_states,
                                                                current_hidden_states,
                                                                history_hidden_states,
                                                                n_docs,
                                                                dialog_lengths)
            # return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
            return retrieved_doc_embeds, doc_ids, doc_scores, self.index.get_doc_dicts(doc_ids)

        # distributed training
        world_size = dist.get_world_size(group=self.process_group)

        # gather logic
        gather_list_1 = None
        gather_list_2 = None
        gather_list_3 = None
        if self._is_main():
            gather_list_1 = [torch.empty(combined_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
            gather_list_2 = [torch.empty(current_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
            gather_list_3 = [torch.empty(history_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
        dist.gather(torch.tensor(combined_hidden_states), dst=0, gather_list=gather_list_1, group=self.process_group)
        dist.gather(torch.tensor(current_hidden_states), dst=0, gather_list=gather_list_2, group=self.process_group)
        dist.gather(torch.tensor(history_hidden_states), dst=0, gather_list=gather_list_3, group=self.process_group)

        # scatter logic
        n_queries = combined_hidden_states.shape[0]
        scatter_ids = []
        scatter_vectors = []
        scatter_scores = []
        if self._is_main():
            assert len(gather_list_1) == len(gather_list_2) == len(gather_list_3) == world_size
            comb_h_s = torch.cat(gather_list_1).numpy()
            curr_h_s = torch.cat(gather_list_2).numpy()
            hist_h_s = torch.cat(gather_list_3).numpy()
            ids, vectors, scores = self._main_retrieve(comb_h_s, curr_h_s, hist_h_s, n_docs, dialog_lengths)
            ids, vectors, scores = torch.tensor(ids), torch.tensor(vectors), torch.tensor(scores)
            scatter_ids = self._chunk_tensor(ids, n_queries)
            scatter_vectors = self._chunk_tensor(vectors, n_queries)
            scatter_scores = self._chunk_tensor(scores, n_queries)

        doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64)
        retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, combined_hidden_states.shape[1]])
        doc_scores = self._scattered(scatter_scores, [n_queries, n_docs], torch.float64)

        return retrieved_doc_embeds.numpy(), doc_ids.numpy(), doc_scores.numpy(), self.index.get_doc_dicts(doc_ids)
 def tearDown(self):
     if dist.is_initialized():
         dist.destroy_process_group()
Beispiel #14
0
def print_once(msg):
    if not dist.is_initialized() or dist.get_rank() == 0:
        print(msg)
Beispiel #15
0
def get_rank():
    return dist.get_rank(
    ) if dist.is_available() and dist.is_initialized() else 0
Beispiel #16
0
def get_world_size():
    return dist.get_world_size(
    ) if dist.is_available() and dist.is_initialized() else 1
Beispiel #17
0
 def barrier(self, *args, **kwargs):
     if torch_distrib.is_initialized():
         torch_distrib.barrier()
Beispiel #18
0
    def __init__(self,
                 args,
                 model,
                 optimizer=None,
                 model_parameters=None,
                 training_data=None,
                 lr_scheduler=None,
                 mpu=None,
                 dist_init_required=None,
                 collate_fn=None):
        super(DeepSpeedLight, self).__init__()

        logging.basicConfig(level=logging.INFO,
                            format="[%(levelname)s %(asctime)s] %(message)s",
                            datefmt="%Y-%m-%d %H:%M:%S")

        self.client_optimizer = optimizer
        self.client_model_parameters = model_parameters
        self.client_lr_scheduler = lr_scheduler
        self.training_data = training_data
        self.collate_fn = collate_fn
        self.mpu = mpu
        self.data_parallel_group = None
        self.global_steps = 0
        self.micro_steps = 0
        self.skipped_steps = 0
        self.gradient_predivide_factor = 1.0
        self.gradient_average = True
        self.warn_unscaled_loss = True

        if dist_init_required is None:
            dist_init_required = not dist.is_initialized()

        self._mpi_check(args, dist_init_required)

        self.dist_backend = "nccl"
        if dist_init_required:
            if not dist.is_initialized():
                logging.info(
                    "Initializing torch distributed with backend: {}".format(
                        self.dist_backend))
                dist.init_process_group(backend=self.dist_backend)
            else:
                logging.warning(
                    "Was given dist_init_required=True but detected that torch"
                    "distributed was already initialized, cannot initialize twice."
                )

        self._do_args_sanity_check(args)
        self._configure_with_arguments(args, mpu)
        self._do_sanity_check()

        self.sample_count = 0
        if self.tensorboard_enabled():
            self.summary_writer = self.get_summary_writer()

        self._init_distributed(dist_init_required)

        # Throughput timer
        self.tput_timer = ThroughputTimer(
            batch_size=self.train_micro_batch_size_per_gpu(),
            num_workers=self.world_size,
            monitor_memory=False)

        self.training_dataloader = self.deepspeed_io(
            training_data) if training_data else None

        # Configure distributed model
        self._configure_distributed_model(model)

        # Configure optimizer and scheduler
        self.optimizer = None
        self.lr_scheduler = None
        if model_parameters or optimizer:
            self._configure_optimizer(optimizer, model_parameters)
            self._configure_lr_scheduler(lr_scheduler)
            self._report_progress(0)

        # Configure wall clock timer
        self.timers = SynchronizedWallClockTimer()

        # Bookkeeping for csr support
        self.csr_tensor_module_names = set()
        if self.sparse_gradients_enabled():
            for name, module in self.module.named_modules():
                if isinstance(module, torch.nn.Embedding):
                    self.csr_tensor_module_names.add(name)
                    logging.info("Will convert {} to sparse (csr) "
                                 "tensor during training".format(name))

        self.save_non_zero_checkpoint = False
        self.save_zero_checkpoint = False
        self._configure_checkpointing(dist_init_required)

        if self.global_rank == 0:
            self._config.print('DeepSpeedLight configuration')
            if self.dump_state():
                print_configuration(self, 'DeepSpeedLight')
Beispiel #19
0
    def save(self,
             model,
             ema_model,
             optimizer,
             epoch,
             step,
             best_wer,
             is_best=False):
        """Saves model checkpoint for inference/resuming training.

        Args:
            model: the model, optionally wrapped by DistributedDataParallel
            ema_model: model with averaged weights, can be None
            optimizer: optimizer
            epoch (int): epoch during which the model is saved
            step (int): number of steps since beginning of training
            best_wer (float): lowest recorded WER on the dev set
            is_best (bool, optional): set name of checkpoint to 'best'
                and overwrite the previous one
        """
        rank = 0
        if dist.is_initialized():
            dist.barrier()
            rank = dist.get_rank()

        if rank != 0:
            return

        # Checkpoint already saved
        if not is_best and epoch in self.tracked:
            return

        unwrap_ddp = lambda model: getattr(model, 'module', model)
        state = {
            'epoch':
            epoch,
            'step':
            step,
            'best_wer':
            best_wer,
            'state_dict':
            unwrap_ddp(model).state_dict(),
            'ema_state_dict':
            unwrap_ddp(ema_model).state_dict()
            if ema_model is not None else None,
            'optimizer':
            optimizer.state_dict(),
            'amp':
            amp.state_dict() if self.use_amp else None,
        }

        if is_best:
            fpath = os.path.join(self.save_dir,
                                 f"{self.model_name}_best_checkpoint.pt")
        else:
            fpath = os.path.join(
                self.save_dir, f"{self.model_name}_epoch{epoch}_checkpoint.pt")

        print_once(f"Saving {fpath}...")
        torch.save(state, fpath)

        if not is_best:
            # Remove old checkpoints; keep milestones and the last two
            self.tracked[epoch] = fpath
            for epoch in set(list(self.tracked)[:-2]) - set(
                    self.keep_milestones):
                try:
                    os.remove(self.tracked[epoch])
                except:
                    pass
                del self.tracked[epoch]
Beispiel #20
0
def get_world_size() -> int:
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size()
Beispiel #21
0
    def losses(self, shifts, gt_instances, box_cls, box_delta, box_center):
        box_cls_flattened = [
            permute_to_N_HWA_K(x, self.num_classes) for x in box_cls
        ]
        box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta]
        box_center_flattened = [permute_to_N_HWA_K(x, 1) for x in box_center]
        pred_class_logits = cat(box_cls_flattened, dim=1)
        pred_shift_deltas = cat(box_delta_flattened, dim=1)
        pred_obj_logits = cat(box_center_flattened, dim=1)

        pred_class_probs = pred_class_logits.sigmoid()
        pred_obj_probs = pred_obj_logits.sigmoid()
        pred_box_probs = []
        num_foreground = pred_class_logits.new_zeros(1)
        num_background = pred_class_logits.new_zeros(1)
        positive_losses = []
        gaussian_norm_losses = []

        for shifts_per_image, gt_instances_per_image, \
            pred_class_probs_per_image, pred_shift_deltas_per_image, \
            pred_obj_probs_per_image in zip(
                shifts, gt_instances, pred_class_probs, pred_shift_deltas,
                pred_obj_probs):
            locations = torch.cat(shifts_per_image, dim=0)
            labels = gt_instances_per_image.gt_classes
            gt_boxes = gt_instances_per_image.gt_boxes

            target_shift_deltas = self.shift2box_transform.get_deltas(
                locations, gt_boxes.tensor.unsqueeze(1))
            is_in_boxes = target_shift_deltas.min(dim=-1).values > 0

            foreground_idxs = torch.nonzero(is_in_boxes, as_tuple=True)

            with torch.no_grad():
                # predicted_boxes_per_image: a_{j}^{loc}, shape: [j, 4]
                predicted_boxes_per_image = self.shift2box_transform.apply_deltas(
                    pred_shift_deltas_per_image, locations)
                # gt_pred_iou: IoU_{ij}^{loc}, shape: [i, j]
                gt_pred_iou = pairwise_iou(
                    gt_boxes, Boxes(predicted_boxes_per_image)).max(
                        dim=0, keepdim=True).values.repeat(
                            len(gt_instances_per_image), 1)

                # pred_box_prob_per_image: P{a_{j} \in A_{+}}, shape: [j, c]
                pred_box_prob_per_image = torch.zeros_like(
                    pred_class_probs_per_image)
                box_prob = 1 / (1 - gt_pred_iou[foreground_idxs]).clamp_(1e-12)
                for i in range(len(gt_instances_per_image)):
                    idxs = foreground_idxs[0] == i
                    if idxs.sum() > 0:
                        box_prob[idxs] = normalize(box_prob[idxs])
                pred_box_prob_per_image[foreground_idxs[1],
                                        labels[foreground_idxs[0]]] = box_prob
                pred_box_probs.append(pred_box_prob_per_image)

            normal_probs = []
            for stride, shifts_i in zip(self.fpn_strides, shifts_per_image):
                gt_shift_deltas = self.shift2box_transform.get_deltas(
                    shifts_i, gt_boxes.tensor.unsqueeze(1))
                distances = (gt_shift_deltas[..., :2] -
                             gt_shift_deltas[..., 2:]) / 2
                normal_probs.append(
                    normal_distribution(distances / stride,
                                        self.mu[labels].unsqueeze(1),
                                        self.sigma[labels].unsqueeze(1)))
            normal_probs = torch.cat(normal_probs, dim=1).prod(dim=-1)

            composed_cls_prob = pred_class_probs_per_image[:,
                                                           labels] * pred_obj_probs_per_image

            # matched_gt_shift_deltas: P_{ij}^{loc}
            loss_box_reg = iou_loss(pred_shift_deltas_per_image.unsqueeze(0),
                                    target_shift_deltas,
                                    box_mode="ltrb",
                                    loss_type=self.iou_loss_type,
                                    reduction="none") * self.reg_weight
            pred_reg_probs = (-loss_box_reg).exp()

            # positive_losses: { -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) }
            positive_losses.append(
                positive_bag_loss(
                    composed_cls_prob.transpose(1, 0) * pred_reg_probs,
                    is_in_boxes.float(), normal_probs))

            num_foreground += len(gt_instances_per_image)
            num_background += normal_probs[foreground_idxs].sum().item()

            gaussian_norm_losses.append(
                len(gt_instances_per_image) /
                normal_probs[foreground_idxs].sum().clamp_(1e-12))

        if dist.is_initialized():
            dist.all_reduce(num_foreground)
            num_foreground /= dist.get_world_size()
            dist.all_reduce(num_background)
            num_background /= dist.get_world_size()

        # positive_loss: \sum_{i}{ -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } / ||B||
        positive_loss = torch.cat(positive_losses).sum() / max(
            1, num_foreground)

        # pred_box_probs: P{a_{j} \in A_{+}}
        pred_box_probs = torch.stack(pred_box_probs, dim=0)
        # negative_loss: \sum_{j}{ FL( (1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}) ) } / n||B||
        negative_loss = negative_bag_loss(
            pred_class_probs * pred_obj_probs * (1 - pred_box_probs),
            self.focal_loss_gamma).sum() / max(1, num_background)

        loss_pos = positive_loss * self.focal_loss_alpha
        loss_neg = negative_loss * (1 - self.focal_loss_alpha)
        loss_norm = torch.stack(gaussian_norm_losses).mean() * (
            1 - self.focal_loss_alpha)

        return {
            "loss_pos": loss_pos,
            "loss_neg": loss_neg,
            "loss_norm": loss_norm,
        }
Beispiel #22
0
def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True
Beispiel #23
0
    def train(self,
              train,
              dev,
              test,
              buckets=32,
              batch_size=5000,
              lr=2e-3,
              mu=.9,
              nu=.9,
              epsilon=1e-12,
              clip=5.0,
              decay=.75,
              decay_steps=5000,
              epochs=5000,
              patience=100,
              verbose=True,
              **kwargs):
        args = self.args.update(locals())
        init_logger(logger, verbose=args.verbose)

        self.transform.train()
        if dist.is_initialized():
            args.batch_size = args.batch_size // dist.get_world_size()
        logger.info("Loading the data")
        train = Dataset(self.transform, args.train, **args)
        dev = Dataset(self.transform, args.dev)
        test = Dataset(self.transform, args.test)
        logger.info("Building the datasets")
        train.build(args.batch_size, args.buckets, True, dist.is_initialized())
        logger.info("train built")
        dev.build(args.batch_size, args.buckets)
        logger.info("dev built")
        test.build(args.batch_size, args.buckets)
        logger.info(
            f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n")

        logger.info(f"{self.model}\n")
        if dist.is_initialized():
            self.model = DDP(self.model,
                             device_ids=[args.local_rank],
                             find_unused_parameters=True)
        self.optimizer = Adam(self.model.parameters(), args.lr,
                              (args.mu, args.nu), args.epsilon)
        self.scheduler = ExponentialLR(self.optimizer,
                                       args.decay**(1 / args.decay_steps))

        elapsed = timedelta()
        best_e, best_metric = 1, Metric()

        for epoch in range(1, args.epochs + 1):
            start = datetime.now()

            logger.info(f"Epoch {epoch} / {args.epochs}:")
            self._train(train.loader)
            loss, dev_metric = self._evaluate(dev.loader)
            logger.info(f"{'dev:':6} - loss: {loss:.4f} - {dev_metric}")
            loss, test_metric = self._evaluate(test.loader)
            logger.info(f"{'test:':6} - loss: {loss:.4f} - {test_metric}")

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric:
                best_e, best_metric = epoch, dev_metric
                if is_master():
                    self.save(args.path)
                logger.info(f"{t}s elapsed (saved)\n")
            else:
                logger.info(f"{t}s elapsed\n")
            elapsed += t
            if epoch - best_e >= args.patience:
                break
        loss, metric = self.load(**args)._evaluate(test.loader)

        logger.info(f"Epoch {best_e} saved")
        logger.info(f"{'dev:':6} - {best_metric}")
        logger.info(f"{'test:':6} - {metric}")
        logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch")
Beispiel #24
0
    def setup(self, config):
        self.args = args = config["args"]
        start = time.time()
        self.tokenizer = AutoTokenizer.from_pretrained(
            args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
            cache_dir=args.cache_dir if args.cache_dir else None,
        )
        logger.info(f"tokenizer instantiation time: {time.time() - start}")

        # Load data.
        train_dataset = load_and_cache_examples(
            args, args.task_name, self.tokenizer, evaluate=False
        )
        train_sampler = (
            RandomSampler(train_dataset) if not dist.is_initialized() else None
        )
        train_loader = DataLoader(
            train_dataset,
            sampler=train_sampler,
            batch_size=args.per_device_train_batch_size,
        )

        # Create model.
        with FileLock(os.path.expanduser("~/.download.lock")):
            processor = processors[args.task_name]()
            label_list = processor.get_labels()
            num_labels = len(label_list)
            model_config = AutoConfig.from_pretrained(
                args.config_name if args.config_name else args.model_name_or_path,
                num_labels=num_labels,
                finetuning_task=args.task_name,
                cache_dir=args.cache_dir if args.cache_dir else None,
            )
            model = AutoModelForSequenceClassification.from_pretrained(
                args.model_name_or_path,
                from_tf=bool(".ckpt" in args.model_name_or_path),
                config=model_config,
                cache_dir=args.cache_dir if args.cache_dir else None,
            )

        # Create optimizer.
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": args.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]

        optimizer = AdamW(
            optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon
        )

        # Register components.
        self.model, self.optimizer = self.register(
            models=model,
            optimizers=optimizer,
            apex_args={"opt_level": args.fp16_opt_level},
        )

        self.register_data(train_loader=train_loader, validation_loader=None)

        self.train_data_len = len(self.train_loader)
        self._warmup_scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=self.calculate_t_total(),
        )
        self._global_step = 0

        announce_training(args, self.train_data_len, self.calculate_t_total())
Beispiel #25
0
 def forward(ctx, x):
     if (dist.is_available() and dist.is_initialized()
             and (dist.get_world_size() > 1)):
         x = x.contiguous() / dist.get_world_size()
         dist.all_reduce(x)
     return x
 def barrier(self, name: str = None):
     if torch_distrib.is_initialized():
         torch_distrib.barrier()
Beispiel #27
0
    def __init__(
        self,
        device: torch.device,
        max_epochs: int,
        data_loader: Union[Iterable, DataLoader],
        epoch_length: Optional[int] = None,
        non_blocking: bool = False,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        post_transform: Optional[Callable] = None,
        key_metric: Optional[Dict[str, Metric]] = None,
        additional_metrics: Optional[Dict[str, Metric]] = None,
        handlers: Optional[Sequence] = None,
        amp: bool = False,
        event_names: Optional[List[Union[str, EventEnum]]] = None,
        event_to_attr: Optional[dict] = None,
    ) -> None:
        if iteration_update is not None:
            super().__init__(iteration_update)
        else:
            super().__init__(self._iteration)
        if not isinstance(device, torch.device):
            raise TypeError(f"device must be a torch.device but is {type(device).__name__}.")

        if isinstance(data_loader, DataLoader):
            sampler = data_loader.__dict__["sampler"]
            if isinstance(sampler, DistributedSampler):

                @self.on(Events.EPOCH_STARTED)
                def set_sampler_epoch(engine: Engine):
                    sampler.set_epoch(engine.state.epoch)

            if epoch_length is None:
                epoch_length = len(data_loader)
        else:
            if epoch_length is None:
                raise ValueError("if data_loader is not PyTorch DataLoader, must specify the epoch_length.")

        # set all sharable data for the workflow based on Ignite engine.state
        self.state = State(
            rank=dist.get_rank() if dist.is_available() and dist.is_initialized() else 0,
            seed=0,
            iteration=0,
            epoch=0,
            max_epochs=max_epochs,
            epoch_length=epoch_length,
            output=None,
            batch=None,
            metrics={},
            metric_details={},
            dataloader=None,
            device=device,
            key_metric_name=None,  # we can set many metrics, only use key_metric to compare and save the best model
            best_metric=-1,
            best_metric_epoch=-1,
        )
        self.data_loader = data_loader
        self.non_blocking = non_blocking
        self.prepare_batch = prepare_batch
        self.amp = amp

        event_names = [IterationEvents] if event_names is None else event_names + [IterationEvents]
        for name in event_names:
            if isinstance(name, str):
                self.register_events(name, event_to_attr=event_to_attr)
            elif issubclass(name, EventEnum):
                self.register_events(*name, event_to_attr=event_to_attr)
            else:
                raise ValueError("event_names must be a list or string or EventEnum.")

        if post_transform is not None:
            self._register_post_transforms(post_transform)
        if key_metric is not None:
            self._register_metrics(key_metric, additional_metrics)
        if handlers is not None:
            self._register_handlers(handlers)