def fnTrain( loader: DataLoader, device: str, model: nn.Module, optimizer: Optimizer, fnLoss, scaler: GradScaler, ) -> float: runningLoss = 0 for _, (data, targets) in enumerate(loader): data = data.to(device=device) targets = targets.float().unsqueeze(1).to(device=device) with torch.cuda.amp.autocast(): predictions = model(data) loss = fnLoss(predictions, targets) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # print(f"batch {idxBatch+ 1} loss {loss.item()}") runningLoss += loss.item() return runningLoss / len(loader)
class Fp16OptimizerHook(OptimizerHook): def __init__(self, grad_clip=None, grad_scaler_config=None): super().__init__(grad_clip) self._grad_scaler_config = grad_scaler_config self._scaler = None def before_train(self): if self._grad_scaler_config is None: self._scaler = GradScaler() else: self._scaler = GradScaler(**self._grad_scaler_config) def after_train_iter(self): loss = self.trainer.output[ 'loss'] / self.trainer.gradient_accumulation_steps self._scaler.scale(loss).backward() if self._grad_clip is not None: self._scaler.unscale_(self.trainer.optimizer) self._clip_grad_norm() if (self.trainer.iter + 1) % self.trainer.gradient_accumulation_steps == 0: self._scaler.step(self.trainer.optimizer) self._scaler.update()
class UDA_Baseline_Trainer(TrainerBase): """ load a model pretrained on the source domain, neglect outliers during training on the target domain """ def __init__(self, cfg): super().__init__() logger = logging.getLogger("fastreid") if not logger.isEnabledFor( logging.INFO): # if setup_logger is not called for fastreid setup_logger() logger.info("==> Load target-domain dataset") self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT) self.tgt_nums = len(tgt.train) cfg = self.auto_scale_hyperparams(cfg, self.tgt_nums) # Create model self.model = self.build_model(cfg, load_model=cfg.MODEL.PRETRAIN, show_model=True, use_dsbn=False) # Optimizer self.optimizer, self.param_wrapper = self.build_optimizer( cfg, self.model) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # for part of the parameters is not updated. self.model = DistributedDataParallel( self.model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True) # Learning rate scheduler self.iters_per_epoch = cfg.SOLVER.ITERS self.scheduler = self.build_lr_scheduler(cfg, self.optimizer, self.iters_per_epoch) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = Checkpointer( # Assume you want to save checkpoints together with logs/statistics self.model, cfg.OUTPUT_DIR, save_to_disk=comm.is_main_process(), optimizer=self.optimizer, **self.scheduler, ) self.start_epoch = 0 self.max_epoch = cfg.SOLVER.MAX_EPOCH self.max_iter = self.max_epoch * self.iters_per_epoch self.warmup_iters = cfg.SOLVER.WARMUP_ITERS self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS self.cfg = cfg self.register_hooks(self.build_hooks()) if cfg.SOLVER.AMP.ENABLED: unsupported = "AMPTrainer does not support single-process multi-device training!" if isinstance(self.model, DistributedDataParallel): assert not (self.model.device_ids and len(self.model.device_ids) > 1), unsupported from torch.cuda.amp.grad_scaler import GradScaler self.grad_scaler = GradScaler() else: self.grad_scaler = None def train(self): """ Run training. Returns: OrderedDict of results, if evaluation is enabled. Otherwise None. """ super().train(self.start_epoch, self.max_epoch, self.iters_per_epoch) if comm.is_main_process(): assert hasattr(self, "_last_eval_results" ), "No evaluation results obtained during training!" return self._last_eval_results def before_train(self): self.model.train() if self.cfg.SOLVER.AMP.ENABLED: assert torch.cuda.is_available( ), "CUDA is required for AMP training!" return super().before_train() def before_epoch(self): logger = logging.getLogger('fastreid') # Calculate distance logger.info("==> Create pseudo labels for unlabeled target domain") with inference_context(self.model), torch.no_grad(): tgt_train = self.build_dataset(self.cfg, self.tgt.train, is_train=False, relabel=False, with_mem_idx=False) tgt_init_feat_loader = self.build_test_loader(self.cfg, tgt_train) tgt_fname_feat_dict, _ = extract_features(self.model, tgt_init_feat_loader) tgt_features = torch.cat([ tgt_fname_feat_dict[f].unsqueeze(0) for f, _, _ in sorted(self.tgt.train) ], 0) tgt_features = F.normalize(tgt_features, dim=1) rerank_dist = compute_jaccard_distance(tgt_features, k1=self.cfg.CLUSTER.JACCARD.K1, k2=self.cfg.CLUSTER.JACCARD.K2) if self.epoch == 0: if self.cfg.CLUSTER.DBSCAN.ADAPTIVE_EPS: logger.info("==> Calculating eps according to rerank_dist...") tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2 tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1 tri_mat = np.sort(tri_mat, axis=None) top_num = np.round(self.cfg.SOLVER.RHO * tri_mat.size).astype(int) self.eps = tri_mat[:top_num].mean() logger.info(f"==> epoch {self.epoch} eps: {self.eps}") else: self.eps = self.cfg.CLUSTER.DBSCAN.EPS self.cluster = DBSCAN(eps=self.eps, min_samples=4, metric="precomputed", n_jobs=-1) # select & cluster images as training set of this epochs logger.info(f"Clustering and labeling...") pseudo_labels = self.cluster.fit_predict(rerank_dist) self.num_clusters = num_clusters = len( set(pseudo_labels)) - (1 if -1 in pseudo_labels else 0) num_outliers = pseudo_labels[pseudo_labels == -1].shape[0] # pseudo_labels = self.generate_pseudo_labels(pseudo_labels, num_clusters) # pseudo_labels = self.assign_outlier(pseudo_labels, tgt_features) del tgt_features pseudo_labeled_dataset = [] cluster_centers = collections.defaultdict(list) for i, ((fname, _, cid), label) in enumerate(zip(sorted(self.tgt.train), pseudo_labels)): if label != -1: pseudo_labeled_dataset.append((fname, label, cid)) cluster_centers[label].append(tgt_fname_feat_dict[fname]) del tgt_fname_feat_dict, rerank_dist cluster_centers = [ torch.stack(cluster_centers[idx]).mean(0) for idx in sorted(cluster_centers.keys()) ] cluster_centers = torch.stack(cluster_centers) if isinstance(self.model, DistributedDataParallel): self.model.module.heads.weight.data[:num_clusters].copy_( F.normalize(cluster_centers, dim=1).float().cuda()) else: self.model.heads.weight.data[:num_clusters].copy_( F.normalize(cluster_centers, dim=1).float().cuda()) # statistics of clusters and un-clustered instances # index2label = collections.defaultdict(int) # for label in pseudo_labels: # index2label[label.item()] += 1 # print(f'cluster_label', min(cluster_label), max(cluster_label), len(cluster_label)) # print(f'outlier label', min(outlier_label), max(outlier_label), len(outlier_label)) # index2label = np.fromiter(index2label.values(), dtype=float) logger.info( "==> Statistics for epoch {}: {} clusters, {} un-clustered instances" .format(self.epoch, num_clusters, num_outliers)) pseudo_tgt_train = self.build_dataset( self.cfg, pseudo_labeled_dataset, is_train=True, relabel=True, # relabel? # relabel=False, with_mem_idx=False) self.pseudo_tgt_train_loader = self.build_train_loader( self.cfg, train_set=pseudo_tgt_train, sampler=RandomMultipleGallerySampler( pseudo_tgt_train.img_items, self.cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(), self.cfg.DATALOADER.NUM_INSTANCE), with_mem_idx=False) return super().before_epoch() # assign outlier to its nearest neighbor def assign_outlier(self, pseudo_labels, tgt_features): outlier_mask = pseudo_labels == -1 cluster_mask = pseudo_labels != -1 outlier_feat = tgt_features[outlier_mask] cluster_feat = tgt_features[cluster_mask] dist = torch.cdist(outlier_feat, cluster_feat) # 计算特征间L2距离 min_dist_idx = dist.argmin(-1).cpu() pseudo_labels[outlier_mask] = pseudo_labels[cluster_mask][min_dist_idx] return pseudo_labels def run_step(self): assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!" if self.cfg.SOLVER.AMP.ENABLED: assert torch.cuda.is_available( ), f"[{self.__class__.__name__}] CUDA is required for AMP training!" from torch.cuda.amp.autocast_mode import autocast start = time.perf_counter() # load data tgt_inputs = self.pseudo_tgt_train_loader.next() def _parse_data(inputs): imgs, _, pids, _ = inputs return imgs.cuda(), pids.cuda() # process inputs t_inputs, t_targets = _parse_data(tgt_inputs) data_time = time.perf_counter() - start def _forward(): outputs = self.model(t_inputs) f_out_t = outputs['features'] p_out_t = outputs['pred_class_logits'][:, :self.num_clusters] loss_dict = {} loss_ce = cross_entropy_loss(pred_class_outputs=p_out_t, gt_classes=t_targets, eps=self.cfg.MODEL.LOSSES.CE.EPSILON, alpha=self.cfg.MODEL.LOSSES.CE.ALPHA) loss_dict.update({'loss_ce': loss_ce}) if 'TripletLoss' in self.cfg.MODEL.LOSSES.NAME: loss_tri = triplet_loss(f_out_t, t_targets, margin=0.0, norm_feat=True, hard_mining=False) loss_dict.update({'loss_tri': loss_tri}) return loss_dict if self.cfg.SOLVER.AMP.ENABLED: with autocast(): loss_dict = _forward() losses = sum(loss_dict.values()) self.optimizer.zero_grad() self.grad_scaler.scale(losses).backward() self._write_metrics(loss_dict, data_time) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: loss_dict = _forward() losses = sum(loss_dict.values()) self.optimizer.zero_grad() losses.backward() self._write_metrics(loss_dict, data_time) self.optimizer.step() if isinstance(self.param_wrapper, ContiguousParams): self.param_wrapper.assert_buffer_is_valid() @classmethod def load_dataset(cls, name): logger = logging.getLogger(__name__) logger.info(f"Preparing {name}") _root = os.getenv("FASTREID_DATASETS", "/root/datasets") data = DATASET_REGISTRY.get(name)(root=_root) if comm.is_main_process(): data.show_train() return data @classmethod def build_dataset(cls, cfg, img_items, is_train=False, relabel=False, transforms=None, with_mem_idx=False): if transforms is None: transforms = build_transforms(cfg, is_train=is_train) if with_mem_idx: sorted_img_items = sorted(img_items) for i in range(len(sorted_img_items)): sorted_img_items[i] += (i, ) return InMemoryDataset(sorted_img_items, transforms, relabel) else: return CommDataset(img_items, transforms, relabel) @classmethod def build_train_loader(cls, cfg, train_set=None, sampler=None, with_mem_idx=False): logger = logging.getLogger('fastreid') logger.info("Prepare training loader") total_batch_size = cfg.SOLVER.IMS_PER_BATCH mini_batch_size = total_batch_size // comm.get_world_size() if sampler is None: num_instance = cfg.DATALOADER.NUM_INSTANCE sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = samplers.TrainingSampler(len(train_set)) elif sampler_name == "NaiveIdentitySampler": sampler = samplers.NaiveIdentitySampler( train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "BalancedIdentitySampler": sampler = samplers.BalancedIdentitySampler( train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "SetReWeightSampler": set_weight = cfg.DATALOADER.SET_WEIGHT sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight) elif sampler_name == "ImbalancedDatasetSampler": sampler = samplers.ImbalancedDatasetSampler( train_set.img_items) else: raise ValueError( "Unknown training sampler: {}".format(sampler_name)) iters = cfg.SOLVER.ITERS num_workers = cfg.DATALOADER.NUM_WORKERS batch_sampler = BatchSampler(sampler, mini_batch_size, True) train_loader = IterLoader( DataLoader( Preprocessor(train_set, with_mem_idx), num_workers=num_workers, batch_sampler=batch_sampler, pin_memory=True, ), length=iters, ) # train_loader = DataLoaderX( # comm.get_local_rank(), # dataset=Preprocessor(train_set, with_mem_idx), # num_workers=num_workers, # batch_sampler=batch_sampler, # collate_fn=fast_batch_collator, # pin_memory=True, # ) return train_loader @classmethod def build_test_loader(cls, cfg, test_set): logger = logging.getLogger('fastreid') logger.info("Prepare testing loader") # test_loader = DataLoader( # # Preprocessor(test_set), # test_set, # batch_size=cfg.TEST.IMS_PER_BATCH, # num_workers=cfg.DATALOADER.NUM_WORKERS, # shuffle=False, # pin_memory=True, # ) test_batch_size = cfg.TEST.IMS_PER_BATCH mini_batch_size = test_batch_size // comm.get_world_size() num_workers = cfg.DATALOADER.NUM_WORKERS data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoaderX( comm.get_local_rank(), dataset=test_set, batch_sampler=batch_sampler, num_workers=num_workers, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader @classmethod def build_evaluator(cls, cfg, dataset_name, output_dir=None): data_loader, num_query = build_reid_test_loader( cfg, dataset_name=dataset_name) return data_loader, ReidEvaluator(cfg, num_query, output_dir) @classmethod def test(cls, cfg, model): """ Args: cfg (CfgNode): model (nn.Module): Returns: dict: a dict of result metrics """ logger = logging.getLogger('fastreid') results = OrderedDict() dataset_name = cfg.DATASETS.TGT logger.info("Prepare testing set") try: data_loader, evaluator = cls.build_evaluator(cfg, dataset_name) except NotImplementedError: logger.warn( "No evaluator found. implement its `build_evaluator` method.") results[dataset_name] = {} results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP.ENABLED) results[dataset_name] = results_i if comm.is_main_process(): assert isinstance( results, dict ), "Evaluator must return a dict on the main process. Got {} instead.".format( results) logger.info("Evaluation results for {} in csv format:".format( dataset_name)) results_i['dataset'] = dataset_name print_csv_format(results_i) # if len(results) == 1: # results = list(results.values())[0] return results @classmethod def build_model(cls, cfg, load_model=True, show_model=True, use_dsbn=False): cfg = cfg.clone() # cfg can be modified by model cfg.defrost() cfg.MODEL.DEVICE = "cpu" model = build_model(cfg) logger = logging.getLogger('fastreid') if load_model: pretrain_path = cfg.MODEL.PRETRAIN_PATH try: state_dict = torch.load( pretrain_path, map_location=torch.device("cpu"))['model'] for layer in cfg.MODEL.IGNORE_LAYERS: if layer in state_dict.keys(): del state_dict[layer] logger.info(f"Loading pretrained model from {pretrain_path}") except FileNotFoundError as e: logger.info( f"{pretrain_path} is not found! Please check this path.") raise e except KeyError as e: logger.info( "State dict keys error! Please check the state dict.") raise e incompatible = model.load_state_dict(state_dict, strict=False) if incompatible.missing_keys: logger.info( get_missing_parameters_message(incompatible.missing_keys)) if incompatible.unexpected_keys: logger.info( get_unexpected_parameters_message( incompatible.unexpected_keys)) if use_dsbn: logger.info("==> Convert BN to Domain Specific BN") convert_dsbn(model) if show_model: logger.info("Model:\n{}".format(model)) model.to(torch.device("cuda")) return model @staticmethod def auto_scale_hyperparams(cfg, num_classes): r""" This is used for auto-computation actual training iterations, because some hyper-param, such as MAX_ITER, means training epochs rather than iters, so we need to convert specific hyper-param to training iterations. """ cfg = cfg.clone() frozen = cfg.is_frozen() cfg.defrost() # If you don't hard-code the number of classes, it will compute the number automatically if cfg.MODEL.HEADS.NUM_CLASSES == 0: output_dir = cfg.OUTPUT_DIR cfg.MODEL.HEADS.NUM_CLASSES = num_classes logger = logging.getLogger('fastreid') logger.info( f"Auto-scaling the num_classes={cfg.MODEL.HEADS.NUM_CLASSES}") # Update the saved config file to make the number of classes valid if comm.is_main_process() and output_dir: # Note: some of our scripts may expect the existence of # config.yaml in output directory path = os.path.join(output_dir, "config.yaml") with PathManager.open(path, "w") as f: f.write(cfg.dump()) if frozen: cfg.freeze() return cfg @classmethod def build_optimizer(cls, cfg, model): """ Returns: torch.optim.Optimizer: It now calls :func:`fastreid.solver.build_optimizer`. Overwrite it if you'd like a different optimizer. """ return build_optimizer(cfg, model) @classmethod def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch): """ It now calls :func:`fastreid.solver.build_lr_scheduler`. Overwrite it if you'd like a different scheduler. """ return build_lr_scheduler(cfg, optimizer, iters_per_epoch) def build_hooks(self): """ Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events. Returns: list[HookBase]: """ logger = logging.getLogger(__name__) cfg = self.cfg.clone() cfg.defrost() cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET ]) # set dataset name for PreciseBN ret = [ hooks.IterationTimer(), hooks.LRScheduler(self.optimizer, self.scheduler), ] # if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model): # logger.info("Prepare precise BN dataset") # ret.append(hooks.PreciseBN( # # Run at the same freq as (but before) evaluation. # self.model, # # Build a new data loader to not affect training # self.build_train_loader(cfg), # cfg.TEST.PRECISE_BN.NUM_ITER, # )) if len(cfg.MODEL.FREEZE_LAYERS) > 0 and cfg.SOLVER.FREEZE_ITERS > 0: ret.append( hooks.LayerFreeze( self.model, cfg.MODEL.FREEZE_LAYERS, cfg.SOLVER.FREEZE_ITERS, )) # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, # some checkpoints may have more precise statistics than others. def test_and_save_results(): self._last_eval_results = self.test(self.cfg, self.model) return self._last_eval_results # Do evaluation before checkpointer, because then if it fails, # we can use the saved checkpoint to debug. ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) if comm.is_main_process(): ret.append( hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) # run writers in the end, so that evaluation metrics are written ret.append(hooks.PeriodicWriter(self.build_writers(), 200)) return ret def build_writers(self): """ Build a list of writers to be used. By default it contains writers that write metrics to the screen, a json file, and a tensorboard event file respectively. If you'd like a different list of writers, you can overwrite it in your trainer. Returns: list[EventWriter]: a list of :class:`EventWriter` objects. It is now implemented by: .. code-block:: python return [ CommonMetricPrinter(self.max_iter), JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), TensorboardXWriter(self.cfg.OUTPUT_DIR), ] """ # Assume the default print/log frequency. # TODO: customize my writers return [ # It may not always print what you want to see, since it prints "common" metrics only. CommonMetricPrinter(self.max_iter), JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), TensorboardXWriter(self.cfg.OUTPUT_DIR), ] def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float): """ Args: loss_dict (dict): dict of scalar losses data_time (float): time taken by the dataloader iteration """ device = next(iter(loss_dict.values())).device # Use a new stream so these ops don't wait for DDP or backward with torch.cuda.stream(torch.cuda.Stream() if device.type == "cuda" else None): metrics_dict = { k: v.detach().cpu().item() for k, v in loss_dict.items() } metrics_dict["data_time"] = data_time # Gather metrics among all workers for logging # This assumes we do DDP-style training, which is currently the only # supported method in detectron2. all_metrics_dict = comm.gather(metrics_dict) if comm.is_main_process(): storage = get_event_storage() # data_time among workers can have high variance. The actual latency # caused by data_time is the maximum among workers. data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) storage.put_scalar("data_time", data_time) # average the rest metrics metrics_dict = { k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys() } total_losses_reduced = sum(metrics_dict.values()) if not np.isfinite(total_losses_reduced): raise FloatingPointError( f"Loss became infinite or NaN at iteration={self.iter}!\n" f"loss_dict = {metrics_dict}") storage.put_scalar("total_loss", total_losses_reduced) if len(metrics_dict) > 1: storage.put_scalars(**metrics_dict)
class Trainer: def __init__(self, cfg): self.cfg = cfg self.paths = cfg['paths'] self.net_params = cfg['net'] self.train_params = cfg['train'] self.trans_params = cfg['train']['transforms'] self.checkpoints = self.paths['checkpoints'] Path(self.checkpoints).mkdir(parents=True, exist_ok=True) shutil.copyfile('config.yaml', f'{self.checkpoints}/config.yaml') self.update_interval = self.paths['update_interval'] # amp training self.use_amp = self.train_params['mixed_precision'] self.scaler = GradScaler() if self.use_amp else None # data setup dataset_name = self.train_params['dataset'] self.use_multi = dataset_name == 'multi' print(f'Using dataset: {dataset_name}') self.train_dataset = get_pedestrian_dataset( dataset_name, self.paths, augment=get_train_transforms(self.trans_params), mode='train', multi_datasets=self.train_params['multi_datasets'] if self.use_multi else None) print(f'Train dataset: {len(self.train_dataset)} samples') self.val_dataset = get_pedestrian_dataset( dataset_name, self.paths, augment=get_val_transforms(self.trans_params), mode='val', multi_datasets=self.train_params['multi_datasets'] if self.use_multi else None) print(f'Val dataset: {len(self.val_dataset)} samples') tests_data = self.train_params['test_datasets'] self.test_datasets = [ get_pedestrian_dataset(d_name, self.paths, augment=get_test_transforms( self.trans_params), mode='test') for d_name in tests_data ] self.criterion = AnchorFreeLoss(self.train_params) self.writer = Writer(self.paths['log_dir']) print('Tensorboard logs are saved to: {}'.format( self.paths['log_dir'])) self.sched_type = self.train_params['scheduler'] self.scheduler = None self.optimizer = None def save_checkpoints(self, epoch, net): path = osp.join(self.checkpoints, f'Epoch_{epoch}.pth') torch.save( { 'epoch': epoch, 'net_state_dict': net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict() }, path) def train(self): torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True batch_size = self.train_params['batch_size'] self.batch_size = batch_size num_workers = self.train_params['num_workers'] pin_memory = self.train_params['pin_memory'] print('Batch-size = {}'.format(batch_size)) train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=True) val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=False) # net setup print('Preparing net: ') net = get_fpn_net(self.net_params) # train setup lr = self.train_params['lr'] epochs = self.train_params['epochs'] weight_decay = self.train_params['weight_decay'] self.optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-4) if self.net_params['pretrained']: checkpoint = torch.load(self.net_params['pretrained_model'], map_location="cuda") net.load_state_dict(checkpoint['net_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) for p in self.optimizer.param_groups: p['lr'] = lr for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() print('CHECKPOINT LOADED') net.cuda() first_epoch = 0 # scheduler if self.sched_type == 'ocp': last_epoch = -1 if first_epoch == 0 else first_epoch * len( train_loader) self.scheduler = OneCycleLR( self.optimizer, max_lr=lr, epochs=epochs, last_epoch=last_epoch, steps_per_epoch=len(train_loader), pct_start=self.train_params['ocp_params']['max_lr_pct']) elif self.sched_type == 'multi_step': last_epoch = -1 if first_epoch == 0 else first_epoch self.scheduler = MultiStepLR( self.optimizer, milestones=self.train_params['multi_params']['milestones'], gamma=self.train_params['multi_params']['gamma'], last_epoch=last_epoch) #start training net.train() val_rate = self.train_params['val_rate'] test_rate = self.train_params['test_rate'] for epoch in range(first_epoch, epochs): self.train_epoch(net, train_loader, epoch) if self.sched_type != 'ocp': self.writer.log_lr(epoch, self.scheduler.get_last_lr()[0]) self.scheduler.step() if (epoch + 1) % val_rate == 0 or epoch == epochs - 1: self.eval(net, val_loader, epoch * len(train_loader)) if (epoch + 1) % (val_rate * test_rate) == 0 or epoch == epochs - 1: self.test_ap(net, epoch) self.save_checkpoints(epoch, net) def train_epoch(self, net, loader, epoch): net.train() loss_metric = LossMetric(self.cfg) probs = ProbsAverageMeter() for mini_batch_i, read_mini_batch in tqdm(enumerate(loader), desc=f'Epoch {epoch}:', ascii=True, total=len(loader)): data, labels = read_mini_batch data = data.cuda() labels = [label.cuda() for label in labels] with amp.autocast(): out = net(data) loss_dict, hm_probs = self.criterion(out, labels) loss = loss_metric.calculate_loss(loss_dict) self.optimizer.zero_grad() self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() probs.update(hm_probs) if self.sched_type == 'ocp': self.scheduler.step() loss_metric.add_sample(loss_dict) if mini_batch_i % self.update_interval == 0: if self.sched_type == 'ocp': # TODO write average lr self.writer.log_lr(epoch * len(loader) + mini_batch_i, self.scheduler.get_last_lr()[0]) self.writer.log_training(epoch * len(loader) + mini_batch_i, loss_metric) self.writer.log_probs(epoch, probs.get_average()) def eval(self, net, loader, step): net.eval() loss_metric = LossMetric(self.cfg) with torch.no_grad(): for _, read_mini_batch in tqdm(enumerate(loader), desc=f'Val:', ascii=True, total=len(loader)): data, labels = read_mini_batch data = data.cuda() labels = [label.cuda() for label in labels] with amp.autocast(): out = net(data) loss_dict, _ = self.criterion(out, labels) loss_metric.add_sample(loss_dict) self.writer.log_eval(step, loss_metric) def test_ap(self, net, epoch): for dataset in self.test_datasets: ap, _ = test(net, dataset, batch_size=self.batch_size) self.writer.log_ap(epoch, ap, dataset.name())
class SpCL_UDA_Trainer(TrainerBase): """ load an un-pretrained model and train on the source & target domain from scratch """ def __init__(self, cfg): super().__init__() logger = logging.getLogger("fastreid") if not logger.isEnabledFor( logging.INFO): # if setup_logger is not called for fastreid setup_logger() # Create datasets logger.info("==> Load source-domain dataset") self.src = src = self.load_dataset(cfg.DATASETS.SRC) self.src_pid_nums = src.get_num_pids(src.train) logger.info("==> Load target-domain dataset") self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT) self.tgt_nums = len(tgt.train) # Create model self.model = self.build_model(cfg, load_model=False, show_model=False, use_dsbn=True) # Create hybrid memorys self.hm = HybridMemory(num_features=cfg.MODEL.BACKBONE.FEAT_DIM, num_samples=self.src_pid_nums + self.tgt_nums, temp=cfg.MEMORY.TEMP, momentum=cfg.MEMORY.MOMENTUM, use_half=cfg.SOLVER.AMP.ENABLED).cuda() # Initialize source-domain class centroids logger.info( "==> Initialize source-domain class centroids in the hybrid memory" ) with inference_context(self.model), torch.no_grad(): src_train = self.build_dataset(cfg, src.train, is_train=False, relabel=False, with_mem_idx=False) src_init_feat_loader = self.build_test_loader(cfg, src_train) src_fname_feat_dict, _ = extract_features(self.model, src_init_feat_loader) src_feat_dict = collections.defaultdict(list) for f, pid, _ in sorted(src.train): src_feat_dict[pid].append(src_fname_feat_dict[f].unsqueeze(0)) src_centers = [ torch.cat(src_feat_dict[pid], 0).mean(0) for pid in sorted(src_feat_dict.keys()) ] src_centers = torch.stack(src_centers, 0) src_centers = F.normalize(src_centers, dim=1) # Initialize target-domain instance features logger.info( "==> Initialize target-domain instance features in the hybrid memory" ) with inference_context(self.model), torch.no_grad(): tgt_train = self.build_dataset(cfg, tgt.train, is_train=False, relabel=False, with_mem_idx=False) tgt_init_feat_loader = self.build_test_loader(cfg, tgt_train) tgt_fname_feat_dict, _ = extract_features(self.model, tgt_init_feat_loader) tgt_features = torch.cat([ tgt_fname_feat_dict[f].unsqueeze(0) for f, _, _ in sorted(self.tgt.train) ], 0) tgt_features = F.normalize(tgt_features, dim=1) self.hm.features = torch.cat((src_centers, tgt_features), dim=0).cuda() del (src_train, src_init_feat_loader, src_fname_feat_dict, src_feat_dict, src_centers, tgt_train, tgt_init_feat_loader, tgt_fname_feat_dict, tgt_features) # Optimizer self.optimizer, self.param_wrapper = self.build_optimizer( cfg, self.model) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # for part of the parameters is not updated. self.model = DistributedDataParallel( self.model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True) # Learning rate scheduler self.iters_per_epoch = cfg.SOLVER.ITERS self.scheduler = self.build_lr_scheduler(cfg, self.optimizer, self.iters_per_epoch) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = Checkpointer( # Assume you want to save checkpoints together with logs/statistics self.model, cfg.OUTPUT_DIR, save_to_disk=comm.is_main_process(), optimizer=self.optimizer, **self.scheduler, ) self.start_epoch = 0 self.max_epoch = cfg.SOLVER.MAX_EPOCH self.max_iter = self.max_epoch * self.iters_per_epoch self.warmup_iters = cfg.SOLVER.WARMUP_ITERS self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS self.cfg = cfg self.register_hooks(self.build_hooks()) if cfg.SOLVER.AMP.ENABLED: unsupported = "AMPTrainer does not support single-process multi-device training!" if isinstance(self.model, DistributedDataParallel): assert not (self.model.device_ids and len(self.model.device_ids) > 1), unsupported from torch.cuda.amp.grad_scaler import GradScaler self.grad_scaler = GradScaler() else: self.grad_scaler = None def train(self): """ Run training. Returns: OrderedDict of results, if evaluation is enabled. Otherwise None. """ super().train(self.start_epoch, self.max_epoch, self.iters_per_epoch) if comm.is_main_process(): assert hasattr(self, "_last_eval_results" ), "No evaluation results obtained during training!" return self._last_eval_results def before_train(self): self.model.train() if self.cfg.SOLVER.AMP.ENABLED: assert torch.cuda.is_available( ), "CUDA is required for AMP training!" return super().before_train() def before_epoch(self): logger = logging.getLogger('fastreid') # Calculate distance logger.info( "==> Create pseudo labels for unlabeled target domain with self-paced policy" ) tgt_features = self.hm.features[self.src_pid_nums:].clone() rerank_dist = compute_jaccard_distance(tgt_features, k1=self.cfg.CLUSTER.JACCARD.K1, k2=self.cfg.CLUSTER.JACCARD.K2) del tgt_features if self.epoch == 0: if self.cfg.CLUSTER.DBSCAN.ADAPTIVE_EPS: logger.info("==> Calculating eps according to rerank_dist...") tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2 tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1 tri_mat = np.sort(tri_mat, axis=None) top_num = np.round(self.cfg.SOLVER.RHO * tri_mat.size).astype(int) self.eps = tri_mat[:top_num].mean() logger.info(f"==> epoch {self.epoch} eps: {self.eps}") else: self.eps = self.cfg.CLUSTER.DBSCAN.EPS self.eps_tight = self.eps - self.cfg.CLUSTER.DBSCAN.EPS_GAP self.eps_loose = self.eps + self.cfg.CLUSTER.DBSCAN.EPS_GAP self.cluster = DBSCAN(eps=self.eps, min_samples=4, metric="precomputed", n_jobs=-1) self.cluster_tight = DBSCAN(eps=self.eps_tight, min_samples=4, metric="precomputed", n_jobs=-1) self.cluster_loose = DBSCAN(eps=self.eps_loose, min_samples=4, metric="precomputed", n_jobs=-1) # select & cluster images as training set of this epochs pseudo_labels = self.cluster.fit_predict(rerank_dist) pseudo_labels_tight = self.cluster_tight.fit_predict(rerank_dist) pseudo_labels_loose = self.cluster_loose.fit_predict(rerank_dist) num_ids = len(set(pseudo_labels)) - (1 if -1 in pseudo_labels else 0) num_ids_tight = len( set(pseudo_labels_tight)) - (1 if -1 in pseudo_labels_tight else 0) num_ids_loose = len( set(pseudo_labels_loose)) - (1 if -1 in pseudo_labels_loose else 0) pseudo_labels = self.generate_pseudo_labels(pseudo_labels, num_ids) pseudo_labels_tight = self.generate_pseudo_labels( pseudo_labels_tight, num_ids_tight) pseudo_labels_loose = self.generate_pseudo_labels( pseudo_labels_loose, num_ids_loose) # print(pseudo_labels.min(), pseudo_labels.max()) # exit() # compute R_indep and R_comp N = pseudo_labels.size(0) label_sim = (pseudo_labels.expand(N, N).eq( pseudo_labels.expand(N, N).t()).float()) # [N, N] label_sim_tight = (pseudo_labels_tight.expand(N, N).eq( pseudo_labels_tight.expand(N, N).t()).float()) label_sim_loose = (pseudo_labels_loose.expand(N, N).eq( pseudo_labels_loose.expand(N, N).t()).float()) R_comp = 1 - torch.min(label_sim, label_sim_tight).sum(-1) / torch.max( label_sim, label_sim_tight).sum(-1) # [N] R_indep = 1 - torch.min(label_sim, label_sim_loose).sum( -1) / torch.max(label_sim, label_sim_loose).sum(-1) # [N] assert (R_comp.min() >= 0) and (R_comp.max() <= 1) assert (R_indep.min() >= 0) and (R_indep.max() <= 1) cluster_R_comp, cluster_R_indep = ( collections.defaultdict(list), collections.defaultdict(list), ) cluster_img_num = collections.defaultdict(int) for i, (comp, indep, label) in enumerate(zip(R_comp, R_indep, pseudo_labels)): cluster_R_comp[label.item() - self.src_pid_nums].append( comp.item()) cluster_R_indep[label.item() - self.src_pid_nums].append( indep.item()) cluster_img_num[label.item() - self.src_pid_nums] += 1 cluster_R_comp = [ min(cluster_R_comp[i]) for i in sorted(cluster_R_comp.keys()) ] cluster_R_indep = [ min(cluster_R_indep[i]) for i in sorted(cluster_R_indep.keys()) ] cluster_R_indep_noins = [ iou for iou, num in zip(cluster_R_indep, sorted(cluster_img_num.keys())) if cluster_img_num[num] > 1 ] if self.epoch <= self.start_epoch: """ constant threshold α for identifying independent clusters is defined by the top-90% Rindep before the first epoch and remains the same for all the training process """ logger.info("==> calculate independ before first epoch") self.indep_thres = np.sort(cluster_R_indep_noins)[min( len(cluster_R_indep_noins) - 1, np.round(len(cluster_R_indep_noins) * 0.9).astype("int"), )] pseudo_labeled_dataset = [] outliers = 0 for i, ((fname, _, cid), label) in enumerate(zip(sorted(self.tgt.train), pseudo_labels)): indep_score = cluster_R_indep[label.item() - self.src_pid_nums] comp_score = R_comp[i] if (indep_score <= self.indep_thres) and ( comp_score.item() <= cluster_R_comp[label.item() - self.src_pid_nums]): pseudo_labeled_dataset.append((fname, label.item(), cid)) else: pseudo_label = self.src_pid_nums + len( cluster_R_indep) + outliers pseudo_labeled_dataset.append((fname, pseudo_label, cid)) pseudo_labels[i] = pseudo_label outliers += 1 # statistics of clusters and un-clustered instances index2label = collections.defaultdict(int) for label in pseudo_labels: index2label[label.item()] += 1 cluster_label = [] outlier_label = [] for k, v in index2label.items(): if v == 1: outlier_label.append(k) else: cluster_label.append(k) print(f'cluster_label', min(cluster_label), max(cluster_label), len(cluster_label)) print(f'outlier label', min(outlier_label), max(outlier_label), len(outlier_label)) index2label = np.fromiter(index2label.values(), dtype=float) logger.info( "==> Statistics for epoch {}: {} clusters, {} un-clustered instances, R_indep threshold is {}" .format( self.epoch, (index2label > 1).sum(), (index2label == 1).sum(), 1 - self.indep_thres, )) self.hm.labels = torch.cat( (torch.arange(self.src_pid_nums), pseudo_labels)).cuda() src_train = self.build_dataset( self.cfg, self.src.train, is_train=True, relabel=True, # relabel? # relabel=False, with_mem_idx=True) self.src_train_loader = self.build_train_loader( self.cfg, train_set=src_train, sampler=RandomMultipleGallerySampler( src_train.img_items, self.cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(), self.cfg.DATALOADER.NUM_INSTANCE, with_mem_idx=True), with_mem_idx=True) # self.src_load_iter = iter(self.src_train_loader) pseudo_tgt_train = self.build_dataset( self.cfg, pseudo_labeled_dataset, is_train=True, relabel=True, # relabel? # relabel=False, with_mem_idx=True) self.pseudo_tgt_train_loader = self.build_train_loader( self.cfg, train_set=pseudo_tgt_train, sampler=RandomMultipleGallerySampler( pseudo_tgt_train.img_items, self.cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(), self.cfg.DATALOADER.NUM_INSTANCE, with_mem_idx=True), with_mem_idx=True) # self.tgt_load_iter = iter(self.pseudo_tgt_train_loader) return super().before_epoch() # generate new dataset and calculate cluster centers def generate_pseudo_labels(self, cluster_id, num): labels = [] outliers = 0 for i, ((fname, _, cid), id) in enumerate(zip(sorted(self.tgt.train), cluster_id)): if id != -1: labels.append(self.src_pid_nums + id) else: labels.append(self.src_pid_nums + num + outliers) outliers += 1 return torch.Tensor(labels).long() def run_step(self): assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!" if self.cfg.SOLVER.AMP.ENABLED: assert torch.cuda.is_available( ), f"[{self.__class__.__name__}] CUDA is required for AMP training!" from torch.cuda.amp.autocast_mode import autocast start = time.perf_counter() # load data src_inputs = self.src_train_loader.next() tgt_inputs = self.pseudo_tgt_train_loader.next() # src_inputs = next(self.src_load_iter) # tgt_inputs = next(self.tgt_load_iter) def _parse_data(inputs): # print(len(inputs)) # for i in range(len(inputs)): # print(i, type(inputs[i]), inputs[i]) imgs, _, pids, _, indices = inputs return imgs.cuda(), pids.cuda(), indices # process inputs s_inputs, s_targets, s_indices = _parse_data(src_inputs) t_inputs, t_targets, t_indices = _parse_data(tgt_inputs) # print('src', s_targets, s_indices) # print('tgt', t_targets, t_indices) # exit() # arrange batch for domain-specific BNP device_num = torch.cuda.device_count() B, C, H, W = s_inputs.size() def reshape(inputs): return inputs.view(device_num, -1, C, H, W) s_inputs, t_inputs = reshape(s_inputs), reshape(t_inputs) inputs = torch.cat((s_inputs, t_inputs), 1).view(-1, C, H, W) data_time = time.perf_counter() - start def _forward(): outputs = self.model(inputs) if isinstance(outputs, dict): f_out = outputs['features'] else: f_out = outputs # de-arrange batch f_out = f_out.view(device_num, -1, f_out.size(-1)) f_out_s, f_out_t = f_out.split(f_out.size(1) // 2, dim=1) f_out_s, f_out_t = f_out_s.contiguous().view( -1, f_out.size(-1)), f_out_t.contiguous().view(-1, f_out.size(-1)) # compute loss with the hybrid memory # with autocast(enabled=False): loss_s = self.hm(f_out_s, s_targets) loss_t = self.hm(f_out_t, t_indices + self.src_pid_nums) loss_dict = {'loss_s': loss_s, 'loss_t': loss_t} return loss_dict if self.cfg.SOLVER.AMP.ENABLED: with autocast(): loss_dict = _forward() losses = sum(loss_dict.values()) self.optimizer.zero_grad() self.grad_scaler.scale(losses).backward() self._write_metrics(loss_dict, data_time) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: loss_dict = _forward() losses = sum(loss_dict.values()) self.optimizer.zero_grad() losses.backward() self._write_metrics(loss_dict, data_time) self.optimizer.step() if isinstance(self.param_wrapper, ContiguousParams): self.param_wrapper.assert_buffer_is_valid() @classmethod def load_dataset(cls, name): logger = logging.getLogger(__name__) logger.info(f"Preparing {name}") _root = os.getenv("FASTREID_DATASETS", "/root/datasets") data = DATASET_REGISTRY.get(name)(root=_root) if comm.is_main_process(): data.show_train() return data @classmethod def build_dataset(cls, cfg, img_items, is_train=False, relabel=False, transforms=None, with_mem_idx=False): if transforms is None: transforms = build_transforms(cfg, is_train=is_train) if with_mem_idx: sorted_img_items = sorted(img_items) for i in range(len(sorted_img_items)): sorted_img_items[i] += (i, ) return InMemoryDataset(sorted_img_items, transforms, relabel) else: return CommDataset(img_items, transforms, relabel) @classmethod def build_train_loader(cls, cfg, train_set=None, sampler=None, with_mem_idx=False): logger = logging.getLogger('fastreid') logger.info("Prepare training loader") total_batch_size = cfg.SOLVER.IMS_PER_BATCH mini_batch_size = total_batch_size // comm.get_world_size() if sampler is None: num_instance = cfg.DATALOADER.NUM_INSTANCE sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = samplers.TrainingSampler(len(train_set)) elif sampler_name == "NaiveIdentitySampler": sampler = samplers.NaiveIdentitySampler( train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "BalancedIdentitySampler": sampler = samplers.BalancedIdentitySampler( train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "SetReWeightSampler": set_weight = cfg.DATALOADER.SET_WEIGHT sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight) elif sampler_name == "ImbalancedDatasetSampler": sampler = samplers.ImbalancedDatasetSampler( train_set.img_items) else: raise ValueError( "Unknown training sampler: {}".format(sampler_name)) iters = cfg.SOLVER.ITERS num_workers = cfg.DATALOADER.NUM_WORKERS batch_sampler = BatchSampler(sampler, mini_batch_size, True) train_loader = IterLoader( DataLoader( Preprocessor(train_set, with_mem_idx), num_workers=num_workers, batch_sampler=batch_sampler, pin_memory=True, ), length=iters, ) # train_loader = DataLoaderX( # comm.get_local_rank(), # dataset=Preprocessor(train_set, with_mem_idx), # num_workers=num_workers, # batch_sampler=batch_sampler, # collate_fn=fast_batch_collator, # pin_memory=True, # ) return train_loader @classmethod def build_test_loader(cls, cfg, test_set): logger = logging.getLogger('fastreid') logger.info("Prepare testing loader") # test_loader = DataLoader( # # Preprocessor(test_set), # test_set, # batch_size=cfg.TEST.IMS_PER_BATCH, # num_workers=cfg.DATALOADER.NUM_WORKERS, # shuffle=False, # pin_memory=True, # ) test_batch_size = cfg.TEST.IMS_PER_BATCH mini_batch_size = test_batch_size // comm.get_world_size() num_workers = cfg.DATALOADER.NUM_WORKERS data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoaderX( comm.get_local_rank(), dataset=test_set, batch_sampler=batch_sampler, num_workers=num_workers, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader @classmethod def build_evaluator(cls, cfg, dataset_name, output_dir=None): data_loader, num_query = build_reid_test_loader( cfg, dataset_name=dataset_name) return data_loader, ReidEvaluator(cfg, num_query, output_dir) @classmethod def test(cls, cfg, model): """ Args: cfg (CfgNode): model (nn.Module): Returns: dict: a dict of result metrics """ logger = logging.getLogger('fastreid') results = OrderedDict() dataset_name = cfg.DATASETS.TGT logger.info("Prepare testing set") try: data_loader, evaluator = cls.build_evaluator(cfg, dataset_name) except NotImplementedError: logger.warn( "No evaluator found. implement its `build_evaluator` method.") results[dataset_name] = {} results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP.ENABLED) results[dataset_name] = results_i if comm.is_main_process(): assert isinstance( results, dict ), "Evaluator must return a dict on the main process. Got {} instead.".format( results) logger.info("Evaluation results for {} in csv format:".format( dataset_name)) results_i['dataset'] = dataset_name print_csv_format(results_i) # if len(results) == 1: # results = list(results.values())[0] return results @classmethod def build_model(cls, cfg, load_model=True, show_model=True, use_dsbn=False): cfg = cfg.clone() # cfg can be modified by model cfg.defrost() cfg.MODEL.DEVICE = "cpu" model = build_model(cfg) logger = logging.getLogger('fastreid') if load_model: pretrain_path = cfg.MODEL.PRETRAIN_PATH try: state_dict = torch.load( pretrain_path, map_location=torch.device("cpu"))['model'] for layer in cfg.MODEL.IGNORE_LAYERS: if layer in state_dict.keys(): del state_dict[layer] logger.info(f"Loading pretrained model from {pretrain_path}") except FileNotFoundError as e: logger.info( f"{pretrain_path} is not found! Please check this path.") raise e except KeyError as e: logger.info( "State dict keys error! Please check the state dict.") raise e incompatible = model.load_state_dict(state_dict, strict=False) if incompatible.missing_keys: logger.info( get_missing_parameters_message(incompatible.missing_keys)) if incompatible.unexpected_keys: logger.info( get_unexpected_parameters_message( incompatible.unexpected_keys)) if use_dsbn: logger.info("==> Convert BN to Domain Specific BN") convert_dsbn(model) if show_model: logger.info("Model:\n{}".format(model)) model.to(torch.device("cuda")) return model @classmethod def build_optimizer(cls, cfg, model): """ Returns: torch.optim.Optimizer: It now calls :func:`fastreid.solver.build_optimizer`. Overwrite it if you'd like a different optimizer. """ return build_optimizer(cfg, model) @classmethod def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch): """ It now calls :func:`fastreid.solver.build_lr_scheduler`. Overwrite it if you'd like a different scheduler. """ return build_lr_scheduler(cfg, optimizer, iters_per_epoch) def build_hooks(self): """ Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events. Returns: list[HookBase]: """ logger = logging.getLogger(__name__) cfg = self.cfg.clone() cfg.defrost() cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET ]) # set dataset name for PreciseBN ret = [ hooks.IterationTimer(), hooks.LRScheduler(self.optimizer, self.scheduler), ] # if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model): # logger.info("Prepare precise BN dataset") # ret.append(hooks.PreciseBN( # # Run at the same freq as (but before) evaluation. # self.model, # # Build a new data loader to not affect training # self.build_train_loader(cfg), # cfg.TEST.PRECISE_BN.NUM_ITER, # )) if len(cfg.MODEL.FREEZE_LAYERS) > 0 and cfg.SOLVER.FREEZE_ITERS > 0: ret.append( hooks.LayerFreeze( self.model, cfg.MODEL.FREEZE_LAYERS, cfg.SOLVER.FREEZE_ITERS, )) # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, # some checkpoints may have more precise statistics than others. def test_and_save_results(): self._last_eval_results = self.test(self.cfg, self.model) return self._last_eval_results # Do evaluation before checkpointer, because then if it fails, # we can use the saved checkpoint to debug. ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) if comm.is_main_process(): ret.append( hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) # run writers in the end, so that evaluation metrics are written ret.append(hooks.PeriodicWriter(self.build_writers(), 200)) return ret def build_writers(self): """ Build a list of writers to be used. By default it contains writers that write metrics to the screen, a json file, and a tensorboard event file respectively. If you'd like a different list of writers, you can overwrite it in your trainer. Returns: list[EventWriter]: a list of :class:`EventWriter` objects. It is now implemented by: .. code-block:: python return [ CommonMetricPrinter(self.max_iter), JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), TensorboardXWriter(self.cfg.OUTPUT_DIR), ] """ # Assume the default print/log frequency. # TODO: customize my writers return [ # It may not always print what you want to see, since it prints "common" metrics only. CommonMetricPrinter(self.max_iter), JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), TensorboardXWriter(self.cfg.OUTPUT_DIR), ] def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float): """ Args: loss_dict (dict): dict of scalar losses data_time (float): time taken by the dataloader iteration """ device = next(iter(loss_dict.values())).device # Use a new stream so these ops don't wait for DDP or backward with torch.cuda.stream(torch.cuda.Stream() if device.type == "cuda" else None): metrics_dict = { k: v.detach().cpu().item() for k, v in loss_dict.items() } metrics_dict["data_time"] = data_time # Gather metrics among all workers for logging # This assumes we do DDP-style training, which is currently the only # supported method in detectron2. all_metrics_dict = comm.gather(metrics_dict) if comm.is_main_process(): storage = get_event_storage() # data_time among workers can have high variance. The actual latency # caused by data_time is the maximum among workers. data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) storage.put_scalar("data_time", data_time) # average the rest metrics metrics_dict = { k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys() } total_losses_reduced = sum(metrics_dict.values()) if not np.isfinite(total_losses_reduced): raise FloatingPointError( f"Loss became infinite or NaN at iteration={self.iter}!\n" f"loss_dict = {metrics_dict}") storage.put_scalar("total_loss", total_losses_reduced) if len(metrics_dict) > 1: storage.put_scalars(**metrics_dict)
class MyMixedDefaultTrainer(TrainerBase): """w/o AMP mode""" def __init__(self, cfg): """ Args: cfg (CfgNode): """ super().__init__() logger = logging.getLogger("fastreid") if not logger.isEnabledFor( logging.INFO): # setup_logger is not called for fastreid setup_logger() # Assume these objects must be constructed in this order. data_loader = self.build_train_loader(cfg) cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes) self.model = self.build_model(cfg) self.optimizer, self.param_wrapper = self.build_optimizer( cfg, self.model) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # for part of the parameters is not updated. self.model = DistributedDataParallel( self.model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, ) self._data_loader_iter = iter(data_loader) self.iters_per_epoch = len( data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH self.scheduler = self.build_lr_scheduler(cfg, self.optimizer, self.iters_per_epoch) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = Checkpointer( # Assume you want to save checkpoints together with logs/statistics self.model, cfg.OUTPUT_DIR, save_to_disk=comm.is_main_process(), optimizer=self.optimizer, **self.scheduler, ) self.start_epoch = 0 self.max_epoch = cfg.SOLVER.MAX_EPOCH self.max_iter = self.max_epoch * self.iters_per_epoch self.warmup_iters = cfg.SOLVER.WARMUP_ITERS self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS self.cfg = cfg self.register_hooks(self.build_hooks()) if cfg.SOLVER.AMP.ENABLED: unsupported = f"[{self.__class__.__name__}] does not support single-process multi-device training!" if isinstance(self.model, DistributedDataParallel): assert not (self.model.device_ids and len(self.model.device_ids) > 1), unsupported from torch.cuda.amp.grad_scaler import GradScaler self.grad_scaler = GradScaler() else: self.grad_scaler = None def resume_or_load(self, resume=True): """ If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by a `last_checkpoint` file), resume from the file. Resuming means loading all available states (eg. optimizer and scheduler) and update iteration counter from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used. Otherwise, this is considered as an independent training. The method will load model weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start from iteration 0. Args: resume (bool): whether to do resume or not """ # The checkpoint stores the training iteration that just finished, thus we start # at the next iteration (or iter zero if there's no checkpoint). checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) if resume and self.checkpointer.has_checkpoint(): self.start_epoch = checkpoint.get("epoch", -1) + 1 # The checkpoint stores the training iteration that just finished, thus we start # at the next iteration (or iter zero if there's no checkpoint). def build_hooks(self): """ Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events. Returns: list[HookBase]: """ logger = logging.getLogger(__name__) cfg = self.cfg.clone() cfg.defrost() cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET ]) # set dataset name for PreciseBN ret = [ hooks.IterationTimer(), hooks.LRScheduler(self.optimizer, self.scheduler), ] if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model): logger.info("Prepare precise BN dataset") ret.append( hooks.PreciseBN( # Run at the same freq as (but before) evaluation. self.model, # Build a new data loader to not affect training self.build_train_loader(cfg), cfg.TEST.PRECISE_BN.NUM_ITER, )) if len(cfg.MODEL.FREEZE_LAYERS) > 0 and cfg.SOLVER.FREEZE_ITERS > 0: ret.append( hooks.LayerFreeze( self.model, cfg.MODEL.FREEZE_LAYERS, cfg.SOLVER.FREEZE_ITERS, )) # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, # some checkpoints may have more precise statistics than others. def test_and_save_results(): self._last_eval_results = self.test(self.cfg, self.model) return self._last_eval_results # Do evaluation before checkpointer, because then if it fails, # we can use the saved checkpoint to debug. ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) if comm.is_main_process(): ret.append( hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) # run writers in the end, so that evaluation metrics are written ret.append(hooks.PeriodicWriter(self.build_writers(), 200)) return ret def build_writers(self): """ Build a list of writers to be used. By default it contains writers that write metrics to the screen, a json file, and a tensorboard event file respectively. If you'd like a different list of writers, you can overwrite it in your trainer. Returns: list[EventWriter]: a list of :class:`EventWriter` objects. It is now implemented by: .. code-block:: python return [ CommonMetricPrinter(self.max_iter), JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), TensorboardXWriter(self.cfg.OUTPUT_DIR), ] """ # Assume the default print/log frequency. return [ # It may not always print what you want to see, since it prints "common" metrics only. CommonMetricPrinter(self.max_iter), JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), TensorboardXWriter(self.cfg.OUTPUT_DIR), ] def train(self): """ Run training. Returns: OrderedDict of results, if evaluation is enabled. Otherwise None. """ super().train(self.start_epoch, self.max_epoch, self.iters_per_epoch) if comm.is_main_process(): assert hasattr(self, "_last_eval_results" ), "No evaluation results obtained during training!" return self._last_eval_results def run_step(self): assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!" if self.cfg.SOLVER.AMP.ENABLED: assert torch.cuda.is_available( ), f"[{self.__class__.__name__}] CUDA is required for AMP training!" from torch.cuda.amp.autocast_mode import autocast start = time.perf_counter() data = next(self._data_loader_iter) data_time = time.perf_counter() - start if self.cfg.SOLVER.AMP.ENABLED: with autocast(): loss_dict = self.model(data) losses = sum(loss_dict.values()) self.optimizer.zero_grad() self.grad_scaler.scale(losses).backward() self._write_metrics(loss_dict, data_time) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: loss_dict = self.model(data) losses = sum(loss_dict.values()) self.optimizer.zero_grad() losses.backward() self._write_metrics(loss_dict, data_time) self.optimizer.step() if isinstance(self.param_wrapper, ContiguousParams): self.param_wrapper.assert_buffer_is_valid() def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float): """ Args: loss_dict (dict): dict of scalar losses data_time (float): time taken by the dataloader iteration """ device = next(iter(loss_dict.values())).device # Use a new stream so these ops don't wait for DDP or backward with torch.cuda.stream(torch.cuda.Stream() if device.type == "cuda" else None): metrics_dict = { k: v.detach().cpu().item() for k, v in loss_dict.items() } metrics_dict["data_time"] = data_time # Gather metrics among all workers for logging # This assumes we do DDP-style training, which is currently the only # supported method in detectron2. all_metrics_dict = comm.gather(metrics_dict) if comm.is_main_process(): storage = get_event_storage() # data_time among workers can have high variance. The actual latency # caused by data_time is the maximum among workers. data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) storage.put_scalar("data_time", data_time) # average the rest metrics metrics_dict = { k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys() } total_losses_reduced = sum(metrics_dict.values()) if not np.isfinite(total_losses_reduced): raise FloatingPointError( f"Loss became infinite or NaN at iteration={self.iter}!\n" f"loss_dict = {metrics_dict}") storage.put_scalar("total_loss", total_losses_reduced) if len(metrics_dict) > 1: storage.put_scalars(**metrics_dict) @classmethod def build_model(cls, cfg, show_model=True): """ Returns: torch.nn.Module: It now calls :func:`fastreid.modeling.build_model`. Overwrite it if you'd like a different model. """ model = build_model(cfg) if show_model: logger = logging.getLogger('fastreid') logger.info("Model:\n{}".format(model)) return model @classmethod def build_optimizer(cls, cfg, model): """ Returns: torch.optim.Optimizer: It now calls :func:`fastreid.solver.build_optimizer`. Overwrite it if you'd like a different optimizer. """ return build_optimizer(cfg, model) @classmethod def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch): """ It now calls :func:`fastreid.solver.build_lr_scheduler`. Overwrite it if you'd like a different scheduler. """ return build_lr_scheduler(cfg, optimizer, iters_per_epoch) @classmethod def build_train_loader(cls, cfg): """ Returns: iterable It now calls :func:`fastreid.data.build_reid_train_loader`. Overwrite it if you'd like a different data loader. """ logger = logging.getLogger(__name__) logger.info("Prepare training set") return build_reid_train_loader(cfg, combineall=cfg.DATASETS.COMBINEALL) @classmethod def build_test_loader(cls, cfg, dataset_name): """ Returns: iterable It now calls :func:`fastreid.data.build_reid_test_loader`. Overwrite it if you'd like a different data loader. """ return build_reid_test_loader(cfg, dataset_name=dataset_name) @classmethod def build_evaluator(cls, cfg, dataset_name, output_dir=None): data_loader, num_query = cls.build_test_loader(cfg, dataset_name) return data_loader, ReidEvaluator(cfg, num_query, output_dir) @classmethod def test(cls, cfg, model): """ Args: cfg (CfgNode): model (nn.Module): Returns: dict: a dict of result metrics """ logger = logging.getLogger(__name__) results = OrderedDict() for idx, dataset_name in enumerate(cfg.DATASETS.TESTS): logger.info("Prepare testing set") try: data_loader, evaluator = cls.build_evaluator(cfg, dataset_name) except NotImplementedError: logger.warn( "No evaluator found. implement its `build_evaluator` method." ) results[dataset_name] = {} continue results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP.ENABLED) results[dataset_name] = results_i if comm.is_main_process(): assert isinstance( results, dict ), "Evaluator must return a dict on the main process. Got {} instead.".format( results) logger.info("Evaluation results for {} in csv format:".format( dataset_name)) results_i['dataset'] = dataset_name print_csv_format(results_i) if len(results) == 1: results = list(results.values())[0] return results @staticmethod def auto_scale_hyperparams(cfg, num_classes): r""" This is used for auto-computation actual training iterations, because some hyper-param, such as MAX_ITER, means training epochs rather than iters, so we need to convert specific hyper-param to training iterations. """ cfg = cfg.clone() frozen = cfg.is_frozen() cfg.defrost() # If you don't hard-code the number of classes, it will compute the number automatically if cfg.MODEL.HEADS.NUM_CLASSES == 0: output_dir = cfg.OUTPUT_DIR cfg.MODEL.HEADS.NUM_CLASSES = num_classes logger = logging.getLogger(__name__) logger.info( f"Auto-scaling the num_classes={cfg.MODEL.HEADS.NUM_CLASSES}") # Update the saved config file to make the number of classes valid if comm.is_main_process() and output_dir: # Note: some of our scripts may expect the existence of # config.yaml in output directory path = os.path.join(output_dir, "config.yaml") with PathManager.open(path, "w") as f: f.write(cfg.dump()) if frozen: cfg.freeze() return cfg
def train( epoch: int, data: DistributedDataObject, device: torch.device, rank: int, model: nn.Module, loss_fn: LossFunction, optimizer: optim.Optimizer, args: dict, scaler: GradScaler = None, ): model.train() # Horovod: set epoch to sampler for shuffling data.sampler.set_epoch(epoch) running_loss = torch.tensor(0.0) training_acc = torch.tensor(0.0) if torch.cuda.is_available(): running_loss = running_loss.to(device) training_acc = training_acc.to(device) for batch_idx, (batch, target) in enumerate(data.loader): if torch.cuda.is_available(): batch, target = batch.to(device), target.to(device) optimizer.zero_grad() output = model(batch) loss = loss_fn(output, target) if scaler is not None: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss.backward() optimizer.step() pred = output.data.max(1, keepdim=True)[1] acc = pred.eq(target.data.view_as(pred)).cpu().float().sum() training_acc += acc running_loss += loss.item() if batch_idx % args.log_interval == 0: metrics_ = { 'epoch': epoch, 'batch_loss': loss.item() / args.batch_size, 'running_loss': running_loss / len(data.sampler), 'batch_acc': acc.item() / args.batch_size, 'training_acc': training_acc / len(data.sampler), } jdx = batch_idx * len(batch) frac = 100. * batch_idx / len(data.loader) pre = [ f'[{rank}]', f'[{jdx:>5}/{len(data.sampler):<5} ({frac:>03.1f}%)]' ] io.print_metrics(metrics_, pre=pre, logger=logger) running_loss = running_loss / len(data.sampler) training_acc = training_acc / len(data.sampler) loss_avg = metric_average(running_loss) training_acc = metric_average(training_acc) if rank == 0: logger.log(f'training set; avg loss: {loss_avg:.4g}, ' f'accuracy: {training_acc * 100:.2f}%')
def train(rank, cfg: TrainConfig): if cfg.distributed.n_gpus_per_node > 1: init_process_group(backend=cfg.distributed.dist_backend, init_method=cfg.distributed.dist_url, world_size=cfg.distributed.n_nodes * cfg.distributed.n_gpus_per_node, rank=rank) device = torch.device(f'cuda:{rank:d}') model = ConvRNNEmbedder(cfg.model_cfg).to(device) loss_fn = GE2ELoss(device).to(device) logging.info(f"Initialized rank {rank}") if rank == 0: logging.getLogger().setLevel(logging.INFO) logging.info(f"Model initialized as:\n {model}") os.makedirs(cfg.checkpoint_path, exist_ok=True) logging.info(f"checkpoints directory : {cfg.checkpoint_path}") logging.info( f"Model has {sum([p.numel() for p in model.parameters()]):,d} parameters." ) steps = 0 if cfg.resume_checkpoint != '' and os.path.isfile(cfg.resume_checkpoint): state_dict = torch.load(cfg.resume_checkpoint, map_location=device) model.load_state_dict(state_dict['model_state_dict']) loss_fn.load_state_dict(state_dict['loss_fn_state_dict']) steps = state_dict['steps'] + 1 last_epoch = state_dict['epoch'] print( f"Checkpoint loaded from {cfg.resume_checkpoint}. Resuming training from {steps} steps at epoch {last_epoch}" ) else: state_dict = None last_epoch = -1 if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1: if rank == 0: logging.info("Multi-gpu detected") model = DDP(model, device_ids=[rank]).to(device) loss_fn = DDP(loss_fn, device_ids=[rank]).to(device) optim = torch.optim.AdamW(chain(model.parameters(), loss_fn.parameters()), 1.0, betas=cfg.betas) if state_dict is not None: optim.load_state_dict(state_dict['optim_state_dict']) train_df, valid_df = pd.read_csv(cfg.train_csv), pd.read_csv(cfg.valid_csv) trainset = UtteranceDS(train_df, cfg.sample_rate, cfg.n_uttr_per_spk) train_sampler = DistributedSampler( trainset ) if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else None train_loader = DataLoader(trainset, num_workers=cfg.num_workers, shuffle=False, sampler=train_sampler, batch_size=cfg.batch_size, pin_memory=False, drop_last=True, collate_fn=SpecialCollater( cfg.min_seq_len, cfg.max_seq_len)) if rank == 0: validset = UtteranceDS(valid_df, cfg.sample_rate, cfg.n_uttr_per_spk) validation_loader = DataLoader(validset, num_workers=cfg.num_workers, shuffle=False, sampler=None, batch_size=cfg.batch_size, pin_memory=False, drop_last=True, collate_fn=SpecialCollater( cfg.min_seq_len, cfg.max_seq_len)) sw = SummaryWriter(os.path.join(cfg.checkpoint_path, 'logs')) total_iters = cfg.n_epochs * len(train_loader) def sched_lam(x): return lin_one_cycle(cfg.start_lr, cfg.max_lr, cfg.end_lr, cfg.warmup_pct, total_iters, x) scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[sched_lam], last_epoch=steps - 1) if state_dict is not None: scheduler.load_state_dict(state_dict['scheduler_state_dict']) if cfg.fp16: scaler = GradScaler() if state_dict is not None and 'scaler_state_dict' in state_dict: scaler.load_state_dict(state_dict['scaler_state_dict']) model.train() if rank == 0: mb = master_bar(range(max(0, last_epoch), cfg.n_epochs)) smooth_loss = None else: mb = range(max(0, last_epoch), cfg.n_epochs) for epoch in mb: if rank == 0: start = time.time() mb.write("Epoch: {}".format(epoch + 1)) if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1: train_sampler.set_epoch(epoch) if rank == 0: pb = progress_bar(enumerate(train_loader), total=len(train_loader), parent=mb) else: pb = enumerate(train_loader) for i, batch in pb: if rank == 0: start_b = time.time() x, xlen = batch x = x.to(device, non_blocking=True) xlen = xlen.to(device, non_blocking=True) optim.zero_grad() with torch.cuda.amp.autocast(enabled=cfg.fp16): embeds = model(x, xlen) loss = loss_fn(embeds) if cfg.fp16: scaler.scale(loss).backward() scaler.unscale_(optim) gnorm = torch.nn.utils.clip_grad.clip_grad_norm_( model.parameters(), cfg.grad_clip) torch.nn.utils.clip_grad.clip_grad_norm_( loss_fn.parameters(), cfg.grad_clip / 2) scaler.step(optim) scaler.update() else: loss.backward() gnorm = torch.nn.utils.clip_grad.clip_grad_norm_( model.parameters(), cfg.grad_clip) torch.nn.utils.clip_grad.clip_grad_norm_( loss_fn.parameters(), cfg.grad_clip / 2) optim.step() if rank == 0: if smooth_loss is None: smooth_loss = float(loss.item()) else: smooth_loss = smooth_loss + 0.1 * (float(loss.item()) - smooth_loss) # STDOUT logging if steps % cfg.stdout_interval == 0: mb.write('steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}, peak mem: {:5.2f}GB'. \ format(steps, loss.item(), time.time() - start_b, torch.cuda.max_memory_allocated()/1e9)) mb.child.comment = 'steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}'. \ format(steps, loss.item(), time.time() - start_b) # mb.write(f"lr = {float(optim.param_groups[0]['lr'])}") # checkpointing if steps % cfg.checkpoint_interval == 0 and steps != 0: checkpoint_path = f"{cfg.checkpoint_path}/ckpt_{steps:08d}.pt" torch.save( { 'model_state_dict': (model.module if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else model).state_dict(), 'loss_fn_state_dict': (loss_fn.module if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else loss_fn).state_dict(), 'optim_state_dict': optim.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': (scaler.state_dict() if cfg.fp16 else None), 'steps': steps, 'epoch': epoch }, checkpoint_path) logging.info(f"Saved checkpoint to {checkpoint_path}") # Tensorboard summary logging if steps % cfg.summary_interval == 0: sw.add_scalar("training/loss_smooth", smooth_loss, steps) sw.add_scalar("training/loss_raw", loss.item(), steps) sw.add_scalar( "ge2e/w", float((loss_fn.module if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else loss_fn).w.item()), steps) sw.add_scalar( "ge2e/b", float((loss_fn.module if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else loss_fn).b.item()), steps) sw.add_scalar("opt/lr", float(optim.param_groups[0]['lr']), steps) sw.add_scalar('opt/grad_norm', float(gnorm), steps) # Validation if steps % cfg.validation_interval == 0 and steps != 0: model.eval() loss_fn.eval() torch.cuda.empty_cache() val_err_tot = 0 flat_embeds = [] flat_lbls = [] with torch.no_grad(): for j, batch in progress_bar( enumerate(validation_loader), total=len(validation_loader), parent=mb): x, xlen = batch embeds = model(x.to(device), xlen.to(device)) val_err_tot += loss_fn(embeds) if j <= 2: lbls = [ f'spk-{j}-{indr:03d}' for indr in range(cfg.batch_size) for _ in range(cfg.n_uttr_per_spk) ] fembeds = embeds.view( cfg.batch_size * cfg.n_uttr_per_spk, cfg.model_cfg.fc_dim) flat_embeds.append(fembeds.cpu()) flat_lbls.extend(lbls) elif j == 3: flat_embeds = torch.cat(flat_embeds, dim=0) sw.add_embedding(flat_embeds, metadata=flat_lbls, global_step=steps) val_err = val_err_tot / (j + 1) sw.add_scalar("validation/loss", val_err, steps) mb.write( f"validation run complete at {steps:,d} steps. validation loss: {val_err:5.4f}" ) model.train() loss_fn.train() sw.add_scalar("memory/max_allocated_gb", torch.cuda.max_memory_allocated() / 1e9, steps) sw.add_scalar("memory/max_reserved_gb", torch.cuda.max_memory_reserved() / 1e9, steps) torch.cuda.reset_peak_memory_stats() torch.cuda.reset_accumulated_memory_stats() steps += 1 scheduler.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start))) sw.add_hparams(flatten_cfg(cfg), metric_dict={'validation/loss': val_err}, run_name=f'run-{cfg.checkpoint_path}') print("Training completed!")
class Amp: def __init__( self, enabled: bool = False, max_norm: Optional[float] = None, ) -> None: self.grad_scaler = GradScaler(enabled=enabled) self.enabled = enabled self.max_norm = max_norm _logger.info("amp: %s", self.enabled) if self.max_norm: _logger.info( "you are using grad clip, don't forget to pass params in") def autocast(self): return autocast(enabled=self.enabled) def scale(self, outputs: TensorOrIterableTensors) -> TensorOrIterableTensors: return self.grad_scaler.scale(outputs) def unscale_(self, optimizer: Optimizer): return self.grad_scaler.unscale_(optimizer) def step(self, optimizer: Optimizer, *args, **kwargs): return self.grad_scaler.step(optimizer, *args, **kwargs) def update(self, new_scale: Union[float, Tensor, None] = None): return self.grad_scaler.update(new_scale=new_scale) def clip_grad_norm_(self, params: TensorOrIterableTensors): torch.nn.utils.clip_grad_norm_(params, self.max_norm) def state_dict(self) -> dict: return self.grad_scaler.state_dict() def load_state_dict(self, state_dict: dict): return self.grad_scaler.load_state_dict(state_dict) def __call__( self, loss: Tensor, optimizer: torch.optim.Optimizer, parameters: Optional[TensorOrIterableTensors] = None, zero_grad_set_to_none: bool = False, ): self.scale(loss).backward() if self.max_norm is not None: assert parameters is not None self.unscale_(optimizer) self.clip_grad_norm_(parameters) self.grad_scaler.step(optimizer) self.grad_scaler.update() optimizer.zero_grad(set_to_none=zero_grad_set_to_none) def backward( self, loss: Tensor, optimizer: torch.optim.Optimizer, parameters: Optional[TensorOrIterableTensors] = None, ): return self(loss, optimizer, parameters=parameters)
def train_step( model: FlowModel, config: TrainConfig, action: ActionFn, optimizer: optim.Optimizer, batch_size: int, scheduler: Any = None, scaler: GradScaler = None, pre_model: FlowModel = None, dkl_factor: float = 1., xi: torch.Tensor = None, ): """Perform a single training step. TODO: Add `torch.device` to arguments for DDP. """ t0 = time.time() # layers, prior = model['layers'], model['prior'] optimizer.zero_grad() loss_dkl = torch.tensor(0.0) if torch.cuda.is_available(): loss_dkl = loss_dkl.cuda() if pre_model is not None: pre_xi = pre_model.prior.sample_n(batch_size) x = qed.ft_flow(pre_model.layers, pre_xi) xi = qed.ft_flow_inv(pre_model.layers, x) # with torch.cuda.amp.autocast(): x, xi, logq = apply_flow_to_prior(model.prior, model.layers, xi=xi, batch_size=batch_size) logp = (-1.) * action(x) dkl = calc_dkl(logp, logq) ess = calc_ess(logp, logq) qi = qed.batch_charges(xi) q = qed.batch_charges(x) plaq = logp / (config.beta * config.volume) dq = torch.sqrt((q - qi) ** 2) loss_dkl = dkl_factor * dkl if scaler is not None: scaler.scale(loss_dkl).backward() scaler.step(optimizer) scaler.update() else: loss_dkl.backward() optimizer.step() if scheduler is not None: scheduler.step(loss_dkl) metrics = { 'dt': time.time() - t0, 'ess': grab(ess), 'logp': grab(logp), 'logq': grab(logq), 'loss_dkl': grab(loss_dkl), 'q': grab(q), 'dq': grab(dq), 'plaq': grab(plaq), } return metrics