예제 #1
0
 def _load_pretrained_model(cls, model, pretrained_model_file):
     pretrained_model_file = cls._get_abs_path(pretrained_model_file)
     logging.info("load model weights from file, weights file={}".format(pretrained_model_file))
     if zeus.is_torch_backend():
         if not os.path.isfile(pretrained_model_file):
             raise "Pretrained model is not existed, model={}".format(pretrained_model_file)
         import torch
         checkpoint = torch.load(pretrained_model_file)
         model.load_state_dict(checkpoint)
     if zeus.is_tf_backend():
         if pretrained_model_file.endswith('.pth'):
             checkpoint = convert_checkpoint_from_pytorch(pretrained_model_file, model)
             model.load_checkpoint_from_numpy(checkpoint)
         else:
             pretrained_model_file = cls._get_tf_model_file(pretrained_model_file)
             model.load_checkpoint(pretrained_model_file)
     elif zeus.is_ms_backend():
         from mindspore.train.serialization import load_checkpoint
         if hasattr(model, "pretrained"):
             pretrained_weight = model.pretrained(pretrained_model_file)
         else:
             if os.path.isfile(pretrained_model_file):
                 pretrained_weight = pretrained_model_file
             else:
                 for file in os.listdir(pretrained_model_file):
                     if file.endswith(".ckpt"):
                         pretrained_weight = os.path.join(pretrained_model_file, file)
                         break
         load_checkpoint(pretrained_weight, net=model)
     return model
예제 #2
0
    def load_records_from_model_folder(cls, model_folder):
        """Transfer json_file to records."""
        if not model_folder or not os.path.exists(model_folder):
            logging.error(
                "Failed to load records from model folder, folder={}".format(
                    model_folder))
            return []
        records = []
        pattern = FileOps.join_path(model_folder, "desc_*.json")
        files = glob.glob(pattern)
        for _file in files:
            try:
                with open(_file) as f:
                    worker_id = _file.split(".")[-2].split("_")[-1]
                    weights_file = os.path.join(os.path.dirname(_file),
                                                "model_{}".format(worker_id))
                    if zeus.is_torch_backend():
                        weights_file = '{}.pth'.format(weights_file)
                    elif zeus.is_ms_backend():
                        weights_file = '{}.ckpt'.format(weights_file)
                    if not os.path.exists(weights_file):
                        weights_file = None

                    sample = dict(worker_id=worker_id,
                                  desc=json.load(f),
                                  weights_file=weights_file)
                    record = ReportRecord().load_dict(sample)
                    records.append(record)
            except Exception as ex:
                logging.info(
                    'Can not read records from json because {}'.format(ex))
        return records
예제 #3
0
 def load_model(self):
     """Load model."""
     if not self.model_desc and not self.weights_file:
         self.saved_folder = self.get_local_worker_path(
             self.step_name, self.worker_id)
         self.model_desc = FileOps.join_path(
             self.saved_folder, 'desc_{}.json'.format(self.worker_id))
         if zeus.is_torch_backend():
             self.weights_file = FileOps.join_path(
                 self.saved_folder, 'model_{}.pth'.format(self.worker_id))
         elif zeus.is_torch_backend():
             self.weights_file = FileOps.join_path(
                 self.saved_folder, 'model_{}.ckpt'.format(self.worker_id))
     if 'modules' not in self.model_desc:
         self.model_desc = ModelConfig.model_desc
     self.model = ModelZoo.get_model(self.model_desc, self.weights_file)
예제 #4
0
 def _load_checkpoint(self):
     """Load checkpoint."""
     if zeus.is_torch_backend():
         checkpoint_file = FileOps.join_path(
             self.trainer.get_local_worker_path(),
             self.trainer.checkpoint_file_name)
         if os.path.exists(checkpoint_file):
             try:
                 logging.info("Load checkpoint file, file={}".format(
                     checkpoint_file))
                 checkpoint = torch.load(checkpoint_file)
                 self.trainer.model.load_state_dict(checkpoint["weight"])
                 self.trainer.optimizer.load_state_dict(
                     checkpoint["optimizer"])
                 self.trainer.lr_scheduler.load_state_dict(
                     checkpoint["lr_scheduler"])
                 if self.trainer._resume_training:
                     epoch = checkpoint["epoch"]
                     self.trainer._start_epoch = checkpoint["epoch"]
                     logging.info(
                         "Resume fully train, change start epoch to {}".
                         format(self.trainer._start_epoch))
             except Exception as e:
                 logging.info("Load checkpoint failed {}".format(e))
         else:
             logging.info('Use default model')
예제 #5
0
파일: optim.py 프로젝트: vineetrao25/vega
    def __call__(self, model=None, distributed=False, **kwargs):
        """Call Optimizer class.

        :param model: model, used in torch case
        :param distributed: use distributed
        :return: optimizer
        """
        params = self.map_config.get("params", {})
        logging.debug("Call Optimizer. name={}, params={}".format(
            self.optim_cls.__name__, params))
        optimizer = None
        try:
            if zeus.is_torch_backend():
                learnable_params = [
                    param for param in model.parameters()
                    if param.requires_grad
                ]
                optimizer = self.optim_cls(learnable_params, **params)
                if distributed:
                    optimizer = self.set_distributed(optimizer, model)
            elif zeus.is_tf_backend():
                optimizer = dynamic_optimizer(self.optim_cls, **params)
            elif zeus.is_ms_backend():
                if "dynamic_lr" in kwargs:
                    params.update({"learning_rate": kwargs["dynamic_lr"]})
                learnable_params = [
                    param for param in model.trainable_params()
                    if param.requires_grad
                ]
                optimizer = self.optim_cls(learnable_params, **params)
            return optimizer
        except Exception as ex:
            logging.error("Failed to call Optimizer name={}, params={}".format(
                self.optim_cls.__name__, params))
            raise ex
예제 #6
0
파일: cls_ds.py 프로젝트: vineetrao25/vega
 def _to_tensor(self, data):
     if zeus.is_torch_backend():
         import torch
         return torch.tensor(data)
     elif zeus.is_tf_backend():
         import tensorflow as tf
         return tf.convert_to_tensor(data)
예제 #7
0
 def _init_model(self):
     """Load model desc from save path and parse to model."""
     model = self.trainer.model
     if self.trainer.config.is_detection_trainer:
         model_desc = self.trainer.model_desc
     else:
         model_desc = self._get_model_desc()
     if model_desc:
         ModelConfig.model_desc = model_desc
     pretrained_model_file = self._get_pretrained_model_file()
     if not model:
         if not model_desc:
             raise Exception(
                 "Failed to Init model, can not get model description.")
         model = ModelZoo.get_model(model_desc, pretrained_model_file)
     if model:
         if zeus.is_torch_backend():
             import torch
             if self.trainer.use_cuda:
                 model = model.cuda()
             if General._parallel and General.devices_per_trainer > 1:
                 model = torch.nn.DataParallel(self.trainer.model)
         if zeus.is_tf_backend():
             if pretrained_model_file:
                 model_folder = os.path.dirname(pretrained_model_file)
                 FileOps.copy_folder(model_folder,
                                     self.trainer.get_local_worker_path())
     return model
예제 #8
0
    def _reset_classifier_model(self):
        if zeus.is_torch_backend():

            # num_classes = ModelConfig.model_desc.backbone.n_class
            num_classes = ModelConfig.num_classes

            model = self.trainer.model
            out_features = num_classes

            # fix layers
            # for param in model.parameters():
            #     param.requires_grad = False

            # change head
            if "torch_vision_model" in ModelConfig.model_desc["modules"]:
                # torchvision
                import torch.nn as nn
                in_features = model.fc.in_features
                model.fc = nn.Linear(in_features, out_features).cuda()
            else:
                # vega
                in_features = model.fc.in_features
                from zeus.modules.operators import ops
                model.fc = ops.Linear(in_features=in_features,
                                      out_features=out_features).cuda()
                # TODO n_class
                ModelConfig.model_desc.backbone.n_class = num_classes
                logging.info("Model fine tuned successfully.")
예제 #9
0
 def _train_epoch(self):
     if zeus.is_torch_backend():
         self.model.train()
         for batch_index, batch in enumerate(self.train_loader):
             batch = self.make_batch(batch)
             batch_logs = {'train_batch': batch}
             self.callbacks.before_train_step(batch_index, batch_logs)
             train_batch_output = self.train_step(batch)
             batch_logs.update(train_batch_output)
             if self.config.is_detection_trainer:
                 batch_logs.update({'is_detection_trainer': True})
             self.callbacks.after_train_step(batch_index, batch_logs)
     elif zeus.is_tf_backend():
         self.estimator.train(input_fn=self.train_input_fn,
                              steps=len(self.train_loader),
                              hooks=self._init_logging_hook())
     elif zeus.is_ms_backend():
         self.ms_model = MsModel(network=self.model,
                                 loss_fn=self.loss,
                                 optimizer=self.optimizer,
                                 metrics={self.metric_name: self.valid_metrics()})
         config_ck = CheckpointConfig(save_checkpoint_steps=self.config.save_steps)
         # save the network model and parameters for subsequence fine-tuning
         save_path = self.get_local_worker_path(self.step_name, self.worker_id)
         ckpoint_cb = ModelCheckpoint(config=config_ck, directory=save_path)
         loss_cb = LossMonitor(per_print_times=self.config.report_freq)
         eval_cb = EvalCallBack(self.ms_model, self.valid_loader)
         self.ms_model.train(epoch=self.epochs,
                             train_dataset=self.train_loader,
                             callbacks=[ckpoint_cb, loss_cb, eval_cb],
                             dataset_sink_mode=self.dataset_sink_mode)
예제 #10
0
 def filter(self):
     """Apply mask to linear."""
     if sum(self.mask_code) == 0:
         self.mask_code[0] = 1
     mask_code = np.asarray(self.mask_code)
     idx_in = np.squeeze(np.argwhere(mask_code)).tolist()
     idx_in = [idx_in] if not isinstance(idx_in, list) else idx_in
     self.layer.in_features = sum(mask_code)
     weights = self.layer.get_weights()
     out_size = self.layer.out_features
     for name, weight in weights.items():
         if 'kernel' in name or 'weight' in name:
             if is_torch_backend():
                 self.layer.set_weights(name, weight[:, idx_in])
                 out_size = weight.shape[0]
             else:
                 self.layer.set_weights(name, weight[idx_in, :])
                 out_size = weight.shape[1]
     # fineTune out_feature value
     if self.layer.out_features == out_size:
         return
     idx_out = list(
         np.random.permutation(out_size)[:self.layer.out_features])
     for name, weight in self.layer.get_weights().items():
         if 'kernel' in name:
             self.layer.set_weights(name, weight[:, idx_out])
         else:
             self.layer.set_weights(name, weight[idx_out])
     self.layer.out_features = out_size
예제 #11
0
    def get_cls(cls, type_name, t_cls_name=None):
        """Get class and bind config to class.

        :param type_name: type name of class registry
        :param t_cls_name: class name
        :return:t_cls
        """
        # lazy load class
        if not cls.is_exists(type_name, t_cls_name) and t_cls_name:
            cls._import_pkg(type_name, t_cls_name)
        # verify class
        if not cls.is_exists(type_name, t_cls_name):
            raise ValueError("can't find class type {} class name {} in class registry".format(type_name, t_cls_name))
        # create instance without configs
        if t_cls_name is None:
            from zeus.datasets.conf.dataset import DatasetConfig
            from zeus.evaluator.conf import EvaluatorConfig
            if type_name == ClassType.DATASET:
                t_cls_name = DatasetConfig.type
            elif type_name == ClassType.TRAINER:
                import zeus
                if zeus.is_torch_backend():
                    t_cls_name = "TrainerTorch"
                elif zeus.is_tf_backend():
                    t_cls_name = "TrainerTf"
                elif zeus.is_ms_backend():
                    t_cls_name = "TrainerMs"
            elif type_name == ClassType.EVALUATOR:
                t_cls_name = EvaluatorConfig.type
            else:
                pass
        if t_cls_name is None:
            raise ValueError("can't find class. class type={}".format(type_name))
        t_cls = cls.__registry__.get(type_name).get(t_cls_name)
        return t_cls
예제 #12
0
    def __new__(cls,
                model=None,
                id=None,
                hps=None,
                load_ckpt_flag=False,
                model_desc=None,
                lazy_build=True,
                **kwargs):
        """Create Trainer clss."""
        if zeus.is_torch_backend():
            from zeus.trainer_torch import TrainerTorch
            trainer_cls = TrainerTorch
        elif zeus.is_tf_backend():
            from zeus.trainer_tf import TrainerTf
            trainer_cls = TrainerTf
        else:
            from zeus.trainer_ms import TrainerMs
            trainer_cls = TrainerMs

        return trainer_cls(model=model,
                           id=id,
                           hps=hps,
                           load_ckpt_flag=load_ckpt_flag,
                           model_desc=model_desc,
                           lazy_build=lazy_build,
                           **kwargs)
예제 #13
0
    def get_model(cls, model_desc=None, pretrained_model_file=None):
        """Get model from model zoo.

        :param network_name: the name of network, eg. ResNetVariant.
        :type network_name: str or None.
        :param network_desc: the description of network.
        :type network_desc: str or None.
        :param pretrained_model_file: path of model.
        :type pretrained_model_file: str.
        :return: model.
        :rtype: model.

        """
        try:
            network = NetworkDesc(model_desc)
            model = network.to_model()
        except Exception as e:
            logging.error("Failed to get model, model_desc={}, msg={}".format(
                model_desc, str(e)))
            raise e
        logging.info("Model was created.")
        if zeus.is_torch_backend() and pretrained_model_file:
            model = cls._load_pretrained_model(model, pretrained_model_file)
        elif zeus.is_ms_backend() and pretrained_model_file:
            model = cls._load_pretrained_model(model, pretrained_model_file)
        return model
예제 #14
0
def _get_data_format():
    if zeus.is_torch_backend() or zeus.is_ms_backend():
        return 'channels_first'
    elif zeus.is_tf_backend():
        return 'channels_last'
    else:
        return None
예제 #15
0
 def is_filtered(self, desc=None):
     """Filter function of latency."""
     if self.max_latency is None:
         return False
     model, count_input = self.get_model_input(desc)
     num = 100
     if zeus.is_torch_backend():
         start_time = time.time()
         for i in range(num):
             model(count_input)
         latency = (time.time() - start_time) / num
     elif zeus.is_tf_backend():
         import tensorflow as tf
         input = tf.placeholder(tf.float32,
                                shape=count_input.get_shape().as_list())
         output = model(input, training=False)
         with tf.compat.v1.Session() as sess:
             input_numpy = count_input.eval(session=sess)
             start_time = time.time()
             for i in range(num):
                 sess.run(output, feed_dict={input: input_numpy})
             latency = (time.time() - start_time) / num
     logging.info('Sampled model\'s latency: {}'.format(latency))
     if latency > self.max_latency:
         return True
     else:
         return False
예제 #16
0
    def __init__(self, load_path=None):
        """Construct MobileNetV3Tiny class.

        :param load_path: path for saved model
        """
        super(MobileNetV3Tiny, self).__init__()
        input_channel = 9
        features = [
            conv_bn_relu6(inchannel=3,
                          outchannel=input_channel,
                          kernel=3,
                          stride=2)
        ]
        for _, lst in enumerate(self.inverted_residual_setting):
            output_channel = lst[1]
            features.append(
                InvertedResidual(inp=input_channel,
                                 oup=output_channel,
                                 stride=lst[2],
                                 expand_ratio=lst[0]))
            input_channel = output_channel
        self.block = OutlistSequential(*features, out_list=[3, 6, 13, 17])
        if load_path is not None and is_torch_backend():
            import torch
            self.load_state_dict(torch.load(load_path), strict=False)
예제 #17
0
    def __call__(self, model=None, distributed=False):
        """Call Optimizer class.

        :param model: model, used in torch case
        :param distributed: use distributed
        :return: optimizer
        """
        params = self.map_config.get("params", {})
        logging.debug("Call Optimizer. name={}, params={}".format(self.optim_cls.__name__, params))
        optimizer = None
        try:
            if zeus.is_torch_backend():
                learnable_params = [param for param in model.parameters() if param.requires_grad]
                optimizer = self.optim_cls(learnable_params, **params)
                if distributed:
                    optimizer = hvd.DistributedOptimizer(optimizer,
                                                         named_parameters=model.named_parameters(),
                                                         compression=hvd.Compression.none)
            elif zeus.is_tf_backend():
                optimizer = dynamic_optimizer(self.optim_cls, **params)
                if distributed:
                    optimizer = hvd.DistributedOptimizer(optimizer) if zeus.is_gpu_device() else \
                        NPUDistributedOptimizer(optimizer)
            elif zeus.is_ms_backend():
                learnable_params = [param for param in model.trainable_params() if param.requires_grad]
                optimizer = self.optim_cls(learnable_params, **params)
            return optimizer
        except Exception as ex:
            logging.error("Failed to call Optimizer name={}, params={}".format(self.optim_cls.__name__, params))
            raise ex
예제 #18
0
    def _valid_epoch(self):
        self.callbacks.before_valid()
        valid_logs = None
        if zeus.is_torch_backend():
            self.model.eval()
            with torch.no_grad():
                for batch_index, batch in enumerate(self.valid_loader):
                    batch = self.make_batch(batch)
                    batch_logs = {'valid_batch': batch}
                    self.callbacks.before_valid_step(batch_index, batch_logs)
                    valid_batch_output = self.valid_step(batch)
                    self.callbacks.after_valid_step(batch_index, valid_batch_output)
        elif zeus.is_tf_backend():
            eval_metrics = self.estimator.evaluate(input_fn=self.valid_input_fn,
                                                   steps=len(self.valid_loader))
            self.valid_metrics.update(eval_metrics)
            valid_logs = dict()
            valid_logs['cur_valid_perfs'] = self.valid_metrics.results
        elif zeus.is_ms_backend():
            eval_metrics = self.ms_model.eval(valid_dataset=self.valid_loader,
                                              dataset_sink_mode=self.dataset_sink_mode)

            self.valid_metrics.update(eval_metrics)
            valid_logs = dict()
            valid_logs['cur_valid_perfs'] = self.valid_metrics.results
        self.callbacks.after_valid(valid_logs)
예제 #19
0
 def _save_best_model(self):
     """Save best model."""
     if zeus.is_torch_backend():
         torch.save(self.trainer.model.state_dict(),
                    self.trainer.weights_file)
     elif zeus.is_tf_backend():
         worker_path = self.trainer.get_local_worker_path()
         model_id = "model_{}".format(self.trainer.worker_id)
         weights_folder = FileOps.join_path(worker_path, model_id)
         FileOps.make_dir(weights_folder)
         checkpoint_file = tf.train.latest_checkpoint(worker_path)
         ckpt_globs = glob.glob("{}.*".format(checkpoint_file))
         for _file in ckpt_globs:
             dst_file = model_id + os.path.splitext(_file)[-1]
             FileOps.copy_file(_file,
                               FileOps.join_path(weights_folder, dst_file))
         FileOps.copy_file(FileOps.join_path(worker_path, 'checkpoint'),
                           weights_folder)
     elif zeus.is_ms_backend():
         worker_path = self.trainer.get_local_worker_path()
         save_path = os.path.join(
             worker_path, "model_{}.ckpt".format(self.trainer.worker_id))
         for file in os.listdir(worker_path):
             if file.startswith("CKP") and file.endswith(".ckpt"):
                 self.weights_file = FileOps.join_path(worker_path, file)
                 os.rename(self.weights_file, save_path)
예제 #20
0
 def after_valid(self, logs=None):
     """Be called after validation."""
     if self.do_validation and self.valid_metrics is not None:
         # Get the summary of valid metrics
         metrics_results = self.valid_metrics.results
         if zeus.is_torch_backend() and self.trainer.distributed:
             for key, value in metrics_results.items():
                 metrics_results[key] = self.trainer._metric_average(
                     value, key)
         if 'loss' in metrics_results:
             metrics_results.pop('loss')
         if 'global_step' in metrics_results:
             metrics_results.pop('global_step')
         self.cur_valid_perfs = metrics_results
         logs.update({'cur_valid_perfs': self.cur_valid_perfs})
         # update best valid perfs based on current valid valid_perfs
         if self.best_valid_perfs is None:
             self.best_valid_changed = True
             self.best_valid_perfs = self.cur_valid_perfs
         else:
             self.best_valid_changed = self._update_best_perfs(
                 self.cur_valid_perfs, self.best_valid_perfs)
         logs.update({
             'cur_valid_perfs': self.cur_valid_perfs,
             'best_valid_perfs': self.best_valid_perfs,
             'best_valid_perfs_changed': self.best_valid_changed
         })
예제 #21
0
    def _train_loop(self):
        """Do the training with data, callbacks and step functions etc."""
        # Allow user to build trainer in before_train() callback, but they
        # should set lazy_built in configuration file to True
        self.callbacks.before_train()
        if self.skip_train:
            return

        if self.use_unsupervised_pretrain and zeus.is_torch_backend():
            from .trainer.simclr.transforms import TransformsSimCLR
            from .trainer.simclr.train import simclr_train
            train_loader = self._init_dataloader(mode="train",
                                                 transforms=TransformsSimCLR())
            self.model = simclr_train(self.model, train_loader)

        repeat_time = 1 if zeus.is_ms_backend() else self.epochs
        for epoch in range(self._start_epoch, repeat_time):
            epoch_logs = {'train_num_batches': self.batch_num_train}
            if self.do_validation:
                epoch_logs.update({'valid_num_batches': self.batch_num_valid})
            self.callbacks.before_epoch(epoch, epoch_logs)
            self._train_epoch()
            if self.do_validation and self._should_run_validation(epoch):
                self._valid_epoch()
            self.callbacks.after_epoch(epoch)
        self.callbacks.after_train()
        if self.distributed:
            self._shutdown_distributed()
예제 #22
0
def get_named_modules(layer):
    """Get named modules."""
    if zeus.is_tf_backend():
        return [(op.name, op) for op in layer]
    elif zeus.is_torch_backend():
        return layer.named_modules()
    elif zeus.is_ms_backend():
        return layer._children_scope_recursive()
예제 #23
0
 def _load_pretrained_model(cls, network, model, pretrained_model_file):
     if zeus.is_torch_backend():
         import torch
         if not os.path.isfile(pretrained_model_file):
             raise "Pretrained model is not existed, model={}".format(pretrained_model_file)
         logging.info("load model weights from file, weights file={}".format(pretrained_model_file))
         checkpoint = torch.load(pretrained_model_file)
         model.load_state_dict(checkpoint)
     return model
예제 #24
0
def get_shape(layer):
    """Get weight shape."""
    if zeus.is_tf_backend():
        return layer.get_shape()
    elif zeus.is_torch_backend():
        return layer.weight.data.shape
    elif zeus.is_ms_backend():
        para_name = list(layer._params)[0]
        return getattr(layer, para_name).default_input.shape
예제 #25
0
 def _set_default_funcs(self):
     if zeus.is_torch_backend():
         self.make_batch = self._default_make_batch
         self.train_step = self._default_train_step
         self.valid_step = self._default_valid_step
     elif zeus.is_tf_backend():
         self.model_fn = self._default_model_fn
         self.train_input_fn = self._default_train_input_fn
         self.valid_input_fn = self._default_valid_input_fn
예제 #26
0
 def _init_horovod_setting(self):
     """Init horovod setting."""
     self.is_chief = True
     if self.distributed and zeus.is_torch_backend():
         hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
         hvd.broadcast_optimizer_state(self.optimizer, root_rank=0)
         if hvd.rank() != 0:
             self.is_chief = False
         else:
             self.is_chief = True
예제 #27
0
 def _load_pretrained_model(self):
     if self.model is None:
         return
     if zeus.is_torch_backend(
     ) and self.config.pretrained_model_file is not None:
         model_file = self.config.pretrained_model_file
         model_file = os.path.abspath(model_file)
         ckpt = torch.load(model_file)
         self.model.load_state_dict(ckpt)
         return
예제 #28
0
def Adapter(dataset):
    """Adapter of dataset."""
    if zeus.is_torch_backend():
        from .pytorch import TorchAdapter as Adapter
    elif zeus.is_tf_backend():
        from .tensorflow import TfAdapter as Adapter
    elif zeus.is_ms_backend():
        from .mindspore import MsAdapter as Adapter
    else:
        raise ValueError
    return Adapter(dataset)
예제 #29
0
 def filter_in_channels(self, mask_code):
     """Mask in channels."""
     filter_idx = self._make_mask(mask_code)
     weights = self.layer.get_weights()
     self.layer.in_channels = sum(mask_code)
     for name, weight in weights.items():
         if weight is not None:
             if is_torch_backend():
                 self.layer.set_weights(name, weight[:, filter_idx, :, :])
             else:
                 self.layer.set_weights(name, weight[:, :, filter_idx, :])
예제 #30
0
 def _init_cuda_setting(self):
     """Init CUDA setting."""
     if not zeus.is_torch_backend():
         return
     if not self.config.cuda:
         self.config.device = -1
         return
     self.config.device = self.config.cuda if self.config.cuda is not True else 0
     self.use_cuda = True
     if self.distributed:
         torch.cuda.set_device(self._local_rank_id)
     torch.cuda.manual_seed(self.config.seed)