def __init__(self, cfg, model_save_dir, args, logger): self.num_gpus = dist_utils.get_world_size() self.local_rank = dist_utils.get_rank() self.local_device = dist_utils.get_device() self.is_main_process = dist_utils.is_main_process() self.console_logger = logger self.model_save_dir = model_save_dir self.log_dir = os.path.join(self.model_save_dir, 'logs') if self.is_main_process: os.makedirs(self.log_dir, exist_ok=True) self.model = build_model(restore_pretrained_backbone_wts=True, logger=self.console_logger).to(self.local_device) # create optimizer self.optimizer = create_optimizer(self.model, cfg, self.console_logger.info) # wrap model and optimizer around apex if mixed precision training is enabled if cfg.MIXED_PRECISION: assert APEX_IMPORTED self.console_logger.info("Mixed precision training is enabled.") self.model, self.optimizer = amp.initialize( self.model, self.optimizer, opt_level=cfg.MIXED_PRECISION_OPT_LEVEL) if dist_utils.is_distributed(): self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=cfg.FREEZE_BACKBONE ) self.total_iterations = cfg.MAX_ITERATIONS # create LR scheduler self.lr_scheduler = create_lr_scheduler(self.optimizer, cfg, self.console_logger.info) # create parameter logger self.logger = None if self.is_main_process: self.logger = TrainingLogger(self.log_dir) self.interrupt_detector = InterruptDetector() self.cfg = cfg self.elapsed_iterations = 0 assert not (args.restore_session and args.initial_ckpt) if args.restore_session: self.console_logger.info("Restoring session from {}".format(args.restore_session)) self.restore_session(torch.load(args.restore_session, map_location=self.local_device)) elif args.initial_ckpt: self.console_logger.info("Loading model weights from checkpoint at: {}".format(args.initial_ckpt)) self._model.load_state_dict(torch.load(args.initial_ckpt, map_location=self.local_device)['model'])
def create_training_data_loader(dataset, batch_size, shuffle, collate_fn=None, num_workers=0, elapsed_iters=0): is_distributed = dist_utils.is_distributed() if is_distributed: sampler = CustomDistributedSampler(dataset, dist_utils.get_world_size(), dist_utils.get_rank(), shuffle) elif shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last=False) if elapsed_iters > 0: print("Elapsed iters: {}".format(elapsed_iters)) batch_sampler = IterationBasedBatchSampler(batch_sampler, int(len(dataset) / batch_size), elapsed_iters) return DataLoader(dataset, collate_fn=collate_fn, batch_sampler=batch_sampler, num_workers=num_workers)
def create_logger(args): logger = logging.getLogger("MaskTCNNTrainLogger") if dist_utils.is_main_process(): logger.setLevel(args.log_level) else: logger.setLevel(args.subprocess_log_level) ch = logging.StreamHandler() formatter = logging.Formatter("[%(proc_id)d] %(asctime)s - %(levelname)s - %(message)s", "%H:%M:%S") extra = {"proc_id": dist_utils.get_rank()} ch.setFormatter(formatter) logger.addHandler(ch) logger.propagate = False logger = logging.LoggerAdapter(logger, extra) logger.propagate = False return logger
def start(args, cfg): # suppress Python warnings from sub-processes to prevent duplicate warnings being printed to console if dist_utils.get_rank() > 0: warnings.filterwarnings("ignore") logger = create_logger(args) model_save_dir = os.path.join(ModelPaths.checkpoint_base_dir(), cfg.MODE, args.model_dir) if dist_utils.is_main_process(): os.makedirs(model_save_dir, exist_ok=True) # check if a checkpoint already exists in the model save directory. If it does, and the 'no_resume' flag is not set, # training should resume from the last pre-existing checkpoint. existing_ckpts = sorted(glob(os.path.join(model_save_dir, "*.pth"))) if existing_ckpts and not args.no_resume: args.restore_session = existing_ckpts[-1] args.initial_ckpt = None # when jobs auto-restart on the cluster, this might be set, # however we want to use the latest checkpoint instead # backup config to model directory if dist_utils.is_main_process(): with open(os.path.join(model_save_dir, 'config.yaml'), 'w') as writefile: yaml.dump(global_cfg.d(), writefile) trainer = Trainer(cfg, model_save_dir, args, logger) try: trainer.start(args) except InterruptException as _: if dist_utils.is_main_process(): print("Interrupt signal received. Saving checkpoint...") trainer.backup_session() dist_utils.synchronize() exit(1) except Exception as err: if dist_utils.is_main_process(): print("Exception occurred. Saving checkpoint...") print(err) trainer.backup_session() if dist_utils.is_distributed(): dist.destroy_process_group() raise err
def forward(self, embedding_map, targets, output_dict, *args, **kwargs): """ Computes the embedding loss. :param embedding_map: Tensor of shape [N, C, T, H, W] (C = embedding dims + variance dims + seediness dims) :param targets: List (length N) of dicts, each containing a 'masks' field containing a tensor of shape (I (instances), T, H, W) :param output_dict: dict to populate with loss values. :return: Scalar loss """ assert embedding_map.shape[ 1] == self.num_input_channels, "Expected {} channels in input tensor, got {}".format( self.num_input_channels, embedding_map.shape[1]) embedding_map = embedding_map.permute(0, 2, 3, 4, 1) # [N, T, H, W, C] embedding_map, bandwidth_map, seediness_map = embedding_map.split( self.split_sizes, dim=-1) assert bandwidth_map.shape[-1] + self.n_free_dims == embedding_map.shape[-1], \ "Number of predicted bandwidth dims {} + number of free dims {} should equal number of total embedding " \ "dims {}".format(bandwidth_map.shape[-1], self.n_free_dims, embedding_map.shape[-1]) total_instances = 0. lovasz_loss = 0. seediness_loss = 0. bandwidth_smoothness_loss = 0. torch_zero = torch.tensor(0).to(embedding_map).requires_grad_(False) for idx, (embeddings_per_seq, bandwidth_per_seq, seediness_per_seq, targets_per_seq) in \ enumerate(zip(embedding_map, bandwidth_map, seediness_map, targets)): masks = targets_per_seq['masks'] if masks.numel() == 0: continue ignore_masks = targets_per_seq['ignore_masks'] assert masks.shape[-2:] == ignore_masks.shape[-2:], \ "Masks tensor has shape {} while ignore mask has shape {}".format(masks.shape, ignore_masks.shape) assert masks.shape[-2:] == embedding_map.shape[2:4], \ "Masks tensor has shape {} while embedding map has shape {}".format(masks.shape, embedding_map.shape) nonzero_mask_pts = masks.nonzero(as_tuple=False) if nonzero_mask_pts.shape[0] == 0: print("[ WARN] No valid mask points exist in sample.") continue _, instance_pt_counts = nonzero_mask_pts[:, 0].unique( sorted=True, return_counts=True) instance_id_sort_idx = nonzero_mask_pts[:, 0].argsort() nonzero_mask_pts = nonzero_mask_pts[instance_id_sort_idx] nonzero_mask_pts = nonzero_mask_pts.split( tuple(instance_pt_counts.tolist())) nonzero_mask_pts = tuple([ nonzero_mask_pts[i].unbind(1)[1:] for i in range(len(nonzero_mask_pts)) ]) instance_embeddings = [ embeddings_per_seq[nonzero_mask_pts[n]] for n in range(len(nonzero_mask_pts)) ] # list(tensor[I, E]) instance_bandwidths = [ bandwidth_per_seq[nonzero_mask_pts[n]] for n in range(len(nonzero_mask_pts)) ] # list(tensor[I, E]) instance_seediness = [ seediness_per_seq[nonzero_mask_pts[n]] for n in range(len(nonzero_mask_pts)) ] # list(tensor[I, E]) total_instances += len(nonzero_mask_pts) # regress seediness values for background to 0 bg_mask_pts = (masks == 0).all(0).nonzero(as_tuple=False).unbind(1) bg_seediness_pts = seediness_per_seq[bg_mask_pts] bg_seediness_loss = F.mse_loss(bg_seediness_pts, torch.zeros_like(bg_seediness_pts), reduction='none') # ignore loss for ignore mask points ignore_mask_pts = ignore_masks[bg_mask_pts].unsqueeze(1) seediness_loss = seediness_loss + torch.where( ignore_mask_pts, torch_zero, bg_seediness_loss).mean() # compute bandwidth smoothness loss before applying activation bandwidth_smoothness_loss = bandwidth_smoothness_loss + self.compute_bandwidth_smoothness_loss( instance_bandwidths) # apply activation to bandwidths instance_bandwidths = [ bandwidth_per_instance.exp() * 10. for bandwidth_per_instance in instance_bandwidths ] for n in range(len(nonzero_mask_pts)): # iterate over instances probs_map = self.compute_prob_map(embeddings_per_seq, instance_embeddings[n], instance_bandwidths[n]) logits_map = (probs_map * 2.) - 1. instance_target = masks[n].flatten() if instance_target.sum(dtype=torch.long) == 0: continue lovasz_loss = lovasz_loss + self.lovasz_hinge_loss( logits_map.flatten(), instance_target) instance_probs = probs_map.unsqueeze(3)[ nonzero_mask_pts[n]].detach() seediness_loss = seediness_loss + F.mse_loss( instance_seediness[n], instance_probs, reduction='mean') if total_instances == 0: print("Process {}: Zero instances case occurred embedding loss". format(dist_utils.get_rank())) lovasz_loss = (bandwidth_map.sum() + embedding_map.sum()) * 0 bandwidth_smoothness_loss = bandwidth_map.sum() * 0 seediness_loss = seediness_map.sum() * 0 else: # compute weighted sum of lovasz and variance losses based on number of instances per batch sample lovasz_loss = lovasz_loss / total_instances bandwidth_smoothness_loss = bandwidth_smoothness_loss / embedding_map.shape[ 0] # divide by batch size seediness_loss = seediness_loss / float(total_instances + 1) total_loss = (lovasz_loss * self.w_lovasz) + \ (bandwidth_smoothness_loss * self.w_variance_smoothness) + \ (seediness_loss * self.w_seediness) output_dict[ModelOutputConsts.OPTIMIZATION_LOSSES] = { LossConsts.EMBEDDING: total_loss * self.w } output_dict[ModelOutputConsts.OTHERS] = { LossConsts.LOVASZ_LOSS: lovasz_loss, LossConsts.VARIANCE_SMOOTHNESS: bandwidth_smoothness_loss, } output_dict[ModelOutputConsts.OTHERS][ LossConsts.SEEDINESS_LOSS] = seediness_loss