Example #1
0
 def _wrap_distributed(self):
     """Wrap modules with distributed wrapper when requested."""
     if not self.distributed_launch and not self.data_parallel_backend:
         return
     elif self.distributed_launch:
         for name, module in self.modules.items():
             if any(p.requires_grad for p in module.parameters()):
                 # for ddp, all module must run on same GPU
                 module = SyncBatchNorm.convert_sync_batchnorm(module)
                 module = DDP(module, device_ids=[self.device])
                 self.modules[name] = module
     else:
         # data_parallel_backend
         for name, module in self.modules.items():
             if any(p.requires_grad for p in module.parameters()):
                 # if distributed_count = -1 then use all gpus
                 # otherwise, specify the set of gpu to use
                 if self.data_parallel_count == -1:
                     module = DP(module)
                 else:
                     module = DP(
                         module,
                         [i for i in range(self.data_parallel_count)],
                     )
                 self.modules[name] = module
 def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
     """
     Moves metrics and the online evaluator to the correct GPU.
     If training happens via DDP, SyncBatchNorm is enabled for the online evaluator, and it is converted to
     a DDP module.
     """
     for prefix, metrics in [("train", self.train_metrics), ("val", self.val_metrics)]:
         add_submodules_to_same_device(pl_module, metrics, prefix=prefix)
     self.evaluator = SSLEvaluator(n_input=self.z_dim,
                                   n_classes=self.num_classes,
                                   p=self.drop_p,
                                   n_hidden=self.hidden_dim)
     self.evaluator.to(pl_module.device)
     if hasattr(trainer, "accelerator_connector"):
         # This works with Lightning 1.3.8
         accelerator = trainer.accelerator_connector
     elif hasattr(trainer, "_accelerator_connector"):
         # This works with Lightning 1.5.5
         accelerator = trainer._accelerator_connector
     else:
         raise ValueError("Unable to retrieve the accelerator information")
     if accelerator.is_distributed:
         if accelerator.use_ddp:
             self.evaluator = SyncBatchNorm.convert_sync_batchnorm(self.evaluator)
             self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device])  # type: ignore
         else:
             rank_zero_warn("This type of distributed accelerator is not supported. "
                            "The online evaluator will not synchronize across GPUs.")
     self.optimizer = torch.optim.Adam(self.evaluator.parameters(),
                                       lr=self.learning_rate,
                                       weight_decay=self.weight_decay)
     if self.evaluator_state is not None:
         self._wrapped_evaluator().load_state_dict(self.evaluator_state)
     if self.optimizer_state is not None:
         self.optimizer.load_state_dict(self.optimizer_state)
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            # TODO (Min): for now, we just test pytorch sync_bn here.
            #             this will grow into regnet; testing apex sync_bn, etc.
            self.conv = Conv2d(2, 2, (1, 1))
            self.bn = BatchNorm2d(2)

        def forward(self, x):
            x = self.conv(x)
            x = self.bn(x)
            return x

    # TODO (Min): check DDP equivalency.

    model = Model()
    # Note, different rank may wrap in different order due to different random
    # seeds. But results should be the same.
    if random.randint(0, 1) == 0:
        print("auto_wrap_bn, then convert_sync_batchnorm")
        model = auto_wrap_bn(model)
        model = SyncBatchNorm.convert_sync_batchnorm(model)
    else:
        print("convert_sync_batchnorm, then auto_wrap_bn")
        model = SyncBatchNorm.convert_sync_batchnorm(model)
        model = auto_wrap_bn(model)
    model = FSDP(model, **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(3):
        in_data = torch.rand(2, 2, 2, 2).cuda()
        in_data.requires_grad = True
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)
    teardown()
Example #4
0
def get_model(model_config, device, distributed, sync_bn):
    model_name = model_config['type']
    model = torchvision.models.__dict__[model_name](**model_config['params'])
    if distributed and sync_bn:
        model = SyncBatchNorm.convert_sync_batchnorm(model)

    ckpt_file_path = model_config['ckpt']
    load_ckpt(ckpt_file_path, model=model, strict=True)
    return model.to(device)
    def _parser_model(self):
        *model_mod_str_parts, model_class_str = self.cfg.USE_MODEL.split(".")
        model_class = getattr(import_module(".".join(model_mod_str_parts)), model_class_str)
        model = model_class(dictionary=self.dictionary)

        if self.cfg.distributed:
            model = SyncBatchNorm.convert_sync_batchnorm(model).cuda()
        else:
            model = model.cuda()

        return model
Example #6
0
def get_image_classification_model(model_config,
                                   distributed=False,
                                   sync_bn=False):
    model_name = model_config['name']
    if model_name not in models.__dict__:
        return None

    model = models.__dict__[model_name](**model_config['params'])
    if distributed and sync_bn:
        model = SyncBatchNorm.convert_sync_batchnorm(model)
    return model
Example #7
0
def distribute_module(module: nn.Module, rank: Optional[int] = None, *args, **kwargs) -> DDP:
    if rank is None:
        rank = distd.get_rank()

    ddp_module: DDP = DDP(
        SBN.convert_sync_batchnorm(module).to(rank),
        device_ids=[rank],
        output_device=rank,
        *args,
        **kwargs,
    )
    sync(ddp_module, rank)

    return ddp_module
Example #8
0
def get_image_classification_model(model_config, distributed=False):
    model_name = model_config['name']
    quantized = model_config.get('quantized', False)
    if not quantized and model_name in models.__dict__:
        model = models.__dict__[model_name](**model_config['params'])
    elif quantized and model_name in models.quantization.__dict__:
        model = models.quantization.__dict__[model_name](**model_config['params'])
    else:
        return None

    sync_bn = model_config.get('sync_bn', False)
    if distributed and sync_bn:
        model = SyncBatchNorm.convert_sync_batchnorm(model)
    return model
Example #9
0
    def _configure_model(self):
        ''' Mixed precision '''
        if self.fp16:
            if AMP:
                if self.rank == 0:
                    self.logger('Mixed precision training on torch amp.')
            else:
                self.fp16 = False
                if self.rank == 0:
                    self.logger('No mixed precision training backend found.')

        ''' Parallel training '''
        if self.xla:  # DDP on xla
            self.model.to(self.device)
            if self.rank == 0:
                self.logger(f'Model on {self.device}')

        elif self.parallel == 'dp': # DP on cuda
            self.model = DataParallel(
                self.model, device_ids=self.device_ids).to(self.device)
            if hasattr(self, 'criterion'):
                self.criterion = self.criterion.to(self.device)
            self.logger(f'DataParallel on devices {self.device_ids}')

        elif self.parallel == 'ddp': # DDP on cuda
            if self.ddp_sync_batch_norm:
                self.model = SyncBatchNorm.convert_sync_batchnorm(self.model)
            self.model = DistributedDataParallel(
                self.model.to(self.rank), device_ids=[self.rank],
                broadcast_buffers=False,
                find_unused_parameters=True
            )
            if hasattr(self, 'criterion'):
                self.criterion = self.criterion.to(self.rank)
            if self.rank == 0:
                self.logger(
                    f'DistributedDataParallel on devices {self.device_ids}')

        elif self.parallel is not None:
            raise ValueError(f'Unknown type of parallel {self.parallel}')

        else:  # Single device
            self.model.to(self.device)
            if hasattr(self, 'criterion'):
                self.criterion = self.criterion.to(self.device)
            self.logger(f'Model on {self.device}')
        
        self._model_ready = True
Example #10
0
    def to_distributed_data_parallel(self) -> None:
        print(
            f"Run in distributed data parallel w/ {torch.cuda.device_count()} GPUs"
        )

        # setup models
        from torch.nn.parallel import DistributedDataParallel
        from torch.nn import SyncBatchNorm

        local_rank = self.exp_args.local_rank

        self.model = SyncBatchNorm.convert_sync_batchnorm(self.model)
        self.model = DistributedDataParallel(self.model,
                                             device_ids=[local_rank],
                                             output_device=local_rank,
                                             broadcast_buffers=True)

        # setup data loaders
        self.test_loader = self.to_distributed_loader(
            self.test_loader,
            shuffle=False,
            num_workers=self.exp_args.num_workers,
            pin_memory=True,
            drop_last=False)

        self.validation_loader = self.to_distributed_loader(
            self.validation_loader,
            shuffle=False,
            num_workers=self.exp_args.num_workers,
            pin_memory=True,
            drop_last=True)

        self.labeled_train_loader = self.to_distributed_loader(
            self.labeled_train_loader,
            shuffle=True,
            num_workers=self.exp_args.num_workers,
            pin_memory=True,
            drop_last=True)

        self.unlabeled_train_loader = self.to_distributed_loader(
            self.unlabeled_train_loader,
            shuffle=True,
            num_workers=self.exp_args.num_workers,
            pin_memory=True,
            drop_last=True)
Example #11
0
 def _wrap_distributed(self):
     """Wrap modules with distributed wrapper when requested."""
     if not self.distributed_launch and not self.data_parallel_backend:
         return
     elif self.distributed_launch:
         for name, module in self.modules.items():
             if any(p.requires_grad for p in module.parameters()):
                 module = SyncBatchNorm.convert_sync_batchnorm(module)
                 module = DDP(
                     module,
                     device_ids=[self.device],
                     find_unused_parameters=self.find_unused_parameters,
                 )
                 self.modules[name] = module
     else:
         # data_parallel_backend
         for name, module in self.modules.items():
             if any(p.requires_grad for p in module.parameters()):
                 module = DP(module)
                 self.modules[name] = module
Example #12
0
def build_trainer(cfg):
    # Create data loaders
    train_loader, valid_loader = get_train_val_dataloaders(cfg)
    if cfg.train.params.steps_per_epoch == 0:
        cfg.train.params.steps_per_epoch = len(train_loader)
    # Create model
    model = builder.build_model(cfg)
    if cfg.experiment.distributed:
        if cfg.experiment.sync_bn:
            model = sbn.convert_sync_batchnorm(model)
        if cfg.experiment.cuda:
            model.to(f'cuda:{cfg.local_rank}')
        model = DistributedDataParallel(model,
                                        device_ids=[cfg.local_rank],
                                        output_device=cfg.local_rank)
    else:
        if cfg.experiment.cuda:
            model.to(f'cuda:{cfg.local_rank}')
    model.train()
    # Create loss
    criterion = builder.build_loss(cfg)
    # Create optimizer
    optimizer = builder.build_optimizer(cfg, model.parameters())
    # Create learning rate scheduler
    scheduler = builder.build_scheduler(cfg, optimizer)
    # Create evaluator
    evaluator = builder.build_evaluator(cfg, valid_loader)
    trainer = getattr(beehive_train, cfg.train.name)
    trainer = trainer(loader=train_loader,
                      model=model,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      criterion=criterion,
                      evaluator=evaluator,
                      logger=logging.getLogger('root'),
                      cuda=cfg.train.params.pop('cuda'),
                      dist=cfg.experiment.distributed)
    return trainer
Example #13
0
    def build_model(self):

        self.G = networks.get_generator(encoder=self.model_config.arch.encoder,
                                        decoder=self.model_config.arch.decoder)
        self.G.cuda()

        if CONFIG.dist:
            self.logger.info("Using pytorch synced BN")
            self.G = SyncBatchNorm.convert_sync_batchnorm(self.G)

        self.G_optimizer = torch.optim.Adam(
            self.G.parameters(),
            lr=self.train_config.G_lr,
            betas=[self.train_config.beta1, self.train_config.beta2])

        if CONFIG.dist:
            # SyncBatchNorm only supports DistributedDataParallel with single GPU per process
            self.G = DistributedDataParallel(self.G,
                                             device_ids=[CONFIG.local_rank],
                                             output_device=CONFIG.local_rank)
        else:
            self.G = nn.DataParallel(self.G)

        self.build_lr_scheduler()
Example #14
0
def train_function(gpu, world_size, node_rank, gpus):
    import torch.multiprocessing
    torch.multiprocessing.set_sharing_strategy('file_system')

    torch.manual_seed(25)
    np.random.seed(25)

    rank = node_rank * gpus + gpu
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=world_size,
        rank=rank
    )

    width_size = 512
    batch_size = 32
    accumulation_step = 5
    device = torch.device("cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu")

    if rank == 0:
        wandb.init(project='inception_v3', group=wandb.util.generate_id())
        wandb.config.width_size = width_size
        wandb.config.aspect_rate = 1
        wandb.config.batch_size = batch_size
        wandb.config.accumulation_step = accumulation_step

        shutil.rmtree('tensorboard_runs', ignore_errors=True)
        writer = SummaryWriter(log_dir='tensorboard_runs', filename_suffix=str(time.time()))

    ranzcr_df = pd.read_csv('train_folds.csv')
    ranzcr_train_df = ranzcr_df[ranzcr_df['fold'] != 1]

    chestx_df = pd.read_csv('chestx_pseudolabeled_data_lazy_balancing.csv')
    train_image_transforms = alb.Compose([
        alb.ImageCompression(quality_lower=65, p=0.5),
        alb.HorizontalFlip(p=0.5),
        alb.CLAHE(p=0.5),
        alb.OneOf([
            alb.GridDistortion(
                num_steps=8,
                distort_limit=0.5,
                p=1.0
            ),
            alb.OpticalDistortion(
                distort_limit=0.5,
                shift_limit=0.5,
                p=1.0,
            ),
            alb.ElasticTransform(alpha=3, p=1.0)],
            p=0.7
        ),
        alb.RandomResizedCrop(
            height=width_size,
            width=width_size,
            scale=(0.8, 1.2),
            p=0.7
        ),
        alb.RGBShift(p=0.5),
        alb.RandomSunFlare(p=0.5),
        alb.RandomFog(p=0.5),
        alb.RandomBrightnessContrast(p=0.5),
        alb.HueSaturationValue(
            hue_shift_limit=20,
            sat_shift_limit=20,
            val_shift_limit=20,
            p=0.5
        ),
        alb.ShiftScaleRotate(shift_limit=0.025, scale_limit=0.1, rotate_limit=20, p=0.5),
        alb.CoarseDropout(
            max_holes=12,
            min_holes=6,
            max_height=int(width_size / 6),
            max_width=int(width_size / 6),
            min_height=int(width_size / 6),
            min_width=int(width_size / 20),
            p=0.5
        ),
        alb.IAAAdditiveGaussianNoise(loc=0, scale=(2.5500000000000003, 12.75), per_channel=False, p=0.5),
        alb.IAAAffine(scale=1.0, translate_percent=None, translate_px=None, rotate=0.0, shear=0.0, order=1, cval=0,
                      mode='reflect', p=0.5),
        alb.IAAAffine(rotate=90., p=0.5),
        alb.IAAAffine(rotate=180., p=0.5),
        alb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])
    train_set = NoisyStudentDataset(ranzcr_train_df, chestx_df, train_image_transforms,
                                    '../ranzcr/train', '../data', width_size=width_size)
    train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=4, sampler=train_sampler)

    ranzcr_valid_df = ranzcr_df[ranzcr_df['fold'] == 1]
    valid_image_transforms = alb.Compose([
        alb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])
    valid_set = ImageDataset(ranzcr_valid_df, valid_image_transforms, '../ranzcr/train', width_size=width_size)
    valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=4, pin_memory=False, drop_last=False)

    # ranzcr_valid_df = ranzcr_df[ranzcr_df['fold'] == 1]
    # valid_image_transforms = alb.Compose([
    #     alb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    #     ToTensorV2()
    # ])
    # valid_set = ImageDataset(ranzcr_valid_df, valid_image_transforms, '../ranzcr/train', width_size=width_size)
    # valid_sampler = DistributedSampler(valid_set, num_replicas=world_size, rank=rank)
    # valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=4, sampler=valid_sampler)

    checkpoints_dir_name = 'inception_v3_noisy_student_{}'.format(width_size)
    os.makedirs(checkpoints_dir_name, exist_ok=True)

    # model = EfficientNetNoisyStudent(11, pretrained_backbone=True,
    #                                  mixed_precision=True, model_name='tf_efficientnet_b7_ns')
    model = Inception(11, pretrained_backbone=True, mixed_precision=False, model_name='inception_v3')
    model = SyncBatchNorm.convert_sync_batchnorm(model)
    model.to(device)
    model = DistributedDataParallel(model, device_ids=[gpu])

    # class_weights = [354.625, 23.73913043478261, 2.777105767812362, 110.32608695652173,
    #                  52.679245283018865, 9.152656621728786, 4.7851333032083145,
    #                  8.437891632878731, 2.4620064899945917, 0.4034751151063363, 31.534942820838626]
    class_names = ['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal',
                   'NGT - Abnormal', 'NGT - Borderline', 'NGT - Incompletely Imaged', 'NGT - Normal',
                   'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present']
    scaler = GradScaler()
    criterion = torch.nn.BCEWithLogitsLoss()

    lr_start = 1e-4
    lr_end = 1e-6
    weight_decay = 0
    epoch_num = 20
    if rank == 0:
        wandb.config.model_name = checkpoints_dir_name
        wandb.config.lr_start = lr_start
        wandb.config.lr_end = lr_end
        wandb.config.weight_decay = weight_decay
        wandb.config.epoch_num = epoch_num
        wandb.config.optimizer = 'adam'
        wandb.config.scheduler = 'CosineAnnealingLR'
        wandb.config.is_loss_weights = 'no'

    optimizer = Adam(model.parameters(), lr=lr_start, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=epoch_num, eta_min=lr_end, last_epoch=-1)

    max_val_auc = 0

    for epoch in range(epoch_num):
        train_loss, train_avg_auc, train_auc, train_rocs, train_data_pr, train_duration = one_epoch_train(
            model, train_loader, optimizer, criterion, device, scaler,
            iters_to_accumulate=accumulation_step, clip_grads=False)
        scheduler.step()

        if rank == 0:
            val_loss, val_avg_auc, val_auc, val_rocs, val_data_pr, val_duration = eval_model(
                model, valid_loader, device, criterion, scaler)

            wandb.log({'train_loss': train_loss, 'val_loss': val_loss,
                       'train_auc': train_avg_auc, 'val_auc': val_avg_auc, 'epoch': epoch})
            for class_name, auc1, auc2 in zip(class_names, train_auc, val_auc):
                wandb.log({'{} train auc'.format(class_name): auc1,
                           '{} val auc'.format(class_name): auc2, 'epoch': epoch})

            if val_avg_auc > max_val_auc:
                max_val_auc = val_avg_auc
                wandb.run.summary["best_accuracy"] = val_avg_auc

            print('EPOCH %d:\tTRAIN [duration %.3f sec, loss: %.3f, avg auc: %.3f]\t\t'
                  'VAL [duration %.3f sec, loss: %.3f, avg auc: %.3f]\tCurrent time %s' %
                  (epoch + 1, train_duration, train_loss, train_avg_auc,
                   val_duration, val_loss, val_avg_auc, str(datetime.now(timezone('Europe/Moscow')))))

            torch.save(model.module.state_dict(),
                       os.path.join(checkpoints_dir_name, '{}_epoch{}_val_auc{}_loss{}_train_auc{}_loss{}.pth'.format(
                           checkpoints_dir_name, epoch + 1, round(val_avg_auc, 3), round(val_loss, 3),
                           round(train_avg_auc, 3), round(train_loss, 3))))
    if rank == 0:
        wandb.finish()
Example #15
0
    def __init__(self, params):
        """Creates a Trainer.
    """
        utils.set_default_param_values_and_env_vars(params)
        self.params = params

        # Setup logging & log the version.
        global_utils.setup_logging(params.logging_verbosity)

        self.job_name = self.params.job_name  # "" for local training
        self.is_distributed = bool(self.job_name)
        self.task_index = self.params.task_index
        self.local_rank = self.params.local_rank
        self.start_new_model = self.params.start_new_model
        self.train_dir = self.params.train_dir
        self.num_gpus = self.params.num_gpus
        if self.num_gpus and not self.is_distributed:
            self.batch_size = self.params.batch_size * self.num_gpus
        else:
            self.batch_size = self.params.batch_size

        # print self.params parameters
        if self.start_new_model and self.local_rank == 0:
            pp = pprint.PrettyPrinter(indent=2, compact=True)
            logging.info(pp.pformat(params.values()))

        if self.local_rank == 0:
            logging.info("PyTorch version: {}.".format(torch.__version__))
            logging.info("NCCL Version {}".format(torch.cuda.nccl.version()))
            logging.info("Hostname: {}.".format(socket.gethostname()))

        if self.is_distributed:
            self.num_nodes = len(params.worker_hosts.split(';'))
            self.world_size = self.num_nodes * self.num_gpus
            self.rank = self.task_index * self.num_gpus + self.local_rank
            dist.init_process_group(backend='nccl',
                                    init_method='env://',
                                    timeout=datetime.timedelta(seconds=30))
            if self.local_rank == 0:
                logging.info('World Size={} => Total batch size {}'.format(
                    self.world_size, self.batch_size * self.world_size))
            self.is_master = bool(self.rank == 0)
        else:
            self.world_size = 1
            self.is_master = True

        # create a mesage builder for logging
        self.message = global_utils.MessageBuilder()

        # load reader and model
        self.reader = readers_config[self.params.dataset](self.params,
                                                          self.batch_size,
                                                          self.num_gpus,
                                                          is_training=True)
        self.model = model_config.get_model_config(self.params.model,
                                                   self.params.dataset,
                                                   self.params,
                                                   self.reader.n_classes,
                                                   is_training=True)
        # define DistributedDataParallel job
        self.model = SyncBatchNorm.convert_sync_batchnorm(self.model)
        torch.cuda.set_device(params.local_rank)
        self.model = self.model.cuda()
        i = params.local_rank
        self.model = DistributedDataParallel(self.model,
                                             device_ids=[i],
                                             output_device=i)
        if self.local_rank == 0:
            logging.info('Model defined with DistributedDataParallel')

        # define set for saved ckpt
        self.saved_ckpts = set([0])

        # define optimizer
        self.optimizer = get_optimizer(self.params.optimizer,
                                       self.params.optimizer_params,
                                       self.params.init_learning_rate,
                                       self.params.weight_decay,
                                       self.model.parameters())

        # define learning rate scheduler
        self.scheduler = get_scheduler(self.optimizer,
                                       self.params.lr_scheduler,
                                       self.params.lr_scheduler_params)

        # if start_new_model is False, we restart training
        if not self.start_new_model:
            if self.local_rank == 0:
                logging.info('Restarting training...')
            self.load_state()

        # define Lipschitz Reg module
        self.lipschitz_reg = LipschitzRegularization(self.model, self.params,
                                                     self.reader,
                                                     self.local_rank)

        # exponential moving average
        self.ema = None
        if getattr(self.params, 'ema', False) > 0:
            self.ema = utils.EMA(self.params.ema)

        # if adversarial training, create the attack class
        if self.params.adversarial_training:
            if self.local_rank == 0:
                logging.info('Adversarial Training')
            attack_params = self.params.adversarial_training_params
            self.attack = utils.get_attack(
                self.model, self.reader.n_classes,
                self.params.adversarial_training_name, attack_params)
Example #16
0
 def distribute_net(self, net, device):
     net = net.to(self.device)
     return DDP(SyncBatchNorm.convert_sync_batchnorm(net).to(self.device),
                device_ids=[device])
Example #17
0
def main():
    args = arg_parser()
    # turn on benchmark mode
    torch.backends.cudnn.benchmark = True

    accelerator = Accelerator(fp16=args.use_fp16)

    if accelerator.is_main_process:
        # setup logger
        os.makedirs(args.log_dir, exist_ok=True)
        time_stamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
        logger = get_root_logger(logger_name='MOD', log_file=os.path.join(
            args.log_dir, f'{time_stamp}.log'))
        writer = SummaryWriter(log_dir=os.path.join(args.log_dir, 'tf_logs'))
        # log env info
        logger.info('--------------------Env info--------------------')
        for key, value in sorted(collect_env().items()):
            logger.info(str(key) + ': ' + str(value))
        # log args
        logger.info('----------------------Args-----------------------')
        for key, value in sorted(vars(args).items()):
            logger.info(str(key) + ': ' + str(value))
        logger.info('---------------------------------------------------')

    # train_dataset = MOD(root=args.root, annfile=args.train_annfile)
    train_dataset = MOD_3d(
        root=args.root, annfile=args.train_annfile, clip_length=args.clip_length)
    train_dataloader = DataLoader(train_dataset, batch_size=args.samples_per_gpu,
                                  shuffle=True, num_workers=args.num_workers, pin_memory=True)
    # val dataloader
    # val_dataset = MOD(root=args.root, annfile=args.val_annfile, val=True)
    val_dataset = MOD_3d(root=args.root, annfile=args.val_annfile,
                         val=True, clip_length=args.clip_length)
    val_dataloader = DataLoader(val_dataset, batch_size=args.samples_per_gpu,
                                shuffle=False, num_workers=args.num_workers, pin_memory=True)

    # define model
    # model = TinyUNet(
    #     n_channels=1, n_classes=train_dataset.num_classes, upsample='bilinear')
    # replace2dwith3d(model=model)
    model = TinyUNet3d(n_channels=1, n_classes=2)
    # optimizer
    init_lr = args.base_lr*dist.get_world_size()*args.samples_per_gpu/16
    optimizer = optim.SGD(model.parameters(), lr=init_lr,
                          weight_decay=1e-4, momentum=0.9)
    # recover states
    start_epoch = 1
    if args.resume is not None:
        ckpt: dict() = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        start_epoch = ckpt['epoch']+1
        if accelerator.is_main_process:
            logger.info(f"Resume from epoch {start_epoch-1}...")
    else:
        if accelerator.is_main_process:
            logger.info("Start training from scratch...")
    # convert BatchNorm to SyncBatchNorm
    model = SyncBatchNorm.convert_sync_batchnorm(model)
    # prepare to be DDP models
    model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, val_dataloader)
    # closed_form lr_scheduler
    total_steps = len(train_dataloader)*args.epochs
    resume_step = len(train_dataloader)*(start_epoch-1)
    lr_scheduler = ClosedFormCosineLRScheduler(
        optimizer, init_lr, total_steps, resume_step)
    # loss criterion
    criterion = CrossEntropyLoss(weight=torch.tensor([1., 10.]), ignore_index=255).to(
        accelerator.device)  # 
    # training
    # Best acc
    best_miou = 0.
    for e in range(start_epoch, args.epochs+1):
        model.train()
        for i, batch in enumerate(train_dataloader):
            img, mask = batch
            logits = model(img)
            loss = criterion(logits, mask)
            accelerator.backward(loss)
            # clip grad if true
            if args.clip_grad_norm is not None:
                grad_norm = accelerator.clip_grad_norm_(
                    model.parameters(), args.clip_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            # sync before logging
            accelerator.wait_for_everyone()
            ## log and tensorboard
            if accelerator.is_main_process:
                if i % args.log_interval == 0:
                    writer.add_scalar('loss', loss.item(),
                                      (e-1)*len(train_dataloader)+i)
                    lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('lr', lr,
                                      (e-1)*len(train_dataloader)+i)
                    loss_str = f"loss: {loss.item():.4f}"
                    epoch_iter_str = f"Epoch: [{e}] [{i}/{len(train_dataloader)}], "
                    if args.clip_grad_norm is not None:
                        logger.info(
                            epoch_iter_str+f'lr: {lr}, '+loss_str+f', grad_norm: {grad_norm}')
                    else:
                        logger.info(epoch_iter_str+f'lr: {lr}, '+loss_str)

            lr_scheduler.step()
        if accelerator.is_main_process:
            if e % args.save_interval == 0:
                save_path = os.path.join(args.log_dir, f'epoch_{e}.pth')
                torch.save(
                    {'state_dict': model.module.state_dict(), 'epoch': e, 'args': args,
                        'optimizer': optimizer.state_dict()}, save_path)
                logger.info(f"Checkpoint has been saved at {save_path}")
        # start to evaluate
        if accelerator.is_main_process:
            logger.info("Evaluate on validation dataset")
            bar = tqdm(total=len(val_dataloader))
        model.eval()
        preds = []
        gts = []
        for batch in val_dataloader:
            img, mask = batch
            with torch.no_grad():
                logits = model(img)
                pred = accelerator.gather(logits)
                gt = accelerator.gather(mask)
            preds.append(pred)
            gts.append(gt)
            if accelerator.is_main_process:
                bar.update(accelerator.num_processes)
        if accelerator.is_main_process:
            bar.close()
            # compute metrics
            # prepare preds
            preds = torch.cat(preds)[:len(val_dataloader.dataset)]
            preds = average_preds(preds, window=args.clip_length)  # NCHW
            preds = F.softmax(preds, dim=1)
            preds = torch.argmax(preds, dim=1)  # NHW
            # prepare gts
            gts = torch.cat(gts)[:len(val_dataloader.dataset)]  # NTHW
            gts = flat_gts(gts, window=args.clip_length)  # NHW
            # accuarcy
            acc = accuarcy(preds, gts, ignore_index=0, average='micro')
            # mIoU
            miou = mIoU(preds, gts, ignore_index=0)
            logger.info(f"Accuracy on Val dataset: {acc:.4f}")
            logger.info(f"Mean IoU on Val dataset: {miou:.4f}")
            # save preds
            if miou > best_miou:
                best_miou = miou
                val_results_dir = os.path.join(
                    args.log_dir, 'best_val_results')
                os.makedirs(val_results_dir, exist_ok=True)
                imgpaths = flat_paths(val_dataset.imgpaths)
                assert preds.shape[0] == len(imgpaths)
                preds = preds.cpu().numpy()
                for i in range(preds.shape[0]):
                    imgname = imgpaths[i].split('/')[-1]
                    imgpath = os.path.join(val_results_dir, imgname)
                    result = preds[i].astype(np.uint8)
                    result[result == 1] = 255
                    result = Image.fromarray(result)
                    result.save(imgpath)
        # delete unuseful vars
        del preds
        del gts
        accelerator.wait_for_everyone()
def train_function(gpu, world_size, node_rank, gpus, fold_number, group_name):
    import torch.multiprocessing
    torch.multiprocessing.set_sharing_strategy('file_system')

    torch.manual_seed(25)
    np.random.seed(25)

    rank = node_rank * gpus + gpu
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=world_size,
        rank=rank
    )

    device = torch.device("cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu")

    batch_size = 64
    width_size = 416
    init_lr = 1e-4
    end_lr = 1e-6
    n_epochs = 20
    emb_size = 512
    margin = 0.5
    dropout = 0.0
    iters_to_accumulate = 1

    if rank == 0:
        wandb.init(project='shopee_effnet0', group=group_name, job_type=str(fold_number))

        checkpoints_dir_name = 'effnet0_{}_{}_{}'.format(width_size, dropout, group_name)
        os.makedirs(checkpoints_dir_name, exist_ok=True)

        wandb.config.model_name = checkpoints_dir_name
        wandb.config.batch_size = batch_size
        wandb.config.width_size = width_size
        wandb.config.init_lr = init_lr
        wandb.config.n_epochs = n_epochs
        wandb.config.emb_size = emb_size
        wandb.config.dropout = dropout
        wandb.config.iters_to_accumulate = iters_to_accumulate
        wandb.config.optimizer = 'adam'
        wandb.config.scheduler = 'ShopeeScheduler'

    df = pd.read_csv('../../dataset/reliable_validation_tm.csv')
    train_df = df[df['fold_group'] != fold_number]
    train_transforms = alb.Compose([
        alb.RandomResizedCrop(width_size, width_size),
        alb.ShiftScaleRotate(shift_limit=0.1, rotate_limit=30),
        alb.HorizontalFlip(),
        alb.OneOf([
            alb.Sequential([
                alb.HueSaturationValue(hue_shift_limit=50),
                alb.RandomBrightnessContrast(),
            ]),
            alb.FancyPCA(),
            alb.ChannelDropout(),
            alb.ChannelShuffle(),
            alb.RGBShift()
        ]),
        alb.CoarseDropout(max_height=int(width_size*0.1), max_width=int(width_size*0.1)),
        alb.OneOf([
            alb.ElasticTransform(),
            alb.OpticalDistortion(),
            alb.GridDistortion()
        ]),
        alb.Resize(width_size, width_size),
        alb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])
    train_set = ImageDataset(train_df, train_df, '../../dataset/train_images', train_transforms)
    sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
    train_dataloader = DataLoader(train_set, batch_size=batch_size // world_size, shuffle=False, num_workers=4,
                                  sampler=sampler)

    # valid_df = df[df['fold_strat'] == fold_number]
    valid_transforms = alb.Compose([
        alb.Resize(width_size, width_size),
        alb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])
    # valid_set = ImageDataset(train_df, valid_df, '../../dataset/train_images', valid_transforms)
    # valid_dataloader = DataLoader(valid_set, batch_size=batch_size // world_size, shuffle=False, num_workers=4)

    test_df = df[df['fold_group'] == fold_number]
    test_set = ImageDataset(test_df, test_df, '../../dataset/train_images', valid_transforms)
    test_dataloader = DataLoader(test_set, batch_size=batch_size // world_size, shuffle=False, num_workers=4)

    model = EfficientNetArcFace(emb_size, train_df['label_group'].nunique(), device, dropout=dropout,
                                backbone='tf_efficientnet_b0_ns', pretrained=True, margin=margin, is_amp=True)
    model = SyncBatchNorm.convert_sync_batchnorm(model)
    model.to(device)
    model = DistributedDataParallel(model, device_ids=[gpu])

    scaler = GradScaler()
    criterion = CrossEntropyLoss()
    # criterion = LabelSmoothLoss(smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=init_lr)
    # scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs, eta_min=end_lr,
    #                               last_epoch=-1)
    # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2000, T_mult=1,
    #                                         eta_min=end_lr, last_epoch=-1)
    scheduler = ShopeeScheduler(optimizer, lr_start=init_lr,
                                lr_max=init_lr*batch_size, lr_min=end_lr)

    for epoch in range(n_epochs):
        train_loss, train_duration, train_f1 = train_one_epoch(
            model, train_dataloader, optimizer, criterion, device, scaler,
            scheduler=None, iters_to_accumulate=iters_to_accumulate)
        scheduler.step()

        if rank == 0:
            # valid_loss, valid_duration, valid_f1 = evaluate(model, valid_dataloader, criterion, device)
            embeddings = get_embeddings(model, test_dataloader, device)
            embeddings_f1 = validate_embeddings_f1(embeddings, test_df)

            wandb.log({'train_loss': train_loss, 'train_f1': train_f1,
                       'embeddings_f1': embeddings_f1, 'epoch': epoch})

            filename = '{}_foldnum{}_epoch{}_train_loss{}_f1{}'.format(
                checkpoints_dir_name, fold_number+1, epoch+1,
                round(train_loss, 3), round(embeddings_f1, 3))
            torch.save(model.module.state_dict(), os.path.join(checkpoints_dir_name, '{}.pth'.format(filename)))
            # np.savez_compressed(os.path.join(checkpoints_dir_name, '{}.npz'.format(filename)), embeddings=embeddings)

            print('FOLD NUMBER %d\tEPOCH %d:\t'
                  'TRAIN [duration %.3f sec, loss: %.3f, avg f1: %.3f]\t'
                  'VALID EMBEDDINGS [avg f1: %.3f]\tCurrent time %s' %
                  (fold_number + 1, epoch + 1, train_duration,
                   train_loss, train_f1, embeddings_f1,
                   str(datetime.now(timezone('Europe/Moscow')))))

    if rank == 0:
        wandb.finish()
Example #19
0
def main(cfg):
    setup(cfg)
    dataset_names = register_datasets(cfg)
    if cfg.ONLY_REGISTER_DATASETS:
        return {}, cfg
    LOG.info(f"Registered {len(dataset_names)} datasets:" + '\n\t' + '\n\t'.join(dataset_names))

    model = build_model(cfg)

    checkpoint_file = cfg.MODEL.CKPT
    if checkpoint_file:
        if cfg.MODEL.CKPT_REMAPPER:
            if cfg.EVAL_ONLY:
                LOG.warning("Running with 'EVAL_ONLY', but the checkpoint is remapped.")
            checkpoint_file = CHECKPOINT_REMAPPERS[cfg.MODEL.CKPT_REMAPPER](checkpoint_file, model)

        # Batchnorm2D submodules to convert to FrozenBatchnNorm2D.
        modules_to_convert_frozenbb = [
            (name, module) for name, module in model.named_modules() if name in cfg.MODEL.CONVERT_TO_FROZEN_BN_MODULES
        ]
        if len(modules_to_convert_frozenbb) > 0:
            module_names, modules = list(zip(*modules_to_convert_frozenbb))
            LOG.info(
                f"Converting BatchNorm2d -> FrozenBatchNorm2d {len(modules)} submodule(s):" + '\n\t' +
                '\n\t'.join(module_names)
            )
            for module in modules:
                FrozenBatchNorm2d.convert_frozen_batchnorm(module)

        # Some checkpoints contain batchnorm layer with negative value for 'running_var'.
        model = HotFixFrozenBatchNorm2d.convert_frozenbn_to_hotfix_ver(model)
        Checkpointer(model).load(checkpoint_file)

    if cfg.EVAL_ONLY:
        assert cfg.TEST.ENABLED, "'eval-only' mode is not compatible with 'cfg.TEST.ENABLED = False'."
        test_results = do_test(cfg, model, is_last=True)
        if cfg.TEST.AUG.ENABLED:
            test_results.update(do_test(cfg, model, is_last=True, use_tta=True))
        return test_results, cfg

    modules_to_freeze = cfg.MODEL.FREEZE_MODULES
    if modules_to_freeze:
        LOG.info(f"Freezing {len(modules_to_freeze)} submodule(s):" + '\n\t' + '\n\t'.join(modules_to_freeze))
        # `requires_grad=False` must be set *before* wrapping the model with `DistributedDataParallel`
        # modules_to_freeze = [x.strip() for x in cfg.MODEL.FREEZE_MODULES.split(',')]
        # for module_name in cfg.MODEL.FREEZE_MODULES:
        for module_name in modules_to_freeze:
            freeze_submodule(model, module_name)

    if comm.is_distributed():
        assert d2_comm._LOCAL_PROCESS_GROUP is not None
        # Convert all Batchnorm*D to nn.SyncBatchNorm.
        # For faster training, the batch stats are computed over only the GPUs of the same machines (usually 8).
        sync_bn_pg = d2_comm._LOCAL_PROCESS_GROUP if cfg.SOLVER.SYNCBN_USE_LOCAL_WORKERS else None
        model = SyncBatchNorm.convert_sync_batchnorm(model, process_group=sync_bn_pg)
        model = DistributedDataParallel(
            model,
            device_ids=[d2_comm.get_local_rank()],
            broadcast_buffers=False,
            find_unused_parameters=cfg.SOLVER.DDP_FIND_UNUSED_PARAMETERS
        )

    do_train(cfg, model)
    test_results = do_test(cfg, model, is_last=True)
    if cfg.TEST.AUG.ENABLED:
        test_results.update(do_test(cfg, model, is_last=True, use_tta=True))
    return test_results, cfg
Example #20
0
def train_fold(save_dir,
               train_folds,
               val_folds,
               local_rank=0,
               distributed=False,
               pretrain_dir=''):
    folds_data = get_folds_data()

    model = AlaskaModel(PARAMS)
    model.params['nn_module'][1]['pretrained'] = False

    if pretrain_dir:
        pretrain_path = get_best_model_path(pretrain_dir)
        if pretrain_path is not None:
            print(f'Pretrain model path {pretrain_path}')
            load_pretrain_weigths(model, pretrain_path)
        else:
            print(f"Pretrain model not found in '{pretrain_dir}'")

    if USE_AMP:
        initialize_amp(model)

    if distributed:
        model.nn_module = SyncBatchNorm.convert_sync_batchnorm(model.nn_module)
        model.nn_module = DistributedDataParallel(
            model.nn_module.to(local_rank),
            device_ids=[local_rank],
            output_device=local_rank)
        if local_rank:
            model.logger.disabled = True
    else:
        model.set_device(DEVICES)

    if USE_EMA:
        initialize_ema(model, decay=0.9999)
        checkpoint = EmaMonitorCheckpoint
    else:
        checkpoint = MonitorCheckpoint

    for epochs, stage in zip(TRAIN_EPOCHS, STAGE):
        test_transform = get_transforms(train=False)

        if stage == 'train':
            mixer = RandomMixer([BitMix(gamma=0.25), EmptyMix()], p=[0., 1.])
            train_transform = get_transforms(train=True)
        else:
            mixer = EmptyMix()
            train_transform = get_transforms(train=False)

        train_dataset = AlaskaDataset(folds_data,
                                      train_folds,
                                      transform=train_transform,
                                      mixer=mixer)
        val_dataset = AlaskaDataset(folds_data,
                                    val_folds,
                                    transform=test_transform)
        val_sampler = AlaskaSampler(val_dataset, train=False)

        if distributed:
            train_sampler = AlaskaDistributedSampler(train_dataset)
        else:
            train_sampler = AlaskaSampler(train_dataset, train=True)

        train_loader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  num_workers=NUM_WORKERS,
                                  batch_size=BATCH_SIZE)
        val_loader = DataLoader(val_dataset,
                                sampler=val_sampler,
                                num_workers=NUM_WORKERS,
                                batch_size=VAL_BATCH_SIZE)

        callbacks = []
        if local_rank == 0:
            callbacks += [
                checkpoint(save_dir,
                           monitor='val_weighted_auc',
                           max_saves=5,
                           file_format=stage +
                           '-model-{epoch:03d}-{monitor:.6f}.pth'),
                LoggingToFile(save_dir / 'log.txt'),
                LoggingToCSV(save_dir / 'log.csv', append=True)
            ]

        if stage == 'train':
            callbacks += [
                CosineAnnealingLR(T_max=epochs,
                                  eta_min=get_lr(9e-6, WORLD_BATCH_SIZE))
            ]
        elif stage == 'warmup':
            warmup_iterations = epochs * (len(train_sampler) / BATCH_SIZE)
            callbacks += [
                LambdaLR(lambda x: x / warmup_iterations,
                         step_on_iteration=True)
            ]

        if stage == 'train':

            @argus.callbacks.on_epoch_start
            def schedule_mixer_prob(state):
                bitmix_prob = state.epoch / epochs
                mixer.p = [bitmix_prob, 1 - bitmix_prob]
                state.logger.info(f"Mixer probabilities {mixer.p}")

            callbacks += [schedule_mixer_prob]

        if distributed:

            @argus.callbacks.on_epoch_complete
            def schedule_sampler(state):
                train_sampler.set_epoch(state.epoch + 1)

            callbacks += [schedule_sampler]

        metrics = ['weighted_auc', Accuracy('stegano'), Accuracy('quality')]

        model.fit(train_loader,
                  val_loader=val_loader,
                  num_epochs=epochs,
                  callbacks=callbacks,
                  metrics=metrics)
def _test_func(
    rank,
    world_size,
    fsdp_config,
    fsdp_wrap_bn,
    ddp_mixed_precision,
    tempfile_name,
    unused,
    state_before,
    inputs,
    rank_0_output,
    state_after,
):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    ddp = True
    if fsdp_config:
        ddp = False
        assert isinstance(fsdp_config, dict), str(fsdp_config)
        if fsdp_config["mixed_precision"]:
            # To match DDP in AMP -O1, we need fp32 reduce scatter.
            fsdp_config["fp32_reduce_scatter"] = True

    model = Model()
    model.load_state_dict(state_before)
    model = model.cuda()

    class DummyScaler:
        def scale(self, loss):
            return loss

        def step(self, optim):
            optim.step()

        def update(self):
            pass

    scaler = DummyScaler()
    if ddp:
        model = SyncBatchNorm.convert_sync_batchnorm(model)
        model = DDP(model, device_ids=[rank], broadcast_buffers=True)
        if ddp_mixed_precision:
            scaler = GradScaler()
    else:
        # Note, different rank may wrap in different order due to different random
        # seeds. But results should be the same.
        if random.randint(0, 1) == 0:
            print(f"auto_wrap_bn {fsdp_wrap_bn}, then convert_sync_batchnorm")
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
            model = _bn_converter(model)
        else:
            print(f"convert_sync_batchnorm, then auto_wrap_bn {fsdp_wrap_bn}")
            model = _bn_converter(model)
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
        model = FSDP(model, **fsdp_config).cuda()
        if fsdp_config["mixed_precision"]:
            scaler = ShardedGradScaler()
        # Print the model for verification.
        if rank == 0:
            print(model)
    optim = SGD(model.parameters(), lr=0.1)
    loss_func = CrossEntropyLoss()

    for in_data in inputs[rank]:
        in_data = in_data.cuda()
        context = contextlib.suppress()
        if ddp and ddp_mixed_precision:
            in_data = in_data.half()
            context = torch.cuda.amp.autocast(enabled=True)
        if not ddp and fsdp_config["mixed_precision"]:
            context = torch.cuda.amp.autocast(enabled=True)
        with context:
            out = model(in_data)
            fake_label = torch.zeros(1, dtype=torch.long).cuda()
            loss = loss_func(out.unsqueeze(0), fake_label)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        optim.zero_grad()

    if ddp:
        # Save the rank 0 state_dict to the output file.
        if rank == 0:
            state_after = model.module.cpu().state_dict()
            torch.save(state_after, rank_0_output)
    else:
        model.assert_state(TrainingState.IDLE)
        # Ensure final state equals to the state_after.
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        # Change False to True to enable this when you want to debug the mismatch.
        if False and rank == 0:

            def dump(d):
                for k, v in d.items():
                    print(k, v)

            dump(state_after)
            dump(fsdp_state)
        assert objects_are_equal(state_after, fsdp_state, raise_exception=True)

    teardown()
Example #22
0
 def apply_dist(self, network):
     # add syncBN if necessary
     network = SyncBatchNorm.convert_sync_batchnorm(network)
     network_dist = DDP(network.cuda(self.rank), device_ids=[self.rank])
     # print('Apply dist for on rank : {}'.format(self.rank))
     return network_dist
Example #23
0
    def __init__(self, params):
        """Creates a Trainer.
    """
        utils.set_default_param_values_and_env_vars(params)
        self.params = params

        # Setup logging & log the version.
        utils.setup_logging(params.logging_verbosity)

        self.job_name = self.params.job_name  # "" for local training
        self.is_distributed = bool(self.job_name)
        self.task_index = self.params.task_index
        self.local_rank = self.params.local_rank
        self.start_new_model = self.params.start_new_model
        self.train_dir = self.params.train_dir
        self.num_gpus = self.params.num_gpus
        if self.num_gpus and not self.is_distributed:
            self.batch_size = self.params.batch_size * self.num_gpus
        else:
            self.batch_size = self.params.batch_size

        # print self.params parameters
        if self.start_new_model and self.local_rank == 0:
            pp = pprint.PrettyPrinter(indent=2, compact=True)
            logging.info(pp.pformat(params.values()))

        if self.local_rank == 0:
            logging.info("PyTorch version: {}.".format(torch.__version__))
            logging.info("NCCL Version {}".format(torch.cuda.nccl.version()))
            logging.info("Hostname: {}.".format(socket.gethostname()))

        if self.is_distributed:
            self.num_nodes = len(params.worker_hosts.split(';'))
            self.world_size = self.num_nodes * self.num_gpus
            self.rank = self.task_index * self.num_gpus + self.local_rank
            dist.init_process_group(backend='nccl',
                                    init_method='env://',
                                    timeout=datetime.timedelta(seconds=30))
            if self.local_rank == 0:
                logging.info('World Size={} => Total batch size {}'.format(
                    self.world_size, self.batch_size * self.world_size))
            self.is_master = bool(self.rank == 0)
        else:
            self.world_size = 1
            self.is_master = True

        # create a mesage builder for logging
        self.message = utils.MessageBuilder()

        # load reader and model
        self.reader = readers_config[self.params.dataset](self.params,
                                                          self.batch_size,
                                                          self.num_gpus,
                                                          is_training=True)

        # load model
        self.model = model_config.get_model_config(self.params.model,
                                                   self.params.dataset,
                                                   self.params,
                                                   self.reader.n_classes,
                                                   is_training=True)
        # add normalization as first layer of model
        if self.params.add_normalization:
            # In order to certify radii in original coordinates rather than standardized coordinates, we
            # add the noise _before_ standardizing, which is why we have standardization be the first
            # layer of the classifier rather than as a part of preprocessing as is typical.
            normalize_layer = self.reader.get_normalize_layer()
            self.model = torch.nn.Sequential(normalize_layer, self.model)

        # define DistributedDataParallel job
        self.model = SyncBatchNorm.convert_sync_batchnorm(self.model)
        torch.cuda.set_device(params.local_rank)
        self.model = self.model.cuda()
        i = params.local_rank
        self.model = DistributedDataParallel(self.model,
                                             device_ids=[i],
                                             output_device=i)
        if self.local_rank == 0:
            logging.info('Model defined with DistributedDataParallel')

        # define set for saved ckpt
        self.saved_ckpts = set([0])

        # define optimizer
        self.optimizer = utils.get_optimizer(self.params.optimizer,
                                             self.params.optimizer_params,
                                             self.params.init_learning_rate,
                                             self.params.weight_decay,
                                             self.model.parameters())

        # define learning rate scheduler
        self.scheduler = utils.get_scheduler(self.optimizer,
                                             self.params.lr_scheduler,
                                             self.params.lr_scheduler_params)

        # if start_new_model is False, we restart training
        if not self.start_new_model:
            if self.local_rank == 0:
                logging.info('Restarting training...')
            self._load_state()

        # define Lipschitz regularization module
        if self.params.lipschitz_regularization:
            if self.local_rank == 0:
                logging.info(
                    "Lipschitz regularization with decay {}, start after epoch {}"
                    .format(self.params.lipschitz_decay,
                            self.params.lipschitz_start_epoch))
            self.lipschitz = LipschitzRegularization(self.model, self.params,
                                                     self.reader,
                                                     self.local_rank)

        # exponential moving average
        self.ema = None
        if getattr(self.params, 'ema', False) > 0:
            self.ema = utils.EMA(self.params.ema)

        # if adversarial training, create the attack class
        if self.params.adversarial_training:
            if self.local_rank == 0:
                logging.info('Adversarial Training')
            attack_params = self.params.adversarial_training_params
            if 'eps_iter' in attack_params.keys(
            ) and attack_params['eps_iter'] == -1:
                eps = attack_params['eps']
                n_iter = attack_params['nb_iter']
                attack_params['eps_iter'] = eps / n_iter * 2
                if self.local_rank == 0:
                    logging.info('Learning rate for attack: {}'.format(
                        attack_params['eps_iter']))
            self.attack = utils.get_attack(
                self.model, self.reader.n_classes,
                self.params.adversarial_training_name, attack_params)

        # init noise
        if self.params.adaptive_noise and self.params.additive_noise:
            raise ValueError(
                "Adaptive and Additive Noise should not be set together")
        if self.params.adaptive_noise:
            if self.local_rank == 0:
                logging.info('Training with Adaptive Noise: {} {}'.format(
                    self.params.noise_distribution, self.params.noise_scale))
        elif self.params.additive_noise:
            if self.local_rank == 0:
                logging.info('Training with Noise: {} {}'.format(
                    self.params.noise_distribution, self.params.noise_scale))
        if self.params.adaptive_noise or self.params.additive_noise:
            self.noise = utils.Noise(self.params)

        # stability training
        if self.params.stability_training:
            if self.local_rank == 0:
                logging.info("Training with Stability Training: {}".format(
                    self.params.stability_training_lambda))
            if not any([
                    self.params.adversarial_training,
                    self.params.adaptive_noise, self.params.additive_noise
            ]):
                raise ValueError(
                    "Adversarial Training or Adaptive Noise should be activated"
                )
    def worker(self, gpu_id: int):
        """
        What created in this function is only used in this process and not shareable
        """
        if self.seed is not None:
            make_deterministic(self.seed)
        self.current_rank = self.rank
        if self.distributed:
            if self.multiprocessing:
                # For multiprocessing distributed training, rank needs to be the
                # global rank among all the processes
                self.current_rank = self.rank * self.ngpus_per_node + gpu_id
            dist.init_process_group(backend=self.dist_backend,
                                    init_method=self.dist_url,
                                    world_size=self.world_size,
                                    rank=self.current_rank)
        # set up process logger
        self.logger = logging.getLogger("worker_rank_{}".format(
            self.current_rank))
        self.logger.propagate = False
        handler = QueueHandler(self.logger_queue)
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)

        # only write in master process
        if self.current_rank == 0:
            self.tb_writer = self.tb_writer_constructor()

        self.logger.info("Use GPU: %d for training, current rank: %d", gpu_id,
                         self.current_rank)
        # get dataset
        train_dataset = get_dataset(self.global_cfg["dataset"]["name"],
                                    self.global_cfg["dataset"]["root"],
                                    split="train")
        val_dataset = get_dataset(self.global_cfg["dataset"]["name"],
                                  self.global_cfg["dataset"]["root"],
                                  split="val")
        # create model
        self.model = get_model(
            model_name=self.global_cfg["model"]["name"],
            num_classes=self.global_cfg["dataset"]["n_classes"])

        self.device = torch.device("cuda:{}".format(gpu_id))
        self.model.to(self.device)

        batch_size = self.global_cfg["training"]["batch_size"]
        n_workers = self.global_cfg["training"]["num_workers"]
        if self.distributed:
            batch_size = int(batch_size / self.ngpus_per_node)
            n_workers = int(
                (n_workers + self.ngpus_per_node - 1) / self.ngpus_per_node)
            if self.global_cfg["training"]["sync_bn"]:
                self.model = SyncBatchNorm.convert_sync_batchnorm(self.model)
            self.model = DistributedDataParallel(self.model,
                                                 device_ids=[gpu_id])
        self.logger.info("batch_size: {}, workers: {}".format(
            batch_size, n_workers))

        # define loss function (criterion) and optimizer
        self.loss_fn = CrossEntropyLoss().to(self.device)

        optimizer_cls = get_optimizer(self.global_cfg["training"]["optimizer"])
        optimizer_params = copy.deepcopy(
            self.global_cfg["training"]["optimizer"])
        optimizer_params.pop("name")
        self.optimizer: Optimizer = optimizer_cls(self.model.parameters(),
                                                  **optimizer_params)
        self.logger.info("Loaded optimizer:\n%s", self.optimizer)

        # scheduler
        self.scheduler = get_scheduler(
            self.optimizer, self.global_cfg["training"]["lr_schedule"])

        if self.distributed:
            train_sampler = DistributedSampler(train_dataset,
                                               shuffle=True,
                                               drop_last=True)
            val_sampler = DistributedSampler(val_dataset, shuffle=False)
        else:
            train_sampler = RandomSampler(train_dataset)
            val_sampler = SequentialSampler(val_dataset)

        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  num_workers=n_workers,
                                  pin_memory=True,
                                  sampler=train_sampler)

        self.val_loader = DataLoader(val_dataset,
                                     batch_size=batch_size,
                                     num_workers=n_workers,
                                     pin_memory=True,
                                     sampler=val_sampler)
        self.logger.info(
            "Load dataset done\nTraining: %d imgs, %d batchs\nEval: %d imgs, %d batchs",
            len(train_dataset), len(train_loader), len(val_dataset),
            len(self.val_loader))
        iter_generator = make_iter_dataloader(train_loader)

        while self.iter < self.global_cfg["training"]["train_iters"]:
            img, label = next(iter_generator)
            self.train_iter(img, label)

            def is_val():
                p1 = self.iter != 0
                p2 = (self.iter +
                      1) % self.global_cfg["training"]["val_interval"] == 0
                p3 = self.iter == self.global_cfg["training"]["train_iters"] - 1
                return (p1 and p2) or p3

            # have a validation
            if is_val():
                self.validate()
            # end one iteration
            self.iter += 1
def main(batch_size, rank, world_size):

    import os
    import tqdm
    import torch
    import tempfile

    from torch import optim
    from torch import distributed as dist
    from torch.nn import SyncBatchNorm
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter

    from data.aug.compose import Compose
    from data.aug import ops
    from data.dataset import HRSC2016

    from model.rdd import RDD
    from model.backbone import resnet

    from utils.adjust_lr import adjust_lr_multi_step

    torch.manual_seed(0)
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl",
                            init_method='env://',
                            rank=rank,
                            world_size=world_size)

    backbone = resnet.resnet101

    dir_dataset = '<replace with your local path>'
    dir_save = '<replace with your local path>'

    dir_weight = os.path.join(dir_save, 'weight')
    dir_log = os.path.join(dir_save, 'log')
    os.makedirs(dir_weight, exist_ok=True)
    if rank == 0:
        writer = SummaryWriter(dir_log)

    indexes = [
        int(os.path.splitext(path)[0]) for path in os.listdir(dir_weight)
    ]
    current_step = max(indexes) if indexes else 0

    image_size = 768
    lr = 1e-3
    batch_size //= world_size
    num_workers = 4

    max_step = 12000
    lr_cfg = [[7500, lr], [max_step, lr / 10]]
    warm_up = [500, lr / 50, lr]
    save_interval = 1000

    aug = Compose([
        ops.ToFloat(),
        ops.PhotometricDistort(),
        ops.RandomHFlip(),
        ops.RandomVFlip(),
        ops.RandomRotate90(),
        ops.ResizeJitter([0.8, 1.2]),
        ops.PadSquare(),
        ops.Resize(image_size),
    ])
    dataset = HRSC2016(dir_dataset, ['trainval'], aug)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, world_size, rank)
    batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                  batch_size,
                                                  drop_last=True)
    loader = DataLoader(dataset,
                        batch_sampler=batch_sampler,
                        num_workers=num_workers,
                        collate_fn=dataset.collate)
    num_classes = len(dataset.names)

    prior_box = {
        'strides': [8, 16, 32, 64, 128],
        'sizes': [3] * 5,
        'aspects': [[1.5, 3, 5, 8]] * 5,
        'scales': [[2**0, 2**(1 / 3), 2**(2 / 3)]] * 5,
    }

    cfg = {
        'prior_box': prior_box,
        'num_classes': num_classes,
        'extra': 2,
    }
    device = torch.device(f'cuda:{rank}')
    model = RDD(backbone(fetch_feature=True), cfg)
    model.build_pipe(shape=[2, 3, image_size, image_size])
    model = SyncBatchNorm.convert_sync_batchnorm(model)
    model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    if current_step:
        model.module.load_state_dict(
            torch.load(os.path.join(dir_weight, '%d.pth' % current_step),
                       map_location=device))
    else:
        checkpoint = os.path.join(tempfile.gettempdir(), "initial-weights.pth")
        if rank == 0:
            model.module.init()
            torch.save(model.module.state_dict(), checkpoint)
        dist.barrier()
        if rank > 0:
            model.module.load_state_dict(
                torch.load(checkpoint, map_location=device))
        dist.barrier()
        if rank == 0:
            os.remove(checkpoint)

    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    training = True
    while training and current_step < max_step:
        tqdm_loader = tqdm.tqdm(loader) if rank == 0 else loader
        for images, targets, infos in tqdm_loader:
            current_step += 1
            adjust_lr_multi_step(optimizer, current_step, lr_cfg, warm_up)

            images = images.cuda() / 255
            losses = model(images, targets)
            loss = sum(losses.values())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if rank == 0:
                for key, val in list(losses.items()):
                    losses[key] = val.item()
                    writer.add_scalar(key, val, global_step=current_step)
                writer.flush()
                tqdm_loader.set_postfix(losses)
                tqdm_loader.set_description(f'<{current_step}/{max_step}>')

                if current_step % save_interval == 0:
                    save_path = os.path.join(dir_weight,
                                             '%d.pth' % current_step)
                    state_dict = model.module.state_dict()
                    torch.save(state_dict, save_path)
                    cache_file = os.path.join(
                        dir_weight, '%d.pth' % (current_step - save_interval))
                    if os.path.exists(cache_file):
                        os.remove(cache_file)

            if current_step >= max_step:
                training = False
                if rank == 0:
                    writer.close()
                break
Example #26
0
            'drop_rate': 0.2,
            'drop_path_rate': 0.2,
        },
        'optimizer': ('AdamW', {
            'lr': get_linear_scaled_lr(args.lr, world_batch_size)
        }),
        'loss': 'CrossEntropyLoss',
        'device': 'cuda',
        'iter_size': args.iter_size,
        'amp': args.amp
    }

    model = CifarModel(params)

    if distributed:
        model.nn_module = SyncBatchNorm.convert_sync_batchnorm(model.nn_module)
        model.nn_module = DistributedDataParallel(model.nn_module.to(local_rank),
                                                  device_ids=[local_rank],
                                                  output_device=local_rank)
        if local_rank:
            model.logger.disabled = True
    else:
        model.set_device('cuda')

    callbacks = []
    if local_rank == 0:
        callbacks += [
            MonitorCheckpoint(dir_path=EXPERIMENT_DIR,
                              monitor='val_dist_accuracy', max_saves=3),
            LoggingToCSV(EXPERIMENT_DIR / 'log.csv'),
            LoggingToFile(EXPERIMENT_DIR / 'log.txt')
Example #27
0
def main():
    global CONF, DEVICE, TB_LOGGER, RANK, WORLD_SIZE

    parser = ArgumentParser(f"Probabilistic quantization neural networks.")
    parser.add_argument("--conf-path",
                        "-c",
                        required=True,
                        help="path of configuration file")
    parser.add_argument("--port",
                        "-p",
                        type=int,
                        help="port of distributed backend")
    parser.add_argument("--solo",
                        "-s",
                        action="store_true",
                        help="run this script in solo (local machine) mode")
    parser.add_argument("--evaluate-only",
                        "-e",
                        action="store_true",
                        help="evaluate trained model")
    parser.add_argument("--vis-only",
                        "-v",
                        action="store_true",
                        help="visualize trained activations")
    parser.add_argument("--extra",
                        "-x",
                        type=json.loads,
                        help="extra configurations in json format")
    parser.add_argument("--comment",
                        "-m",
                        default="",
                        help="comment for each experiment")
    parser.add_argument("--debug",
                        action="store_true",
                        help="logging debug info")
    args = parser.parse_args()

    with open(args.conf_path, "r", encoding="utf-8") as f:
        CONF = yaml.load(f, Loader=yaml.SafeLoader)
        cli_conf = {
            k: v
            for k, v in vars(args).items()
            if k != "extra" and not k.startswith("__")
        }
        if args.extra is not None:
            cli_conf.update(args.extra)
        CONF = update_config(CONF, cli_conf)
        CONF = EasyDict(CONF)

    RANK, WORLD_SIZE = dist_init(CONF.port, CONF.arch.gpu_per_model, CONF.solo)
    CONF.dist = WORLD_SIZE > 1

    if CONF.arch.gpu_per_model == 1:
        fp_device = get_devices(CONF.arch.gpu_per_model, RANK)
        q_device = fp_device
    else:
        fp_device, q_device = get_devices(CONF.arch.gpu_per_model, RANK)
    DEVICE = fp_device

    logger = init_log(LOGGER_NAME, CONF.debug,
                      f"{CONF.log.file}_{EXP_DATETIME}.log")

    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.backends.cudnn.benchmark = True

    logger.debug(f"configurations:\n{pformat(CONF)}")
    logger.debug(f"fp device: {fp_device}")
    logger.debug(f"quant device: {q_device}")

    logger.debug(f"building dataset {CONF.data.dataset.type}...")
    train_set, val_set = get_dataset(CONF.data.dataset.type,
                                     **CONF.data.dataset.args)
    logger.debug(f"building training loader...")
    train_loader = DataLoader(train_set,
                              sampler=IterationSampler(
                                  train_set,
                                  rank=RANK,
                                  world_size=WORLD_SIZE,
                                  **CONF.data.train_sampler_conf),
                              **CONF.data.train_loader_conf)
    logger.debug(f"building validation loader...")
    val_loader = DataLoader(
        val_set,
        sampler=DistributedSampler(val_set) if CONF.dist else None,
        **CONF.data.val_loader_conf)

    logger.debug(f"building model `{CONF.arch.type}`...")
    model = models.__dict__[CONF.arch.type](**CONF.arch.args).to(
        DEVICE, non_blocking=True)
    if CONF.dist and CONF.arch.sync_bn:
        model = SyncBatchNorm.convert_sync_batchnorm(model)
        model._reinit_multi_domain()
    logger.debug(f"build model {model.__class__.__name__} done:\n{model}")

    param_groups = model.get_param_group(*CONF.param_group.groups,
                                         **CONF.param_group.args)
    opt = HybridOpt(param_groups, CONF.param_group.conf, **CONF.opt.args)
    scheduler = IterationScheduler(
        opt.get_schedulers(),
        CONF.schedule.opt_cfgs,
        CONF.schedule.variable_cfgs,
        iters_per_epoch=len(train_set) // WORLD_SIZE //
        CONF.BATCH_SIZE_PER_GPU,
        quant_start_iter=CONF.schedule.quant_start_iter,
        total_iters=len(train_loader),
        dynamic_variable_scale=CONF.schedule.dynamic_variable_scale)

    if CONF.dist:
        logger.debug(f"building DDP model...")
        model, model_without_ddp = get_ddp_model(model,
                                                 devices=(fp_device, q_device),
                                                 debug=CONF.debug)
    else:
        model_without_ddp = model

    if CONF.log.tb_dir is not None and RANK == 0 and not CONF.evaluate_only:
        tb_dir = f"{EXP_DATETIME}_{CONF.comment}" if CONF.comment is not "" else f"{EXP_DATETIME}"
        tb_dir = os.path.join(CONF.log.tb_dir, tb_dir)
        logger.debug(f"creating TensorBoard at: {tb_dir}...")
        os.makedirs(tb_dir, exist_ok=True)
        TB_LOGGER = SummaryWriter(tb_dir)

    if CONF.resume.path is not None:
        if CONF.resume.path == "latest":
            resume_path = os.path.join(CONF.ckpt.dir, "ckpt_latest.pth")
        elif CONF.resume.path == "best":
            resume_path = os.path.join(CONF.ckpt.dir, "ckpt_best.pth")
        else:
            resume_path = CONF.resume.path
        logger.debug(f"loading checkpoint at: {resume_path}...")
        with open(resume_path, "rb") as f:
            ckpt = torch.load(f, DEVICE)
            model_dict = ckpt["model"] if "model" in ckpt.keys() else ckpt
            try:
                model_without_ddp.load_state_dict(model_dict, strict=False)
            except RuntimeError as e:
                logger.warning(e)
            logger.debug(f"accuracy in state_dict: {ckpt['accuracy']}")
            if CONF.resume.load_opt:
                logger.debug(f"recovering optimizer...")
                opt.load_state_dict(ckpt["opt"])
            if CONF.resume.load_scheduler:
                scheduler.load_state_dict(ckpt["scheduler"])
                train_loader.sampler.set_last_iter(scheduler.last_iter)
                logger.debug(
                    f"recovered opt at iteration: {scheduler.last_iter}")

    if CONF.teacher_arch is not None:
        logger.debug(f"building FP teacher model {CONF.teacher_arch.type}...")
        teacher = models.__dict__[CONF.teacher_arch.type](
            **CONF.teacher_arch.args).to(DEVICE, non_blocking=True)
        with open(CONF.teacher_arch.ckpt, "rb") as f:
            ckpt = torch.load(f, DEVICE)
            teacher.load_state_dict(ckpt)
        for p in teacher.parameters():
            p.requires_grad = False
    else:
        teacher = None

    logger.debug(f"building criterion {CONF.loss.type}...")
    criterion = get_loss(CONF.loss.type, **CONF.loss.args)

    if CONF.debug:
        num_params = 0
        numel_params = 0
        opt_conf = []
        for p in opt.get_param_groups():
            num_params += len(p["params"])
            for param in p["params"]:
                numel_params += param.numel()
            opt_conf.append({k: v for k, v in p.items() if k != "params"})
        logger.debug(f"number of parameter tensors: {num_params}")
        logger.debug(
            f"total numel of parameters: {numel_params / 1024 / 1024:.2f}M")
        logger.debug(f"optimizer conf:\n{pformat(opt_conf)}")

    if CONF.diagnose.enabled:
        logger.debug(
            f"building diagnoser `{CONF.diagnose.diagnoser.type}` with conf: "
            f"\n{pformat(CONF.diagnose.diagnoser.args)}")
        model = get_diagnoser(CONF.diagnose.diagnoser.type,
                              model,
                              logger=TB_LOGGER,
                              **CONF.diagnose.diagnoser.args)
        get_tasks(model,
                  CONF.diagnose.tasks)  # TODO: should we preserve these tasks?

    if CONF.vis_only:
        logger.info("collecting activations...")
        save_activation(model_without_ddp, val_loader, CONF.vis_path,
                        *CONF.vis_names)
        return

    if CONF.evaluate_only:
        if CONF.eval.calibrate:
            logger.info(
                f"calibrating quantization ranges at iteration {scheduler.last_iter}..."
            )
            model_without_ddp.update_ddp_quant_param(
                model,
                val_loader,
                CONF.quant.calib.steps,
                CONF.quant.calib.gamma,
                CONF.quant.calib.update_bn,
            )
        logger.info(f"[Step {scheduler.last_iter}]: evaluating...")
        evaluate(model, val_loader, enable_quant=CONF.eval.quant, verbose=True)
        return

    train(model, criterion, train_loader, val_loader, opt, scheduler, teacher)