示例#1
0
文件: main.py 项目: Afchis/STEm-Seg-1
    def __init__(self, cfg, model_save_dir, args, logger):
        self.num_gpus = dist_utils.get_world_size()
        self.local_rank = dist_utils.get_rank()
        self.local_device = dist_utils.get_device()
        self.is_main_process = dist_utils.is_main_process()

        self.console_logger = logger

        self.model_save_dir = model_save_dir
        self.log_dir = os.path.join(self.model_save_dir, 'logs')

        if self.is_main_process:
            os.makedirs(self.log_dir, exist_ok=True)

        self.model = build_model(restore_pretrained_backbone_wts=True, logger=self.console_logger).to(self.local_device)

        # create optimizer
        self.optimizer = create_optimizer(self.model, cfg, self.console_logger.info)

        # wrap model and optimizer around apex if mixed precision training is enabled
        if cfg.MIXED_PRECISION:
            assert APEX_IMPORTED
            self.console_logger.info("Mixed precision training is enabled.")
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=cfg.MIXED_PRECISION_OPT_LEVEL)

        if dist_utils.is_distributed():
            self.model = nn.parallel.DistributedDataParallel(
                self.model, device_ids=[self.local_rank], output_device=self.local_rank,
                find_unused_parameters=cfg.FREEZE_BACKBONE
            )

        self.total_iterations = cfg.MAX_ITERATIONS

        # create LR scheduler
        self.lr_scheduler = create_lr_scheduler(self.optimizer, cfg, self.console_logger.info)

        # create parameter logger
        self.logger = None
        if self.is_main_process:
            self.logger = TrainingLogger(self.log_dir)

        self.interrupt_detector = InterruptDetector()
        self.cfg = cfg

        self.elapsed_iterations = 0

        assert not (args.restore_session and args.initial_ckpt)

        if args.restore_session:
            self.console_logger.info("Restoring session from {}".format(args.restore_session))
            self.restore_session(torch.load(args.restore_session, map_location=self.local_device))
        elif args.initial_ckpt:
            self.console_logger.info("Loading model weights from checkpoint at: {}".format(args.initial_ckpt))
            self._model.load_state_dict(torch.load(args.initial_ckpt, map_location=self.local_device)['model'])
示例#2
0
def create_training_data_loader(dataset, batch_size, shuffle, collate_fn=None, num_workers=0, elapsed_iters=0):
    is_distributed = dist_utils.is_distributed()
    if is_distributed:
        sampler = CustomDistributedSampler(dataset, dist_utils.get_world_size(), dist_utils.get_rank(), shuffle)
    elif shuffle:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)

    batch_sampler = BatchSampler(sampler, batch_size, drop_last=False)
    if elapsed_iters > 0:
        print("Elapsed iters: {}".format(elapsed_iters))
        batch_sampler = IterationBasedBatchSampler(batch_sampler, int(len(dataset) / batch_size), elapsed_iters)

    return DataLoader(dataset,
                      collate_fn=collate_fn,
                      batch_sampler=batch_sampler,
                      num_workers=num_workers)
示例#3
0
文件: main.py 项目: Afchis/STEm-Seg-1
def create_logger(args):
    logger = logging.getLogger("MaskTCNNTrainLogger")
    if dist_utils.is_main_process():
        logger.setLevel(args.log_level)
    else:
        logger.setLevel(args.subprocess_log_level)

    ch = logging.StreamHandler()
    formatter = logging.Formatter("[%(proc_id)d] %(asctime)s - %(levelname)s - %(message)s", "%H:%M:%S")
    extra = {"proc_id": dist_utils.get_rank()}
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    logger.propagate = False

    logger = logging.LoggerAdapter(logger, extra)
    logger.propagate = False

    return logger
示例#4
0
def start(args, cfg):
    # suppress Python warnings from sub-processes to prevent duplicate warnings being printed to console
    if dist_utils.get_rank() > 0:
        warnings.filterwarnings("ignore")

    logger = create_logger(args)
    model_save_dir = os.path.join(ModelPaths.checkpoint_base_dir(), cfg.MODE,
                                  args.model_dir)

    if dist_utils.is_main_process():
        os.makedirs(model_save_dir, exist_ok=True)

    # check if a checkpoint already exists in the model save directory. If it does, and the 'no_resume' flag is not set,
    # training should resume from the last pre-existing checkpoint.
    existing_ckpts = sorted(glob(os.path.join(model_save_dir, "*.pth")))
    if existing_ckpts and not args.no_resume:
        args.restore_session = existing_ckpts[-1]
        args.initial_ckpt = None  # when jobs auto-restart on the cluster, this might be set,
        # however we want to use the latest checkpoint instead

    # backup config to model directory
    if dist_utils.is_main_process():
        with open(os.path.join(model_save_dir, 'config.yaml'),
                  'w') as writefile:
            yaml.dump(global_cfg.d(), writefile)

    trainer = Trainer(cfg, model_save_dir, args, logger)

    try:
        trainer.start(args)
    except InterruptException as _:
        if dist_utils.is_main_process():
            print("Interrupt signal received. Saving checkpoint...")
            trainer.backup_session()
            dist_utils.synchronize()
        exit(1)
    except Exception as err:
        if dist_utils.is_main_process():
            print("Exception occurred. Saving checkpoint...")
            print(err)
            trainer.backup_session()
            if dist_utils.is_distributed():
                dist.destroy_process_group()
        raise err
示例#5
0
    def forward(self, embedding_map, targets, output_dict, *args, **kwargs):
        """
        Computes the embedding loss.
        :param embedding_map: Tensor of shape [N, C, T, H, W] (C = embedding dims + variance dims + seediness dims)
        :param targets: List (length N) of dicts, each containing a 'masks' field containing a tensor of
        shape (I (instances), T, H, W)
        :param output_dict: dict to populate with loss values.
        :return: Scalar loss
        """
        assert embedding_map.shape[
            1] == self.num_input_channels, "Expected {} channels in input tensor, got {}".format(
                self.num_input_channels, embedding_map.shape[1])

        embedding_map = embedding_map.permute(0, 2, 3, 4, 1)  # [N, T, H, W, C]

        embedding_map, bandwidth_map, seediness_map = embedding_map.split(
            self.split_sizes, dim=-1)
        assert bandwidth_map.shape[-1] + self.n_free_dims == embedding_map.shape[-1], \
            "Number of predicted bandwidth dims {} + number of free dims {} should equal number of total embedding " \
            "dims {}".format(bandwidth_map.shape[-1], self.n_free_dims, embedding_map.shape[-1])

        total_instances = 0.
        lovasz_loss = 0.
        seediness_loss = 0.
        bandwidth_smoothness_loss = 0.

        torch_zero = torch.tensor(0).to(embedding_map).requires_grad_(False)

        for idx, (embeddings_per_seq, bandwidth_per_seq, seediness_per_seq, targets_per_seq) in \
                enumerate(zip(embedding_map, bandwidth_map, seediness_map, targets)):

            masks = targets_per_seq['masks']
            if masks.numel() == 0:
                continue

            ignore_masks = targets_per_seq['ignore_masks']

            assert masks.shape[-2:] == ignore_masks.shape[-2:], \
                "Masks tensor has shape {} while ignore mask has shape {}".format(masks.shape, ignore_masks.shape)

            assert masks.shape[-2:] == embedding_map.shape[2:4], \
                "Masks tensor has shape {} while embedding map has shape {}".format(masks.shape, embedding_map.shape)

            nonzero_mask_pts = masks.nonzero(as_tuple=False)
            if nonzero_mask_pts.shape[0] == 0:
                print("[ WARN] No valid mask points exist in sample.")
                continue

            _, instance_pt_counts = nonzero_mask_pts[:, 0].unique(
                sorted=True, return_counts=True)
            instance_id_sort_idx = nonzero_mask_pts[:, 0].argsort()
            nonzero_mask_pts = nonzero_mask_pts[instance_id_sort_idx]
            nonzero_mask_pts = nonzero_mask_pts.split(
                tuple(instance_pt_counts.tolist()))
            nonzero_mask_pts = tuple([
                nonzero_mask_pts[i].unbind(1)[1:]
                for i in range(len(nonzero_mask_pts))
            ])

            instance_embeddings = [
                embeddings_per_seq[nonzero_mask_pts[n]]
                for n in range(len(nonzero_mask_pts))
            ]  # list(tensor[I, E])

            instance_bandwidths = [
                bandwidth_per_seq[nonzero_mask_pts[n]]
                for n in range(len(nonzero_mask_pts))
            ]  # list(tensor[I, E])

            instance_seediness = [
                seediness_per_seq[nonzero_mask_pts[n]]
                for n in range(len(nonzero_mask_pts))
            ]  # list(tensor[I, E])

            total_instances += len(nonzero_mask_pts)

            # regress seediness values for background to 0
            bg_mask_pts = (masks == 0).all(0).nonzero(as_tuple=False).unbind(1)
            bg_seediness_pts = seediness_per_seq[bg_mask_pts]
            bg_seediness_loss = F.mse_loss(bg_seediness_pts,
                                           torch.zeros_like(bg_seediness_pts),
                                           reduction='none')

            # ignore loss for ignore mask points
            ignore_mask_pts = ignore_masks[bg_mask_pts].unsqueeze(1)
            seediness_loss = seediness_loss + torch.where(
                ignore_mask_pts, torch_zero, bg_seediness_loss).mean()

            # compute bandwidth smoothness loss before applying activation
            bandwidth_smoothness_loss = bandwidth_smoothness_loss + self.compute_bandwidth_smoothness_loss(
                instance_bandwidths)

            # apply activation to bandwidths
            instance_bandwidths = [
                bandwidth_per_instance.exp() * 10.
                for bandwidth_per_instance in instance_bandwidths
            ]

            for n in range(len(nonzero_mask_pts)):  # iterate over instances
                probs_map = self.compute_prob_map(embeddings_per_seq,
                                                  instance_embeddings[n],
                                                  instance_bandwidths[n])
                logits_map = (probs_map * 2.) - 1.
                instance_target = masks[n].flatten()
                if instance_target.sum(dtype=torch.long) == 0:
                    continue

                lovasz_loss = lovasz_loss + self.lovasz_hinge_loss(
                    logits_map.flatten(), instance_target)
                instance_probs = probs_map.unsqueeze(3)[
                    nonzero_mask_pts[n]].detach()
                seediness_loss = seediness_loss + F.mse_loss(
                    instance_seediness[n], instance_probs, reduction='mean')

        if total_instances == 0:
            print("Process {}: Zero instances case occurred embedding loss".
                  format(dist_utils.get_rank()))
            lovasz_loss = (bandwidth_map.sum() + embedding_map.sum()) * 0
            bandwidth_smoothness_loss = bandwidth_map.sum() * 0
            seediness_loss = seediness_map.sum() * 0
        else:
            # compute weighted sum of lovasz and variance losses based on number of instances per batch sample
            lovasz_loss = lovasz_loss / total_instances
            bandwidth_smoothness_loss = bandwidth_smoothness_loss / embedding_map.shape[
                0]  # divide by batch size
            seediness_loss = seediness_loss / float(total_instances + 1)

        total_loss = (lovasz_loss * self.w_lovasz) + \
                     (bandwidth_smoothness_loss * self.w_variance_smoothness) + \
                     (seediness_loss * self.w_seediness)

        output_dict[ModelOutputConsts.OPTIMIZATION_LOSSES] = {
            LossConsts.EMBEDDING: total_loss * self.w
        }

        output_dict[ModelOutputConsts.OTHERS] = {
            LossConsts.LOVASZ_LOSS: lovasz_loss,
            LossConsts.VARIANCE_SMOOTHNESS: bandwidth_smoothness_loss,
        }

        output_dict[ModelOutputConsts.OTHERS][
            LossConsts.SEEDINESS_LOSS] = seediness_loss