Пример #1
0
 def __init__(self, opt):
     if isinstance(opt, ParlaiParser):
         print(
             '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
         )
         opt = opt.parse_args()
     # Possibly load from checkpoint
     if opt['load_from_checkpoint'] and opt.get(
             'model_file') and os.path.isfile(opt['model_file'] +
                                              '.checkpoint'):
         opt['init_model'] = opt['model_file'] + '.checkpoint'
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         # If data built via pytorch data teacher, we need to load prebuilt dict
         if opt.get('pytorch_teacher_task'):
             opt['dict_file'] = get_pyt_dict_file(opt)
         elif opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict(opt, skip_if_built=True)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     self.save_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.max_num_epochs = opt[
         'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
     self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
         else float('inf')
     self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
         else float('inf')
     self.val_every_n_secs = \
         opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
         else float('inf')
     self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
         > 0 else float('inf')
     self.val_every_n_epochs = \
         opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
         else float('inf')
     self.last_valid_epoch = 0
     self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
     self.best_valid = None
     if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                 '.best_valid'):
         with open(opt['model_file'] + ".best_valid", 'r') as f:
             x = f.readline()
             self.best_valid = float(x)
             f.close()
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
     if opt['tensorboard_log'] is True:
         self.writer = TensorboardLogger(opt)
Пример #2
0
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)  #specify model such as seq2seq
        self.world = create_task(opt, self.agent)  # bacthworld or other world
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
            else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.val_every_n_secs = \
            opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
            else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
            > 0 else float('inf')
        self.val_every_n_epochs = \
            opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
            else float('inf')

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and os.path.isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
Пример #3
0
    def __init__(self, opt, shared=None):
        """Initialize DictionaryAgent."""
        self.opt = copy.deepcopy(opt)
        self.minfreq = opt.get('dict_minfreq', DictionaryAgent.default_minfreq)
        self.null_token = opt.get('dict_nulltoken',
                                  DictionaryAgent.default_null)
        self.end_token = opt.get('dict_endtoken', DictionaryAgent.default_end)
        self.unk_token = opt.get('dict_unktoken', DictionaryAgent.default_unk)
        self.start_token = opt.get('dict_starttoken',
                                   DictionaryAgent.default_start)
        self.max_ngram_size = opt.get('dict_max_ngram_size',
                                      DictionaryAgent.default_maxngram)
        self.tokenizer = opt.get('dict_tokenizer', DictionaryAgent.default_tok)
        self.lower = opt.get('dict_lower', DictionaryAgent.default_lower)
        self.maxtokens = opt.get('dict_maxtokens',
                                 DictionaryAgent.default_maxtokens)
        self.textfields = opt.get(
            'dict_textfields', DictionaryAgent.default_textfields).split(",")

        try:
            self.tokenizer_fun = getattr(self, self.tokenizer + '_tokenize')
        except AttributeError:
            raise AttributeError('tokenizer type {} not yet supported'.format(
                self.tokenizer))

        if shared:
            self.freq = shared.get('freq', {})
            self.tok2ind = shared.get('tok2ind', {})
            self.ind2tok = shared.get('ind2tok', {})
        else:
            self.freq = defaultdict(int)
            self.tok2ind = {}
            self.ind2tok = {}

            if self.null_token:
                self.add_token(self.null_token)

            if self.start_token:
                # set special start of sentence word token
                self.add_token(self.start_token)

            if self.end_token:
                # set special end of sentence word token
                self.add_token(self.end_token)

            if self.unk_token:
                # set special unknown word token
                self.add_token(self.unk_token)

            loaded = False
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                from parlai.scripts.build_pytorch_data import get_pyt_dict_file

                opt['dict_file'] = get_pyt_dict_file(opt)
            if opt.get('dict_file'):
                opt['dict_file'] = modelzoo_path(opt.get('datapath'),
                                                 opt['dict_file'])
                if os.path.isfile(opt['dict_file']):
                    # load pre-existing dictionary
                    self.load(opt['dict_file'])
                    loaded = True

            if not loaded and opt.get('dict_initpath'):
                # load seed dictionary
                opt['dict_initpath'] = modelzoo_path(opt.get('datapath'),
                                                     opt['dict_initpath'])
                # don't check isfile first, should fail if file not found
                self.load(opt['dict_initpath'])

        # initialize tokenizers
        if self.tokenizer == 'nltk':
            try:
                import nltk
            except ImportError:
                raise ImportError('Please install nltk (pip install nltk)')
            # nltk-specific setup
            st_path = 'tokenizers/punkt/{0}.pickle'.format(
                opt['dict_language'])
            try:
                self.sent_tok = nltk.data.load(st_path)
            except LookupError:
                nltk.download('punkt')
                self.sent_tok = nltk.data.load(st_path)
            self.word_tok = nltk.tokenize.treebank.TreebankWordTokenizer()
        elif self.tokenizer == 'spacy':
            try:
                import spacy
            except ImportError:
                raise ImportError('Please install spacy and spacy "en" model: '
                                  '`pip install -U spacy && '
                                  'python -m spacy download en` '
                                  'or find alternative installation options '
                                  'at spacy.io')
            self.NLP = spacy.load('en')
        elif self.tokenizer == 'bpe':
            if not opt.get('dict_file'):
                raise RuntimeError('--dict-file is mandatory.')
            self.bpehelper = _BPEHelper(opt.get('dict_file') + '.codecs')
        elif self.tokenizer == 'gpt2':
            if self.lower:
                raise ValueError(
                    'Only use --dict-lower false with --dict-tokenizer gpt2')
            if self.maxtokens > 0 or self.minfreq > 0:
                raise ValueError(
                    'You should not filter vocabulary with using --dict-tokenizer gpt2'
                    ' (no --dict-minfreq or --dict-maxtokens).')

            self.gpt2_bpe = Gpt2BpeHelper(opt)
            for each_token in self.gpt2_bpe.list_tokens():
                self.add_token(each_token)
                self.freq[each_token] = 1
        if not shared:
            if self.null_token:
                # fix count for null token to one billion and three
                self.freq[self.null_token] = 1000000003

            if self.start_token:
                # fix count for start of sentence token to one billion and two
                self.freq[self.start_token] = 1000000002

            if self.end_token:
                # fix count for end of sentence token to one billion and one
                self.freq[self.end_token] = 1000000001

            if self.unk_token:
                # fix count for unknown token to one billion
                self.freq[self.unk_token] = 1000000000

            if opt.get('dict_file'):
                self.save_path = opt['dict_file']
Пример #4
0
    def __init__(self, opt):
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'
        if (opt.get('model_file')
                and isfile(opt['model_file'] + '.checkpoint')):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        else:
            pass
            # TODO for testing only
            # raise RuntimeError('WARNING: Reinforcement learning'
            #                    ' must be initialized by a model.checkpoint '
            #                    'file and {} does not exist.'.format(
            #                        opt['model_file'] + '.checkpoint'))
        # Possibly build a dictionary (not all models do this).
        if (opt['dict_build_first']
                and not (opt.get('dict_file') or opt.get('model_file'))):
            raise RuntimeError('WARNING: For train_model, '
                               'please specify either a '
                               'model_file or dict_file.')

        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)

        # Freeze the model for the static dialogue partner
        static_agent = copy.deepcopy(self.agent)
        self.agent.id = ACTIVE

        static_agent.id = STATIC
        freeze_agent(static_agent)

        self.world = create_task(opt, self.agent, static_agent)

        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')

        self.parleys = 0
        self.max_num_epochs = (opt['num_epochs']
                               if opt['num_epochs'] > 0 else float('inf'))

        self.max_train_time = (opt['max_train_time']
                               if opt['max_train_time'] > 0 else float('inf'))

        self.log_every_n_secs = (opt['log_every_n_secs'] if
                                 opt['log_every_n_secs'] > 0 else float('inf'))

        self.val_every_n_secs = (opt['validation_every_n_secs']
                                 if opt['validation_every_n_secs'] > 0 else
                                 float('inf'))

        self.save_every_n_secs = (opt['save_every_n_secs']
                                  if opt['save_every_n_secs'] > 0 else
                                  float('inf'))

        self.val_every_n_epochs = (opt['validation_every_n_epochs']
                                   if opt['validation_every_n_epochs'] > 0 else
                                   float('inf'))

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = (1
                            if opt['validation_metric_mode'] == 'max' else -1)
        self.valid_reports = []
        self.best_valid = None
        if (opt.get('model_file')
                and isfile(opt['model_file'] + '.best_valid')):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)