def __init__(self, dataset, samples_per_gpu=1, num_replicas=None, rank=None): """ Args: dataset (Dataset): Dataset used for sampling. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. """ _rank = comm.get_rank() _num_replicas = comm.get_world_size() if num_replicas is None: num_replicas = _num_replicas if rank is None: rank = _rank self.dataset = dataset self.samples_per_gpu = samples_per_gpu self.num_replicas = num_replicas self.rank = rank self.epoch = 0 assert hasattr(self.dataset, 'aspect_ratios') self.aspect_ratios = self.dataset.aspect_ratios self.group_sizes = np.bincount(self.aspect_ratios) self.num_samples = 0 for i, j in enumerate(self.group_sizes): self.num_samples += int( math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / self.num_replicas)) * self.samples_per_gpu self.total_size = self.num_samples * self.num_replicas
def __init__(self, dataset, repeat_thresh, shuffle=True, seed=None): """ Args: dataset (Dataset): dataset used for sampling. repeat_thresh (float): frequency threshold below which data is repeated. shuffle (bool): whether to shuffle the indices or not. seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() dataset_dicts = [] if hasattr(dataset, "datasets"): for d in dataset.datasets: dataset_dicts += d.dataset_dicts else: dataset_dicts = dataset.dataset_dicts # Get fractional repeat factors and split into whole number (_int_part) # and fractional (_frac_part) parts. rep_factors = self._get_repeat_factors(dataset_dicts, repeat_thresh) self._int_part = torch.trunc(rep_factors) self._frac_part = rep_factors - self._int_part
def build_detection_train_loader(cfg): """ A data loader is created by the following steps: 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts. 2. Start workers to work on the dicts. Each worker will: * Map each metadata dict into another format to be consumed by the model. * Batch them by simply putting dicts into a list. The batched ``list[mapped_dict]`` is what this dataloader will return. Args: cfg (CfgNode): the config Returns: an infinite iterator of training data """ # For simulate large batch training num_devices = comm.get_world_size() rank = comm.get_rank() # use subdivision batchsize images_per_minibatch = cfg.SOLVER.IMS_PER_DEVICE // cfg.SOLVER.BATCH_SUBDIVISIONS logger = logging.getLogger(__name__) transform_gens = build_transform_gen(cfg.INPUT.AUG.TRAIN_PIPELINES) logger.info(f"TransformGens used: {transform_gens} in training") dataset = build_dataset(cfg, cfg.DATASETS.TRAIN, transforms=transform_gens, is_train=True) sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger.info("Using training sampler {}".format(sampler_name)) assert sampler_name in SAMPLERS, "{} not found in SAMPLERS".format( sampler_name) if sampler_name == "TrainingSampler": sampler = SAMPLERS.get(sampler_name)(len(dataset)) elif sampler_name == "RepeatFactorTrainingSampler": sampler = SAMPLERS.get(sampler_name)(dataset, cfg.DATALOADER.REPEAT_THRESHOLD) elif sampler_name == "DistributedGroupSampler": sampler = SAMPLERS.get(sampler_name)(dataset, images_per_minibatch, num_devices, rank) data_loader = torch.utils.data.DataLoader( dataset, batch_size=images_per_minibatch, sampler=sampler, num_workers=cfg.DATALOADER.NUM_WORKERS, collate_fn=trivial_batch_collator, worker_init_fn=worker_init_reset_seed, ) adjust_epoch_and_iter(cfg, data_loader) return data_loader
def setup(args): cfg = get_cfg() cfg.merge_from_file(args.config_file) cfg.SOLVER.BASE_LR = 0.001 # Avoid NaNs. Not useful in this script anyway. cfg.merge_from_list(args.opts) cfg.freeze() setup_logger(distributed_rank=comm.get_rank()) return cfg
def forward(self, z_i, z_j): device_size = z_i.shape[0] batch_size = device_size * comm.get_world_size() local_rank = comm.get_rank() neg_perm = torch.randperm(batch_size - 1)[:self.K] if comm.get_world_size() > 1: group = comm._get_global_gloo_group() zi_large = [ torch.zeros_like(z_i) for _ in range(comm.get_world_size()) ] zj_large = [ torch.zeros_like(z_j) for _ in range(comm.get_world_size()) ] dist.all_gather(zi_large, z_i, group=group) dist.all_gather(zj_large, z_j, group=group) choices = [ torch.zeros_like(neg_perm, dtype=torch.int64) for _ in range(comm.get_world_size()) ] dist.all_gather(choices, neg_perm, group=group) neg_perm = choices[0] else: zi_large = [z_i] zj_large = [z_j] zi_large[local_rank] = z_i zi_large = torch.cat(zi_large) zj_large = torch.cat(zj_large) sim_i_large = self.similarity_f( zi_large.unsqueeze(1), zj_large.unsqueeze(0)) / self.temperature positive_samples_i = sim_i_large[self.pos_mask_i].reshape( batch_size, 1) negative_samples_i = sim_i_large[self.neg_mask_i].reshape( batch_size, -1)[:, neg_perm] labels_i = torch.zeros(batch_size).to(self.device).long() logits_i = torch.cat((positive_samples_i, negative_samples_i), dim=1) # EqCo loss_i = torch.log( torch.exp(positive_samples_i) + # self.alpha / negative_samples_i.shape[1] * # uncomment this when negatives != bs torch.exp(negative_samples_i).sum(dim=-1, keepdim=True) ) - positive_samples_i loss_i = loss_i.sum() / device_size acc1, acc5 = accuracy(logits_i, labels_i, topk=(1, 5)) return loss_i, acc1, acc5
def default_setup(cfg, args): """ Perform some basic common setups at the beginning of a job, including: 1. Set up the cvpods logger 2. Log basic information about environment, cmdline arguments, and config 3. Backup the config to the output directory Args: cfg (BaseConfig): the full config to be used args (argparse.NameSpace): the command line arguments to be logged """ output_dir = cfg.OUTPUT_DIR if comm.is_main_process() and output_dir: PathManager.mkdirs(output_dir) rank = comm.get_rank() # setup_logger(output_dir, distributed_rank=rank, name="cvpods") logger = setup_logger(output_dir, distributed_rank=rank) logger.info("Rank of current process: {}. World size: {}".format( rank, comm.get_world_size())) logger.info("Environment info:\n" + collect_env_info()) logger.info("Command line arguments: " + str(args)) if hasattr(args, "config_file") and args.config_file != "": logger.info("Contents of args.config_file={}:\n{}".format( args.config_file, PathManager.open(args.config_file, "r").read())) adjust_config(cfg) logger.info("Running with full config:\n{}".format(cfg)) base_config = cfg.__class__.__base__() logger.info("different config with base class:\n{}".format( cfg.diff(base_config))) # if comm.is_main_process() and output_dir: # # Note: some of our scripts may expect the existence of # # config.yaml in output directory # path = os.path.join(output_dir, "config.yaml") # with PathManager.open(path, "w") as f: # f.write(cfg.dump()) # logger.info("Full config saved to {}".format(os.path.abspath(path))) # make sure each worker has a different, yet deterministic seed if specified seed = seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + rank) # save seed to config for dump cfg.SEED = seed # cudnn benchmark has large overhead. It shouldn't be used considering the small size of # typical validation set. if not (hasattr(args, "eval_only") and args.eval_only): torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK return cfg, logger
def __init__(self, size: int): """ Args: size (int): the total number of data of the underlying dataset to sample from """ self._size = size assert size > 0 self._rank = comm.get_rank() self._world_size = comm.get_world_size() shard_size = (self._size - 1) // self._world_size + 1 begin = shard_size * self._rank end = min(shard_size * (self._rank + 1), self._size) self._local_indices = range(begin, end)
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): """ Args: size (int): the total number of data of the underlying dataset to sample from shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._size = size assert size > 0 self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def default_setup(cfg, args): """ Perform some basic common setups at the beginning of a job, including: 1. Set up the cvpods logger 2. Log basic information about environment, cmdline arguments, and config 3. Backup the config to the output directory Args: cfg (BaseConfig): the full config to be used args (argparse.NameSpace): the command line arguments to be logged """ output_dir = cfg.OUTPUT_DIR if comm.is_main_process() and output_dir: ensure_dir(output_dir) rank = comm.get_rank() # setup_logger(output_dir, distributed_rank=rank, name="cvpods") setup_logger(output_dir, distributed_rank=rank) logger.info("Rank of current process: {}. World size: {}".format( rank, comm.get_world_size())) logger.info("Environment info:\n" + collect_env_info()) logger.info("Command line arguments: " + str(args)) if hasattr(args, "config_file") and args.config_file != "": logger.info("Contents of args.config_file={}:\n{}".format( args.config_file, megfile.smart_open(args.config_file, "r").read())) adjust_config(cfg) # make sure each worker has a different, yet deterministic seed if specified seed = seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + rank) # save seed to config for dump cfg.SEED = seed # cudnn benchmark has large overhead. It shouldn't be used considering the small size of # typical validation set. if not (hasattr(args, "eval_only") and args.eval_only): torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK return cfg
def get_evaluator(cfg, dataset_name, output_folder=None): """ Create evaluator(s) for a given dataset. This uses the special metadata "evaluator_type" associated with each builtin dataset. For your own dataset, you can simply create an evaluator manually in your script and do not have to worry about the hacky if-else logic here. """ if output_folder is None: output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") evaluator_list = [] evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: evaluator_list.append( SemSegEvaluator( dataset_name, distributed=True, num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, output_dir=output_folder, )) if evaluator_type in ["coco", "coco_panoptic_seg"]: evaluator_list.append( COCOEvaluator(dataset_name, cfg, True, output_folder)) if evaluator_type == "coco_panoptic_seg": evaluator_list.append( COCOPanopticEvaluator(dataset_name, output_folder)) if evaluator_type == "cityscapes": assert ( torch.cuda.device_count() >= comm.get_rank() ), "CityscapesEvaluator currently do not work with multiple machines." return CityscapesEvaluator(dataset_name) if evaluator_type == "pascal_voc": return PascalVOCDetectionEvaluator(dataset_name) if evaluator_type == "lvis": return LVISEvaluator(dataset_name, cfg, True, output_folder) if len(evaluator_list) == 0: raise NotImplementedError( "no Evaluator for the dataset {} with the type {}".format( dataset_name, evaluator_type)) if len(evaluator_list) == 1: return evaluator_list[0] return DatasetEvaluators(evaluator_list)
def build_detection_train_loader(cfg): """ A data loader is created by the following steps: 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts. 2. Start workers to work on the dicts. Each worker will: * Map each metadata dict into another format to be consumed by the model. * Batch them by simply putting dicts into a list. The batched ``list[mapped_dict]`` is what this dataloader will return. Args: cfg (CfgNode): the config Returns: an infinite iterator of training data """ num_workers = comm.get_world_size() rank = comm.get_rank() images_per_batch = cfg.SOLVER.IMS_PER_BATCH # Adjust batchsize according to BATCH_SUBDIVISIONS images_per_batch //= cfg.SOLVER.BATCH_SUBDIVISIONS assert ( images_per_batch % num_workers == 0 ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format( images_per_batch, num_workers) assert ( images_per_batch >= num_workers ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format( images_per_batch, num_workers) images_per_worker = images_per_batch // num_workers logger = logging.getLogger(__name__) transform_gens = build_transform_gen(cfg.INPUT.AUG.TRAIN_PIPELINES) logger.info(f"TransformGens used: {transform_gens} in training") dataset = build_dataset(cfg, cfg.DATASETS.TRAIN, transforms=transform_gens, is_train=True) sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = SAMPLERS.get("TrainingSampler")(len(dataset)) elif sampler_name == "RepeatFactorTrainingSampler": sampler = SAMPLERS.get("RepeatFactorTrainingSampler")( dataset, cfg.DATALOADER.REPEAT_THRESHOLD) elif sampler_name == "DistributedGroupSampler": sampler = SAMPLERS.get("DistributedGroupSampler")(dataset, images_per_worker, num_workers, rank) else: raise ValueError("Unknown training sampler: {}".format(sampler_name)) data_loader = torch.utils.data.DataLoader( dataset, batch_size=images_per_worker, sampler=sampler, num_workers=cfg.DATALOADER.NUM_WORKERS, collate_fn=trivial_batch_collator, worker_init_fn=worker_init_reset_seed, ) return data_loader
def build_evaluator(cls, cfg, dataset_name, dataset, output_folder=None): """ Create evaluator(s) for a given dataset. This uses the special metadata "evaluator_type" associated with each builtin dataset. For your own dataset, you can simply create an evaluator manually in your script and do not have to worry about the hacky if-else logic here. """ if output_folder is None: output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") dump = config.GLOBAL.DUMP_TRAIN evaluator_list = [] meta = dataset.meta evaluator_type = meta.evaluator_type if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: evaluator_list.append( SemSegEvaluator( dataset_name, dataset, distributed=True, num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, output_dir=output_folder, dump=dump, )) if evaluator_type in ["coco", "coco_panoptic_seg", "citypersons"]: evaluator_list.append( COCOEvaluator(dataset_name, meta, cfg, True, output_folder, dump)) if evaluator_type == "coco_panoptic_seg": evaluator_list.append( COCOPanopticEvaluator(dataset_name, meta, output_folder, dump)) elif evaluator_type == "cityscapes": assert ( torch.cuda.device_count() >= comm.get_rank() ), "CityscapesEvaluator currently do not work with multiple machines." return CityscapesEvaluator(dataset_name, meta, dump) elif evaluator_type == "pascal_voc": return PascalVOCDetectionEvaluator(dataset_name, meta, dump) elif evaluator_type == "lvis": return LVISEvaluator(dataset_name, meta, cfg, True, output_folder, dump) elif evaluator_type == "citypersons": evaluator_list.append( CityPersonsEvaluator(dataset_name, meta, cfg, True, output_folder, dump)) elif evaluator_type == "widerface": return WiderFaceEvaluator(dataset_name, meta, cfg, True, output_folder, dump) if evaluator_type == "classification": return ClassificationEvaluator(dataset_name, meta, cfg, True, output_folder, dump) if hasattr(cfg, "EVALUATORS"): for evaluator in cfg.EVALUATORS: evaluator_list.append( evaluator(dataset_name, meta, True, output_folder, dump=True)) if len(evaluator_list) == 0: raise NotImplementedError( "no Evaluator for the dataset {} with the type {}".format( dataset_name, evaluator_type)) elif len(evaluator_list) == 1: return evaluator_list[0] return DatasetEvaluators(evaluator_list)
def default_setup(cfg, args): """ Perform some basic common setups at the beginning of a job, including: 1. Set up the cvpods logger 2. Log basic information about environment, cmdline arguments, and config 3. Backup the config to the output directory Args: cfg (BaseConfig): the full config to be used args (argparse.NameSpace): the command line arguments to be logged """ output_dir = cfg.OUTPUT_DIR if comm.is_main_process() and output_dir: PathManager.mkdirs(output_dir) rank = comm.get_rank() # setup_logger(output_dir, distributed_rank=rank, name="cvpods") logger = setup_logger(output_dir, distributed_rank=rank) logger.info("Rank of current process: {}. World size: {}".format( rank, comm.get_world_size())) logger.info("Environment info:\n" + collect_env_info()) logger.info("Command line arguments: " + str(args)) if hasattr(args, "config_file") and args.config_file != "": logger.info("Contents of args.config_file={}:\n{}".format( args.config_file, PathManager.open(args.config_file, "r").read())) logger.info("Running with full config:\n{}".format(cfg)) base_config = cfg.__class__.__base__() logger.info("different config with base class:\n{}".format( cfg.show_diff(base_config))) # if comm.is_main_process() and output_dir: # # Note: some of our scripts may expect the existence of # # config.yaml in output directory # path = os.path.join(output_dir, "config.yaml") # with PathManager.open(path, "w") as f: # f.write(cfg.dump()) # logger.info("Full config saved to {}".format(os.path.abspath(path))) # make sure each worker has a different, yet deterministic seed if specified seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + rank) # cudnn benchmark has large overhead. It shouldn't be used considering the small size of # typical validation set. if not (hasattr(args, "eval_only") and args.eval_only): torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK # dynamic adjust batch_size, steps according to world size base_world_size = int(cfg.SOLVER.IMS_PER_BATCH / cfg.SOLVER.IMS_PER_DEVICE) world_size = comm.get_world_size() ratio = world_size / base_world_size cfg.SOLVER.IMS_PER_BATCH = int(ratio * cfg.SOLVER.IMS_PER_BATCH) cfg.SOLVER.LR_SCHEDULER.MAX_ITER = int(cfg.SOLVER.LR_SCHEDULER.MAX_ITER / ratio) # Divided by scale ratio when using iterations rather than epochs if cfg.SOLVER.LR_SCHEDULER.MAX_EPOCH is None: cfg.SOLVER.LR_SCHEDULER.STEPS = list( (int(step / ratio) for step in cfg.SOLVER.LR_SCHEDULER.STEPS)) cfg.SOLVER.CHECKPOINT_PERIOD = int(cfg.SOLVER.CHECKPOINT_PERIOD / ratio) cfg.TEST.EVAL_PERIOD = int(cfg.TEST.EVAL_PERIOD / ratio) cfg.SOLVER.OPTIMIZER.BASE_LR = ratio * cfg.SOLVER.OPTIMIZER.BASE_LR assert cfg.SOLVER.IMS_PER_BATCH / cfg.SOLVER.IMS_PER_DEVICE == world_size return cfg, logger
def forward(self, z_i, z_j): local_rank = comm.get_rank() if comm.get_world_size() > 1: group = comm._get_global_gloo_group() zi_large = [ torch.zeros_like(z_i) for _ in range(comm.get_world_size()) ] zj_large = [ torch.zeros_like(z_j) for _ in range(comm.get_world_size()) ] dist.all_gather(zi_large, z_i, group=group) dist.all_gather(zj_large, z_j, group=group) else: zi_large = [z_i] zj_large = [z_j] z_large = [] for idx in range(comm.get_world_size()): if idx == local_rank: # current device z_large.append(z_i) z_large.append(z_j) else: z_large.append(zi_large[idx]) z_large.append(zj_large[idx]) zi_large[local_rank] = z_i zj_large[local_rank] = z_j zi_large = torch.cat(zi_large) zj_large = torch.cat(zj_large) device_size = z_i.shape[0] batch_size = device_size * comm.get_world_size() z_large = torch.cat(z_large) sim_i_large = self.similarity_f( zi_large.unsqueeze(1), z_large.unsqueeze(0)) / self.temperature sim_j_large = self.similarity_f( zj_large.unsqueeze(1), z_large.unsqueeze(0)) / self.temperature positive_samples_i = sim_i_large[self.pos_mask_i].reshape( batch_size, 1) negative_samples_i = sim_i_large[self.neg_mask_i].reshape( batch_size, -1) positive_samples_j = sim_j_large[self.pos_mask_j].reshape( batch_size, 1) negative_samples_j = sim_j_large[self.neg_mask_j].reshape( batch_size, -1) labels_i = torch.zeros(batch_size).to(self.device).long() logits_i = torch.cat((positive_samples_i, negative_samples_i), dim=1) labels_j = torch.zeros(batch_size).to(self.device).long() logits_j = torch.cat((positive_samples_j, negative_samples_j), dim=1) loss_i = self.criterion(logits_i, labels_i) loss_j = self.criterion(logits_j, labels_j) loss_i /= device_size loss_j /= device_size acc1, acc5 = accuracy(logits_i, labels_i, topk=(1, 5)) return loss_i, loss_j, acc1, acc5
def evaluate(self): """ Returns: dict: has a key "segm", whose value is a dict of "AP" and "AP50". """ comm.synchronize() if comm.get_rank() > 0: return os.environ["CITYSCAPES_DATASET"] = os.path.abspath( os.path.join(self._metadata.gt_dir, "..", "..") ) # Load the Cityscapes eval script *after* setting the required env var, # since the script reads CITYSCAPES_DATASET into global variables at load time. import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as cityscapes_eval self._logger.info("Evaluating results under {} ...".format(self._temp_dir)) # set some global states in cityscapes evaluation API, before evaluating cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir) cityscapes_eval.args.predictionWalk = None cityscapes_eval.args.JSONOutput = False cityscapes_eval.args.colorized = False cityscapes_eval.args.gtInstancesFile = os.path.join(self._temp_dir, "gtInstances.json") # These lines are adopted from # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa groundTruthImgList = glob.glob(cityscapes_eval.args.groundTruthSearch) assert len( groundTruthImgList ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format( cityscapes_eval.args.groundTruthSearch ) predictionImgList = [] for gt in groundTruthImgList: predictionImgList.append(cityscapes_eval.getPrediction(gt, cityscapes_eval.args)) results = cityscapes_eval.evaluateImgLists( predictionImgList, groundTruthImgList, cityscapes_eval.args )["averages"] ret = OrderedDict() ret["segm"] = {"AP": results["allAp"] * 100, "AP50": results["allAp50%"] * 100} self._working_dir.cleanup() small_table = create_small_table(ret["segm"]) self._logger.info("Evaluation results for segm: \n" + small_table) results_per_category = [] for cat, ap in results["classes"].items(): ap = [ap_i * 100 for ap_i in ap.values()] results_per_category.append([cat, *ap]) table = tabulate( results_per_category, headers=["category", "AP", "AP50"], tablefmt="pipe", floatfmt=".3f", numalign="left" ) self._logger.info("Per-category segm AP: \n" + table) if self._dump: dump_info_one_task = { "task": "segm", "tables": [small_table, table], } _dump_to_markdown([dump_info_one_task]) return ret
def run_step(self): """ Implement the standard training logic described above. """ assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" """ If you need accumulate gradients or something similar, you can wrap the optimizer with your custom `zero_grad()` method. """ self.optimizer.zero_grad() """ If your want to do something with the data, you can wrap the dataloader. """ start = time.perf_counter() data_time_sum = 0. loss_dict_summary = {} # for each mini step for division_iter in range(self.batch_subdivisions): start = time.perf_counter() try: data = next(self._data_loader_iter) except StopIteration: # start new epoch self.epoch += 1 self.data_loader.sampler.set_epoch(self.epoch) self._data_loader_iter = iter(self.data_loader) data = next(self._data_loader_iter) data_time = time.perf_counter() - start data_time_sum += data_time """ If your want to do something with the losses, you can wrap the model. """ try: loss_dict = self.model(data) for metrics_name, metrics_value in loss_dict.items(): # Actually, some metrics are not loss, such as # top1_acc, top5_acc in classification, filter them out if metrics_value.requires_grad: loss_dict[metrics_name] = metrics_value losses = sum([ metrics_value for metrics_value in loss_dict.values() if metrics_value.requires_grad ]) / self.batch_subdivisions self._detect_anomaly(losses, loss_dict) # only in last subdivision iter, DDP needs to backward with sync if (division_iter != self.batch_subdivisions - 1 and isinstance(self.model, DistributedDataParallel)): with self.model.no_sync(): losses.backward() else: losses.backward() except Exception: ckpt = Checkpointer(self.model, save_dir="./log", save_to_disk=True, optimizer=self.optimizer) ckpt.save("debug_ckpt_rank{}".format(comm.get_rank()), tag_checkpoint=False, inputs=data) raise # The values in dict: `loss_dict` can be divided into two cases: # * case 1. value.requires_grad = True, this values is loss, need to be summed # * case 2. value.requires_grad = False, like top1_acc, top5_acc in classification ... # use the last mini_step value as the current iter value. for metrics_name, metrics_value in loss_dict.items(): if metrics_name not in loss_dict_summary: loss_dict_summary[metrics_name] = metrics_value elif metrics_value.requires_grad: loss_dict_summary[ metrics_name] += metrics_value # Sum the loss else: loss_dict_summary[ metrics_name] = metrics_value # Update other metrics metrics_dict = {"data_time": data_time_sum} metrics_dict.update(loss_dict_summary) self._write_metrics(metrics_dict) """ If you need gradient clipping/scaling or other processing, you can wrap the optimizer with your custom `step()` method. """ self.optimizer.step()
def forward(self, z_i, z_j): """ We do not sample negative examples explicitly. Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. """ local_rank = comm.get_rank() if comm.get_world_size() > 1: group = comm._get_global_gloo_group() zi_large = [torch.zeros_like(z_i) for _ in range(comm.get_world_size())] zj_large = [torch.zeros_like(z_j) for _ in range(comm.get_world_size())] dist.all_gather(zi_large, z_i, group=group) dist.all_gather(zj_large, z_j, group=group) else: zi_large = [z_i] zj_large = [z_j] z_large = [] for idx in range(comm.get_world_size()): if idx == local_rank: # current device z_large.append(z_i) z_large.append(z_j) else: z_large.append(zi_large[idx]) z_large.append(zj_large[idx]) zi_large[local_rank] = z_i zj_large[local_rank] = z_j zi_large = torch.cat(zi_large) zj_large = torch.cat(zj_large) device_size = z_i.shape[0] batch_size = device_size * comm.get_world_size() z_large = torch.cat(z_large) sim_i_large = self.similarity_f(zi_large.unsqueeze(1), z_large.unsqueeze(0)) / self.temperature sim_j_large = self.similarity_f(zj_large.unsqueeze(1), z_large.unsqueeze(0)) / self.temperature positive_samples_i = sim_i_large[self.pos_mask_i].reshape(batch_size, 1) negative_samples_i = sim_i_large[self.neg_mask_i].reshape(batch_size, -1) r = (positive_samples_i.exp() / negative_samples_i.exp().sum(dim=1, keepdim=True)).mean() if local_rank == 0: print("SimQK to SimQN: ", r) positive_samples_j = sim_j_large[self.pos_mask_j].reshape(batch_size, 1) negative_samples_j = sim_j_large[self.neg_mask_j].reshape(batch_size, -1) labels_i = torch.zeros(batch_size).to(self.device).long() logits_i = torch.cat((positive_samples_i, negative_samples_i), dim=1) labels_j = torch.zeros(batch_size).to(self.device).long() logits_j = torch.cat((positive_samples_j, negative_samples_j), dim=1) loss_i = self.criterion(logits_i, labels_i) loss_j = self.criterion(logits_j, labels_j) loss_i /= device_size loss_j /= device_size return loss_i, loss_j