コード例 #1
0
ファイル: train.py プロジェクト: bcmi220/multilingual_srl
 def __init__(self, model, **kwargs):
     super(ClassifyTrainerPyTorch, self).__init__()
     self.clip = float(kwargs.get('clip', 5))
     self.labels = model.labels
     self.optimizer = OptimizerManager(model, **kwargs)
     self.crit = model.create_loss().cuda()
     self.model = torch.nn.DataParallel(model).cuda()
     self.nsteps = kwargs.get('nsteps', six.MAXSIZE)
コード例 #2
0
ファイル: train.py プロジェクト: tanthml/baseline
    def __init__(self, model, **kwargs):
        super(LanguageModelTrainerPyTorch, self).__init__()
        self.model = model
        self.clip = float(kwargs.get('clip', 5))
        self.gpu = not bool(kwargs.get('nogpu', False))
        self.crit = model.create_loss()

        if self.gpu:
            self.model = self.model.cuda()
            self.crit.cuda()
        self.nsteps = kwargs.get('nsteps', 500)

        self.optimizer = OptimizerManager(self.model, **kwargs)
コード例 #3
0
ファイル: train.py プロジェクト: bjayakumar/mead-baseline
 def __init__(self, model, **kwargs):
     super(Seq2SeqTrainerPyTorch, self).__init__()
     self.gpu = bool(kwargs.get('gpu', True))
     self.clip = float(kwargs.get('clip', 5))
     self.model = model
     self.optimizer = OptimizerManager(self.model, **kwargs)
     self._input = model.make_input
     self._predict = model.predict
     self.crit = model.create_loss()
     self.tgt_rlut = kwargs['tgt_rlut']
     if self.gpu:
         self.model = torch.nn.DataParallel(model).cuda()
         self.crit.cuda()
     self.nsteps = kwargs.get('nsteps', 500)
コード例 #4
0
ファイル: train.py プロジェクト: byfaith/baseline
    def __init__(self, model, **kwargs):
        super(TaggerTrainerPyTorch, self).__init__()
        self.gpu = not bool(kwargs.get('nogpu', False))
        # By default support IOB1/IOB2
        self.span_type = kwargs.get('span_type', 'iob')
        self.verbose = kwargs.get('verbose', False)

        logger.info('Setting span type %s', self.span_type)
        self.model = model
        self.idx2label = revlut(self.model.labels)
        self.clip = float(kwargs.get('clip', 5))
        self.optimizer = OptimizerManager(self.model, **kwargs)
        if self.gpu:
            self.model = model.to_gpu()
        self.nsteps = kwargs.get('nsteps', six.MAXSIZE)
コード例 #5
0
    def __init__(self, model, **kwargs):
        super(ClassifyTrainerPyTorch, self).__init__()
        self.clip = float(kwargs.get('clip', 5))
        self.labels = model.labels
        self.gpus = int(kwargs.get('gpus', 1))
        if self.gpus == -1:
            self.gpus = len(
                os.getenv('CUDA_VISIBLE_DEVICES', os.getenv('NV_GPU',
                                                            '0')).split(','))

        self.optimizer = OptimizerManager(model, **kwargs)
        self.model = model
        if self.gpus > 0:
            self.crit = model.create_loss().cuda()
            if self.gpus > 1:
                self.model = torch.nn.DataParallel(model).cuda()
            else:
                self.model.cuda()
        else:
            logger.warning("Requested training on CPU.  This will be slow.")
            self.crit = model.create_loss()
            self.model = model
        self.nsteps = kwargs.get('nsteps', six.MAXSIZE)
コード例 #6
0
    def __init__(self, model, **kwargs):
        super(TaggerTrainerPyTorch, self).__init__()
        self.gpus = int(kwargs.get('gpus', 1))
        # By default support IOB1/IOB2
        self.span_type = kwargs.get('span_type', 'iob')
        self.verbose = kwargs.get('verbose', False)

        logger.info('Setting span type %s', self.span_type)
        self.model = model
        self.idx2label = revlut(self.model.labels)
        self.clip = float(kwargs.get('clip', 5))
        self.optimizer = OptimizerManager(self.model, **kwargs)
        if self.gpus > 1:
            logger.info(
                "Trainer for PyTorch tagger currently doesnt support multiple GPUs.  Setting to 1"
            )
            self.gpus = 1
        if self.gpus > 0:
            self.model = model.to_gpu()
        else:
            logger.warning("Requested training on CPU.  This will be slow.")

        self.nsteps = kwargs.get('nsteps', six.MAXSIZE)