Exemplo n.º 1
0
 def __init__(self, grad_clip=None):
     super().__init__()
     self._grad_clip = grad_clip
     self._level = PriorityStatus.HIGH
     self._user_scale = False if self._grad_clip is None else getattr(
         self._grad_clip, 'use_scale', False)
     self._scaler = GradScaler(enabled=False) if self._user_scale else None
Exemplo n.º 2
0
 def load_model(self, checkpoint=None):
     print('Loading model from', self.config['model'])
     self.model, self.tokenizer = instantiate_model_and_tokenizer(
         self.config['model'],
         additional_tokens_smart_init=self.config['smart_init'],
         dropout=self.config['dropout'],
         attention_dropout=self.config['attention_dropout'],
         penman_linearization=self.config['penman_linearization'],
         collapse_name_ops=self.config['collapse_name_ops'],
         use_pointer_tokens=self.config['use_pointer_tokens'],
         raw_graph=self.config['raw_graph'])
     self.model.to(self.device)
     # Load optimization components
     self.optimizer = AdamW(self.model.parameters(),
                            lr=self.config['learning_rate'],
                            weight_decay=self.config['weight_decay'])
     self.scheduler = transformers.get_constant_schedule_with_warmup(
         self.optimizer, num_warmup_steps=self.config['warmup_steps'])
     self.scaler = GradScaler(enabled=self.config['fp16'])
     # Reload checkpoint model weights and optimizer params if loading from a checkpoint
     if checkpoint is not None:
         print('Checkpoint %s restored' % checkpoint)
         load_state_dict_from_checkpoint(checkpoint, self.model,
                                         self.optimizer, self.scheduler)
         # Try to load the smatch score and last_epoch from the config in the model directory.
         try:
             with open(os.path.join(self.model_dir, 'config.json')) as f:
                 model_config = json.load(f)
             self.best_smatch = model_config['smatch_dev']
             self.start_epoch = model_config['last_epoch'] + 1
         except:
             logger.exception(
                 'Unable to load config file in model directory')
Exemplo n.º 3
0
    def __init__(self, cfg):
        self.cfg = cfg
        self.paths = cfg['paths']
        self.net_params = cfg['net']
        self.train_params = cfg['train']
        self.trans_params = cfg['train']['transforms']

        self.checkpoints = self.paths['checkpoints']
        Path(self.checkpoints).mkdir(parents=True, exist_ok=True)
        shutil.copyfile('config.yaml', f'{self.checkpoints}/config.yaml')

        self.update_interval = self.paths['update_interval']

        # amp training
        self.use_amp = self.train_params['mixed_precision']
        self.scaler = GradScaler() if self.use_amp else None

        # data setup
        dataset_name = self.train_params['dataset']
        self.use_multi = dataset_name == 'multi'
        print(f'Using dataset: {dataset_name}')
        self.train_dataset = get_pedestrian_dataset(
            dataset_name,
            self.paths,
            augment=get_train_transforms(self.trans_params),
            mode='train',
            multi_datasets=self.train_params['multi_datasets']
            if self.use_multi else None)
        print(f'Train dataset: {len(self.train_dataset)} samples')

        self.val_dataset = get_pedestrian_dataset(
            dataset_name,
            self.paths,
            augment=get_val_transforms(self.trans_params),
            mode='val',
            multi_datasets=self.train_params['multi_datasets']
            if self.use_multi else None)
        print(f'Val dataset: {len(self.val_dataset)} samples')

        tests_data = self.train_params['test_datasets']
        self.test_datasets = [
            get_pedestrian_dataset(d_name,
                                   self.paths,
                                   augment=get_test_transforms(
                                       self.trans_params),
                                   mode='test') for d_name in tests_data
        ]

        self.criterion = AnchorFreeLoss(self.train_params)

        self.writer = Writer(self.paths['log_dir'])
        print('Tensorboard logs are saved to: {}'.format(
            self.paths['log_dir']))

        self.sched_type = self.train_params['scheduler']
        self.scheduler = None
        self.optimizer = None
Exemplo n.º 4
0
    def __init__(
        self,
        enabled: bool = False,
        max_norm: Optional[float] = None,
    ) -> None:
        self.grad_scaler = GradScaler(enabled=enabled)
        self.enabled = enabled
        self.max_norm = max_norm

        _logger.info("amp: %s", self.enabled)
        if self.max_norm:
            _logger.info(
                "you are using grad clip, don't forget to pass params in")
Exemplo n.º 5
0
class OptimizerHook(HookBase):
    def __init__(self, grad_clip=None):
        super().__init__()
        self._grad_clip = grad_clip
        self._level = PriorityStatus.HIGH
        self._user_scale = False if self._grad_clip is None else getattr(
            self._grad_clip, 'use_scale', False)
        self._scaler = GradScaler(enabled=False) if self._user_scale else None

    def _clip_grad_norm(self) -> None:
        clip_norm_params = list(
            filter(lambda parm: parm.requires_grad and parm.grad is not None,
                   self.trainer.model.parameters()))
        if len(clip_norm_params) == 0:
            return
        else:
            if hasattr(self._grad_clip, 'clip_norm_mode'):
                scale = self._scaler.get_scale() if self._user_scale else 1.0
                max_norm = self._grad_clip.max_grad_l2_norm * scale
                grad_norm = clip_grad.clip_grad_norm(clip_norm_params,
                                                     max_norm)
            else:
                grad_norm = clip_grad.clip_grad_norm_(clip_norm_params,
                                                      **self._grad_clip)
            self.trainer.log_buffer.put_scalar('grad_norm', float(grad_norm))

    def after_train_iter(self):
        self.trainer.output['loss'] /= self.trainer.gradient_accumulation_steps
        self.trainer.output['loss'].backward()
        if self._grad_clip is not None:
            self._clip_grad_norm()

        if (self.trainer.iter +
                1) % self.trainer.gradient_accumulation_steps == 0:
            if self._user_scale:
                self._scaler.step(self.trainer.optimizer)
                self._scaler.update()
            else:
                self.trainer.optimizer.step()

    def before_train_iter(self):
        if self.trainer.iter == 0:
            is_clean = True
        elif self.trainer.iter % self.trainer.gradient_accumulation_steps == 0:
            is_clean = True
        else:
            is_clean = False

        if is_clean:
            self.trainer.optimizer.zero_grad()
def fnTrain(
    loader: DataLoader,
    device: str,
    model: nn.Module,
    optimizer: Optimizer,
    fnLoss,
    scaler: GradScaler,
) -> float:

    runningLoss = 0
    for _, (data, targets) in enumerate(loader):

        data = data.to(device=device)
        targets = targets.float().unsqueeze(1).to(device=device)

        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = fnLoss(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # print(f"batch {idxBatch+ 1} loss {loss.item()}")

        runningLoss += loss.item()

    return runningLoss / len(loader)
Exemplo n.º 7
0
    def train_on_epoch(self, fold, epoch):
        logging.info(f'Fold-[{fold}] ==> Training on epoch{epoch}...')
        
        trn_loss_list = []
        trn_acc_list = [] 
        trn_f1_list = []
        self.model.train()
        
        if self.cfg.mixed_precision:
            scaler = GradScaler()
        else:
            scaler = None
        
        self.reset_scoring_dict()
        with tqdm(self.trn_dl, total=len(self.trn_dl), unit="batch") as train_bar:
            for batch, sample in enumerate(train_bar):
                train_bar.set_description(f"Fold-[{fold}|{self.cfg.n_fold}] ==> Train Epoch [{str(epoch).zfill(len(str(self.cfg.t_epoch)))}|{self.cfg.t_epoch}]")
                
                result = self.step(sample, scaler)
                batch_f1  = result['batch_f1']
                batch_acc = result['batch_acc']
                loss = result['loss']
                
                if not torch.isfinite(loss):
                    print(loss, sample, result['logit'], sample['image'].shape, sample['label'].shape)
                    raise ValueError('WARNING: non-finite loss, ending training ')
                
                trn_f1_list.append(batch_f1)
                trn_acc_list.append(batch_acc)
                trn_loss_list.append(loss.item())
                trn_f1 = np.mean(trn_f1_list)
                trn_acc = np.mean(trn_acc_list)
                trn_loss = np.mean(trn_loss_list)
                if batch % self.cfg.log_interval == 0 or batch == len(self.trn_dl)-1:
                    logging.info(f"Fold-[{fold}] ==> <Train> Epoch: [{str(epoch).zfill(len(str(self.cfg.t_epoch)))}|{str(self.cfg.t_epoch)}]  Batch: [{str(batch).zfill(len(str(len(self.trn_dl))))}|{len(self.trn_dl)}]\t Train Acc: {trn_acc}\t Train F1: {trn_f1}\t Train Loss: {trn_loss}")

                train_bar.set_postfix(train_loss=trn_loss, train_acc=trn_acc, train_f1=trn_f1)
                
                if self.cfg.sched_type == "onecycle":
                    self.sched.step()            
        
        if self.cfg.sched_type == "cosine": self.sched.step()
        reports = report(self.scoring_dict["preds"], self.scoring_dict["labels"])
        logging.info(f"Fold-[{fold}] ==> <Train> Epoch: [{str(epoch).zfill(len(str(self.cfg.t_epoch)))}|{str(self.cfg.t_epoch)}] REPOST\n{reports}\n")
Exemplo n.º 8
0
class Fp16OptimizerHook(OptimizerHook):
    def __init__(self, grad_clip=None, grad_scaler_config=None):
        super().__init__(grad_clip)
        self._grad_scaler_config = grad_scaler_config
        self._scaler = None

    def before_train(self):
        if self._grad_scaler_config is None:
            self._scaler = GradScaler()
        else:
            self._scaler = GradScaler(**self._grad_scaler_config)

    def after_train_iter(self):
        loss = self.trainer.output[
            'loss'] / self.trainer.gradient_accumulation_steps
        self._scaler.scale(loss).backward()
        if self._grad_clip is not None:
            self._scaler.unscale_(self.trainer.optimizer)
            self._clip_grad_norm()

        if (self.trainer.iter +
                1) % self.trainer.gradient_accumulation_steps == 0:
            self._scaler.step(self.trainer.optimizer)
            self._scaler.update()
Exemplo n.º 9
0
    def valid_on_epoch(self, fold, epoch):
        logging.info(f'Fold-[{fold}] ==>  Validation on epoch{epoch}...')
        
        val_loss_list = []
        val_acc_list = []
        val_f1_list = []
        self.model.eval()
        
        if self.cfg.mixed_precision: scaler = GradScaler()
        else: scaler = None
            
        self.reset_scoring_dict()
        with torch.no_grad():
            with tqdm(self.val_dl, total=len(self.val_dl), unit="batch") as valid_bar:
                for batch, sample in enumerate(valid_bar):
                    valid_bar.set_description(f"Fold-[{fold}|{self.cfg.n_fold}] ==> Valid Epoch [{str(epoch).zfill(len(str(self.cfg.t_epoch)))}|{self.cfg.t_epoch}]")
                    
                    
                    result = self.step(sample, scaler, valid=True)
                    batch_f1  = result['batch_f1']
                    batch_acc = result['batch_acc']
                    loss = result['loss']
                    
                    val_f1_list.append(batch_f1)
                    val_acc_list.append(batch_acc)
                    val_loss_list.append(loss.item())
                    val_f1 = np.mean(val_f1_list)
                    val_acc = np.mean(val_acc_list)
                    val_loss = np.mean(val_loss_list)
                    if batch % self.cfg.log_interval == 0 or batch == len(self.val_dl)-1:
                        logging.info(f"Fold-[{fold}] ==> <Valid> Epoch: [{str(epoch).zfill(len(str(self.cfg.t_epoch)))}|{str(self.cfg.t_epoch)}]  Batch: [{str(batch).zfill(len(str(len(self.val_dl))))}|{len(self.val_dl)}]\t Valid Acc: {val_acc}\t Valid F1: {val_f1}\t Valid Loss: {val_loss}")

                    valid_bar.set_postfix(valid_loss=val_loss,valid_acc=val_acc, valid_f1=val_f1)
        
        if self.cfg.sched_type == "plateau": self.sched.step(val_loss)
        reports = report(self.scoring_dict["preds"], self.scoring_dict["labels"])
        logging.info(f"Fold-[{fold}] ==> <Valid> Epoch: [{str(epoch).zfill(len(str(self.cfg.t_epoch)))}|{str(self.cfg.t_epoch)}] REPOST\n{reports}\n")
        return val_loss, val_acc, val_f1
Exemplo n.º 10
0
    def __init__(self, cfg):
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # if setup_logger is not called for fastreid
            setup_logger()

        # Create datasets
        logger.info("==> Load source-domain dataset")
        self.src = src = self.load_dataset(cfg.DATASETS.SRC)
        self.src_pid_nums = src.get_num_pids(src.train)

        logger.info("==> Load target-domain dataset")
        self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT)
        self.tgt_nums = len(tgt.train)

        # Create model
        self.model = self.build_model(cfg,
                                      load_model=False,
                                      show_model=False,
                                      use_dsbn=True)

        # Create hybrid memorys
        self.hm = HybridMemory(num_features=cfg.MODEL.BACKBONE.FEAT_DIM,
                               num_samples=self.src_pid_nums + self.tgt_nums,
                               temp=cfg.MEMORY.TEMP,
                               momentum=cfg.MEMORY.MOMENTUM,
                               use_half=cfg.SOLVER.AMP.ENABLED).cuda()

        # Initialize source-domain class centroids
        logger.info(
            "==> Initialize source-domain class centroids in the hybrid memory"
        )
        with inference_context(self.model), torch.no_grad():
            src_train = self.build_dataset(cfg,
                                           src.train,
                                           is_train=False,
                                           relabel=False,
                                           with_mem_idx=False)
            src_init_feat_loader = self.build_test_loader(cfg, src_train)
            src_fname_feat_dict, _ = extract_features(self.model,
                                                      src_init_feat_loader)
            src_feat_dict = collections.defaultdict(list)
            for f, pid, _ in sorted(src.train):
                src_feat_dict[pid].append(src_fname_feat_dict[f].unsqueeze(0))
            src_centers = [
                torch.cat(src_feat_dict[pid], 0).mean(0)
                for pid in sorted(src_feat_dict.keys())
            ]
            src_centers = torch.stack(src_centers, 0)
            src_centers = F.normalize(src_centers, dim=1)

        # Initialize target-domain instance features
        logger.info(
            "==> Initialize target-domain instance features in the hybrid memory"
        )
        with inference_context(self.model), torch.no_grad():
            tgt_train = self.build_dataset(cfg,
                                           tgt.train,
                                           is_train=False,
                                           relabel=False,
                                           with_mem_idx=False)
            tgt_init_feat_loader = self.build_test_loader(cfg, tgt_train)
            tgt_fname_feat_dict, _ = extract_features(self.model,
                                                      tgt_init_feat_loader)
            tgt_features = torch.cat([
                tgt_fname_feat_dict[f].unsqueeze(0)
                for f, _, _ in sorted(self.tgt.train)
            ], 0)
            tgt_features = F.normalize(tgt_features, dim=1)

        self.hm.features = torch.cat((src_centers, tgt_features), dim=0).cuda()

        del (src_train, src_init_feat_loader, src_fname_feat_dict,
             src_feat_dict, src_centers, tgt_train, tgt_init_feat_loader,
             tgt_fname_feat_dict, tgt_features)

        # Optimizer
        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
                find_unused_parameters=True)

        # Learning rate scheduler
        self.iters_per_epoch = cfg.SOLVER.ITERS
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = "AMPTrainer does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None
Exemplo n.º 11
0
    def __init__(self, cfg):
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # if setup_logger is not called for fastreid
            setup_logger()

        logger.info("==> Load target-domain dataset")
        self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT)
        self.tgt_nums = len(tgt.train)

        cfg = self.auto_scale_hyperparams(cfg, self.tgt_nums)

        # Create model
        self.model = self.build_model(cfg,
                                      load_model=cfg.MODEL.PRETRAIN,
                                      show_model=True,
                                      use_dsbn=False)

        # Optimizer
        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
                find_unused_parameters=True)

        # Learning rate scheduler
        self.iters_per_epoch = cfg.SOLVER.ITERS
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = "AMPTrainer does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None
Exemplo n.º 12
0
def train(rank, cfg: TrainConfig):
    if cfg.distributed.n_gpus_per_node > 1:
        init_process_group(backend=cfg.distributed.dist_backend,
                           init_method=cfg.distributed.dist_url,
                           world_size=cfg.distributed.n_nodes *
                           cfg.distributed.n_gpus_per_node,
                           rank=rank)

    device = torch.device(f'cuda:{rank:d}')

    model = ConvRNNEmbedder(cfg.model_cfg).to(device)
    loss_fn = GE2ELoss(device).to(device)

    logging.info(f"Initialized rank {rank}")

    if rank == 0:
        logging.getLogger().setLevel(logging.INFO)
        logging.info(f"Model initialized as:\n {model}")
        os.makedirs(cfg.checkpoint_path, exist_ok=True)
        logging.info(f"checkpoints directory : {cfg.checkpoint_path}")
        logging.info(
            f"Model has {sum([p.numel() for p in model.parameters()]):,d} parameters."
        )

    steps = 0
    if cfg.resume_checkpoint != '' and os.path.isfile(cfg.resume_checkpoint):
        state_dict = torch.load(cfg.resume_checkpoint, map_location=device)
        model.load_state_dict(state_dict['model_state_dict'])
        loss_fn.load_state_dict(state_dict['loss_fn_state_dict'])
        steps = state_dict['steps'] + 1
        last_epoch = state_dict['epoch']
        print(
            f"Checkpoint loaded from {cfg.resume_checkpoint}. Resuming training from {steps} steps at epoch {last_epoch}"
        )
    else:
        state_dict = None
        last_epoch = -1

    if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1:
        if rank == 0: logging.info("Multi-gpu detected")
        model = DDP(model, device_ids=[rank]).to(device)
        loss_fn = DDP(loss_fn, device_ids=[rank]).to(device)

    optim = torch.optim.AdamW(chain(model.parameters(), loss_fn.parameters()),
                              1.0,
                              betas=cfg.betas)
    if state_dict is not None:
        optim.load_state_dict(state_dict['optim_state_dict'])

    train_df, valid_df = pd.read_csv(cfg.train_csv), pd.read_csv(cfg.valid_csv)

    trainset = UtteranceDS(train_df, cfg.sample_rate, cfg.n_uttr_per_spk)

    train_sampler = DistributedSampler(
        trainset
    ) if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else None

    train_loader = DataLoader(trainset,
                              num_workers=cfg.num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=cfg.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=SpecialCollater(
                                  cfg.min_seq_len, cfg.max_seq_len))

    if rank == 0:
        validset = UtteranceDS(valid_df, cfg.sample_rate, cfg.n_uttr_per_spk)
        validation_loader = DataLoader(validset,
                                       num_workers=cfg.num_workers,
                                       shuffle=False,
                                       sampler=None,
                                       batch_size=cfg.batch_size,
                                       pin_memory=False,
                                       drop_last=True,
                                       collate_fn=SpecialCollater(
                                           cfg.min_seq_len, cfg.max_seq_len))

        sw = SummaryWriter(os.path.join(cfg.checkpoint_path, 'logs'))

    total_iters = cfg.n_epochs * len(train_loader)

    def sched_lam(x):
        return lin_one_cycle(cfg.start_lr, cfg.max_lr, cfg.end_lr,
                             cfg.warmup_pct, total_iters, x)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optim,
                                                  lr_lambda=[sched_lam],
                                                  last_epoch=steps - 1)

    if state_dict is not None:
        scheduler.load_state_dict(state_dict['scheduler_state_dict'])

    if cfg.fp16:
        scaler = GradScaler()
        if state_dict is not None and 'scaler_state_dict' in state_dict:
            scaler.load_state_dict(state_dict['scaler_state_dict'])

    model.train()

    if rank == 0:
        mb = master_bar(range(max(0, last_epoch), cfg.n_epochs))
        smooth_loss = None
    else:
        mb = range(max(0, last_epoch), cfg.n_epochs)

    for epoch in mb:
        if rank == 0:
            start = time.time()
            mb.write("Epoch: {}".format(epoch + 1))

        if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1:
            train_sampler.set_epoch(epoch)

        if rank == 0:
            pb = progress_bar(enumerate(train_loader),
                              total=len(train_loader),
                              parent=mb)
        else:
            pb = enumerate(train_loader)

        for i, batch in pb:
            if rank == 0: start_b = time.time()
            x, xlen = batch
            x = x.to(device, non_blocking=True)
            xlen = xlen.to(device, non_blocking=True)

            optim.zero_grad()

            with torch.cuda.amp.autocast(enabled=cfg.fp16):
                embeds = model(x, xlen)
                loss = loss_fn(embeds)
            if cfg.fp16:
                scaler.scale(loss).backward()
                scaler.unscale_(optim)
                gnorm = torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), cfg.grad_clip)
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    loss_fn.parameters(), cfg.grad_clip / 2)
                scaler.step(optim)
                scaler.update()
            else:
                loss.backward()
                gnorm = torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), cfg.grad_clip)
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    loss_fn.parameters(), cfg.grad_clip / 2)
                optim.step()

            if rank == 0:
                if smooth_loss is None: smooth_loss = float(loss.item())
                else:
                    smooth_loss = smooth_loss + 0.1 * (float(loss.item()) -
                                                       smooth_loss)
                # STDOUT logging
                if steps % cfg.stdout_interval == 0:
                    mb.write('steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}, peak mem: {:5.2f}GB'. \
                            format(steps, loss.item(), time.time() - start_b, torch.cuda.max_memory_allocated()/1e9))
                    mb.child.comment = 'steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}'. \
                            format(steps, loss.item(), time.time() - start_b)
                    # mb.write(f"lr = {float(optim.param_groups[0]['lr'])}")

                # checkpointing
                if steps % cfg.checkpoint_interval == 0 and steps != 0:
                    checkpoint_path = f"{cfg.checkpoint_path}/ckpt_{steps:08d}.pt"
                    torch.save(
                        {
                            'model_state_dict':
                            (model.module if cfg.distributed.n_gpus_per_node *
                             cfg.distributed.n_nodes > 1 else
                             model).state_dict(),
                            'loss_fn_state_dict':
                            (loss_fn.module
                             if cfg.distributed.n_gpus_per_node *
                             cfg.distributed.n_nodes > 1 else
                             loss_fn).state_dict(),
                            'optim_state_dict':
                            optim.state_dict(),
                            'scheduler_state_dict':
                            scheduler.state_dict(),
                            'scaler_state_dict':
                            (scaler.state_dict() if cfg.fp16 else None),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        }, checkpoint_path)
                    logging.info(f"Saved checkpoint to {checkpoint_path}")

                # Tensorboard summary logging
                if steps % cfg.summary_interval == 0:
                    sw.add_scalar("training/loss_smooth", smooth_loss, steps)
                    sw.add_scalar("training/loss_raw", loss.item(), steps)
                    sw.add_scalar(
                        "ge2e/w",
                        float((loss_fn.module
                               if cfg.distributed.n_gpus_per_node *
                               cfg.distributed.n_nodes > 1 else
                               loss_fn).w.item()), steps)
                    sw.add_scalar(
                        "ge2e/b",
                        float((loss_fn.module
                               if cfg.distributed.n_gpus_per_node *
                               cfg.distributed.n_nodes > 1 else
                               loss_fn).b.item()), steps)
                    sw.add_scalar("opt/lr", float(optim.param_groups[0]['lr']),
                                  steps)
                    sw.add_scalar('opt/grad_norm', float(gnorm), steps)

                # Validation
                if steps % cfg.validation_interval == 0 and steps != 0:
                    model.eval()
                    loss_fn.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    flat_embeds = []
                    flat_lbls = []
                    with torch.no_grad():
                        for j, batch in progress_bar(
                                enumerate(validation_loader),
                                total=len(validation_loader),
                                parent=mb):
                            x, xlen = batch
                            embeds = model(x.to(device), xlen.to(device))
                            val_err_tot += loss_fn(embeds)

                            if j <= 2:
                                lbls = [
                                    f'spk-{j}-{indr:03d}'
                                    for indr in range(cfg.batch_size)
                                    for _ in range(cfg.n_uttr_per_spk)
                                ]
                                fembeds = embeds.view(
                                    cfg.batch_size * cfg.n_uttr_per_spk,
                                    cfg.model_cfg.fc_dim)
                                flat_embeds.append(fembeds.cpu())
                                flat_lbls.extend(lbls)
                            elif j == 3:
                                flat_embeds = torch.cat(flat_embeds, dim=0)
                                sw.add_embedding(flat_embeds,
                                                 metadata=flat_lbls,
                                                 global_step=steps)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/loss", val_err, steps)
                        mb.write(
                            f"validation run complete at {steps:,d} steps. validation loss: {val_err:5.4f}"
                        )

                    model.train()
                    loss_fn.train()
                    sw.add_scalar("memory/max_allocated_gb",
                                  torch.cuda.max_memory_allocated() / 1e9,
                                  steps)
                    sw.add_scalar("memory/max_reserved_gb",
                                  torch.cuda.max_memory_reserved() / 1e9,
                                  steps)
                    torch.cuda.reset_peak_memory_stats()
                    torch.cuda.reset_accumulated_memory_stats()

            steps += 1
            scheduler.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
    sw.add_hparams(flatten_cfg(cfg),
                   metric_dict={'validation/loss': val_err},
                   run_name=f'run-{cfg.checkpoint_path}')
    print("Training completed!")
Exemplo n.º 13
0
class Amp:
    def __init__(
        self,
        enabled: bool = False,
        max_norm: Optional[float] = None,
    ) -> None:
        self.grad_scaler = GradScaler(enabled=enabled)
        self.enabled = enabled
        self.max_norm = max_norm

        _logger.info("amp: %s", self.enabled)
        if self.max_norm:
            _logger.info(
                "you are using grad clip, don't forget to pass params in")

    def autocast(self):
        return autocast(enabled=self.enabled)

    def scale(self, outputs: TensorOrIterableTensors) -> TensorOrIterableTensors:
        return self.grad_scaler.scale(outputs)

    def unscale_(self, optimizer: Optimizer):
        return self.grad_scaler.unscale_(optimizer)

    def step(self, optimizer: Optimizer, *args, **kwargs):
        return self.grad_scaler.step(optimizer, *args, **kwargs)

    def update(self, new_scale: Union[float, Tensor, None] = None):
        return self.grad_scaler.update(new_scale=new_scale)

    def clip_grad_norm_(self, params: TensorOrIterableTensors):
        torch.nn.utils.clip_grad_norm_(params, self.max_norm)

    def state_dict(self) -> dict:
        return self.grad_scaler.state_dict()

    def load_state_dict(self, state_dict: dict):
        return self.grad_scaler.load_state_dict(state_dict)

    def __call__(
        self,
        loss: Tensor,
        optimizer: torch.optim.Optimizer,
        parameters: Optional[TensorOrIterableTensors] = None,
        zero_grad_set_to_none: bool = False,
    ):
        self.scale(loss).backward()

        if self.max_norm is not None:
            assert parameters is not None
            self.unscale_(optimizer)
            self.clip_grad_norm_(parameters)

        self.grad_scaler.step(optimizer)
        self.grad_scaler.update()
        optimizer.zero_grad(set_to_none=zero_grad_set_to_none)

    def backward(
        self,
        loss: Tensor,
        optimizer: torch.optim.Optimizer,
        parameters: Optional[TensorOrIterableTensors] = None,
    ):
        return self(loss, optimizer, parameters=parameters)
Exemplo n.º 14
0
 def before_train(self):
     if self._grad_scaler_config is None:
         self._scaler = GradScaler()
     else:
         self._scaler = GradScaler(**self._grad_scaler_config)
def train(
    epoch: int,
    data: DistributedDataObject,
    device: torch.device,
    rank: int,
    model: nn.Module,
    loss_fn: LossFunction,
    optimizer: optim.Optimizer,
    args: dict,
    scaler: GradScaler = None,
):
    model.train()
    # Horovod: set epoch to sampler for shuffling
    data.sampler.set_epoch(epoch)
    running_loss = torch.tensor(0.0)
    training_acc = torch.tensor(0.0)
    if torch.cuda.is_available():
        running_loss = running_loss.to(device)
        training_acc = training_acc.to(device)

    for batch_idx, (batch, target) in enumerate(data.loader):
        if torch.cuda.is_available():
            batch, target = batch.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(batch)
        loss = loss_fn(output, target)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        pred = output.data.max(1, keepdim=True)[1]
        acc = pred.eq(target.data.view_as(pred)).cpu().float().sum()

        training_acc += acc
        running_loss += loss.item()

        if batch_idx % args.log_interval == 0:
            metrics_ = {
                'epoch': epoch,
                'batch_loss': loss.item() / args.batch_size,
                'running_loss': running_loss / len(data.sampler),
                'batch_acc': acc.item() / args.batch_size,
                'training_acc': training_acc / len(data.sampler),
            }

            jdx = batch_idx * len(batch)
            frac = 100. * batch_idx / len(data.loader)
            pre = [
                f'[{rank}]',
                f'[{jdx:>5}/{len(data.sampler):<5} ({frac:>03.1f}%)]'
            ]
            io.print_metrics(metrics_, pre=pre, logger=logger)

    running_loss = running_loss / len(data.sampler)
    training_acc = training_acc / len(data.sampler)
    loss_avg = metric_average(running_loss)
    training_acc = metric_average(training_acc)
    if rank == 0:
        logger.log(f'training set; avg loss: {loss_avg:.4g}, '
                   f'accuracy: {training_acc * 100:.2f}%')
Exemplo n.º 16
0
class Trainer:
    def __init__(self, config, device='cuda:0'):
        self.config = config
        self.device = torch.device(device)
        self.model_dir = self.config['model_dir']
        self.dev_gold_path = os.path.join(self.model_dir, 'dev-gold.txt')
        self.dev_pred_path = os.path.join(self.model_dir, 'dev-pred.txt')
        self.best_smatch = -1  # May get reloaded if using a checkpoint
        self.start_epoch = 1  # May get reloaded if using a checkpoint
        os.makedirs(self.model_dir, exist_ok=True)

    def train(self, checkpoint=None):
        self.load_model(
            checkpoint
        )  # sets self.model, tokenizer, optimizer, .., best_smatch, start_epoch
        self.load_training_data()  # sets self.train_loader
        self.load_eval_data()  # sets self.inference, graphs_gold
        # Loop through max epochs
        assert self.start_epoch < self.config[
            'max_epochs']  # May fail if reloading a checkpoint
        for epoch in range(self.start_epoch, self.config['max_epochs'] + 1):
            # Setup batch
            print('Training epoch %d' % epoch)
            trn_amr_loss = RunningAverage()
            self.optimizer.zero_grad()
            pbar = tqdm(total=len(self.train_loader.dataset), ncols=100)
            self.set_pbar_desc_train(pbar, None)
            self.model.train()
            # Loop through all the data
            for bnum, batch in enumerate(self.train_loader):
                x, y, extra = batch
                with autocast(enabled=self.config['fp16']):
                    rdict = self.model(**x, **y)
                    loss = rdict['loss']
                self.scaler.scale(
                    (loss / self.config['accum_steps'])).backward()
                trn_amr_loss.update(loss.item())
                # Perform an update every accum_steps
                if (bnum + 1) % self.config['accum_steps'] == 0:
                    self.step_otimizer()
                # Update progress
                pbar.update(x['input_ids'].shape[0])
                self.set_pbar_desc_train(pbar, trn_amr_loss.value)
            pbar.close()
            # Perform an update with the last batch if it wasn't already done in the loop
            if (bnum + 1) % self.config['accum_steps'] != 0:
                self.step_otimizer()
            # Run evaluate, compute smatch and save the model if it's the new best
            try:
                smatch = self.evaluate()
                if smatch > self.best_smatch:
                    self.best_smatch = smatch
                    self.save_and_remove_checkpoints(epoch, smatch)
            except:
                print('!! Evaluation / save failed !!')
                logger.exception('Evaluation or model save failed')
            print()

    # Run Inference and evaluate the model
    def evaluate(self):
        self.model.eval()
        sents = [g.metadata['snt'] for g in self.graphs_gold]
        graphs_gen = self.inference.parse_sents(sents,
                                                return_penman=True,
                                                disable_progress=False,
                                                pbar_desc='%-14s' %
                                                'Evaluating:')
        assert len(graphs_gen) == len(self.graphs_gold)
        # Detect bad graphs. In Penman 1.2.0, metadata does not impact penam.Graph.__eq__()
        num_bad = sum(g == Inference.invalid_graph for g in graphs_gen)
        print('Out of %d graphs, %d did not generate properly.' %
              (len(graphs_gen), num_bad))
        # Save the final graphs
        print('Generated graphs written to', self.dev_pred_path)
        penman.dump(graphs_gen, self.dev_pred_path, indent=6, model=amr_model)
        # Run smatch
        try:
            gold_entries = get_entries(self.dev_gold_path)
            test_entries = get_entries(self.dev_pred_path)
            precision, recall, f_score = compute_smatch(
                test_entries, gold_entries)
            print('SMATCH -> P: %.3f,  R: %.3f,  F: %.3f' %
                  (precision, recall, f_score))
        except:
            logger.exception('Failed to compute smatch score.')
            precision, recall, f_score = 0, 0, 0
        return f_score

    # Save the checkpoints if this is the best score
    def save_and_remove_checkpoints(self, epoch, smatch):
        prev_checkpoints = [
            fn for fn in os.listdir(self.model_dir) if fn.endswith('.pt')
        ]
        model_fn = 'checkpoint_epoch_%02d_smatch_%04d.pt' % (epoch,
                                                             smatch * 10000)
        model_fpath = os.path.join(self.model_dir, model_fn)
        # Create the dictionary with the optional optimizer and save it
        print('Saving new, best model to', model_fpath)
        save_dict = {'model': self.model.state_dict()}
        if self.config.get('save_optimizer'):
            save_dict['optimizer'] = self.optimizer.state_dict()
            save_dict['scheduler'] = self.scheduler.state_dict()
        torch.save(save_dict, model_fpath)
        # Save the config file
        self.config['smatch_dev'] = smatch
        self.config['last_epoch'] = epoch
        with open(os.path.join(self.model_dir, 'config.json'), 'w') as f:
            json.dump(self.config, f, indent=4)
        # Remove previous checkpoints
        for chkpt_fn in prev_checkpoints:
            os.remove(os.path.join(self.model_dir, chkpt_fn))

    # Load and setup the model, tokenizer, optimizer, etc..
    def load_model(self, checkpoint=None):
        print('Loading model from', self.config['model'])
        self.model, self.tokenizer = instantiate_model_and_tokenizer(
            self.config['model'],
            additional_tokens_smart_init=self.config['smart_init'],
            dropout=self.config['dropout'],
            attention_dropout=self.config['attention_dropout'],
            penman_linearization=self.config['penman_linearization'],
            collapse_name_ops=self.config['collapse_name_ops'],
            use_pointer_tokens=self.config['use_pointer_tokens'],
            raw_graph=self.config['raw_graph'])
        self.model.to(self.device)
        # Load optimization components
        self.optimizer = AdamW(self.model.parameters(),
                               lr=self.config['learning_rate'],
                               weight_decay=self.config['weight_decay'])
        self.scheduler = transformers.get_constant_schedule_with_warmup(
            self.optimizer, num_warmup_steps=self.config['warmup_steps'])
        self.scaler = GradScaler(enabled=self.config['fp16'])
        # Reload checkpoint model weights and optimizer params if loading from a checkpoint
        if checkpoint is not None:
            print('Checkpoint %s restored' % checkpoint)
            load_state_dict_from_checkpoint(checkpoint, self.model,
                                            self.optimizer, self.scheduler)
            # Try to load the smatch score and last_epoch from the config in the model directory.
            try:
                with open(os.path.join(self.model_dir, 'config.json')) as f:
                    model_config = json.load(f)
                self.best_smatch = model_config['smatch_dev']
                self.start_epoch = model_config['last_epoch'] + 1
            except:
                logger.exception(
                    'Unable to load config file in model directory')

    # Setup the training data loader
    def load_training_data(self):
        print('Loading train data from', self.config['train'])
        self.train_loader = get_dataloader(
            self.tokenizer,
            glob_pattern=self.config['train'],
            evaluation=False,
            batch_size=self.config['batch_size'],
            use_recategorization=self.config['use_recategorization'],
            remove_longer_than=self.config['remove_longer_than'],
            remove_wiki=self.config['remove_wiki'],
            dereify=self.config['dereify'],
            device=self.device)

    # Setup the inference object and create the gold data test file
    def load_eval_data(self):
        print('Loading eval data from ', self.config['dev'])
        self.inference = Inference(model=self.model,
                                   tokenizer=self.tokenizer,
                                   device=self.device,
                                   num_beams=self.config['eval_beam_size'],
                                   batch_size=self.config['eval_batch_sents'],
                                   config=self.config)
        self.graphs_gold = read_raw_amr_data(
            self.config['dev'],
            use_recategorization=self.config['use_recategorization'],
            dereify=self.config['dereify'],
            remove_wiki=self.config['remove_wiki'])
        penman.dump(self.graphs_gold,
                    self.dev_gold_path,
                    indent=6,
                    model=amr_model)

    # Function to update the model's parameters for accumulated loss
    def step_otimizer(self):
        self.scaler.unscale_(self.optimizer)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                       self.config['grad_norm'])
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
        self.scheduler.step()

    # Update tqdm progress bar description with loss values
    @staticmethod
    def set_pbar_desc_train(pbar, av_loss):
        desc = 'Loss: '
        if av_loss is None:
            desc += ' ' * 8
        else:
            desc += '%8.3f' % av_loss
        pbar.set_description(desc)
Exemplo n.º 17
0
class MyMixedDefaultTrainer(TrainerBase):
    """w/o AMP mode"""
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)

        self.model = self.build_model(cfg)

        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
            )

        self._data_loader_iter = iter(data_loader)
        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = f"[{self.__class__.__name__}] does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None

    def resume_or_load(self, resume=True):
        """
        If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
        a `last_checkpoint` file), resume from the file. Resuming means loading all
        available states (eg. optimizer and scheduler) and update iteration counter
        from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
        Otherwise, this is considered as an independent training. The method will load model
        weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
        from iteration 0.
        Args:
            resume (bool): whether to do resume or not
        """
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS,
                                                      resume=resume)

        if resume and self.checkpointer.has_checkpoint():
            self.start_epoch = checkpoint.get("epoch", -1) + 1
            # The checkpoint stores the training iteration that just finished, thus we start
            # at the next iteration (or iter zero if there's no checkpoint).

    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
        logger = logging.getLogger(__name__)
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN
        cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET
                                    ])  # set dataset name for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
        ]

        if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
            logger.info("Prepare precise BN dataset")
            ret.append(
                hooks.PreciseBN(
                    # Run at the same freq as (but before) evaluation.
                    self.model,
                    # Build a new data loader to not affect training
                    self.build_train_loader(cfg),
                    cfg.TEST.PRECISE_BN.NUM_ITER,
                ))

        if len(cfg.MODEL.FREEZE_LAYERS) > 0 and cfg.SOLVER.FREEZE_ITERS > 0:
            ret.append(
                hooks.LayerFreeze(
                    self.model,
                    cfg.MODEL.FREEZE_LAYERS,
                    cfg.SOLVER.FREEZE_ITERS,
                ))

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation before checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), 200))

        return ret

    def build_writers(self):
        """
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.
        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        It is now implemented by:
        .. code-block:: python
            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]
        """
        # Assume the default print/log frequency.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Run training.
        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        super().train(self.start_epoch, self.max_epoch, self.iters_per_epoch)
        if comm.is_main_process():
            assert hasattr(self, "_last_eval_results"
                           ), "No evaluation results obtained during training!"
            return self._last_eval_results

    def run_step(self):
        assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!"
        if self.cfg.SOLVER.AMP.ENABLED:
            assert torch.cuda.is_available(
            ), f"[{self.__class__.__name__}] CUDA is required for AMP training!"
            from torch.cuda.amp.autocast_mode import autocast

        start = time.perf_counter()
        data = next(self._data_loader_iter)
        data_time = time.perf_counter() - start

        if self.cfg.SOLVER.AMP.ENABLED:
            with autocast():
                loss_dict = self.model(data)
                losses = sum(loss_dict.values())
            self.optimizer.zero_grad()
            self.grad_scaler.scale(losses).backward()

            self._write_metrics(loss_dict, data_time)

            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            loss_dict = self.model(data)
            losses = sum(loss_dict.values())
            self.optimizer.zero_grad()
            losses.backward()

            self._write_metrics(loss_dict, data_time)

            self.optimizer.step()

        if isinstance(self.param_wrapper, ContiguousParams):
            self.param_wrapper.assert_buffer_is_valid()

    def _write_metrics(self, loss_dict: Dict[str, torch.Tensor],
                       data_time: float):
        """
        Args:
            loss_dict (dict): dict of scalar losses
            data_time (float): time taken by the dataloader iteration
        """
        device = next(iter(loss_dict.values())).device

        # Use a new stream so these ops don't wait for DDP or backward
        with torch.cuda.stream(torch.cuda.Stream() if device.type ==
                               "cuda" else None):
            metrics_dict = {
                k: v.detach().cpu().item()
                for k, v in loss_dict.items()
            }
            metrics_dict["data_time"] = data_time

            # Gather metrics among all workers for logging
            # This assumes we do DDP-style training, which is currently the only
            # supported method in detectron2.
            all_metrics_dict = comm.gather(metrics_dict)

        if comm.is_main_process():
            storage = get_event_storage()

            # data_time among workers can have high variance. The actual latency
            # caused by data_time is the maximum among workers.
            data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
            storage.put_scalar("data_time", data_time)

            # average the rest metrics
            metrics_dict = {
                k: np.mean([x[k] for x in all_metrics_dict])
                for k in all_metrics_dict[0].keys()
            }
            total_losses_reduced = sum(metrics_dict.values())
            if not np.isfinite(total_losses_reduced):
                raise FloatingPointError(
                    f"Loss became infinite or NaN at iteration={self.iter}!\n"
                    f"loss_dict = {metrics_dict}")

            storage.put_scalar("total_loss", total_losses_reduced)
            if len(metrics_dict) > 1:
                storage.put_scalars(**metrics_dict)

    @classmethod
    def build_model(cls, cfg, show_model=True):
        """
        Returns:
            torch.nn.Module:
        It now calls :func:`fastreid.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)

        if show_model:
            logger = logging.getLogger('fastreid')
            logger.info("Model:\n{}".format(model))

        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Returns:
            torch.optim.Optimizer:
        It now calls :func:`fastreid.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch):
        """
        It now calls :func:`fastreid.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer, iters_per_epoch)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_reid_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        logger = logging.getLogger(__name__)
        logger.info("Prepare training set")
        return build_reid_train_loader(cfg, combineall=cfg.DATASETS.COMBINEALL)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_reid_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_reid_test_loader(cfg, dataset_name=dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_dir=None):
        data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
        return data_loader, ReidEvaluator(cfg, num_query, output_dir)

    @classmethod
    def test(cls, cfg, model):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
            logger.info("Prepare testing set")
            try:
                data_loader, evaluator = cls.build_evaluator(cfg, dataset_name)
            except NotImplementedError:
                logger.warn(
                    "No evaluator found. implement its `build_evaluator` method."
                )
                results[dataset_name] = {}
                continue
            results_i = inference_on_dataset(model,
                                             data_loader,
                                             evaluator,
                                             flip_test=cfg.TEST.FLIP.ENABLED)
            results[dataset_name] = results_i

            if comm.is_main_process():
                assert isinstance(
                    results, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results)
                logger.info("Evaluation results for {} in csv format:".format(
                    dataset_name))
                results_i['dataset'] = dataset_name
                print_csv_format(results_i)

        if len(results) == 1:
            results = list(results.values())[0]

        return results

    @staticmethod
    def auto_scale_hyperparams(cfg, num_classes):
        r"""
        This is used for auto-computation actual training iterations,
        because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
        so we need to convert specific hyper-param to training iterations.
        """
        cfg = cfg.clone()
        frozen = cfg.is_frozen()
        cfg.defrost()

        # If you don't hard-code the number of classes, it will compute the number automatically
        if cfg.MODEL.HEADS.NUM_CLASSES == 0:
            output_dir = cfg.OUTPUT_DIR
            cfg.MODEL.HEADS.NUM_CLASSES = num_classes
            logger = logging.getLogger(__name__)
            logger.info(
                f"Auto-scaling the num_classes={cfg.MODEL.HEADS.NUM_CLASSES}")

            # Update the saved config file to make the number of classes valid
            if comm.is_main_process() and output_dir:
                # Note: some of our scripts may expect the existence of
                # config.yaml in output directory
                path = os.path.join(output_dir, "config.yaml")
                with PathManager.open(path, "w") as f:
                    f.write(cfg.dump())

        if frozen: cfg.freeze()

        return cfg
Exemplo n.º 18
0
class Trainer:
    def __init__(self, cfg):
        self.cfg = cfg
        self.paths = cfg['paths']
        self.net_params = cfg['net']
        self.train_params = cfg['train']
        self.trans_params = cfg['train']['transforms']

        self.checkpoints = self.paths['checkpoints']
        Path(self.checkpoints).mkdir(parents=True, exist_ok=True)
        shutil.copyfile('config.yaml', f'{self.checkpoints}/config.yaml')

        self.update_interval = self.paths['update_interval']

        # amp training
        self.use_amp = self.train_params['mixed_precision']
        self.scaler = GradScaler() if self.use_amp else None

        # data setup
        dataset_name = self.train_params['dataset']
        self.use_multi = dataset_name == 'multi'
        print(f'Using dataset: {dataset_name}')
        self.train_dataset = get_pedestrian_dataset(
            dataset_name,
            self.paths,
            augment=get_train_transforms(self.trans_params),
            mode='train',
            multi_datasets=self.train_params['multi_datasets']
            if self.use_multi else None)
        print(f'Train dataset: {len(self.train_dataset)} samples')

        self.val_dataset = get_pedestrian_dataset(
            dataset_name,
            self.paths,
            augment=get_val_transforms(self.trans_params),
            mode='val',
            multi_datasets=self.train_params['multi_datasets']
            if self.use_multi else None)
        print(f'Val dataset: {len(self.val_dataset)} samples')

        tests_data = self.train_params['test_datasets']
        self.test_datasets = [
            get_pedestrian_dataset(d_name,
                                   self.paths,
                                   augment=get_test_transforms(
                                       self.trans_params),
                                   mode='test') for d_name in tests_data
        ]

        self.criterion = AnchorFreeLoss(self.train_params)

        self.writer = Writer(self.paths['log_dir'])
        print('Tensorboard logs are saved to: {}'.format(
            self.paths['log_dir']))

        self.sched_type = self.train_params['scheduler']
        self.scheduler = None
        self.optimizer = None

    def save_checkpoints(self, epoch, net):
        path = osp.join(self.checkpoints, f'Epoch_{epoch}.pth')
        torch.save(
            {
                'epoch': epoch,
                'net_state_dict': net.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict()
            }, path)

    def train(self):
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

        batch_size = self.train_params['batch_size']
        self.batch_size = batch_size
        num_workers = self.train_params['num_workers']
        pin_memory = self.train_params['pin_memory']
        print('Batch-size = {}'.format(batch_size))

        train_loader = DataLoader(self.train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=pin_memory,
                                  drop_last=True)
        val_loader = DataLoader(self.val_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                drop_last=False)

        # net setup
        print('Preparing net: ')
        net = get_fpn_net(self.net_params)
        # train setup
        lr = self.train_params['lr']
        epochs = self.train_params['epochs']
        weight_decay = self.train_params['weight_decay']

        self.optimizer = optim.Adam(net.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    eps=1e-4)
        if self.net_params['pretrained']:
            checkpoint = torch.load(self.net_params['pretrained_model'],
                                    map_location="cuda")
            net.load_state_dict(checkpoint['net_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            for p in self.optimizer.param_groups:
                p['lr'] = lr
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.cuda()
            print('CHECKPOINT LOADED')
        net.cuda()

        first_epoch = 0
        # scheduler
        if self.sched_type == 'ocp':
            last_epoch = -1 if first_epoch == 0 else first_epoch * len(
                train_loader)
            self.scheduler = OneCycleLR(
                self.optimizer,
                max_lr=lr,
                epochs=epochs,
                last_epoch=last_epoch,
                steps_per_epoch=len(train_loader),
                pct_start=self.train_params['ocp_params']['max_lr_pct'])
        elif self.sched_type == 'multi_step':
            last_epoch = -1 if first_epoch == 0 else first_epoch
            self.scheduler = MultiStepLR(
                self.optimizer,
                milestones=self.train_params['multi_params']['milestones'],
                gamma=self.train_params['multi_params']['gamma'],
                last_epoch=last_epoch)

        #start training

        net.train()
        val_rate = self.train_params['val_rate']
        test_rate = self.train_params['test_rate']
        for epoch in range(first_epoch, epochs):
            self.train_epoch(net, train_loader, epoch)

            if self.sched_type != 'ocp':
                self.writer.log_lr(epoch, self.scheduler.get_last_lr()[0])
                self.scheduler.step()

            if (epoch + 1) % val_rate == 0 or epoch == epochs - 1:
                self.eval(net, val_loader, epoch * len(train_loader))
            if (epoch + 1) % (val_rate *
                              test_rate) == 0 or epoch == epochs - 1:
                self.test_ap(net, epoch)
                self.save_checkpoints(epoch, net)

    def train_epoch(self, net, loader, epoch):
        net.train()
        loss_metric = LossMetric(self.cfg)
        probs = ProbsAverageMeter()

        for mini_batch_i, read_mini_batch in tqdm(enumerate(loader),
                                                  desc=f'Epoch {epoch}:',
                                                  ascii=True,
                                                  total=len(loader)):
            data, labels = read_mini_batch
            data = data.cuda()
            labels = [label.cuda() for label in labels]

            with amp.autocast():
                out = net(data)
                loss_dict, hm_probs = self.criterion(out, labels)
                loss = loss_metric.calculate_loss(loss_dict)
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            probs.update(hm_probs)

            if self.sched_type == 'ocp':
                self.scheduler.step()

            loss_metric.add_sample(loss_dict)

            if mini_batch_i % self.update_interval == 0:
                if self.sched_type == 'ocp':
                    # TODO write average lr
                    self.writer.log_lr(epoch * len(loader) + mini_batch_i,
                                       self.scheduler.get_last_lr()[0])
                self.writer.log_training(epoch * len(loader) + mini_batch_i,
                                         loss_metric)
        self.writer.log_probs(epoch, probs.get_average())

    def eval(self, net, loader, step):
        net.eval()
        loss_metric = LossMetric(self.cfg)
        with torch.no_grad():
            for _, read_mini_batch in tqdm(enumerate(loader),
                                           desc=f'Val:',
                                           ascii=True,
                                           total=len(loader)):
                data, labels = read_mini_batch
                data = data.cuda()
                labels = [label.cuda() for label in labels]
                with amp.autocast():
                    out = net(data)
                    loss_dict, _ = self.criterion(out, labels)

                loss_metric.add_sample(loss_dict)

            self.writer.log_eval(step, loss_metric)

    def test_ap(self, net, epoch):
        for dataset in self.test_datasets:
            ap, _ = test(net, dataset, batch_size=self.batch_size)
            self.writer.log_ap(epoch, ap, dataset.name())
Exemplo n.º 19
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)

        self.model = self.build_model(cfg)

        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
            )

        self._data_loader_iter = iter(data_loader)
        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = f"[{self.__class__.__name__}] does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None
Exemplo n.º 20
0
class UDA_Baseline_Trainer(TrainerBase):
    """
    load a model pretrained on the source domain,
    neglect outliers during training on the target domain
    """
    def __init__(self, cfg):
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # if setup_logger is not called for fastreid
            setup_logger()

        logger.info("==> Load target-domain dataset")
        self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT)
        self.tgt_nums = len(tgt.train)

        cfg = self.auto_scale_hyperparams(cfg, self.tgt_nums)

        # Create model
        self.model = self.build_model(cfg,
                                      load_model=cfg.MODEL.PRETRAIN,
                                      show_model=True,
                                      use_dsbn=False)

        # Optimizer
        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
                find_unused_parameters=True)

        # Learning rate scheduler
        self.iters_per_epoch = cfg.SOLVER.ITERS
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = "AMPTrainer does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None

    def train(self):
        """
        Run training.
        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        super().train(self.start_epoch, self.max_epoch, self.iters_per_epoch)
        if comm.is_main_process():
            assert hasattr(self, "_last_eval_results"
                           ), "No evaluation results obtained during training!"
            return self._last_eval_results

    def before_train(self):
        self.model.train()
        if self.cfg.SOLVER.AMP.ENABLED:
            assert torch.cuda.is_available(
            ), "CUDA is required for AMP training!"
        return super().before_train()

    def before_epoch(self):
        logger = logging.getLogger('fastreid')

        # Calculate distance
        logger.info("==> Create pseudo labels for unlabeled target domain")
        with inference_context(self.model), torch.no_grad():
            tgt_train = self.build_dataset(self.cfg,
                                           self.tgt.train,
                                           is_train=False,
                                           relabel=False,
                                           with_mem_idx=False)
            tgt_init_feat_loader = self.build_test_loader(self.cfg, tgt_train)
            tgt_fname_feat_dict, _ = extract_features(self.model,
                                                      tgt_init_feat_loader)
            tgt_features = torch.cat([
                tgt_fname_feat_dict[f].unsqueeze(0)
                for f, _, _ in sorted(self.tgt.train)
            ], 0)
            tgt_features = F.normalize(tgt_features, dim=1)

        rerank_dist = compute_jaccard_distance(tgt_features,
                                               k1=self.cfg.CLUSTER.JACCARD.K1,
                                               k2=self.cfg.CLUSTER.JACCARD.K2)

        if self.epoch == 0:
            if self.cfg.CLUSTER.DBSCAN.ADAPTIVE_EPS:
                logger.info("==> Calculating eps according to rerank_dist...")
                tri_mat = np.triu(rerank_dist, 1)  # tri_mat.dim=2
                tri_mat = tri_mat[np.nonzero(tri_mat)]  # tri_mat.dim=1
                tri_mat = np.sort(tri_mat, axis=None)
                top_num = np.round(self.cfg.SOLVER.RHO *
                                   tri_mat.size).astype(int)
                self.eps = tri_mat[:top_num].mean()
                logger.info(f"==> epoch {self.epoch} eps: {self.eps}")
            else:
                self.eps = self.cfg.CLUSTER.DBSCAN.EPS

        self.cluster = DBSCAN(eps=self.eps,
                              min_samples=4,
                              metric="precomputed",
                              n_jobs=-1)

        # select & cluster images as training set of this epochs
        logger.info(f"Clustering and labeling...")
        pseudo_labels = self.cluster.fit_predict(rerank_dist)
        self.num_clusters = num_clusters = len(
            set(pseudo_labels)) - (1 if -1 in pseudo_labels else 0)
        num_outliers = pseudo_labels[pseudo_labels == -1].shape[0]

        # pseudo_labels = self.generate_pseudo_labels(pseudo_labels, num_clusters)
        # pseudo_labels = self.assign_outlier(pseudo_labels, tgt_features)

        del tgt_features

        pseudo_labeled_dataset = []
        cluster_centers = collections.defaultdict(list)
        for i, ((fname, _, cid),
                label) in enumerate(zip(sorted(self.tgt.train),
                                        pseudo_labels)):
            if label != -1:
                pseudo_labeled_dataset.append((fname, label, cid))
                cluster_centers[label].append(tgt_fname_feat_dict[fname])

        del tgt_fname_feat_dict, rerank_dist

        cluster_centers = [
            torch.stack(cluster_centers[idx]).mean(0)
            for idx in sorted(cluster_centers.keys())
        ]
        cluster_centers = torch.stack(cluster_centers)

        if isinstance(self.model, DistributedDataParallel):
            self.model.module.heads.weight.data[:num_clusters].copy_(
                F.normalize(cluster_centers, dim=1).float().cuda())
        else:
            self.model.heads.weight.data[:num_clusters].copy_(
                F.normalize(cluster_centers, dim=1).float().cuda())

        # statistics of clusters and un-clustered instances
        # index2label = collections.defaultdict(int)
        # for label in pseudo_labels:
        #     index2label[label.item()] += 1

        # print(f'cluster_label', min(cluster_label), max(cluster_label), len(cluster_label))
        # print(f'outlier label', min(outlier_label), max(outlier_label), len(outlier_label))

        # index2label = np.fromiter(index2label.values(), dtype=float)
        logger.info(
            "==> Statistics for epoch {}: {} clusters, {} un-clustered instances"
            .format(self.epoch, num_clusters, num_outliers))

        pseudo_tgt_train = self.build_dataset(
            self.cfg,
            pseudo_labeled_dataset,
            is_train=True,
            relabel=True,  # relabel?
            # relabel=False,
            with_mem_idx=False)
        self.pseudo_tgt_train_loader = self.build_train_loader(
            self.cfg,
            train_set=pseudo_tgt_train,
            sampler=RandomMultipleGallerySampler(
                pseudo_tgt_train.img_items,
                self.cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(),
                self.cfg.DATALOADER.NUM_INSTANCE),
            with_mem_idx=False)

        return super().before_epoch()

    # assign outlier to its nearest neighbor
    def assign_outlier(self, pseudo_labels, tgt_features):
        outlier_mask = pseudo_labels == -1
        cluster_mask = pseudo_labels != -1
        outlier_feat = tgt_features[outlier_mask]
        cluster_feat = tgt_features[cluster_mask]
        dist = torch.cdist(outlier_feat, cluster_feat)  # 计算特征间L2距离

        min_dist_idx = dist.argmin(-1).cpu()
        pseudo_labels[outlier_mask] = pseudo_labels[cluster_mask][min_dist_idx]

        return pseudo_labels

    def run_step(self):
        assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!"
        if self.cfg.SOLVER.AMP.ENABLED:
            assert torch.cuda.is_available(
            ), f"[{self.__class__.__name__}] CUDA is required for AMP training!"
            from torch.cuda.amp.autocast_mode import autocast

        start = time.perf_counter()

        # load data
        tgt_inputs = self.pseudo_tgt_train_loader.next()

        def _parse_data(inputs):
            imgs, _, pids, _ = inputs
            return imgs.cuda(), pids.cuda()

        # process inputs
        t_inputs, t_targets = _parse_data(tgt_inputs)

        data_time = time.perf_counter() - start

        def _forward():
            outputs = self.model(t_inputs)
            f_out_t = outputs['features']
            p_out_t = outputs['pred_class_logits'][:, :self.num_clusters]

            loss_dict = {}

            loss_ce = cross_entropy_loss(pred_class_outputs=p_out_t,
                                         gt_classes=t_targets,
                                         eps=self.cfg.MODEL.LOSSES.CE.EPSILON,
                                         alpha=self.cfg.MODEL.LOSSES.CE.ALPHA)
            loss_dict.update({'loss_ce': loss_ce})

            if 'TripletLoss' in self.cfg.MODEL.LOSSES.NAME:
                loss_tri = triplet_loss(f_out_t,
                                        t_targets,
                                        margin=0.0,
                                        norm_feat=True,
                                        hard_mining=False)
                loss_dict.update({'loss_tri': loss_tri})

            return loss_dict

        if self.cfg.SOLVER.AMP.ENABLED:
            with autocast():
                loss_dict = _forward()
                losses = sum(loss_dict.values())

            self.optimizer.zero_grad()
            self.grad_scaler.scale(losses).backward()

            self._write_metrics(loss_dict, data_time)

            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            loss_dict = _forward()
            losses = sum(loss_dict.values())

            self.optimizer.zero_grad()
            losses.backward()

            self._write_metrics(loss_dict, data_time)

            self.optimizer.step()

        if isinstance(self.param_wrapper, ContiguousParams):
            self.param_wrapper.assert_buffer_is_valid()

    @classmethod
    def load_dataset(cls, name):
        logger = logging.getLogger(__name__)
        logger.info(f"Preparing {name}")

        _root = os.getenv("FASTREID_DATASETS", "/root/datasets")
        data = DATASET_REGISTRY.get(name)(root=_root)
        if comm.is_main_process():
            data.show_train()

        return data

    @classmethod
    def build_dataset(cls,
                      cfg,
                      img_items,
                      is_train=False,
                      relabel=False,
                      transforms=None,
                      with_mem_idx=False):
        if transforms is None:
            transforms = build_transforms(cfg, is_train=is_train)

        if with_mem_idx:
            sorted_img_items = sorted(img_items)
            for i in range(len(sorted_img_items)):
                sorted_img_items[i] += (i, )
            return InMemoryDataset(sorted_img_items, transforms, relabel)
        else:
            return CommDataset(img_items, transforms, relabel)

    @classmethod
    def build_train_loader(cls,
                           cfg,
                           train_set=None,
                           sampler=None,
                           with_mem_idx=False):
        logger = logging.getLogger('fastreid')
        logger.info("Prepare training loader")

        total_batch_size = cfg.SOLVER.IMS_PER_BATCH
        mini_batch_size = total_batch_size // comm.get_world_size()

        if sampler is None:
            num_instance = cfg.DATALOADER.NUM_INSTANCE
            sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
            logger.info("Using training sampler {}".format(sampler_name))

            if sampler_name == "TrainingSampler":
                sampler = samplers.TrainingSampler(len(train_set))
            elif sampler_name == "NaiveIdentitySampler":
                sampler = samplers.NaiveIdentitySampler(
                    train_set.img_items, mini_batch_size, num_instance)
            elif sampler_name == "BalancedIdentitySampler":
                sampler = samplers.BalancedIdentitySampler(
                    train_set.img_items, mini_batch_size, num_instance)
            elif sampler_name == "SetReWeightSampler":
                set_weight = cfg.DATALOADER.SET_WEIGHT
                sampler = samplers.SetReWeightSampler(train_set.img_items,
                                                      mini_batch_size,
                                                      num_instance, set_weight)
            elif sampler_name == "ImbalancedDatasetSampler":
                sampler = samplers.ImbalancedDatasetSampler(
                    train_set.img_items)
            else:
                raise ValueError(
                    "Unknown training sampler: {}".format(sampler_name))

        iters = cfg.SOLVER.ITERS
        num_workers = cfg.DATALOADER.NUM_WORKERS
        batch_sampler = BatchSampler(sampler, mini_batch_size, True)

        train_loader = IterLoader(
            DataLoader(
                Preprocessor(train_set, with_mem_idx),
                num_workers=num_workers,
                batch_sampler=batch_sampler,
                pin_memory=True,
            ),
            length=iters,
        )
        # train_loader = DataLoaderX(
        #     comm.get_local_rank(),
        #     dataset=Preprocessor(train_set, with_mem_idx),
        #     num_workers=num_workers,
        #     batch_sampler=batch_sampler,
        #     collate_fn=fast_batch_collator,
        #     pin_memory=True,
        # )

        return train_loader

    @classmethod
    def build_test_loader(cls, cfg, test_set):
        logger = logging.getLogger('fastreid')
        logger.info("Prepare testing loader")

        # test_loader = DataLoader(
        #     # Preprocessor(test_set),
        #     test_set,
        #     batch_size=cfg.TEST.IMS_PER_BATCH,
        #     num_workers=cfg.DATALOADER.NUM_WORKERS,
        #     shuffle=False,
        #     pin_memory=True,
        # )

        test_batch_size = cfg.TEST.IMS_PER_BATCH
        mini_batch_size = test_batch_size // comm.get_world_size()
        num_workers = cfg.DATALOADER.NUM_WORKERS
        data_sampler = samplers.InferenceSampler(len(test_set))
        batch_sampler = BatchSampler(data_sampler, mini_batch_size, False)
        test_loader = DataLoaderX(
            comm.get_local_rank(),
            dataset=test_set,
            batch_sampler=batch_sampler,
            num_workers=num_workers,  # save some memory
            collate_fn=fast_batch_collator,
            pin_memory=True,
        )

        return test_loader

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_dir=None):
        data_loader, num_query = build_reid_test_loader(
            cfg, dataset_name=dataset_name)
        return data_loader, ReidEvaluator(cfg, num_query, output_dir)

    @classmethod
    def test(cls, cfg, model):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger('fastreid')

        results = OrderedDict()
        dataset_name = cfg.DATASETS.TGT

        logger.info("Prepare testing set")
        try:
            data_loader, evaluator = cls.build_evaluator(cfg, dataset_name)
        except NotImplementedError:
            logger.warn(
                "No evaluator found. implement its `build_evaluator` method.")
            results[dataset_name] = {}

        results_i = inference_on_dataset(model,
                                         data_loader,
                                         evaluator,
                                         flip_test=cfg.TEST.FLIP.ENABLED)
        results[dataset_name] = results_i

        if comm.is_main_process():
            assert isinstance(
                results, dict
            ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                results)
            logger.info("Evaluation results for {} in csv format:".format(
                dataset_name))
            results_i['dataset'] = dataset_name
            print_csv_format(results_i)

        # if len(results) == 1:
        #     results = list(results.values())[0]

        return results

    @classmethod
    def build_model(cls,
                    cfg,
                    load_model=True,
                    show_model=True,
                    use_dsbn=False):
        cfg = cfg.clone()  # cfg can be modified by model
        cfg.defrost()
        cfg.MODEL.DEVICE = "cpu"

        model = build_model(cfg)
        logger = logging.getLogger('fastreid')

        if load_model:
            pretrain_path = cfg.MODEL.PRETRAIN_PATH
            try:
                state_dict = torch.load(
                    pretrain_path, map_location=torch.device("cpu"))['model']
                for layer in cfg.MODEL.IGNORE_LAYERS:
                    if layer in state_dict.keys():
                        del state_dict[layer]
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError as e:
                logger.info(
                    f"{pretrain_path} is not found! Please check this path.")
                raise e
            except KeyError as e:
                logger.info(
                    "State dict keys error! Please check the state dict.")
                raise e

            incompatible = model.load_state_dict(state_dict, strict=False)
            if incompatible.missing_keys:
                logger.info(
                    get_missing_parameters_message(incompatible.missing_keys))
            if incompatible.unexpected_keys:
                logger.info(
                    get_unexpected_parameters_message(
                        incompatible.unexpected_keys))

        if use_dsbn:
            logger.info("==> Convert BN to Domain Specific BN")
            convert_dsbn(model)

        if show_model:
            logger.info("Model:\n{}".format(model))

        model.to(torch.device("cuda"))
        return model

    @staticmethod
    def auto_scale_hyperparams(cfg, num_classes):
        r"""
        This is used for auto-computation actual training iterations,
        because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
        so we need to convert specific hyper-param to training iterations.
        """
        cfg = cfg.clone()
        frozen = cfg.is_frozen()
        cfg.defrost()

        # If you don't hard-code the number of classes, it will compute the number automatically
        if cfg.MODEL.HEADS.NUM_CLASSES == 0:
            output_dir = cfg.OUTPUT_DIR
            cfg.MODEL.HEADS.NUM_CLASSES = num_classes
            logger = logging.getLogger('fastreid')
            logger.info(
                f"Auto-scaling the num_classes={cfg.MODEL.HEADS.NUM_CLASSES}")

            # Update the saved config file to make the number of classes valid
            if comm.is_main_process() and output_dir:
                # Note: some of our scripts may expect the existence of
                # config.yaml in output directory
                path = os.path.join(output_dir, "config.yaml")
                with PathManager.open(path, "w") as f:
                    f.write(cfg.dump())

        if frozen: cfg.freeze()

        return cfg

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Returns:
            torch.optim.Optimizer:
        It now calls :func:`fastreid.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch):
        """
        It now calls :func:`fastreid.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer, iters_per_epoch)

    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
        logger = logging.getLogger(__name__)
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN
        cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET
                                    ])  # set dataset name for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
        ]

        # if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
        #     logger.info("Prepare precise BN dataset")
        #     ret.append(hooks.PreciseBN(
        #         # Run at the same freq as (but before) evaluation.
        #         self.model,
        #         # Build a new data loader to not affect training
        #         self.build_train_loader(cfg),
        #         cfg.TEST.PRECISE_BN.NUM_ITER,
        #     ))

        if len(cfg.MODEL.FREEZE_LAYERS) > 0 and cfg.SOLVER.FREEZE_ITERS > 0:
            ret.append(
                hooks.LayerFreeze(
                    self.model,
                    cfg.MODEL.FREEZE_LAYERS,
                    cfg.SOLVER.FREEZE_ITERS,
                ))

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation before checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), 200))

        return ret

    def build_writers(self):
        """
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.
        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        It is now implemented by:
        .. code-block:: python
            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]
        """
        # Assume the default print/log frequency.
        # TODO: customize my writers
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def _write_metrics(self, loss_dict: Dict[str, torch.Tensor],
                       data_time: float):
        """
        Args:
            loss_dict (dict): dict of scalar losses
            data_time (float): time taken by the dataloader iteration
        """
        device = next(iter(loss_dict.values())).device

        # Use a new stream so these ops don't wait for DDP or backward
        with torch.cuda.stream(torch.cuda.Stream() if device.type ==
                               "cuda" else None):
            metrics_dict = {
                k: v.detach().cpu().item()
                for k, v in loss_dict.items()
            }
            metrics_dict["data_time"] = data_time

            # Gather metrics among all workers for logging
            # This assumes we do DDP-style training, which is currently the only
            # supported method in detectron2.
            all_metrics_dict = comm.gather(metrics_dict)

        if comm.is_main_process():
            storage = get_event_storage()
            # data_time among workers can have high variance. The actual latency
            # caused by data_time is the maximum among workers.
            data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
            storage.put_scalar("data_time", data_time)

            # average the rest metrics
            metrics_dict = {
                k: np.mean([x[k] for x in all_metrics_dict])
                for k in all_metrics_dict[0].keys()
            }
            total_losses_reduced = sum(metrics_dict.values())
            if not np.isfinite(total_losses_reduced):
                raise FloatingPointError(
                    f"Loss became infinite or NaN at iteration={self.iter}!\n"
                    f"loss_dict = {metrics_dict}")

            storage.put_scalar("total_loss", total_losses_reduced)
            if len(metrics_dict) > 1:
                storage.put_scalars(**metrics_dict)
Exemplo n.º 21
0
class SpCL_UDA_Trainer(TrainerBase):
    """
    load an un-pretrained model and train on the source & target domain from scratch
    """
    def __init__(self, cfg):
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # if setup_logger is not called for fastreid
            setup_logger()

        # Create datasets
        logger.info("==> Load source-domain dataset")
        self.src = src = self.load_dataset(cfg.DATASETS.SRC)
        self.src_pid_nums = src.get_num_pids(src.train)

        logger.info("==> Load target-domain dataset")
        self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT)
        self.tgt_nums = len(tgt.train)

        # Create model
        self.model = self.build_model(cfg,
                                      load_model=False,
                                      show_model=False,
                                      use_dsbn=True)

        # Create hybrid memorys
        self.hm = HybridMemory(num_features=cfg.MODEL.BACKBONE.FEAT_DIM,
                               num_samples=self.src_pid_nums + self.tgt_nums,
                               temp=cfg.MEMORY.TEMP,
                               momentum=cfg.MEMORY.MOMENTUM,
                               use_half=cfg.SOLVER.AMP.ENABLED).cuda()

        # Initialize source-domain class centroids
        logger.info(
            "==> Initialize source-domain class centroids in the hybrid memory"
        )
        with inference_context(self.model), torch.no_grad():
            src_train = self.build_dataset(cfg,
                                           src.train,
                                           is_train=False,
                                           relabel=False,
                                           with_mem_idx=False)
            src_init_feat_loader = self.build_test_loader(cfg, src_train)
            src_fname_feat_dict, _ = extract_features(self.model,
                                                      src_init_feat_loader)
            src_feat_dict = collections.defaultdict(list)
            for f, pid, _ in sorted(src.train):
                src_feat_dict[pid].append(src_fname_feat_dict[f].unsqueeze(0))
            src_centers = [
                torch.cat(src_feat_dict[pid], 0).mean(0)
                for pid in sorted(src_feat_dict.keys())
            ]
            src_centers = torch.stack(src_centers, 0)
            src_centers = F.normalize(src_centers, dim=1)

        # Initialize target-domain instance features
        logger.info(
            "==> Initialize target-domain instance features in the hybrid memory"
        )
        with inference_context(self.model), torch.no_grad():
            tgt_train = self.build_dataset(cfg,
                                           tgt.train,
                                           is_train=False,
                                           relabel=False,
                                           with_mem_idx=False)
            tgt_init_feat_loader = self.build_test_loader(cfg, tgt_train)
            tgt_fname_feat_dict, _ = extract_features(self.model,
                                                      tgt_init_feat_loader)
            tgt_features = torch.cat([
                tgt_fname_feat_dict[f].unsqueeze(0)
                for f, _, _ in sorted(self.tgt.train)
            ], 0)
            tgt_features = F.normalize(tgt_features, dim=1)

        self.hm.features = torch.cat((src_centers, tgt_features), dim=0).cuda()

        del (src_train, src_init_feat_loader, src_fname_feat_dict,
             src_feat_dict, src_centers, tgt_train, tgt_init_feat_loader,
             tgt_fname_feat_dict, tgt_features)

        # Optimizer
        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
                find_unused_parameters=True)

        # Learning rate scheduler
        self.iters_per_epoch = cfg.SOLVER.ITERS
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = "AMPTrainer does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None

    def train(self):
        """
        Run training.
        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        super().train(self.start_epoch, self.max_epoch, self.iters_per_epoch)
        if comm.is_main_process():
            assert hasattr(self, "_last_eval_results"
                           ), "No evaluation results obtained during training!"
            return self._last_eval_results

    def before_train(self):
        self.model.train()
        if self.cfg.SOLVER.AMP.ENABLED:
            assert torch.cuda.is_available(
            ), "CUDA is required for AMP training!"
        return super().before_train()

    def before_epoch(self):
        logger = logging.getLogger('fastreid')

        # Calculate distance
        logger.info(
            "==> Create pseudo labels for unlabeled target domain with self-paced policy"
        )
        tgt_features = self.hm.features[self.src_pid_nums:].clone()

        rerank_dist = compute_jaccard_distance(tgt_features,
                                               k1=self.cfg.CLUSTER.JACCARD.K1,
                                               k2=self.cfg.CLUSTER.JACCARD.K2)
        del tgt_features

        if self.epoch == 0:
            if self.cfg.CLUSTER.DBSCAN.ADAPTIVE_EPS:
                logger.info("==> Calculating eps according to rerank_dist...")
                tri_mat = np.triu(rerank_dist, 1)  # tri_mat.dim=2
                tri_mat = tri_mat[np.nonzero(tri_mat)]  # tri_mat.dim=1
                tri_mat = np.sort(tri_mat, axis=None)
                top_num = np.round(self.cfg.SOLVER.RHO *
                                   tri_mat.size).astype(int)
                self.eps = tri_mat[:top_num].mean()
                logger.info(f"==> epoch {self.epoch} eps: {self.eps}")
            else:
                self.eps = self.cfg.CLUSTER.DBSCAN.EPS

        self.eps_tight = self.eps - self.cfg.CLUSTER.DBSCAN.EPS_GAP
        self.eps_loose = self.eps + self.cfg.CLUSTER.DBSCAN.EPS_GAP

        self.cluster = DBSCAN(eps=self.eps,
                              min_samples=4,
                              metric="precomputed",
                              n_jobs=-1)
        self.cluster_tight = DBSCAN(eps=self.eps_tight,
                                    min_samples=4,
                                    metric="precomputed",
                                    n_jobs=-1)
        self.cluster_loose = DBSCAN(eps=self.eps_loose,
                                    min_samples=4,
                                    metric="precomputed",
                                    n_jobs=-1)

        # select & cluster images as training set of this epochs
        pseudo_labels = self.cluster.fit_predict(rerank_dist)
        pseudo_labels_tight = self.cluster_tight.fit_predict(rerank_dist)
        pseudo_labels_loose = self.cluster_loose.fit_predict(rerank_dist)
        num_ids = len(set(pseudo_labels)) - (1 if -1 in pseudo_labels else 0)
        num_ids_tight = len(
            set(pseudo_labels_tight)) - (1 if -1 in pseudo_labels_tight else 0)
        num_ids_loose = len(
            set(pseudo_labels_loose)) - (1 if -1 in pseudo_labels_loose else 0)

        pseudo_labels = self.generate_pseudo_labels(pseudo_labels, num_ids)
        pseudo_labels_tight = self.generate_pseudo_labels(
            pseudo_labels_tight, num_ids_tight)
        pseudo_labels_loose = self.generate_pseudo_labels(
            pseudo_labels_loose, num_ids_loose)

        # print(pseudo_labels.min(), pseudo_labels.max())
        # exit()

        # compute R_indep and R_comp
        N = pseudo_labels.size(0)
        label_sim = (pseudo_labels.expand(N, N).eq(
            pseudo_labels.expand(N, N).t()).float())  # [N, N]
        label_sim_tight = (pseudo_labels_tight.expand(N, N).eq(
            pseudo_labels_tight.expand(N, N).t()).float())
        label_sim_loose = (pseudo_labels_loose.expand(N, N).eq(
            pseudo_labels_loose.expand(N, N).t()).float())

        R_comp = 1 - torch.min(label_sim, label_sim_tight).sum(-1) / torch.max(
            label_sim, label_sim_tight).sum(-1)  # [N]
        R_indep = 1 - torch.min(label_sim, label_sim_loose).sum(
            -1) / torch.max(label_sim, label_sim_loose).sum(-1)  # [N]
        assert (R_comp.min() >= 0) and (R_comp.max() <= 1)
        assert (R_indep.min() >= 0) and (R_indep.max() <= 1)

        cluster_R_comp, cluster_R_indep = (
            collections.defaultdict(list),
            collections.defaultdict(list),
        )
        cluster_img_num = collections.defaultdict(int)
        for i, (comp, indep,
                label) in enumerate(zip(R_comp, R_indep, pseudo_labels)):
            cluster_R_comp[label.item() - self.src_pid_nums].append(
                comp.item())
            cluster_R_indep[label.item() - self.src_pid_nums].append(
                indep.item())
            cluster_img_num[label.item() - self.src_pid_nums] += 1

        cluster_R_comp = [
            min(cluster_R_comp[i]) for i in sorted(cluster_R_comp.keys())
        ]
        cluster_R_indep = [
            min(cluster_R_indep[i]) for i in sorted(cluster_R_indep.keys())
        ]
        cluster_R_indep_noins = [
            iou for iou, num in zip(cluster_R_indep,
                                    sorted(cluster_img_num.keys()))
            if cluster_img_num[num] > 1
        ]

        if self.epoch <= self.start_epoch:
            """
            constant threshold α for identifying independent clusters is defined
            by the top-90% Rindep before the first epoch and remains the same for all the training process
            """
            logger.info("==> calculate independ before first epoch")
            self.indep_thres = np.sort(cluster_R_indep_noins)[min(
                len(cluster_R_indep_noins) - 1,
                np.round(len(cluster_R_indep_noins) * 0.9).astype("int"),
            )]

        pseudo_labeled_dataset = []
        outliers = 0
        for i, ((fname, _, cid),
                label) in enumerate(zip(sorted(self.tgt.train),
                                        pseudo_labels)):
            indep_score = cluster_R_indep[label.item() - self.src_pid_nums]
            comp_score = R_comp[i]
            if (indep_score <= self.indep_thres) and (
                    comp_score.item() <=
                    cluster_R_comp[label.item() - self.src_pid_nums]):
                pseudo_labeled_dataset.append((fname, label.item(), cid))
            else:
                pseudo_label = self.src_pid_nums + len(
                    cluster_R_indep) + outliers
                pseudo_labeled_dataset.append((fname, pseudo_label, cid))
                pseudo_labels[i] = pseudo_label
                outliers += 1

        # statistics of clusters and un-clustered instances
        index2label = collections.defaultdict(int)
        for label in pseudo_labels:
            index2label[label.item()] += 1

        cluster_label = []
        outlier_label = []
        for k, v in index2label.items():
            if v == 1:
                outlier_label.append(k)
            else:
                cluster_label.append(k)
        print(f'cluster_label', min(cluster_label), max(cluster_label),
              len(cluster_label))
        print(f'outlier label', min(outlier_label), max(outlier_label),
              len(outlier_label))

        index2label = np.fromiter(index2label.values(), dtype=float)
        logger.info(
            "==> Statistics for epoch {}: {} clusters, {} un-clustered instances, R_indep threshold is {}"
            .format(
                self.epoch,
                (index2label > 1).sum(),
                (index2label == 1).sum(),
                1 - self.indep_thres,
            ))

        self.hm.labels = torch.cat(
            (torch.arange(self.src_pid_nums), pseudo_labels)).cuda()

        src_train = self.build_dataset(
            self.cfg,
            self.src.train,
            is_train=True,
            relabel=True,  # relabel?
            # relabel=False,
            with_mem_idx=True)
        self.src_train_loader = self.build_train_loader(
            self.cfg,
            train_set=src_train,
            sampler=RandomMultipleGallerySampler(
                src_train.img_items,
                self.cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(),
                self.cfg.DATALOADER.NUM_INSTANCE,
                with_mem_idx=True),
            with_mem_idx=True)
        # self.src_load_iter = iter(self.src_train_loader)

        pseudo_tgt_train = self.build_dataset(
            self.cfg,
            pseudo_labeled_dataset,
            is_train=True,
            relabel=True,  # relabel?
            # relabel=False,
            with_mem_idx=True)
        self.pseudo_tgt_train_loader = self.build_train_loader(
            self.cfg,
            train_set=pseudo_tgt_train,
            sampler=RandomMultipleGallerySampler(
                pseudo_tgt_train.img_items,
                self.cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(),
                self.cfg.DATALOADER.NUM_INSTANCE,
                with_mem_idx=True),
            with_mem_idx=True)
        # self.tgt_load_iter = iter(self.pseudo_tgt_train_loader)

        return super().before_epoch()

    # generate new dataset and calculate cluster centers
    def generate_pseudo_labels(self, cluster_id, num):
        labels = []
        outliers = 0
        for i, ((fname, _, cid),
                id) in enumerate(zip(sorted(self.tgt.train), cluster_id)):
            if id != -1:
                labels.append(self.src_pid_nums + id)
            else:
                labels.append(self.src_pid_nums + num + outliers)
                outliers += 1
        return torch.Tensor(labels).long()

    def run_step(self):
        assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!"
        if self.cfg.SOLVER.AMP.ENABLED:
            assert torch.cuda.is_available(
            ), f"[{self.__class__.__name__}] CUDA is required for AMP training!"
            from torch.cuda.amp.autocast_mode import autocast

        start = time.perf_counter()

        # load data
        src_inputs = self.src_train_loader.next()
        tgt_inputs = self.pseudo_tgt_train_loader.next()

        # src_inputs = next(self.src_load_iter)
        # tgt_inputs = next(self.tgt_load_iter)

        def _parse_data(inputs):
            # print(len(inputs))
            # for i in range(len(inputs)):
            #     print(i, type(inputs[i]), inputs[i])
            imgs, _, pids, _, indices = inputs
            return imgs.cuda(), pids.cuda(), indices

        # process inputs
        s_inputs, s_targets, s_indices = _parse_data(src_inputs)
        t_inputs, t_targets, t_indices = _parse_data(tgt_inputs)
        # print('src', s_targets, s_indices)
        # print('tgt', t_targets, t_indices)
        # exit()

        # arrange batch for domain-specific BNP
        device_num = torch.cuda.device_count()
        B, C, H, W = s_inputs.size()

        def reshape(inputs):
            return inputs.view(device_num, -1, C, H, W)

        s_inputs, t_inputs = reshape(s_inputs), reshape(t_inputs)
        inputs = torch.cat((s_inputs, t_inputs), 1).view(-1, C, H, W)

        data_time = time.perf_counter() - start

        def _forward():
            outputs = self.model(inputs)
            if isinstance(outputs, dict):
                f_out = outputs['features']
            else:
                f_out = outputs

            # de-arrange batch
            f_out = f_out.view(device_num, -1, f_out.size(-1))

            f_out_s, f_out_t = f_out.split(f_out.size(1) // 2, dim=1)
            f_out_s, f_out_t = f_out_s.contiguous().view(
                -1,
                f_out.size(-1)), f_out_t.contiguous().view(-1, f_out.size(-1))

            # compute loss with the hybrid memory
            # with autocast(enabled=False):
            loss_s = self.hm(f_out_s, s_targets)
            loss_t = self.hm(f_out_t, t_indices + self.src_pid_nums)

            loss_dict = {'loss_s': loss_s, 'loss_t': loss_t}
            return loss_dict

        if self.cfg.SOLVER.AMP.ENABLED:
            with autocast():
                loss_dict = _forward()
                losses = sum(loss_dict.values())

            self.optimizer.zero_grad()
            self.grad_scaler.scale(losses).backward()

            self._write_metrics(loss_dict, data_time)

            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            loss_dict = _forward()
            losses = sum(loss_dict.values())

            self.optimizer.zero_grad()
            losses.backward()

            self._write_metrics(loss_dict, data_time)

            self.optimizer.step()

        if isinstance(self.param_wrapper, ContiguousParams):
            self.param_wrapper.assert_buffer_is_valid()

    @classmethod
    def load_dataset(cls, name):
        logger = logging.getLogger(__name__)
        logger.info(f"Preparing {name}")

        _root = os.getenv("FASTREID_DATASETS", "/root/datasets")
        data = DATASET_REGISTRY.get(name)(root=_root)
        if comm.is_main_process():
            data.show_train()

        return data

    @classmethod
    def build_dataset(cls,
                      cfg,
                      img_items,
                      is_train=False,
                      relabel=False,
                      transforms=None,
                      with_mem_idx=False):
        if transforms is None:
            transforms = build_transforms(cfg, is_train=is_train)

        if with_mem_idx:
            sorted_img_items = sorted(img_items)
            for i in range(len(sorted_img_items)):
                sorted_img_items[i] += (i, )
            return InMemoryDataset(sorted_img_items, transforms, relabel)
        else:
            return CommDataset(img_items, transforms, relabel)

    @classmethod
    def build_train_loader(cls,
                           cfg,
                           train_set=None,
                           sampler=None,
                           with_mem_idx=False):
        logger = logging.getLogger('fastreid')
        logger.info("Prepare training loader")

        total_batch_size = cfg.SOLVER.IMS_PER_BATCH
        mini_batch_size = total_batch_size // comm.get_world_size()

        if sampler is None:
            num_instance = cfg.DATALOADER.NUM_INSTANCE
            sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
            logger.info("Using training sampler {}".format(sampler_name))

            if sampler_name == "TrainingSampler":
                sampler = samplers.TrainingSampler(len(train_set))
            elif sampler_name == "NaiveIdentitySampler":
                sampler = samplers.NaiveIdentitySampler(
                    train_set.img_items, mini_batch_size, num_instance)
            elif sampler_name == "BalancedIdentitySampler":
                sampler = samplers.BalancedIdentitySampler(
                    train_set.img_items, mini_batch_size, num_instance)
            elif sampler_name == "SetReWeightSampler":
                set_weight = cfg.DATALOADER.SET_WEIGHT
                sampler = samplers.SetReWeightSampler(train_set.img_items,
                                                      mini_batch_size,
                                                      num_instance, set_weight)
            elif sampler_name == "ImbalancedDatasetSampler":
                sampler = samplers.ImbalancedDatasetSampler(
                    train_set.img_items)
            else:
                raise ValueError(
                    "Unknown training sampler: {}".format(sampler_name))

        iters = cfg.SOLVER.ITERS
        num_workers = cfg.DATALOADER.NUM_WORKERS
        batch_sampler = BatchSampler(sampler, mini_batch_size, True)

        train_loader = IterLoader(
            DataLoader(
                Preprocessor(train_set, with_mem_idx),
                num_workers=num_workers,
                batch_sampler=batch_sampler,
                pin_memory=True,
            ),
            length=iters,
        )
        # train_loader = DataLoaderX(
        #     comm.get_local_rank(),
        #     dataset=Preprocessor(train_set, with_mem_idx),
        #     num_workers=num_workers,
        #     batch_sampler=batch_sampler,
        #     collate_fn=fast_batch_collator,
        #     pin_memory=True,
        # )

        return train_loader

    @classmethod
    def build_test_loader(cls, cfg, test_set):
        logger = logging.getLogger('fastreid')
        logger.info("Prepare testing loader")

        # test_loader = DataLoader(
        #     # Preprocessor(test_set),
        #     test_set,
        #     batch_size=cfg.TEST.IMS_PER_BATCH,
        #     num_workers=cfg.DATALOADER.NUM_WORKERS,
        #     shuffle=False,
        #     pin_memory=True,
        # )

        test_batch_size = cfg.TEST.IMS_PER_BATCH
        mini_batch_size = test_batch_size // comm.get_world_size()
        num_workers = cfg.DATALOADER.NUM_WORKERS
        data_sampler = samplers.InferenceSampler(len(test_set))
        batch_sampler = BatchSampler(data_sampler, mini_batch_size, False)
        test_loader = DataLoaderX(
            comm.get_local_rank(),
            dataset=test_set,
            batch_sampler=batch_sampler,
            num_workers=num_workers,  # save some memory
            collate_fn=fast_batch_collator,
            pin_memory=True,
        )

        return test_loader

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_dir=None):
        data_loader, num_query = build_reid_test_loader(
            cfg, dataset_name=dataset_name)
        return data_loader, ReidEvaluator(cfg, num_query, output_dir)

    @classmethod
    def test(cls, cfg, model):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger('fastreid')

        results = OrderedDict()
        dataset_name = cfg.DATASETS.TGT

        logger.info("Prepare testing set")
        try:
            data_loader, evaluator = cls.build_evaluator(cfg, dataset_name)
        except NotImplementedError:
            logger.warn(
                "No evaluator found. implement its `build_evaluator` method.")
            results[dataset_name] = {}

        results_i = inference_on_dataset(model,
                                         data_loader,
                                         evaluator,
                                         flip_test=cfg.TEST.FLIP.ENABLED)
        results[dataset_name] = results_i

        if comm.is_main_process():
            assert isinstance(
                results, dict
            ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                results)
            logger.info("Evaluation results for {} in csv format:".format(
                dataset_name))
            results_i['dataset'] = dataset_name
            print_csv_format(results_i)

        # if len(results) == 1:
        #     results = list(results.values())[0]

        return results

    @classmethod
    def build_model(cls,
                    cfg,
                    load_model=True,
                    show_model=True,
                    use_dsbn=False):
        cfg = cfg.clone()  # cfg can be modified by model
        cfg.defrost()
        cfg.MODEL.DEVICE = "cpu"

        model = build_model(cfg)
        logger = logging.getLogger('fastreid')

        if load_model:
            pretrain_path = cfg.MODEL.PRETRAIN_PATH
            try:
                state_dict = torch.load(
                    pretrain_path, map_location=torch.device("cpu"))['model']
                for layer in cfg.MODEL.IGNORE_LAYERS:
                    if layer in state_dict.keys():
                        del state_dict[layer]
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError as e:
                logger.info(
                    f"{pretrain_path} is not found! Please check this path.")
                raise e
            except KeyError as e:
                logger.info(
                    "State dict keys error! Please check the state dict.")
                raise e

            incompatible = model.load_state_dict(state_dict, strict=False)
            if incompatible.missing_keys:
                logger.info(
                    get_missing_parameters_message(incompatible.missing_keys))
            if incompatible.unexpected_keys:
                logger.info(
                    get_unexpected_parameters_message(
                        incompatible.unexpected_keys))

        if use_dsbn:
            logger.info("==> Convert BN to Domain Specific BN")
            convert_dsbn(model)

        if show_model:
            logger.info("Model:\n{}".format(model))

        model.to(torch.device("cuda"))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Returns:
            torch.optim.Optimizer:
        It now calls :func:`fastreid.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch):
        """
        It now calls :func:`fastreid.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer, iters_per_epoch)

    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
        logger = logging.getLogger(__name__)
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN
        cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET
                                    ])  # set dataset name for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
        ]

        # if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
        #     logger.info("Prepare precise BN dataset")
        #     ret.append(hooks.PreciseBN(
        #         # Run at the same freq as (but before) evaluation.
        #         self.model,
        #         # Build a new data loader to not affect training
        #         self.build_train_loader(cfg),
        #         cfg.TEST.PRECISE_BN.NUM_ITER,
        #     ))

        if len(cfg.MODEL.FREEZE_LAYERS) > 0 and cfg.SOLVER.FREEZE_ITERS > 0:
            ret.append(
                hooks.LayerFreeze(
                    self.model,
                    cfg.MODEL.FREEZE_LAYERS,
                    cfg.SOLVER.FREEZE_ITERS,
                ))

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation before checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), 200))

        return ret

    def build_writers(self):
        """
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.
        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        It is now implemented by:
        .. code-block:: python
            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]
        """
        # Assume the default print/log frequency.
        # TODO: customize my writers
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def _write_metrics(self, loss_dict: Dict[str, torch.Tensor],
                       data_time: float):
        """
        Args:
            loss_dict (dict): dict of scalar losses
            data_time (float): time taken by the dataloader iteration
        """
        device = next(iter(loss_dict.values())).device

        # Use a new stream so these ops don't wait for DDP or backward
        with torch.cuda.stream(torch.cuda.Stream() if device.type ==
                               "cuda" else None):
            metrics_dict = {
                k: v.detach().cpu().item()
                for k, v in loss_dict.items()
            }
            metrics_dict["data_time"] = data_time

            # Gather metrics among all workers for logging
            # This assumes we do DDP-style training, which is currently the only
            # supported method in detectron2.
            all_metrics_dict = comm.gather(metrics_dict)

        if comm.is_main_process():
            storage = get_event_storage()
            # data_time among workers can have high variance. The actual latency
            # caused by data_time is the maximum among workers.
            data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
            storage.put_scalar("data_time", data_time)

            # average the rest metrics
            metrics_dict = {
                k: np.mean([x[k] for x in all_metrics_dict])
                for k in all_metrics_dict[0].keys()
            }
            total_losses_reduced = sum(metrics_dict.values())
            if not np.isfinite(total_losses_reduced):
                raise FloatingPointError(
                    f"Loss became infinite or NaN at iteration={self.iter}!\n"
                    f"loss_dict = {metrics_dict}")

            storage.put_scalar("total_loss", total_losses_reduced)
            if len(metrics_dict) > 1:
                storage.put_scalars(**metrics_dict)
Exemplo n.º 22
0
def do_train(checkpoint=None,
             direction='amr',
             split_both_decoder=False,
             fp16=False):

    assert direction in ('amr', 'text', 'both')

    model, tokenizer = instantiate_model_and_tokenizer(
        config['model'],
        checkpoint=checkpoint,
        additional_tokens_smart_init=config['smart_init'],
        dropout=config['dropout'],
        attention_dropout=config['attention_dropout'],
        from_pretrained=config['warm_start'],
        init_reverse=split_both_decoder,
        penman_linearization=config['penman_linearization'],
        collapse_name_ops=config['collapse_name_ops'],
        use_pointer_tokens=config['use_pointer_tokens'],
        raw_graph=config.get('raw_graph', False))

    print(model)
    print(model.config)

    if checkpoint is not None:
        print(f'Checkpoint restored ({checkpoint})!')

    if direction == 'both' and split_both_decoder:
        params_dir_enc = list(model.model.encoder.parameters())
        params_dir_enc_check = {id(p) for p in params_dir_enc}
        params_dir_dec = set()
        params_dir_dec |= {
            p
            for p in model.model.decoder.parameters()
            if id(p) not in params_dir_enc_check
        }
        params_dir_dec |= {
            p
            for p in model.rev.model.decoder.parameters()
            if id(p) not in params_dir_enc_check
        }
        params_dir_dec = list(params_dir_dec)
        optimizer = RAdam([
            {
                'params': params_dir_enc,
                'lr': config['learning_rate']
            },
            {
                'params': params_dir_dec,
                'lr': config['learning_rate'] * 2
            },
        ],
                          weight_decay=config['weight_decay'])
    else:
        optimizer = RAdam(model.parameters(),
                          lr=config['learning_rate'],
                          weight_decay=config['weight_decay'])
    if checkpoint is not None:
        optimizer.load_state_dict(torch.load(checkpoint)['optimizer'])

    if config['scheduler'] == 'cosine':
        scheduler = transformers.get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=config['warmup_steps'],
            num_training_steps=config['training_steps'])
    elif config['scheduler'] == 'constant':
        scheduler = transformers.get_constant_schedule_with_warmup(
            optimizer, num_warmup_steps=config['warmup_steps'])
    else:
        raise ValueError

    scaler = GradScaler(enabled=fp16)

    train_loader = instantiate_loader(
        config['train'],
        tokenizer,
        batch_size=config['batch_size'],
        evaluation=False,
        use_recategorization=config['use_recategorization'],
        remove_longer_than=config['remove_longer_than'],
        remove_wiki=config['remove_wiki'],
        dereify=config['dereify'],
    )

    dev_gold_path = ROOT / 'data/tmp/dev-gold.txt'
    dev_pred_path = ROOT / 'data/tmp/dev-pred.txt'
    dev_loader = instantiate_loader(
        config['dev'],
        tokenizer,
        batch_size=config['batch_size'],
        evaluation=True,
        out=dev_gold_path,
        use_recategorization=config['use_recategorization'],
        remove_wiki=config['remove_wiki'],
        dereify=config['dereify'],
    )

    if direction == 'amr':

        def train_step(engine, batch):
            model.train()
            x, y, extra = batch
            model.amr_mode = True
            with autocast(enabled=fp16):
                loss, *_ = model(**x, **y)
            scaler.scale((loss / config['accum_steps'])).backward()
            return loss.item()

        @torch.no_grad()
        def eval_step(engine, batch):
            model.eval()
            x, y, extra = batch
            model.amr_mode = True
            loss, *_ = model(**x, **y)
            return loss.item()

    elif direction == 'text':

        def train_step(engine, batch):
            model.train()
            x, y, extra = batch
            x, y = reverse_direction(x, y)
            model.rev.amr_mode = False
            with autocast(enabled=fp16):
                loss, *_ = model.rev(**x, **y)
            scaler.scale((loss / config['accum_steps'])).backward()
            return loss.item()

        @torch.no_grad()
        def eval_step(engine, batch):
            model.eval()
            x, y, extra = batch
            x, y = reverse_direction(x, y)
            model.rev.amr_mode = False
            loss, *_ = model(**x, **y)
            return loss.item()

    elif direction == 'both':

        def train_step(engine, batch):
            model.train()
            x, y, extra = batch
            model.amr_mode = True
            with autocast(enabled=fp16):
                loss1, *_ = model(**x, **y)
            scaler.scale((loss1 / config['accum_steps'] * 0.5)).backward()
            loss1 = loss1.item()
            x, y = reverse_direction(x, y)
            model.rev.amr_mode = False
            with autocast(enabled=fp16):
                loss2, *_ = model.rev(**x, **y)
            scaler.scale((loss2 / config['accum_steps'] * 0.5)).backward()
            return loss1, loss2.item()

        @torch.no_grad()
        def eval_step(engine, batch):
            model.eval()
            x, y, extra = batch
            model.amr_mode = True
            loss1, *_ = model(**x, **y)
            x, y = reverse_direction(x, y)
            model.rev.amr_mode = False
            loss2, *_ = model.rev(**x, **y)
            return loss1.item(), loss2.item()

    else:
        raise ValueError

    trainer = Engine(train_step)
    evaluator = Engine(eval_step)

    @trainer.on(Events.STARTED)
    def update(engine):
        print('training started!')

    @trainer.on(Events.EPOCH_COMPLETED)
    @trainer.on(Events.ITERATION_COMPLETED(every=config['accum_steps']))
    def update(engine):
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_norm'])
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        scheduler.step()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_trn_loss(engine):
        log_msg = f"training epoch: {engine.state.epoch}"
        if direction in ('amr', 'both'):
            log_msg += f" | loss_amr: {engine.state.metrics['trn_amr_loss']:.3f}"
        if direction in ('text', 'both'):
            log_msg += f" | loss_text: {engine.state.metrics['trn_text_loss']:.3f}"
        print(log_msg)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_dev_eval(engine):
        dev_loader.batch_size = config['batch_size']
        dev_loader.device = next(model.parameters()).device
        evaluator.run(dev_loader)

    if not config['best_loss']:
        if direction in ('amr', 'both'):

            @evaluator.on(Events.EPOCH_COMPLETED)
            def smatch_eval(engine):
                device = next(model.parameters()).device
                dev_loader.device = device
                graphs = predict_amrs(
                    dev_loader,
                    model,
                    tokenizer,
                    restore_name_ops=config['collapse_name_ops'])
                write_predictions(dev_pred_path, tokenizer, graphs)
                try:
                    smatch = compute_smatch(dev_gold_path, dev_pred_path)
                except:
                    smatch = 0.
                engine.state.metrics['dev_smatch'] = smatch

        if direction in ('text', 'both'):

            @evaluator.on(Events.EPOCH_COMPLETED)
            def smatch_eval(engine):
                device = next(model.parameters()).device
                dev_loader.device = device
                pred_sentences = predict_sentences(
                    dev_loader,
                    model.rev,
                    tokenizer,
                    beam_size=config['beam_size'])
                bleu = compute_bleu(dev_loader.dataset.sentences,
                                    pred_sentences)
                engine.state.metrics['dev_bleu'] = bleu.score

    @evaluator.on(Events.EPOCH_COMPLETED)
    def log_dev_loss(engine):
        log_msg = f"dev epoch: {trainer.state.epoch}"
        if direction in ('amr', 'both'):
            log_msg += f" | loss_amr: {engine.state.metrics['dev_amr_loss']:.3f}"
            if not config['best_loss']:
                log_msg += f" | smatch: {engine.state.metrics['dev_smatch']:.3f}"
        if direction in ('text', 'both'):
            log_msg += f" | loss_text: {engine.state.metrics['dev_text_loss']:.3f}"
            if not config['best_loss']:
                log_msg += f" | bleu: {engine.state.metrics['dev_bleu']:.3f}"
        print(log_msg)

    if direction == 'amr':
        RunningAverage(output_transform=lambda out: out).attach(
            trainer, 'trn_amr_loss')
        RunningAverage(output_transform=lambda out: out).attach(
            evaluator, 'dev_amr_loss')
    elif direction == 'text':
        RunningAverage(output_transform=lambda out: out).attach(
            trainer, 'trn_text_loss')
        RunningAverage(output_transform=lambda out: out).attach(
            evaluator, 'dev_text_loss')
    elif direction == 'both':
        RunningAverage(output_transform=lambda out: out[0]).attach(
            trainer, 'trn_amr_loss')
        RunningAverage(output_transform=lambda out: out[1]).attach(
            trainer, 'trn_text_loss')
        RunningAverage(output_transform=lambda out: out[0]).attach(
            evaluator, 'dev_amr_loss')
        RunningAverage(output_transform=lambda out: out[1]).attach(
            evaluator, 'dev_text_loss')

    if config['log_wandb']:
        from ignite.contrib.handlers.wandb_logger import WandBLogger
        wandb_logger = WandBLogger(init=False)

        if direction == 'amr':
            wandb_logger.attach_output_handler(
                trainer,
                event_name=Events.ITERATION_COMPLETED,
                tag="iterations/trn_amr_loss",
                output_transform=lambda loss: loss)
        elif direction == 'text':
            wandb_logger.attach_output_handler(
                trainer,
                event_name=Events.ITERATION_COMPLETED,
                tag="iterations/trn_text_loss",
                output_transform=lambda loss: loss)
        if direction == 'both':
            wandb_logger.attach_output_handler(
                trainer,
                event_name=Events.ITERATION_COMPLETED,
                tag="iterations/trn_amr_loss",
                output_transform=lambda loss: loss[0])
            wandb_logger.attach_output_handler(
                trainer,
                event_name=Events.ITERATION_COMPLETED,
                tag="iterations/trn_text_loss",
                output_transform=lambda loss: loss[1])

        if direction == 'amr':
            metric_names_trn = ['trn_amr_loss']
            metric_names_dev = ['dev_amr_loss']
            if not config['best_loss']:
                metric_names_dev.append('dev_smatch')
        elif direction == 'text':
            metric_names_trn = ['trn_text_loss']
            metric_names_dev = ['dev_text_loss']
            if not config['best_loss']:
                metric_names_dev.append('dev_bleu')
        elif direction == 'both':
            metric_names_trn = ['trn_amr_loss', 'trn_text_loss']
            metric_names_dev = ['dev_amr_loss', 'dev_smatch']
            if not config['best_loss']:
                metric_names_dev.extend(['dev_text_loss', 'dev_bleu'])

        wandb_logger.attach_output_handler(
            trainer,
            event_name=Events.EPOCH_COMPLETED,
            tag="epochs",
            metric_names=metric_names_trn,
            global_step_transform=lambda *_: trainer.state.iteration,
        )

        wandb_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag="epochs",
            metric_names=metric_names_dev,
            global_step_transform=lambda *_: trainer.state.iteration,
        )

        @trainer.on(Events.ITERATION_COMPLETED)
        def wandb_log_lr(engine):
            wandb.log({'lr': scheduler.get_last_lr()[0]},
                      step=engine.state.iteration)

    if config['save_checkpoints']:

        if direction in ('amr', 'both'):
            if config['best_loss']:
                prefix = 'best-loss-amr'
                score_function = lambda x: 1 / evaluator.state.metrics[
                    'dev_amr_loss']
            else:
                prefix = 'best-smatch'
                score_function = lambda x: evaluator.state.metrics['dev_smatch'
                                                                   ]
        else:
            if config['best_loss']:
                prefix = 'best-loss-text'
                score_function = lambda x: 1 / evaluator.state.metrics[
                    'dev_amr_loss']
            else:
                prefix = 'best-bleu'
                score_function = lambda x: evaluator.state.metrics['dev_bleu']

        to_save = {'model': model, 'optimizer': optimizer}
        if config['log_wandb']:
            where_checkpoints = str(wandb_logger.run.dir)
        else:
            root = ROOT / 'runs'
            try:
                root.mkdir()
            except:
                pass
            where_checkpoints = root / str(len(list(root.iterdir())))
            try:
                where_checkpoints.mkdir()
            except:
                pass
            where_checkpoints = str(where_checkpoints)

        print(where_checkpoints)
        handler = ModelCheckpoint(
            where_checkpoints,
            prefix,
            n_saved=1,
            create_dir=True,
            score_function=score_function,
            global_step_transform=global_step_from_engine(trainer),
        )
        evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save)

    model.cuda()
    device = next(model.parameters()).device
    train_loader.device = device
    trainer.run(train_loader, max_epochs=config['max_epochs'])
Exemplo n.º 23
0
def train_step(
        model: FlowModel,
        config: TrainConfig,
        action: ActionFn,
        optimizer: optim.Optimizer,
        batch_size: int,
        scheduler: Any = None,
        scaler: GradScaler = None,
        pre_model: FlowModel = None,
        dkl_factor: float = 1.,
        xi: torch.Tensor = None,
):
    """Perform a single training step.

    TODO: Add `torch.device` to arguments for DDP.
    """
    t0 = time.time()
    #  layers, prior = model['layers'], model['prior']
    optimizer.zero_grad()

    loss_dkl = torch.tensor(0.0)
    if torch.cuda.is_available():
        loss_dkl = loss_dkl.cuda()

    if pre_model is not None:
        pre_xi = pre_model.prior.sample_n(batch_size)
        x = qed.ft_flow(pre_model.layers, pre_xi)
        xi = qed.ft_flow_inv(pre_model.layers, x)

    #  with torch.cuda.amp.autocast():
    x, xi, logq = apply_flow_to_prior(model.prior,
                                      model.layers,
                                      xi=xi, batch_size=batch_size)
    logp = (-1.) * action(x)
    dkl = calc_dkl(logp, logq)

    ess = calc_ess(logp, logq)
    qi = qed.batch_charges(xi)
    q = qed.batch_charges(x)
    plaq = logp / (config.beta * config.volume)
    dq = torch.sqrt((q - qi) ** 2)

    loss_dkl = dkl_factor * dkl

    if scaler is not None:
        scaler.scale(loss_dkl).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        loss_dkl.backward()
        optimizer.step()

    if scheduler is not None:
        scheduler.step(loss_dkl)

    metrics = {
        'dt': time.time() - t0,
        'ess': grab(ess),
        'logp': grab(logp),
        'logq': grab(logq),
        'loss_dkl': grab(loss_dkl),
        'q': grab(q),
        'dq': grab(dq),
        'plaq': grab(plaq),
    }

    return metrics