def build_model(self):
        '''
        Perform the model building operation only. ``Do not include loading weights``.

        The interface must be used in the model building process by init function.
        Don't implement like below, InterfaceImplement is your implementation of interface class:

        Returns:
            torch.nn.Module: The Pytorch GPU model.
        '''
        if hasattr(self.config, 'lms'):
            if self.config.lms.enable:
                torch.cuda.set_enabled_lms(True)
                byte_limit = self.config.lms.kwargs.limit * (1 << 30)
                torch.cuda.set_limit_lms(byte_limit)
                self.logger.info(
                    'Enable large model support, limit of {}G!'.format(
                        self.config.lms.kwargs.limit))

        self.model = model_entry(self.config.model)
        self.model.cuda()
        # count flops and params
        count_params(self.model)
        count_flops(self.model,
                    input_shape=[
                        1, 3, self.config.data.input_size,
                        self.config.data.input_size
                    ])
        # handle fp16
        if self.config.optimizer.type in [
                'FP16SGD', 'FusedFP16SGD', 'FP16RMSprop'
        ]:
            self.fp16 = True
        else:
            self.fp16 = False

        if self.fp16:
            # if you have modules that must use fp32 parameters, and need fp32 input
            # try use link.fp16.register_float_module(your_module)
            # if you only need fp32 parameters set cast_args=False when call this
            # function, then call link.fp16.init() before call model.half()
            if self.config.optimizer.get('fp16_normal_bn', False):
                self.logger.info('using normal bn for fp16')
                link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                                cast_args=False)
                link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                                cast_args=False)
            if self.config.optimizer.get('fp16_normal_fc', False):
                self.logger.info('using normal fc for fp16')
                link.fp16.register_float_module(torch.nn.Linear,
                                                cast_args=True)
            link.fp16.init()
            self.model.half()

        self.model = DistModule(self.model, self.config.dist.sync)
        return self.model
Exemple #2
0
    def build_model(self):
        if hasattr(self.config, 'lms'):
            if self.config.lms.enable:
                torch.cuda.set_enabled_lms(True)
                byte_limit = self.config.lms.kwargs.limit * (1 << 30)
                torch.cuda.set_limit_lms(byte_limit)
                self.logger.info(
                    'Enable large model support, limit of {}G!'.format(
                        self.config.lms.kwargs.limit))

        self.model = model_entry(self.config.model)
        self.prototype_info.model = self.config.model.type
        self.model.cuda()

        count_params(self.model)
        count_flops(self.model,
                    input_shape=[
                        1, 3, self.config.data.input_size,
                        self.config.data.input_size
                    ])

        # handle fp16
        if self.config.optimizer.type == 'FP16SGD' or \
           self.config.optimizer.type == 'FusedFP16SGD' or \
           self.config.optimizer.type == 'FP16RMSprop':
            self.fp16 = True
        else:
            self.fp16 = False

        if self.fp16:
            # if you have modules that must use fp32 parameters, and need fp32 input
            # try use link.fp16.register_float_module(your_module)
            # if you only need fp32 parameters set cast_args=False when call this
            # function, then call link.fp16.init() before call model.half()
            if self.config.optimizer.get('fp16_normal_bn', False):
                self.logger.info('using normal bn for fp16')
                link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                                cast_args=False)
                link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                                cast_args=False)
            if self.config.optimizer.get('fp16_normal_fc', False):
                self.logger.info('using normal fc for fp16')
                link.fp16.register_float_module(torch.nn.Linear,
                                                cast_args=True)
            link.fp16.init()
            self.model.half()

        self.model = DistModule(self.model, self.config.dist.sync)

        if 'model' in self.state:
            load_state_model(self.model, self.state['model'])
Exemple #3
0
    def build_model(self):
        encoder = model_entry(self.config.model)
        self.model = SimCLR(encoder)
        self.model.cuda()
        count_params(self.model.encoder)
        count_flops(self.model.encoder,
                    input_shape=[
                        1, 3, self.config.data.input_size,
                        self.config.data.input_size
                    ])

        # handle fp16
        if self.config.optimizer.type in [
                'FP16SGD', 'FusedFP16SGD', 'FP16RMSprop'
        ]:
            self.fp16 = True
        else:
            self.fp16 = False

        if self.fp16:
            # if you have modules that must use fp32 parameters, and need fp32 input
            # try use link.fp16.register_float_module(your_module)
            # if you only need fp32 parameters set cast_args=False when call this
            # function, then call link.fp16.init() before call model.half()
            if self.config.optimizer.get('fp16_normal_bn', False):
                self.logger.info('using normal bn for fp16')
                link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                                cast_args=False)
                link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                                cast_args=False)
            if self.config.optimizer.get('fp16_normal_fc', False):
                self.logger.info('using normal fc for fp16')
                link.fp16.register_float_module(torch.nn.Linear,
                                                cast_args=True)
            link.fp16.init()
            self.model.half()

        self.model = DistModule(self.model, self.config.dist.sync)

        if 'model' in self.state:
            load_state_model(self.model, self.state['model'])
    def build_model_helper(config_dict=None):
        '''
        Static function. Build a model from the given config_dict.

        Args:
            config (dict): the config that contains model information

        Returns:
            torch.nn.Module: The Pytorch GPU model.
        '''
        if not isinstance(config_dict, EasyDict):
            config_dict = EasyDict(config_dict)
        model = model_entry(config_dict.model)
        model.cuda()

        if config_dict.optimizer.type in [
                'FP16SGD', 'FusedFP16SGD', 'FP16RMSprop'
        ]:
            fp16 = True
        else:
            fp16 = False

        if fp16:
            # if you have modules that must use fp32 parameters, and need fp32 input
            # try use link.fp16.register_float_module(your_module)
            # if you only need fp32 parameters set cast_args=False when call this
            # function, then call link.fp16.init() before call model.half()
            if config_dict.optimizer.get('fp16_normal_bn', False):
                link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                                cast_args=False)
                link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                                cast_args=False)
            if config_dict.optimizer.get('fp16_normal_fc', False):
                link.fp16.register_float_module(torch.nn.Linear,
                                                cast_args=True)
            link.fp16.init()
            model.half()

        model = DistModule(model, config_dict.dist.sync)
        return model
Exemple #5
0
class ClsSolver(BaseSolver):
    def __init__(self, config_file):
        self.config_file = config_file
        self.prototype_info = EasyDict()
        self.config = parse_config(config_file)
        self.setup_env()
        self.build_model()
        self.build_optimizer()
        self.build_data()
        self.build_lr_scheduler()
        send_info(self.prototype_info)

    def setup_env(self):
        # dist
        self.dist = EasyDict()
        self.dist.rank, self.dist.world_size = link.get_rank(
        ), link.get_world_size()
        self.prototype_info.world_size = self.dist.world_size
        # directories
        self.path = EasyDict()
        self.path.root_path = os.path.dirname(self.config_file)
        self.path.save_path = os.path.join(self.path.root_path, 'checkpoints')
        self.path.event_path = os.path.join(self.path.root_path, 'events')
        self.path.result_path = os.path.join(self.path.root_path, 'results')
        makedir(self.path.save_path)
        makedir(self.path.event_path)
        makedir(self.path.result_path)
        # tb_logger
        if self.dist.rank == 0:
            self.tb_logger = SummaryWriter(self.path.event_path)
        # logger
        create_logger(os.path.join(self.path.root_path, 'log.txt'))
        self.logger = get_logger(__name__)
        self.logger.info(f'config: {pprint.pformat(self.config)}')
        if 'SLURM_NODELIST' in os.environ:
            self.logger.info(f"hostnames: {os.environ['SLURM_NODELIST']}")
        # load pretrain checkpoint
        if hasattr(self.config.saver, 'pretrain'):
            self.state = torch.load(self.config.saver.pretrain.path, 'cpu')
            self.logger.info(
                f"Recovering from {self.config.saver.pretrain.path}, keys={list(self.state.keys())}"
            )
            if hasattr(self.config.saver.pretrain, 'ignore'):
                self.state = modify_state(self.state,
                                          self.config.saver.pretrain.ignore)
        else:
            self.state = {}
            self.state['last_iter'] = 0
        # others
        torch.backends.cudnn.benchmark = True

    def build_model(self):
        if hasattr(self.config, 'lms'):
            if self.config.lms.enable:
                torch.cuda.set_enabled_lms(True)
                byte_limit = self.config.lms.kwargs.limit * (1 << 30)
                torch.cuda.set_limit_lms(byte_limit)
                self.logger.info(
                    'Enable large model support, limit of {}G!'.format(
                        self.config.lms.kwargs.limit))

        self.model = model_entry(self.config.model)
        self.prototype_info.model = self.config.model.type
        self.model.cuda()

        count_params(self.model)
        count_flops(self.model,
                    input_shape=[
                        1, 3, self.config.data.input_size,
                        self.config.data.input_size
                    ])

        # handle fp16
        if self.config.optimizer.type == 'FP16SGD' or \
           self.config.optimizer.type == 'FusedFP16SGD' or \
           self.config.optimizer.type == 'FP16RMSprop':
            self.fp16 = True
        else:
            self.fp16 = False

        if self.fp16:
            # if you have modules that must use fp32 parameters, and need fp32 input
            # try use link.fp16.register_float_module(your_module)
            # if you only need fp32 parameters set cast_args=False when call this
            # function, then call link.fp16.init() before call model.half()
            if self.config.optimizer.get('fp16_normal_bn', False):
                self.logger.info('using normal bn for fp16')
                link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                                cast_args=False)
                link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                                cast_args=False)
            if self.config.optimizer.get('fp16_normal_fc', False):
                self.logger.info('using normal fc for fp16')
                link.fp16.register_float_module(torch.nn.Linear,
                                                cast_args=True)
            link.fp16.init()
            self.model.half()

        self.model = DistModule(self.model, self.config.dist.sync)

        if 'model' in self.state:
            load_state_model(self.model, self.state['model'])

    def build_optimizer(self):

        opt_config = self.config.optimizer
        opt_config.kwargs.lr = self.config.lr_scheduler.kwargs.base_lr
        self.prototype_info.optimizer = self.config.optimizer.type

        # make param_groups
        pconfig = {}

        if opt_config.get('no_wd', False):
            pconfig['conv_b'] = {'weight_decay': 0.0}
            pconfig['linear_b'] = {'weight_decay': 0.0}
            pconfig['bn_w'] = {'weight_decay': 0.0}
            pconfig['bn_b'] = {'weight_decay': 0.0}

        if 'pconfig' in opt_config:
            pconfig.update(opt_config['pconfig'])

        param_group, type2num = param_group_all(self.model, pconfig)

        opt_config.kwargs.params = param_group

        self.optimizer = optim_entry(opt_config)

        if 'optimizer' in self.state:
            load_state_optimizer(self.optimizer, self.state['optimizer'])

        # EMA
        if self.config.ema.enable:
            self.config.ema.kwargs.model = self.model
            self.ema = EMA(**self.config.ema.kwargs)
        else:
            self.ema = None

        if 'ema' in self.state:
            self.ema.load_state_dict(self.state['ema'])

    def build_lr_scheduler(self):
        self.prototype_info.lr_scheduler = self.config.lr_scheduler.type
        if not getattr(self.config.lr_scheduler.kwargs, 'max_iter', False):
            self.config.lr_scheduler.kwargs.max_iter = self.config.data.max_iter
        self.config.lr_scheduler.kwargs.optimizer = self.optimizer.optimizer if isinstance(self.optimizer, FP16SGD) or \
            isinstance(self.optimizer, FP16RMSprop) else self.optimizer
        self.config.lr_scheduler.kwargs.last_iter = self.state['last_iter']
        self.lr_scheduler = scheduler_entry(self.config.lr_scheduler)

    def build_data(self):
        self.config.data.last_iter = self.state['last_iter']
        if getattr(self.config.lr_scheduler.kwargs, 'max_iter', False):
            self.config.data.max_iter = self.config.lr_scheduler.kwargs.max_iter
        else:
            self.config.data.max_epoch = self.config.lr_scheduler.kwargs.max_epoch

        if self.config.data.get('type', 'imagenet') == 'imagenet':
            self.train_data = build_imagenet_train_dataloader(self.config.data)
        else:
            self.train_data = build_custom_dataloader('train',
                                                      self.config.data)

        if self.config.data.get('type', 'imagenet') == 'imagenet':
            self.val_data = build_imagenet_test_dataloader(self.config.data)
        else:
            self.val_data = build_custom_dataloader('test', self.config.data)

    def pre_train(self):
        self.meters = EasyDict()
        self.meters.batch_time = AverageMeter(self.config.saver.print_freq)
        self.meters.step_time = AverageMeter(self.config.saver.print_freq)
        self.meters.data_time = AverageMeter(self.config.saver.print_freq)
        self.meters.losses = AverageMeter(self.config.saver.print_freq)
        self.meters.top1 = AverageMeter(self.config.saver.print_freq)
        self.meters.top5 = AverageMeter(self.config.saver.print_freq)

        self.model.train()

        label_smooth = self.config.get('label_smooth', 0.0)
        self.num_classes = self.config.model.kwargs.get('num_classes', 1000)
        self.topk = 5 if self.num_classes >= 5 else self.num_classes
        if label_smooth > 0:
            self.logger.info('using label_smooth: {}'.format(label_smooth))
            self.criterion = LabelSmoothCELoss(label_smooth, self.num_classes)
        else:
            self.criterion = torch.nn.CrossEntropyLoss()
        self.mixup = self.config.get('mixup', 1.0)
        self.cutmix = self.config.get('cutmix', 0.0)
        if self.mixup < 1.0:
            self.logger.info('using mixup with alpha of: {}'.format(
                self.mixup))
        if self.cutmix > 0.0:
            self.logger.info('using cutmix with alpha of: {}'.format(
                self.cutmix))

    def train(self):

        self.pre_train()
        total_step = len(self.train_data['loader'])
        start_step = self.state['last_iter'] + 1
        end = time.time()
        for i, batch in enumerate(self.train_data['loader']):
            input = batch['image']
            target = batch['label']
            curr_step = start_step + i
            self.lr_scheduler.step(curr_step)
            # lr_scheduler.get_lr()[0] is the main lr
            current_lr = self.lr_scheduler.get_lr()[0]
            # measure data loading time
            self.meters.data_time.update(time.time() - end)
            # transfer input to gpu
            target = target.squeeze().cuda().long()
            input = input.cuda().half() if self.fp16 else input.cuda()
            # mixup
            if self.mixup < 1.0:
                input, target_a, target_b, lam = mixup_data(
                    input, target, self.mixup)
            # cutmix
            if self.cutmix > 0.0:
                input, target_a, target_b, lam = cutmix_data(
                    input, target, self.cutmix)
            # forward
            logits = self.model(input)
            # mixup
            if self.mixup < 1.0 or self.cutmix > 0.0:
                loss = mix_criterion(self.criterion, logits, target_a,
                                     target_b, lam)
                loss /= self.dist.world_size
            else:
                loss = self.criterion(logits, target) / self.dist.world_size
            # measure accuracy and record loss
            prec1, prec5 = accuracy(logits, target, topk=(1, self.topk))

            reduced_loss = loss.clone()
            reduced_prec1 = prec1.clone() / self.dist.world_size
            reduced_prec5 = prec5.clone() / self.dist.world_size

            self.meters.losses.reduce_update(reduced_loss)
            self.meters.top1.reduce_update(reduced_prec1)
            self.meters.top5.reduce_update(reduced_prec5)

            # compute and update gradient
            self.optimizer.zero_grad()
            if FusedFP16SGD is not None and isinstance(self.optimizer,
                                                       FusedFP16SGD):
                self.optimizer.backward(loss)
                self.model.sync_gradients()
                self.optimizer.step()
            elif isinstance(self.optimizer, FP16SGD) or isinstance(
                    self.optimizer, FP16RMSprop):

                def closure():
                    self.optimizer.backward(loss, False)
                    self.model.sync_gradients()
                    # check overflow, convert to fp32 grads, downscale
                    self.optimizer.update_master_grads()
                    return loss

                self.optimizer.step(closure)
            else:
                loss.backward()
                self.model.sync_gradients()
                self.optimizer.step()

            # EMA
            if self.ema is not None:
                self.ema.step(self.model, curr_step=curr_step)
            # measure elapsed time
            self.meters.batch_time.update(time.time() - end)

            # training logger
            if curr_step % self.config.saver.print_freq == 0 and self.dist.rank == 0:
                self.tb_logger.add_scalar('loss_train', self.meters.losses.avg,
                                          curr_step)
                self.tb_logger.add_scalar('acc1_train', self.meters.top1.avg,
                                          curr_step)
                self.tb_logger.add_scalar('acc5_train', self.meters.top5.avg,
                                          curr_step)
                self.tb_logger.add_scalar('lr', current_lr, curr_step)
                remain_secs = (total_step -
                               curr_step) * self.meters.batch_time.avg
                remain_time = datetime.timedelta(seconds=round(remain_secs))
                finish_time = time.strftime(
                    "%Y-%m-%d %H:%M:%S",
                    time.localtime(time.time() + remain_secs))
                log_msg = f'Iter: [{curr_step}/{total_step}]\t' \
                    f'Time {self.meters.batch_time.val:.3f} ({self.meters.batch_time.avg:.3f})\t' \
                    f'Data {self.meters.data_time.val:.3f} ({self.meters.data_time.avg:.3f})\t' \
                    f'Loss {self.meters.losses.val:.4f} ({self.meters.losses.avg:.4f})\t' \
                    f'Prec@1 {self.meters.top1.val:.3f} ({self.meters.top1.avg:.3f})\t' \
                    f'Prec@5 {self.meters.top5.val:.3f} ({self.meters.top5.avg:.3f})\t' \
                    f'LR {current_lr:.4f}\t' \
                    f'Remaining Time {remain_time} ({finish_time})'
                self.logger.info(log_msg)

            # testing during training
            if curr_step > 0 and curr_step % self.config.saver.val_freq == 0:
                metrics = self.evaluate()
                if self.ema is not None:
                    self.ema.load_ema(self.model)
                    ema_metrics = self.evaluate()
                    self.ema.recover(self.model)
                    if self.dist.rank == 0 and self.config.data.test.evaluator.type == 'imagenet':
                        metric_key = 'top{}'.format(self.topk)
                        self.tb_logger.add_scalars(
                            'acc1_val', {'ema': ema_metrics.metric['top1']},
                            curr_step)
                        self.tb_logger.add_scalars(
                            'acc5_val',
                            {'ema': ema_metrics.metric[metric_key]}, curr_step)

                # testing logger
                if self.dist.rank == 0 and self.config.data.test.evaluator.type == 'imagenet':
                    metric_key = 'top{}'.format(self.topk)
                    self.tb_logger.add_scalar('acc1_val',
                                              metrics.metric['top1'],
                                              curr_step)
                    self.tb_logger.add_scalar('acc5_val',
                                              metrics.metric[metric_key],
                                              curr_step)

                # save ckpt
                if self.dist.rank == 0:
                    if self.config.saver.save_many:
                        ckpt_name = f'{self.path.save_path}/ckpt_{curr_step}.pth.tar'
                    else:
                        ckpt_name = f'{self.path.save_path}/ckpt.pth.tar'
                    self.state['model'] = self.model.state_dict()
                    self.state['optimizer'] = self.optimizer.state_dict()
                    self.state['last_iter'] = curr_step
                    if self.ema is not None:
                        self.state['ema'] = self.ema.state_dict()
                    torch.save(self.state, ckpt_name)

            end = time.time()

    @torch.no_grad()
    def evaluate(self):
        self.model.eval()
        res_file = os.path.join(self.path.result_path,
                                f'results.txt.rank{self.dist.rank}')
        writer = open(res_file, 'w')
        for batch_idx, batch in enumerate(self.val_data['loader']):
            input = batch['image']
            label = batch['label']
            input = input.cuda().half() if self.fp16 else input.cuda()
            label = label.squeeze().view(-1).cuda().long()
            # compute output
            logits = self.model(input)
            scores = F.softmax(logits, dim=1)
            # compute prediction
            _, preds = logits.data.topk(k=1, dim=1)
            preds = preds.view(-1)
            # update batch information
            batch.update({'prediction': preds})
            batch.update({'score': scores})
            # save prediction information
            self.val_data['loader'].dataset.dump(writer, batch)

        writer.close()
        link.barrier()
        if self.dist.rank == 0:
            metrics = self.val_data['loader'].dataset.evaluate(res_file)
            self.logger.info(json.dumps(metrics.metric, indent=2))
        else:
            metrics = {}
        link.barrier()
        # broadcast metrics to other process
        metrics = broadcast_object(metrics)
        self.model.train()
        return metrics
Exemple #6
0
class MoCoSolver(ClsSolver):

    def build_model(self):
        """
        Build encode_q and encoder_k.
        """
        if hasattr(self.config, 'lms'):
            if self.config.lms.enable:
                torch.cuda.set_enabled_lms(True)
                byte_limit = self.config.lms.kwargs.limit * (1 << 30)
                torch.cuda.set_limit_lms(byte_limit)
                self.logger.info('Enable large model support, limit of {}G!'.format(
                    self.config.lms.kwargs.limit))

        encoder_q = model_entry(self.config.model)
        encoder_k = model_entry(self.config.model)
        self.model = MoCo(encoder_q, encoder_k, **self.config.moco.kwargs)
        self.model.cuda()
        count_params(self.model.encoder_k)
        count_flops(self.model.encoder_k, input_shape=[1, 3, self.config.data.input_size, self.config.data.input_size])

        # handle fp16
        if self.config.optimizer.type in ['FP16SGD', 'FusedFP16SGD', 'FP16RMSprop']:
            self.fp16 = True
        else:
            self.fp16 = False

        if self.fp16:
            # if you have modules that must use fp32 parameters, and need fp32 input
            # try use link.fp16.register_float_module(your_module)
            # if you only need fp32 parameters set cast_args=False when call this
            # function, then call link.fp16.init() before call model.half()
            if self.config.optimizer.get('fp16_normal_bn', False):
                self.logger.info('using normal bn for fp16')
                link.fp16.register_float_module(link.nn.SyncBatchNorm2d, cast_args=False)
                link.fp16.register_float_module(torch.nn.BatchNorm2d, cast_args=False)
            if self.config.optimizer.get('fp16_normal_fc', False):
                self.logger.info('using normal fc for fp16')
                link.fp16.register_float_module(torch.nn.Linear, cast_args=True)
            link.fp16.init()
            self.model.half()

        self.model = DistModule(self.model, self.config.dist.sync)
        if 'model' in self.state:
            load_state_model(self.model, self.state['model'])

    def build_data(self):
        """
        Unsupervised training: only training data is needed.
        """
        self.config.data.max_iter = self.config.lr_scheduler.kwargs.max_iter
        self.config.data.last_iter = self.state['last_iter']
        if self.config.data.last_iter < self.config.data.max_iter:
            if self.config.data.type == 'imagenet':
                self.train_data = build_imagenet_train_dataloader(self.config.data)
            elif self.config.data.type == 'custom':
                self.train_data = build_custom_dataloader('train', self.config.data)
            else:
                raise RuntimeError("undefined data type!")

    def train(self):

        self.pre_train()
        total_step = len(self.train_data['loader'])
        start_step = self.state['last_iter'] + 1
        end = time.time()
        for i, batch in enumerate(self.train_data['loader']):
            input = batch['image']
            curr_step = start_step + i
            self.lr_scheduler.step(curr_step)
            # lr_scheduler.get_lr()[0] is the main lr
            current_lr = self.lr_scheduler.get_lr()[0]
            # measure data loading time
            self.meters.data_time.update(time.time() - end)
            # transfer input to gpu
            input = input.cuda().half() if self.fp16 else input.cuda()
            # forward
            logits, target = self.model(input)
            loss = self.criterion(logits, target) / self.dist.world_size
            reduced_loss = loss.clone()
            self.meters.losses.reduce_update(reduced_loss)
            self.optimizer.zero_grad()
            if FusedFP16SGD is not None and isinstance(self.optimizer, FusedFP16SGD):
                self.optimizer.backward(loss)
                self.model.sync_gradients()
                self.optimizer.step()
            elif isinstance(self.optimizer, FP16SGD) or isinstance(self.optimizer, FP16RMSprop):
                def closure():
                    self.optimizer.backward(loss, False)
                    self.model.sync_gradients()
                    # check overflow, convert to fp32 grads, downscale
                    self.optimizer.update_master_grads()
                    return loss
                self.optimizer.step(closure)
            else:
                loss.backward()
                self.model.sync_gradients()
                self.optimizer.step()

            # measure elapsed time
            self.meters.batch_time.update(time.time() - end)
            if curr_step % self.config.saver.print_freq == 0 and self.dist.rank == 0:
                self.tb_logger.add_scalar('loss_train', self.meters.losses.avg, curr_step)
                self.tb_logger.add_scalar('lr', current_lr, curr_step)
                remain_secs = (total_step - curr_step) * self.meters.batch_time.avg
                remain_time = datetime.timedelta(seconds=round(remain_secs))
                finish_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()+remain_secs))
                log_msg = f'Iter: [{curr_step}/{total_step}]\t' \
                    f'Time {self.meters.batch_time.val:.3f} ({self.meters.batch_time.avg:.3f})\t' \
                    f'Data {self.meters.data_time.val:.3f} ({self.meters.data_time.avg:.3f})\t' \
                    f'Loss {self.meters.losses.val:.4f} ({self.meters.losses.avg:.4f})\t' \
                    f'LR {current_lr:.4f}\t' \
                    f'Remaining Time {remain_time} ({finish_time})'
                self.logger.info(log_msg)

            if curr_step > 0 and curr_step % self.config.saver.val_freq == 0:
                if self.dist.rank == 0:
                    if self.config.saver.save_many:
                        ckpt_name = f'{self.path.save_path}/ckpt_{curr_step}.pth.tar'
                    else:
                        ckpt_name = f'{self.path.save_path}/ckpt.pth.tar'
                    self.state['model'] = self.model.state_dict()
                    self.state['optimizer'] = self.optimizer.state_dict()
                    self.state['last_iter'] = curr_step
                    torch.save(self.state, ckpt_name)
            end = time.time()
class PrototypeHelper(SpringCommonInterface):
    """Unitied interface for spring"""

    external_model_builder = {}

    def __init__(self,
                 config,
                 metric_dict=None,
                 work_dir=None,
                 ckpt_dict=None):
        """
        Args:
            config (dict): All the configs to build the task
            metric_dict (dict): Dict of Prometheus logger system's instances. Currently it contains two keywords:
                         - process: Gauge type. 0 - 100. Indicates the process of training.
                         - eta: Gauge type. hh:mm:ss. Indicates the estimated remaining time.
            work_dir (str): Task's root folder. Please save all the intermediate results in work_dir.
            ckpt_dict (dict): It contains all the saved k-v pairs for resuming. Only resume task when it's not None
        """
        super(PrototypeHelper, self).__init__(config, metric_dict, work_dir,
                                              ckpt_dict)
        # configuration for model training, evaluating, etc.
        self.config = config
        self.config_copy = copy.deepcopy(config)
        self.metric_dict = metric_dict
        self.work_dir = work_dir
        self.ckpt_dict = ckpt_dict

        self._setup_env()
        self._resume(ckpt_dict)
        self._build()
        self._pre_train()
        self.end_time = time.time()

    def _setup_env(self):
        # distribution information
        self.dist = EasyDict()
        self.dist.rank, self.dist.world_size = link.get_rank(
        ), link.get_world_size()
        # directories
        self.path = EasyDict()
        self.path.root_path = self.work_dir
        self.path.save_path = os.path.join(self.path.root_path, 'checkpoints')
        self.path.event_path = os.path.join(self.path.root_path, 'events')
        self.path.result_path = os.path.join(self.path.root_path, 'results')
        makedir(self.path.save_path)
        makedir(self.path.event_path)
        makedir(self.path.result_path)
        # create tensorboard logger
        if self.dist.rank == 0:
            self.tb_logger = SummaryWriter(self.path.event_path)
        # create logger
        create_logger(os.path.join(self.path.root_path, 'log.txt'))
        self.logger = get_logger(__name__)
        self.logger.info(f'config: {pprint.pformat(self.config)}')
        self.logger.info(f"hostnames: {os.environ['SLURM_NODELIST']}")
        # others
        torch.backends.cudnn.benchmark = True

    def _resume(self, ckpt_dict=None):
        '''
        The ckpt_dict owns higher priority than element's resuming
        '''
        if ckpt_dict:
            self.state = ckpt_dict
            self.curr_step = self.state['last_iter']
            self.logger.info(
                f"Recovering from ckpt_dict, keys={list(self.state.keys())}")
        else:
            # load pretrain
            if hasattr(self.config.saver, 'pretrain'):
                self.state = torch.load(self.config.saver.pretrain.path, 'cpu')
                self.logger.info(
                    f"Recovering from {self.config.saver.pretrain.path}, keys={list(self.state.keys())}"
                )
                if hasattr(self.config.saver.pretrain, 'ignore'):
                    self.state = modify_state(
                        self.state, self.config.saver.pretrain.ignore)
                self.curr_step = self.state['last_iter']
            else:
                self.state = {'last_iter': 0}
                self.curr_step = 0

    def _build(self):
        self.build_model()
        self._build_optimizer()
        self._build_data()
        self._build_lr_scheduler()
        # load pretrain state to model
        if 'model' in self.state:
            load_state_model(self.model, self.state['model'])

    # sci
    def build_model(self):
        '''
        Perform the model building operation only. ``Do not include loading weights``.

        The interface must be used in the model building process by init function.
        Don't implement like below, InterfaceImplement is your implementation of interface class:

        Returns:
            torch.nn.Module: The Pytorch GPU model.
        '''
        if hasattr(self.config, 'lms'):
            if self.config.lms.enable:
                torch.cuda.set_enabled_lms(True)
                byte_limit = self.config.lms.kwargs.limit * (1 << 30)
                torch.cuda.set_limit_lms(byte_limit)
                self.logger.info(
                    'Enable large model support, limit of {}G!'.format(
                        self.config.lms.kwargs.limit))

        self.model = model_entry(self.config.model)
        self.model.cuda()
        # count flops and params
        count_params(self.model)
        count_flops(self.model,
                    input_shape=[
                        1, 3, self.config.data.input_size,
                        self.config.data.input_size
                    ])
        # handle fp16
        if self.config.optimizer.type in [
                'FP16SGD', 'FusedFP16SGD', 'FP16RMSprop'
        ]:
            self.fp16 = True
        else:
            self.fp16 = False

        if self.fp16:
            # if you have modules that must use fp32 parameters, and need fp32 input
            # try use link.fp16.register_float_module(your_module)
            # if you only need fp32 parameters set cast_args=False when call this
            # function, then call link.fp16.init() before call model.half()
            if self.config.optimizer.get('fp16_normal_bn', False):
                self.logger.info('using normal bn for fp16')
                link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                                cast_args=False)
                link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                                cast_args=False)
            if self.config.optimizer.get('fp16_normal_fc', False):
                self.logger.info('using normal fc for fp16')
                link.fp16.register_float_module(torch.nn.Linear,
                                                cast_args=True)
            link.fp16.init()
            self.model.half()

        self.model = DistModule(self.model, self.config.dist.sync)
        return self.model

    def _build_optimizer(self):
        opt_config = self.config.optimizer
        opt_config.kwargs.lr = self.config.lr_scheduler.kwargs.base_lr
        # divide param_groups
        pconfig = {}
        if opt_config.get('no_wd', False):
            pconfig['conv_b'] = {'weight_decay': 0.0}
            pconfig['linear_b'] = {'weight_decay': 0.0}
            pconfig['bn_w'] = {'weight_decay': 0.0}
            pconfig['bn_b'] = {'weight_decay': 0.0}
        if 'pconfig' in opt_config:
            pconfig.update(opt_config['pconfig'])
        param_group, type2num = param_group_all(self.model, pconfig)
        opt_config.kwargs.params = param_group
        self.optimizer = optim_entry(opt_config)
        # load optimizer
        if 'optimizer' in self.state:
            load_state_optimizer(self.optimizer, self.state['optimizer'])
        # EMA
        if self.config.ema.enable:
            self.config.ema.kwargs.model = self.model
            self.ema = EMA(**self.config.ema.kwargs)
            # load EMA
            if 'ema' in self.state:
                self.ema.load_state_model(self.state['ema'])
        else:
            self.ema = None

    def _build_lr_scheduler(self):
        if not getattr(self.config.lr_scheduler.kwargs, 'max_iter', False):
            self.config.lr_scheduler.kwargs.max_iter = self.config.data.max_iter
        self.config.lr_scheduler.kwargs.optimizer = self.optimizer.optimizer if isinstance(self.optimizer, FP16SGD) or \
            isinstance(self.optimizer, FP16RMSprop) else self.optimizer
        self.config.lr_scheduler.kwargs.last_iter = self.state['last_iter']
        self.lr_scheduler = scheduler_entry(self.config.lr_scheduler)

    def _build_data(self):
        self.config.data.last_iter = self.state['last_iter']
        if getattr(self.config.lr_scheduler.kwargs, 'max_iter', False):
            self.config.data.max_iter = self.config.lr_scheduler.kwargs.max_iter
        else:
            self.config.data.max_epoch = self.config.lr_scheduler.kwargs.max_epoch

        self.data_loaders = {}
        key_list = list(self.config.data.keys())
        for data_type in key_list:
            if data_type in ['train', 'test', 'val', 'arch', 'inference']:
                if self.config.data.type == 'imagenet':
                    # imagenet type
                    if data_type == 'train':
                        loader = build_imagenet_train_dataloader(
                            self.config.data)
                    elif data_type == 'test':
                        loader = build_imagenet_test_dataloader(
                            self.config.data)
                    else:
                        loader = build_imagenet_search_dataloader(
                            self.config.data)
                else:
                    # custom type
                    loader = build_custom_dataloader(data_type,
                                                     self.config.data)
                self.data_loaders.update({data_type: loader['loader']})

        if self.data_loaders['train'] is not None:
            self.total_step = len(self.data_loaders['train'])
        else:
            self.total_step = 0

    def _pre_train(self):
        self.meters = EasyDict()
        self.meters.batch_time = AverageMeter(self.config.saver.print_freq)
        self.meters.step_time = AverageMeter(self.config.saver.print_freq)
        self.meters.data_time = AverageMeter(self.config.saver.print_freq)
        self.meters.losses = AverageMeter(self.config.saver.print_freq)
        self.meters.top1 = AverageMeter(self.config.saver.print_freq)
        self.meters.top5 = AverageMeter(self.config.saver.print_freq)

        self.model.train()

        label_smooth = self.config.get('label_smooth', 0.0)
        self.num_classes = self.config.model.kwargs.get('num_classes', 1000)
        self.topk = 5 if self.num_classes >= 5 else self.num_classes
        if label_smooth > 0:
            self.logger.info('using label_smooth: {}'.format(label_smooth))
            self.criterion = LabelSmoothCELoss(label_smooth, self.num_classes)
        else:
            self.criterion = torch.nn.CrossEntropyLoss()
        self.mixup = self.config.get('mixup', 1.0)
        self.cutmix = self.config.get('cutmix', 0.0)
        if self.mixup < 1.0:
            self.logger.info('using mixup with alpha of: {}'.format(
                self.mixup))
        if self.cutmix > 0.0:
            self.logger.info('using cutmix with alpha of: {}'.format(
                self.cutmix))

    # sci
    def get_model(self):
        '''
        Return the Pytorch GPU Model.
        The model should not be built every time this interface is called.

        Returns:
            torch.nn.Module: The Pytorch GPU model to train/mimic/prune/quant/...
        '''
        return self.model

    # sci
    def get_optimizer(self):
        '''
        Return the optimizer of the Pytorch GPU Model.
        The optimizer should not be built every time this interface is called.

        Returns:
            torch.optim.Optimizer: The optimizer of the Pytorch GPU model
        '''
        return self.optimizer

    # sci
    def get_scheduler(self):
        '''
        It should not be built every time this interface is called.

        Returns:
            torch.optim.lr_scheduler: The scheduler
        '''
        return self.lr_scheduler

    # sci
    def get_dummy_input(self):
        '''
        Request the input for forwarding the model.
        It will be used to calc FLOPs and so on.
        Make sure the returned input can support operation: model(get_dummy_input())

        Returns:
            object: The dummy input for current task.
        '''
        input = torch.zeros(1, 3, self.config.data.input_size,
                            self.config.data.input_size)
        input = input.cuda().half() if self.fp16 else input.cuda()
        return input

    # sci
    def get_dump_dict(self):
        '''
        Returns:
            dict: Custom dict to be dumped by torch.save
        '''
        dict_to_dump = {}
        dict_to_dump['config'] = self.config_copy
        dict_to_dump['model'] = self.model.state_dict()
        dict_to_dump['optimizer'] = self.optimizer.state_dict()
        dict_to_dump['last_iter'] = self.curr_step
        if self.ema is not None:
            dict_to_dump['ema'] = self.ema.state_dict()
        return dict_to_dump

    # sci
    def get_batch(self, batch_type='train'):
        '''
        Return the batch of the given batch_type. The valid batch_type is set in config.
        e.g. Your config file is like below, then the interface should return corresponding batch
        when batch_type is train, test or val.

        The returned batch will be used to call `forward` function of SpringCommonInterface.
        The first item will be used to forward model like: model(get_batch('train')[0]).
        Please make sure the first item in the returned batch be FP32 and in GPU.
        If your model is FP16, please do convert in your model's forward function.

        Args:
            batch_type (str): Default: 'train'. It can also be 'val', 'test' or other custom type.

        Returns:
            tuple: a tuple of batch (input, label)
        '''
        assert batch_type in self.data_loaders
        if not hasattr(self, 'data_iterators'):
            self.data_iterators = {}
        if batch_type not in self.data_iterators:
            iterator = self.data_iterators[batch_type] = iter(
                self.data_loaders[batch_type])
        else:
            iterator = self.data_iterators[batch_type]

        try:
            batch = next(iterator)
        except StopIteration as e:  # noqa
            iterator = self.data_iterators[batch_type] = iter(
                self.data_loaders[batch_type])
            batch = next(iterator)
        return batch['image'], batch['label']

    # sci
    def get_total_iter(self):
        '''
        Return the total iteration of the Task.
        Note that even the task is resumed, the returned value should not be changed.

        Returns:
            total_iter (int): the total iteration of the Task. Please convert epoch to iter if needed
        '''
        return self.config.data.max_iter

    # sci
    @staticmethod
    def load_weights(model, ckpt_dict):
        '''
        Static Function. Load weights for model from ckpt

        Args:
            model (torch.nn.Module): Pytorch GPU model to be loaded
            ckpt_dict (dict): checkpoint dict to resume task
        '''
        model.load_state_dict(ckpt_dict['model'], strict=True)

    # sci
    def forward(self, batch):
        '''
        Forward with the given batch and return the loss, e.g. the batch from `get_batch` interface.
        Do not manually change the magnitude of loss, like dividing the world size.
        Do not manually change the model's state, i.e. model.train() or model.eval()

        Args:
            batch (tuple): the batch for forwarding the model

        Returns:
            loss (torch.cuda.Tensor): loss tensor in GPU, the loss of the given batch and current model.
        '''
        # measure data loading time
        self.meters.data_time.update(time.time() - self.end_time)
        input, target = batch[0], batch[1]
        input = input.cuda().half() if self.fp16 else input.cuda()
        target = target.squeeze().view(-1).cuda().long()
        # forward
        logits = self.model(input)
        loss = self.criterion(logits, target)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(logits, target, topk=(1, self.topk))

        reduced_loss = loss.clone() / self.dist.world_size
        reduced_prec1 = prec1.clone() / self.dist.world_size
        reduced_prec5 = prec5.clone() / self.dist.world_size

        self.meters.losses.reduce_update(reduced_loss)
        self.meters.top1.reduce_update(reduced_prec1)
        self.meters.top5.reduce_update(reduced_prec5)
        return loss

    # sci
    def backward(self, loss):
        '''
        Backward with the given loss.
        Please don't modify the magnitude of loss (divid by the world size or multiply loss weight).
        The gradient should be synchronized by all_reduce with sum operation.
        Do not manually change the model's state, i.e. model.train() or model.eval()

        Args:
            loss (torch.cuda.Tensor): loss tensor in GPU
        '''
        self.optimizer.zero_grad()
        if self.fp16:
            self.optimizer.backward(loss)
            self.model.sync_gradients()
        else:
            loss.backward()
            self.model.sync_gradients()

    # sci
    def update(self):
        '''
        Update the model with current calculated gradients.
        The scheduler should also be stepped.
        '''
        self.curr_step += 1  # move from forward() to update()
        self.lr_scheduler.step(self.curr_step)
        self.optimizer.step()
        # EMA
        if self.ema is not None:
            self.ema.step(self.model, curr_step=self.curr_step)

        # set metric_dict
        if self.metric_dict is not None and 'eta' in self.metric_dict:
            self.metric_dict['eta'].set(self.get_eta())
        if self.metric_dict is not None and 'progress' in self.metric_dict:
            self.metric_dict['progress'].set(self.get_progress())
        self.meters.batch_time.update(time.time() - self.end_time)
        self.end_time = time.time()

    def get_eta(self):
        return (self.total_step - self.curr_step) * self.meters.batch_time.avg

    def get_progress(self):
        return self.curr_step / self.total_step * 100

    # sci
    def train(self):
        '''
        Perform the entire training process
        '''
        for i in range(self.total_step):
            batch = self.get_batch()
            loss = self.forward(batch)
            self.backward(loss / self.dist.world_size)
            self.update()

            # measure elapsed time
            self.meters.batch_time.update(time.time() - self.end_time)
            # lr_scheduler.get_lr()[0] is the main lr
            current_lr = self.lr_scheduler.get_lr()[0]
            curr_step = self.curr_step
            if curr_step % self.config.saver.print_freq == 0 and self.dist.rank == 0:
                self.tb_logger.add_scalar('loss_train', self.meters.losses.avg,
                                          curr_step)
                self.tb_logger.add_scalar('acc1_train', self.meters.top1.avg,
                                          curr_step)
                self.tb_logger.add_scalar('acc5_train', self.meters.top5.avg,
                                          curr_step)
                self.tb_logger.add_scalar('lr', current_lr, curr_step)
                remain_secs = (self.total_step -
                               curr_step) * self.meters.batch_time.avg
                remain_time = datetime.timedelta(seconds=round(remain_secs))
                finish_time = time.strftime(
                    "%Y-%m-%d %H:%M:%S",
                    time.localtime(time.time() + remain_secs))
                log_msg = f'Iter: [{curr_step}/{self.total_step}]\t' \
                    f'Time {self.meters.batch_time.val:.3f} ({self.meters.batch_time.avg:.3f})\t' \
                    f'Data {self.meters.data_time.val:.3f} ({self.meters.data_time.avg:.3f})\t' \
                    f'Loss {self.meters.losses.val:.4f} ({self.meters.losses.avg:.4f})\t' \
                    f'Prec@1 {self.meters.top1.val:.3f} ({self.meters.top1.avg:.3f})\t' \
                    f'Prec@5 {self.meters.top5.val:.3f} ({self.meters.top5.avg:.3f})\t' \
                    f'LR {current_lr:.4f}\t' \
                    f'Remaining Time {remain_time} ({finish_time})' \

                self.logger.info(log_msg)

            if curr_step > 0 and curr_step % self.config.saver.val_freq == 0:
                metrics = self.evaluate()
                if self.ema is not None:
                    self.ema.load_ema(self.model)
                    ema_metrics = self.evaluate()
                    self.ema.recover(self.model)
                    if self.dist.rank == 0 and self.config.data.test.evaluator.type == 'imagenet':
                        self.tb_logger.add_scalars(
                            'acc1_val', {'ema': ema_metrics.metric['top1']},
                            curr_step)
                        self.tb_logger.add_scalars(
                            'acc5_val', {'ema': ema_metrics.metric['top5']},
                            curr_step)

                # testing logger
                if self.dist.rank == 0 and self.config.data.test.evaluator.type == 'imagenet':
                    self.tb_logger.add_scalar('acc1_val',
                                              metrics.metric['top1'],
                                              curr_step)
                    self.tb_logger.add_scalar('acc5_val',
                                              metrics.metric['top5'],
                                              curr_step)

                # save ckpt
                if self.dist.rank == 0:
                    if self.config.saver.save_many:
                        ckpt_name = f'{self.path.save_path}/ckpt_{curr_step}.pth.tar'
                    else:
                        ckpt_name = f'{self.path.save_path}/ckpt.pth.tar'
                    self.state['model'] = self.model.state_dict()
                    self.state['optimizer'] = self.optimizer.state_dict()
                    self.state['last_iter'] = curr_step
                    if self.ema is not None:
                        self.state['ema'] = self.ema.state_dict()
                    torch.save(self.state, ckpt_name)

            self.end_time = time.time()

    # sci
    @torch.no_grad()
    def evaluate(self):
        '''
        Do evaluation and return Metric Class instance

        Returns:
            Metric: Metric class instance
        '''
        self.model.eval()
        res_file = os.path.join(self.path.result_path,
                                f'results.txt.rank{self.dist.rank}')
        writer = open(res_file, 'w')
        for batch_idx, batch in enumerate(self.data_loaders['test']):
            input = batch['image']
            label = batch['label']
            input = input.cuda().half() if self.fp16 else input.cuda()
            label = label.squeeze().view(-1).cuda().long()
            # compute output
            logits = self.model(input)
            scores = F.softmax(logits, dim=1)
            # compute prediction
            _, preds = logits.data.topk(k=1, dim=1)
            preds = preds.view(-1)
            # update batch information
            batch.update({'prediction': preds})
            batch.update({'score': scores})
            # save prediction information
            self.data_loaders['test'].dataset.dump(writer, batch)

        writer.close()
        link.barrier()
        if self.dist.rank == 0:
            metrics = self.data_loaders['test'].dataset.evaluate(res_file)
            self.logger.info(json.dumps(metrics.metric, indent=2))
        else:
            metrics = {}
        link.barrier()
        # broadcast metrics to other process
        metrics = broadcast_object(metrics)
        self.model.train()
        return metrics

    # sci
    @staticmethod
    def build_model_helper(config_dict=None):
        '''
        Static function. Build a model from the given config_dict.

        Args:
            config (dict): the config that contains model information

        Returns:
            torch.nn.Module: The Pytorch GPU model.
        '''
        if not isinstance(config_dict, EasyDict):
            config_dict = EasyDict(config_dict)
        model = model_entry(config_dict.model)
        model.cuda()

        if config_dict.optimizer.type in [
                'FP16SGD', 'FusedFP16SGD', 'FP16RMSprop'
        ]:
            fp16 = True
        else:
            fp16 = False

        if fp16:
            # if you have modules that must use fp32 parameters, and need fp32 input
            # try use link.fp16.register_float_module(your_module)
            # if you only need fp32 parameters set cast_args=False when call this
            # function, then call link.fp16.init() before call model.half()
            if config_dict.optimizer.get('fp16_normal_bn', False):
                link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                                cast_args=False)
                link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                                cast_args=False)
            if config_dict.optimizer.get('fp16_normal_fc', False):
                link.fp16.register_float_module(torch.nn.Linear,
                                                cast_args=True)
            link.fp16.init()
            model.half()

        model = DistModule(model, config_dict.dist.sync)
        return model

    # sci
    def show_log(self):
        '''
        Display the log of current iteration. The interface will be called after
        forward/backward/update
        '''
        curr_step = self.curr_step
        current_lr = self.lr_scheduler.get_lr()[0]
        remain_secs = (self.total_step -
                       curr_step) * self.meters.batch_time.avg
        remain_time = datetime.timedelta(seconds=round(remain_secs))
        finish_time = time.strftime("%Y-%m-%d %H:%M:%S",
                                    time.localtime(time.time() + remain_secs))
        log_msg = f'Iter: [{curr_step}/{self.total_step}]\t' \
            f'Time {self.meters.batch_time.val:.3f} ({self.meters.batch_time.avg:.3f})\t' \
            f'Data {self.meters.data_time.val:.3f} ({self.meters.data_time.avg:.3f})\t' \
            f'Loss {self.meters.losses.val:.4f} ({self.meters.losses.avg:.4f})\t' \
            f'Prec@1 {self.meters.top1.val:.3f} ({self.meters.top1.avg:.3f})\t' \
            f'Prec@5 {self.meters.top5.val:.3f} ({self.meters.top5.avg:.3f})\t' \
            f'LR {current_lr:.4f}\t' \
            f'Remaining Time {remain_time} ({finish_time})' \

        self.logger.info(log_msg)

    # sci
    @classmethod
    def add_external_model(cls, name, callable_object):
        '''
        Add external model into the element. After this interface is called, the element should
        be able to build model by the given ``name``.

        Args:
            name (str): The identifier of callable_object
            callable_object (callable object): A class or a function that is callable to build a torch.nn.Module model.
        '''
        cls.external_model_builder[name] = callable_object

    # sci
    def convert_model(self, type="skme"):
        '''
        Dump pytorch model to deployable type model (skme or caffe).
        More about SKME(AKA onnx with oplib): https://confluence.sensetime.com/pages/viewpage.action?pageId=135889068
        Recommand spring.nart.tools.pytorch.convert_caffe_by_return for caffe,
        spring.nart.tools.pytorch.convert_onnx_by_return for skme

        Args:
            type: deploy model type. "skme" or "caffe"

        Returns:
            dict: A dict of model file (string) and type.
            The format should be {"model": [model1, model2, ...], "type": "skme"}.
        '''
        if type == 'skme':
            self.logger.warning(
                'skme is not supported yet, we support convert to caffemodel for now'
            )
        return {'model': [self.to_caffe()], 'type': type}

    # sci
    def get_kestrel_parameter(self):
        '''
        get kestrel plugin parameter file.
        parameters contains model_files and other options (which bind to relate kestrel plugin)
        model_files must follow this contract:

        Returns:
            str: Kestrel plugin parameter.json content.
        '''
        assert hasattr(self.config, 'to_kestrel')
        kestrel_config = self.config.to_kestrel
        kestrel_param = EasyDict()
        # default: ImageNet statistics
        kestrel_param['pixel_means'] = kestrel_config.get(
            'pixel_means', [123.675, 116.28, 103.53])
        kestrel_param['pixel_stds'] = kestrel_config.get(
            'pixel_stds', [58.395, 57.12, 57.375])
        # default: True/True/UNKNOWN
        kestrel_param['is_rgb'] = kestrel_config.get('is_rgb', True)
        kestrel_param['save_all_label'] = kestrel_config.get(
            'save_all_label', True)
        kestrel_param['type'] = kestrel_config.get('type', 'UNKNOWN')
        # class label
        if hasattr(kestrel_config, 'class_label'):
            kestrel_param['class_label'] = kestrel_config['class_label']
        else:
            # default: imagenet
            kestrel_param['class_label'] = {}
            kestrel_param['class_label']['imagenet'] = {}
            kestrel_param['class_label']['imagenet']['calculator'] = 'bypass'
            kestrel_param['class_label']['imagenet']['labels'] = [
                str(i) for i in np.arange(self.num_classes)
            ]
            kestrel_param['class_label']['imagenet']['feature_start'] = 0
            kestrel_param['class_label']['imagenet']['feature_end'] = 0

        return json.dumps(kestrel_param)

    # sci
    def get_epoch_iter(self, loader_name):
        '''
        Return the epoch iteration of the loader with the given loader_name.
        The returned value equals to one epoch iteration, i.e. len(self.xxx_loader)
        This interface is different with `get_total_iter`, which is required to return
        the total iteration of the training process.

        Args:
            loader_name (str): loader's name.

        Returns:
            int: The epoch iteration of the loader with the given loader_name.
        '''
        return len(self.data_loaders[loader_name])

    # sensespring
    def to_caffe(self, save_prefix='model', input_size=None):
        from spring.nart.tools import pytorch

        with pytorch.convert_mode():
            pytorch.convert(self.model.float(), [
                (3, self.config.data.input_size, self.config.data.input_size)
            ],
                            filename=save_prefix,
                            input_names=['data'],
                            output_names=['out'])

    # sensespring
    def to_kestrel(self, save_to=None):
        assert hasattr(self.config, 'to_kestrel')
        prefix = 'model'
        self.logger.info('Converting Model to Caffe...')
        if self.dist.rank == 0:
            self.to_caffe(save_prefix=prefix)
        link.synchronize()
        self.logger.info('To Caffe Done!')

        prototxt = '{}.prototxt'.format(prefix)
        caffemodel = '{}.caffemodel'.format(prefix)
        # default version '1.0.0'
        # acquire version and model_name
        version = self.config.to_kestrel.get('version', '1.0.0')
        model_name = self.config.to_kestrel.get('model_name',
                                                self.config.model.type)
        kestrel_model = '{}_{}.tar'.format(model_name, version)
        to_kestrel_yml = 'temp_to_kestrel.yml'
        # acquire to_kestrel params
        kestrel_param = json.loads(self.get_kestrel_parameter())
        with open(to_kestrel_yml, 'w') as f:
            yaml.dump(kestrel_param, f)

        cmd = 'python -m spring.nart.tools.kestrel.classifier {} {} -v {} -c {} -n {}'.format(
            prototxt, caffemodel, version, to_kestrel_yml, model_name)
        self.logger.info('Converting Model to Kestrel...')
        if self.dist.rank == 0:
            os.system(cmd)
        link.synchronize()
        self.logger.info('To Kestrel Done!')
        if save_to is None:
            save_to = self.config['to_kestrel']['save_to']
        shutil.move(kestrel_model, save_to)
        self.logger.info('save kestrel model to: {}'.format(save_to))

        # convert model to nnie
        nnie_cfg = self.config.to_kestrel.get('nnie', None)
        if nnie_cfg is not None:
            nnie_model = 'nnie_{}_{}.tar'.format(model_name, version)
            nnie_cfg_path = generate_nnie_config(nnie_cfg, self.config)
            nnie_cmd = 'python -m spring.nart.switch -c {} -t nnie {} {}'.format(
                nnie_cfg_path, prototxt, caffemodel)
            self.logger.info('Converting Model to NNIE...')
            if self.dist.rank == 0:
                os.system(nnie_cmd)
                # refactor json
                assert os.path.exists("parameters.json")
                with open("parameters.json", "r") as f:
                    params = json.load(f)
                params["model_files"]["net"]["net"] = "engine.bin"
                params["model_files"]["net"]["backend"] = "kestrel_nart"
                with open("parameters.json", "w") as f:
                    json.dump(params, f, indent=2)
                tar_cmd = 'tar cvf {} engine.bin engine.bin.json meta.json meta.conf \
                    parameters.json category_param.json'.format(nnie_model)
                os.system(tar_cmd)
                self.logger.info(f"generate {nnie_model} done!")
            shutil.move(nnie_model, save_to)
            link.synchronize()
            self.logger.info('To NNIE Done!')

        return save_to

    # sci
    @torch.no_grad()
    def inference(self):
        '''
        Inference the inference dataset and save the raw results (not the evaluation value) in the result file.
        The inference dataset and correspoding config should be set at the config file.
        The result file should be in json format.

        Returns:
            str: The absolute path to the saved results file.
        '''
        assert 'inference' in self.data_loaders.keys()
        self.model.eval()
        res_file = os.path.join(self.path.result_path,
                                f'infer_results.txt.rank{self.dist.rank}')
        writer = open(res_file, 'w')
        for batch_idx, batch in enumerate(self.data_loaders['inference']):
            input = batch['image']
            input = input.cuda().half() if self.fp16 else input.cuda()
            # compute output
            logits = self.model(input)
            scores = F.softmax(logits, dim=1)
            # compute prediction
            _, preds = logits.data.topk(k=1, dim=1)
            preds = preds.view(-1)
            # update batch information
            batch.update({'prediction': preds})
            batch.update({'score': scores})
            # save prediction information
            self.data_loaders['inference'].dataset.dump(writer, batch)

        writer.close()
        link.barrier()
        if self.dist.rank == 0:
            infer_res_file = self.data_loaders['inference'].dataset.inference(
                res_file)
        else:
            infer_res_file = None
        link.barrier()
        # broadcast file to other process
        infer_res_file = broadcast_object(infer_res_file)
        self.model.train()
        return infer_res_file
Exemple #8
0
class SimCLRSolver(ClsSolver):
    def build_model(self):
        encoder = model_entry(self.config.model)
        self.model = SimCLR(encoder)
        self.model.cuda()
        count_params(self.model.encoder)
        count_flops(self.model.encoder,
                    input_shape=[
                        1, 3, self.config.data.input_size,
                        self.config.data.input_size
                    ])

        # handle fp16
        if self.config.optimizer.type in [
                'FP16SGD', 'FusedFP16SGD', 'FP16RMSprop'
        ]:
            self.fp16 = True
        else:
            self.fp16 = False

        if self.fp16:
            # if you have modules that must use fp32 parameters, and need fp32 input
            # try use link.fp16.register_float_module(your_module)
            # if you only need fp32 parameters set cast_args=False when call this
            # function, then call link.fp16.init() before call model.half()
            if self.config.optimizer.get('fp16_normal_bn', False):
                self.logger.info('using normal bn for fp16')
                link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                                cast_args=False)
                link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                                cast_args=False)
            if self.config.optimizer.get('fp16_normal_fc', False):
                self.logger.info('using normal fc for fp16')
                link.fp16.register_float_module(torch.nn.Linear,
                                                cast_args=True)
            link.fp16.init()
            self.model.half()

        self.model = DistModule(self.model, self.config.dist.sync)

        if 'model' in self.state:
            load_state_model(self.model, self.state['model'])

    def pre_train(self):
        super().pre_train()
        self.criterion = NT_Xent(self.config.data.batch_size,
                                 self.config.temperature)

    def train(self):

        self.pre_train()
        total_step = len(self.train_data['loader'])
        start_step = self.state['last_iter'] + 1
        end = time.time()

        for i, batch in enumerate(self.train_data['loader']):
            input = batch['image']
            curr_step = start_step + i
            self.lr_scheduler.step(curr_step)
            # lr_scheduler.get_lr()[0] is the main lr
            current_lr = self.lr_scheduler.get_lr()[0]
            # measure data loading time
            self.meters.data_time.update(time.time() - end)
            # transfer input to gpu
            input = input.cuda().half() if self.fp16 else input.cuda()

            # forward
            z_i, z_j = self.model(input)
            # normalize projection feature vectors
            z_i = F.normalize(z_i, dim=1)
            z_j = F.normalize(z_j, dim=1)
            loss = self.criterion(z_i, z_j) / self.dist.world_size
            reduced_loss = loss.clone()
            self.meters.losses.reduce_update(reduced_loss)
            self.optimizer.zero_grad()

            if FusedFP16SGD is not None and isinstance(self.optimizer,
                                                       FusedFP16SGD):
                self.optimizer.backward(loss)
                self.model.sync_gradients()
                self.optimizer.step()
            elif isinstance(self.optimizer, FP16SGD) or isinstance(
                    self.optimizer, FP16RMSprop):

                def closure():
                    self.optimizer.backward(loss, False)
                    self.model.sync_gradients()
                    # check overflow, convert to fp32 grads, downscale
                    self.optimizer.update_master_grads()
                    return loss

                self.optimizer.step(closure)
            else:
                loss.backward()
                self.model.sync_gradients()
                self.optimizer.step()

            # measure elapsed time
            self.meters.batch_time.update(time.time() - end)
            if curr_step % self.config.saver.print_freq == 0 and self.dist.rank == 0:
                self.tb_logger.add_scalar('loss_train', self.meters.losses.avg,
                                          curr_step)
                self.tb_logger.add_scalar('lr', current_lr, curr_step)
                remain_secs = (total_step -
                               curr_step) * self.meters.batch_time.avg
                remain_time = datetime.timedelta(seconds=round(remain_secs))
                finish_time = time.strftime(
                    "%Y-%m-%d %H:%M:%S",
                    time.localtime(time.time() + remain_secs))
                log_msg = f'Iter: [{curr_step}/{total_step}]\t' \
                    f'Time {self.meters.batch_time.val:.3f} ({self.meters.batch_time.avg:.3f})\t' \
                    f'Data {self.meters.data_time.val:.3f} ({self.meters.data_time.avg:.3f})\t' \
                    f'Loss {self.meters.losses.val:.4f} ({self.meters.losses.avg:.4f})\t' \
                    f'LR {current_lr:.4f}\t' \
                    f'Remaining Time {remain_time} ({finish_time})'
                self.logger.info(log_msg)

            if curr_step > 0 and curr_step % self.config.saver.val_freq == 0:
                if self.dist.rank == 0:
                    if self.config.saver.save_many:
                        ckpt_name = f'{self.path.save_path}/ckpt_{curr_step}.pth.tar'
                    else:
                        ckpt_name = f'{self.path.save_path}/ckpt.pth.tar'
                    self.state['model'] = self.model.state_dict()
                    self.state['optimizer'] = self.optimizer.state_dict()
                    self.state['last_iter'] = curr_step
                    torch.save(self.state, ckpt_name)

            end = time.time()