Beispiel #1
0
    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']]
        # TODO: make this optional
        if 'test_file' in self.dataset:
            vocab_sources.append(self.dataset['test_file'])
        vocab1, vocab2 = self.reader.build_vocabs(vocab_sources,
                                                  min_f=Task._get_min_f(self.config_params),
                                                  vocab_file=self.dataset.get('vocab_file'))

        # To keep the config file simple, share a list between source and destination (tgt)
        features_src = []
        features_tgt = None
        for feature in self.config_params['features']:
            if feature['name'] == 'tgt':
                features_tgt = feature
            else:
                features_src += [feature]

        self.src_embeddings, self.feat2src = self._create_embeddings(embeddings_set, vocab1, features_src)
        # For now, dont allow multiple vocabs of output
        baseline.save_vocabs(self.get_basedir(), self.feat2src)
        self.tgt_embeddings, self.feat2tgt = self._create_embeddings(embeddings_set, {'tgt': vocab2}, [features_tgt])
        baseline.save_vocabs(self.get_basedir(), self.feat2tgt)
        self.tgt_embeddings = self.tgt_embeddings['tgt']
        self.feat2tgt = self.feat2tgt['tgt']
Beispiel #2
0
 def initialize(self, embeddings):
     embeddings_set = mead.utils.index_by_label(embeddings)
     self.dataset = DataDownloader(self.dataset, self.data_download_cache, True).download()
     print("[train file]: {}\n[valid file]: {}\n[test file]: {}\n[vocab file]: {}".format(self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file'], self.dataset.get('vocab_file',"None")))
     vocab_file = self.dataset.get('vocab_file',None)
     if vocab_file is not None:
         vocab1, vocab2 = self.reader.build_vocabs([vocab_file])
     else:
         vocab1, vocab2 = self.reader.build_vocabs([self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']])
     self.embeddings1, self.feat2index1 = self._create_embeddings(embeddings_set, {'word': vocab1})
     self.embeddings2, self.feat2index2 = self._create_embeddings(embeddings_set, {'word': vocab2})
Beispiel #3
0
 def initialize(self, embeddings):
     embeddings = read_config_file_or_json(embeddings, 'embeddings')
     embeddings_set = index_by_label(embeddings)
     self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
     print_dataset_info(self.dataset)
     vocabs = self.reader.build_vocab(
         [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']],
         min_f=Task._get_min_f(self.config_params),
         vocab_file=self.dataset.get('vocab_file')
     )
     self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features'])
     baseline.save_vocabs(self.get_basedir(), self.feat2index)
Beispiel #4
0
    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocab1, vocab2 = self.reader.build_vocabs(
            [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']],
            min_f=Task._get_min_f(self.config_params),
            vocab_file=self.dataset.get('vocab_file')
        )

        # To keep the config file simple, share a list between source and destination (tgt)
        features_src = []
        features_tgt = None
        for feature in self.config_params['features']:
            if feature['name'] == 'tgt':
                features_tgt = feature
            else:
                features_src += [feature]

        self.src_embeddings, self.feat2src = self._create_embeddings(embeddings_set, vocab1, features_src)
        # For now, dont allow multiple vocabs of output
        baseline.save_vocabs(self.get_basedir(), self.feat2src)
        self.tgt_embeddings, self.feat2tgt = self._create_embeddings(embeddings_set, {'tgt': vocab2}, [features_tgt])
        baseline.save_vocabs(self.get_basedir(), self.feat2tgt)
        self.tgt_embeddings = self.tgt_embeddings['tgt']
        self.feat2tgt = self.feat2tgt['tgt']
Beispiel #5
0
    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)

        vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']]
        # TODO: make this optional
        if 'test_file' in self.dataset:
            vocab_sources.append(self.dataset['test_file'])

        vocab, self.labels = self.reader.build_vocab(vocab_sources,
                                                     min_f=Task._get_min_f(self.config_params),
                                                     vocab_file=self.dataset.get('vocab_file'),
                                                     label_file=self.dataset.get('label_file'))
        self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocab, self.config_params['features'])
        baseline.save_vocabs(self.get_basedir(), self.feat2index)
Beispiel #6
0
 def initialize(self, embeddings):
     embeddings = read_config_file_or_json(embeddings, 'embeddings')
     embeddings_set = index_by_label(embeddings)
     self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
     print_dataset_info(self.dataset)
     vocabs = self.reader.build_vocab(
         [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']],
         min_f=Task._get_min_f(self.config_params),
         vocab_file=self.dataset.get('vocab_file')
     )
     self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features'])
     baseline.save_vocabs(self.get_basedir(), self.feat2index)
Beispiel #7
0
 def initialize(self, embeddings):
     embeddings_set = mead.utils.index_by_label(embeddings)
     self.dataset = DataDownloader(self.dataset,
                                   self.data_download_cache).download()
     print("[train file]: {}\n[valid file]: {}\n[test file]: {}".format(
         self.dataset['train_file'], self.dataset['valid_file'],
         self.dataset['test_file']))
     vocab, self.num_words = self.reader.build_vocab([
         self.dataset['train_file'], self.dataset['valid_file'],
         self.dataset['test_file']
     ])
     self.embeddings, self.feat2index = self._create_embeddings(
         embeddings_set, vocab)
Beispiel #8
0
 def initialize(self, embeddings):
     self.dataset = DataDownloader(self.dataset,
                                   self.data_download_cache).download()
     print("[train file]: {}\n[valid file]: {}\n[test file]: {}".format(
         self.dataset['train_file'], self.dataset['valid_file'],
         self.dataset['test_file']))
     embeddings = read_config_file_or_json(embeddings, 'embeddings')
     embeddings_set = index_by_label(embeddings)
     vocabs = self.reader.build_vocab([
         self.dataset['train_file'], self.dataset['valid_file'],
         self.dataset['test_file']
     ])
     self.embeddings, self.feat2index = self._create_embeddings(
         embeddings_set, vocabs)
Beispiel #9
0
class LanguageModelingTask(Task):
    def __init__(self, logging_config, mead_settings_config, **kwargs):
        super(LanguageModelingTask,
              self).__init__(logging_config, mead_settings_config, **kwargs)

    @classmethod
    def task_name(cls):
        return 'lm'

    def _create_task_specific_reader(self):
        self._create_vectorizers()

        reader_params = self.config_params['loader']
        reader_params['nctx'] = reader_params.get(
            'nctx',
            self.config_params.get('nctx', self.config_params.get('nbptt',
                                                                  35)))
        reader_params['clean_fn'] = reader_params.get(
            'clean_fn',
            self.config_params.get('preproc', {}).get('clean_fn'))
        reader_params['mxlen'] = self.vectorizers[self.primary_key].mxlen
        if self.config_params['model'].get('gpus', 1) > 1:
            reader_params['truncate'] = True
        return baseline.reader.create_reader(
            self.task_name(), self.vectorizers,
            self.config_params['preproc'].get('trim', False), **reader_params)

    def _create_backend(self, **kwargs):
        backend = Backend(self.config_params.get('backend', 'tf'))

        if backend.name == 'pytorch':
            self.config_params.get('preproc', {})['trim'] = True

        elif backend.name == 'dy':
            self.config_params.get('preproc', {})['trim'] = True
            import _dynet
            dy_params = _dynet.DynetParams()
            dy_params.from_args()
            dy_params.set_requested_gpus(1)
            if 'autobatchsz' in self.config_params['train']:
                dy_params.set_autobatch(True)
                batched = False
            else:
                batched = True
            dy_params.init()
            backend.params = {
                'pc': _dynet.ParameterCollection(),
                'batched': batched
            }

        backend.load(self.task_name())
        return backend

    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset,
                                      self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocabs = self.reader.build_vocab(
            [
                self.dataset['train_file'], self.dataset['valid_file'],
                self.dataset['test_file']
            ],
            min_f=Task._get_min_f(self.config_params),
            vocab_file=self.dataset.get('vocab_file'))
        self.embeddings, self.feat2index = self._create_embeddings(
            embeddings_set, vocabs, self.config_params['features'])
        baseline.save_vocabs(self.get_basedir(), self.feat2index)

    def _load_dataset(self):
        tgt_key = self.config_params['loader'].get('tgt_key', self.primary_key)
        self.train_data = self.reader.load(self.dataset['train_file'],
                                           self.feat2index,
                                           self.config_params['batchsz'],
                                           tgt_key=tgt_key)
        self.valid_data = self.reader.load(self.dataset['valid_file'],
                                           self.feat2index,
                                           self.config_params.get(
                                               'valid_batchsz',
                                               self.config_params['batchsz']),
                                           tgt_key=tgt_key)
        self.test_data = self.reader.load(self.dataset['test_file'],
                                          self.feat2index,
                                          1,
                                          tgt_key=tgt_key)

    def _create_model(self):

        model = self.config_params['model']
        unif = self.config_params.get('unif', 0.1)
        model['unif'] = model.get('unif', unif)
        model['batchsz'] = self.config_params['batchsz']
        model['tgt_key'] = self.config_params['loader'].get(
            'tgt_key', self.primary_key)
        if self.backend.params is not None:
            for k, v in self.backend.params.items():
                model[k] = v
        return baseline.model.create_lang_model(self.embeddings, **model)

    def train(self):
        self._load_dataset()
        if self.config_params['train'].get('lr_scheduler_type',
                                           None) == 'zaremba':
            first_range = int(
                self.config_params['train']['start_decay_epoch'] *
                self.train_data.steps)
            self.config_params['train']['bounds'] = [first_range] + list(
                np.arange(self.config_params['train']['start_decay_epoch'] + 1,
                          self.config_params['train']['epochs'] + 1,
                          dtype=np.int32) * self.train_data.steps)
        baseline.save_vectorizers(self.get_basedir(), self.vectorizers)
        model = self._create_model()
        baseline.train.fit(model, self.train_data, self.valid_data,
                           self.test_data, **self.config_params['train'])
        baseline.zip_files(self.get_basedir())
        self._close_reporting_hooks()

    @staticmethod
    def _num_steps_per_epoch(num_examples, nctx, batchsz):
        rest = num_examples // batchsz
        return rest // nctx
Beispiel #10
0
class ClassifierTask(Task):

    def __init__(self, mead_settings_config, **kwargs):
        super(ClassifierTask, self).__init__(mead_settings_config, **kwargs)

    @classmethod
    def task_name(cls):
        return 'classify'

    def _create_backend(self, **kwargs):
        backend = Backend(self.config_params.get('backend', 'tf'))
        if backend.name == 'dy':
            import _dynet
            dy_params = _dynet.DynetParams()
            dy_params.from_args()
            dy_params.set_requested_gpus(1)
            if 'autobatchsz' in self.config_params['train']:
                self.config_params['train']['trainer_type'] = 'autobatch'
                dy_params.set_autobatch(True)
                batched = False
            else:
                batched = True
            dy_params.init()
            backend.params = {'pc': _dynet.ParameterCollection(), 'batched': batched}

        backend.load(self.task_name())

        return backend

    def _setup_task(self, **kwargs):
        super(ClassifierTask, self)._setup_task(**kwargs)
        if self.config_params.get('preproc', {}).get('clean', False) is True:
            self.config_params.get('preproc', {})['clean_fn'] = baseline.TSVSeqLabelReader.do_clean
            logger.info('Clean')
        else:
            self.config_params.setdefault('preproc', {})
            self.config_params['preproc']['clean_fn'] = None

    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)

        vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']]
        # TODO: make this optional
        if 'test_file' in self.dataset:
            vocab_sources.append(self.dataset['test_file'])

        vocab, self.labels = self.reader.build_vocab(vocab_sources,
                                                     min_f=Task._get_min_f(self.config_params),
                                                     vocab_file=self.dataset.get('vocab_file'),
                                                     label_file=self.dataset.get('label_file'))
        self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocab, self.config_params['features'])
        baseline.save_vocabs(self.get_basedir(), self.feat2index)

    def _create_model(self):
        unif = self.config_params.get('unif', 0.1)
        model = self.config_params['model']
        model['unif'] = model.get('unif', unif)
        lengths_key = model.get('lengths_key', self.primary_key)
        if lengths_key is not None:
            if not lengths_key.endswith('_lengths'):
                lengths_key = '{}_lengths'.format(lengths_key)
            model['lengths_key'] = lengths_key
        if self.backend.params is not None:
            for k, v in self.backend.params.items():
                model[k] = v
        return baseline.model.create_model(self.embeddings, self.labels, **model)

    def _load_dataset(self):
        read = self.config_params['reader'] if 'reader' in self.config_params else self.config_params['loader']
        sort_key = read.get('sort_key')
        bsz, vbsz, tbsz = Task._get_batchsz(self.config_params)
        self.train_data = self.reader.load(
            self.dataset['train_file'],
            self.feat2index,
            bsz,
            shuffle=True,
            sort_key=sort_key,
        )
        self.valid_data = self.reader.load(
            self.dataset['valid_file'],
            self.feat2index,
            vbsz,
        )
        self.test_data = None
        if 'test_file' in self.dataset:
            self.test_data = self.reader.load(
                self.dataset['test_file'],
                self.feat2index,
                tbsz,
            )
Beispiel #11
0
class EncoderDecoderTask(Task):
    def __init__(self, logging_file, mead_config, **kwargs):
        super(EncoderDecoderTask, self).__init__(logging_file, mead_config,
                                                 **kwargs)
        self.task = None

    def _create_task_specific_reader(self):
        preproc = self.config_params['preproc']
        reader = baseline.create_parallel_corpus_reader(
            preproc['mxlen'], preproc['vec_alloc'], preproc['trim'],
            preproc['word_trans_fn'], **self.config_params['loader'])
        return reader

    def _setup_task(self):

        # If its not vanilla seq2seq, dont bother reversing
        do_reverse = self.config_params['model']['model_type'] == 'default'
        backend = self.config_params.get('backend', 'tensorflow')
        if backend == 'pytorch':
            print('PyTorch backend')
            from baseline.pytorch import long_0_tensor_alloc as vec_alloc
            from baseline.pytorch import tensor_shape as vec_shape
            from baseline.pytorch import tensor_reverse_2nd as rev2nd
            import baseline.pytorch.seq2seq as seq2seq
            self.config_params['preproc']['vec_alloc'] = vec_alloc
            self.config_params['preproc']['vec_shape'] = vec_shape
            src_vec_trans = rev2nd if do_reverse else None
            self.config_params['preproc']['word_trans_fn'] = src_vec_trans
            self.config_params['preproc'][
                'show_ex'] = baseline.pytorch.show_examples_pytorch
            self.config_params['preproc']['trim'] = True
        else:
            import baseline.tf.seq2seq as seq2seq
            import mead.tf
            self.ExporterType = mead.tf.Seq2SeqTensorFlowExporter
            self.config_params['preproc']['vec_alloc'] = np.zeros
            self.config_params['preproc']['vec_shape'] = np.shape
            self.config_params['preproc']['trim'] = False
            src_vec_trans = baseline.reverse_2nd if do_reverse else None
            self.config_params['preproc']['word_trans_fn'] = src_vec_trans
            self.config_params['preproc'][
                'show_ex'] = baseline.tf.show_examples_tf

        self.task = seq2seq

    def initialize(self, embeddings):
        embeddings_set = mead.utils.index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache,
                                      True).download()
        print(
            "[train file]: {}\n[valid file]: {}\n[test file]: {}\n[vocab file]: {}"
            .format(self.dataset['train_file'], self.dataset['valid_file'],
                    self.dataset['test_file'],
                    self.dataset.get('vocab_file', "None")))
        vocab_file = self.dataset.get('vocab_file', None)
        if vocab_file is not None:
            vocab1, vocab2 = self.reader.build_vocabs([vocab_file])
        else:
            vocab1, vocab2 = self.reader.build_vocabs([
                self.dataset['train_file'], self.dataset['valid_file'],
                self.dataset['test_file']
            ])
        self.embeddings1, self.feat2index1 = self._create_embeddings(
            embeddings_set, {'word': vocab1})
        self.embeddings2, self.feat2index2 = self._create_embeddings(
            embeddings_set, {'word': vocab2})

    def _load_dataset(self):
        self.train_data = self.reader.load(self.dataset['train_file'],
                                           self.feat2index1['word'],
                                           self.feat2index2['word'],
                                           self.config_params['batchsz'],
                                           shuffle=True)
        self.valid_data = self.reader.load(self.dataset['valid_file'],
                                           self.feat2index1['word'],
                                           self.feat2index2['word'],
                                           self.config_params['batchsz'],
                                           shuffle=True)
        self.test_data = self.reader.load(
            self.dataset['test_file'],
            self.feat2index1['word'], self.feat2index2['word'],
            self.config_params.get('test_batchsz', 1))

    def _create_model(self):
        return self.task.create_model(self.embeddings1['word'],
                                      self.embeddings2['word'],
                                      **self.config_params['model'])

    def train(self):

        num_ex = self.config_params['num_valid_to_show']

        if num_ex > 0:
            print('Showing examples')
            preproc = self.config_params['preproc']
            show_ex_fn = preproc['show_ex']
            rlut1 = baseline.revlut(self.feat2index1['word'])
            rlut2 = baseline.revlut(self.feat2index2['word'])
            self.config_params['train'][
                'after_train_fn'] = lambda model: show_ex_fn(model,
                                                             self.valid_data,
                                                             rlut1,
                                                             rlut2,
                                                             self.embeddings2[
                                                                 'word'],
                                                             preproc['mxlen'],
                                                             False,
                                                             0,
                                                             num_ex,
                                                             reverse=False)
        super(EncoderDecoderTask, self).train()
Beispiel #12
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--dataset_key",
                        type=str,
                        default='wikitext-2',
                        help="key from DATASETS global")
    parser.add_argument("--train_file",
                        type=str,
                        help='Optional file path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        help='Optional file path to use for valid file')
    parser.add_argument("--dataset_cache",
                        type=str,
                        default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--cache_features", type=str2bool, default=True)
    parser.add_argument("--d_model",
                        type=int,
                        default=410,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2100, help="FFN dimension")
    parser.add_argument("--num_heads",
                        type=int,
                        default=10,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=16,
                        help="Number of layers")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch Size")
    parser.add_argument("--tokens",
                        choices=["words", "chars", "bpe", "wordpiece"],
                        default="wordpiece",
                        help="What tokens to use")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="If using subwords, pass this",
                        default='bert-base-cased')
    parser.add_argument(
        "--subword_vocab_file",
        type=str,
        help="If using subwords with separate vocab file, pass here")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=0.25,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.0,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=20,
                        help="Num training epochs")
    parser.add_argument(
        "--restart_from",
        type=str,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=1000,
                        help="Num warmup steps")
    parser.add_argument("--mlm",
                        type=str2bool,
                        default=False,
                        help="Use Masked Language Model (MLM) objective")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )
    parser.add_argument("--chars_per_word",
                        type=int,
                        default=40,
                        help="How many max characters per word")

    args = parser.parse_args()

    if args.train_file and not args.valid_file:
        logger.error(
            "If you provide a train_file, you must provide a valid_file")
        return

    if not args.train_file and args.valid_file:
        logger.error(
            "If you provide a valid_file, you must also provide a train_file")
        return

    if args.tokens == "chars" and args.mlm:
        logger.error(
            "Character composition cannot currently be used with the MLM objective"
        )

    if args.basedir is None:
        args.basedir = 'transformer-{}-{}-{}'.format(args.dataset_key,
                                                     args.tokens, os.getpid())
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.info("Cache directory [%s]", args.dataset_cache)

    args.distributed = args.distributed or int(os.environ.get("WORLD_SIZE",
                                                              1)) > 1

    if args.distributed:
        if args.local_rank == -1:
            # https://github.com/kubeflow/pytorch-operator/issues/128
            # https://github.com/pytorch/examples/blob/master/imagenet/main.py
            logger.info("Setting local rank to RANK env variable")
            args.local_rank = int(os.environ['RANK'])
        logger.warning("Local rank (%d)", args.local_rank)
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    if args.train_file:
        dataset = {
            'train_file': args.train_file,
            'valid_file': args.valid_file
        }
    else:
        dataset = DataDownloader(DATASETS[args.dataset_key],
                                 args.dataset_cache).download()
    reader = create_reader(args.tokens, args.nctx, args.chars_per_word,
                           args.subword_model_file, args.subword_vocab_file)

    preproc_data = load_embed_and_vocab(args.tokens, reader, dataset,
                                        args.dataset_key, args.d_model,
                                        args.cache_features)

    vocabs = preproc_data['vocabs']
    if args.mlm:
        mask_from = vocabs['x']
        vocab_size = len(mask_from)
        mask_value = mask_from.get("[MASK]", mask_from.get("<MASK>", -1))
        if mask_value == -1:
            logger.error(
                "We could not find a suitable masking token in the vocab")
            return
    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs['x'], os.path.join(args.basedir, 'vocabs.json'))
    embeddings = preproc_data['embeddings']
    valid_num_words = preproc_data['valid_num_words']
    tgt_key = preproc_data['tgt_key']
    logger.info("Loaded embeddings")

    train_set = load_data(args.tokens, reader, dataset, 'train_file', vocabs,
                          args.cache_features)
    valid_set = load_data(args.tokens, reader, dataset, 'valid_file', vocabs,
                          args.cache_features)
    logger.info("valid. tokens [%s], valid. words [%s]",
                valid_set.tensors[-1].numel(), valid_num_words)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set) if args.distributed else None
    train_loader = DataLoader(train_set,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              shuffle=(not args.distributed))
    valid_loader = DataLoader(valid_set,
                              batch_size=args.batch_size,
                              shuffle=False)
    logger.info("Loaded datasets")

    model = TransformerLanguageModel.create(
        embeddings,
        hsz=args.d_model,
        d_ff=args.d_ff,
        tie_weights=(args.tokens != 'chars'),
        dropout=args.dropout,
        gpu=False,
        num_heads=args.num_heads,
        layers=args.num_layers,
        src_keys=['x'],
        tgt_key=tgt_key)
    model.to(args.device)
    loss_function = model.create_loss()
    loss_function.to(args.device)

    logger.info("Loaded model and loss")

    steps_per_epoch = len(train_loader)
    update_on = steps_per_epoch // 10
    cosine_decay = CosineDecaySchedulerPyTorch(len(train_loader) * args.epochs,
                                               lr=args.lr)
    linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRScheduler(linear_warmup, cosine_decay, lr=args.lr)

    global_step = 0
    start_epoch = 0
    if args.restart_from:
        model.load_state_dict(torch.load(args.restart_from))
        start_epoch = int(args.restart_from.split("-")[-1].split(".")[0]) - 1
        global_step = (start_epoch + 1) * steps_per_epoch
        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            args.restart_from, global_step, start_epoch + 1)
    optimizer = OptimizerManager(model,
                                 global_step,
                                 optim='adam',
                                 lr=args.lr,
                                 lr_function=lr_sched,
                                 weight_decay=args.weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # Prepare model for distributed training if needed
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        logger.info("Model located on %d", args.local_rank)

    # This is the training loop
    for epoch in range(start_epoch, args.epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()

        if args.distributed:
            train_sampler.set_epoch(epoch)

        start = time.time()
        model.train()
        for i, batch in enumerate(train_loader):
            x, y = batch
            inputs = {'x': x.to(args.device)}
            labels = y.to(args.device)
            if args.mlm:
                # Replace 15% of tokens
                masked_indices = torch.bernoulli(torch.full(
                    labels.shape, 0.15)).byte()
                # Anything not masked is 0 so no loss
                labels[~masked_indices] = 0
                # Of the masked items, mask 80% of them with [MASK]
                indices_replaced = torch.bernoulli(
                    torch.full(labels.shape, 0.8)).byte() & masked_indices
                inputs[indices_replaced] = mask_value
                # Replace 10% of them with random words, rest preserved for auto-encoding
                indices_random = torch.bernoulli(torch.full(
                    labels.shape,
                    0.5)).byte() & masked_indices & ~indices_replaced
                random_words = torch.randint(vocab_size,
                                             labels.shape,
                                             dtype=torch.long,
                                             device=args.device)
                inputs[indices_random] = random_words[indices_random]

            labels = labels.transpose(0, 1).contiguous()
            logits = model(inputs, None)[0].transpose(0, 1).contiguous()
            if args.mlm:
                loss = loss_function(logits, labels)
            else:
                shift_logits = logits[:-1]
                shift_labels = labels[1:]
                loss = loss_function(shift_logits, shift_labels)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % update_on == 0:
                logging.info(avg_loss)

        # How much time elapsed in minutes
        elapsed = (time.time() - start) / 60
        train_token_loss = avg_loss.avg
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        train_token_ppl = math.exp(train_token_loss)
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_loss'] = train_token_loss
        metrics['train_ppl'] = train_token_ppl
        model_base = os.path.join(args.basedir, 'checkpoint')
        avg_valid_loss = Average('average_valid_loss')
        start = time.time()
        model.eval()
        for batch in valid_loader:
            with torch.no_grad():
                x, y = batch
                inputs = {'x': x.to(args.device)}
                labels = y.to(args.device)

                if args.mlm:
                    # Replace 15% of tokens
                    masked_indices = torch.bernoulli(
                        torch.full(labels.shape, 0.15)).byte()
                    # Anything not masked is 0 so no loss
                    labels[~masked_indices] = 0
                    # Of the masked items, mask 80% of them with [MASK]
                    indices_replaced = torch.bernoulli(
                        torch.full(labels.shape, 0.8)).byte() & masked_indices
                    inputs[indices_replaced] = mask_value
                    # Replace 10% of them with random work
                    indices_random = torch.bernoulli(
                        torch.full(
                            labels.shape,
                            0.5)).byte() & masked_indices & ~indices_replaced
                    random_words = torch.randint(vocab_size,
                                                 labels.shape,
                                                 dtype=torch.long,
                                                 device=args.device)
                    inputs[indices_random] = random_words[indices_random]

                labels = labels.transpose(0, 1).contiguous()
                logits = model(inputs, None)[0].transpose(0, 1).contiguous()
                if args.mlm:
                    loss = loss_function(logits, labels)
                else:
                    shift_logits = logits[:-1]
                    shift_labels = labels[1:]
                    loss = loss_function(shift_logits, shift_labels)
                avg_valid_loss.update(loss.item())

        valid_token_loss = avg_valid_loss.avg
        valid_token_ppl = math.exp(valid_token_loss)

        elapsed = (time.time() - start) / 60
        metrics['valid_elapsed_min'] = elapsed

        metrics['average_valid_loss'] = valid_token_loss
        if args.tokens in ['bpe', 'wordpiece']:
            metrics['valid_token_ppl'] = valid_token_ppl
            metrics['average_valid_word_ppl'] = math.exp(
                valid_token_loss * valid_set.tensors[-1].numel() /
                valid_num_words)
        else:
            metrics['average_valid_word_ppl'] = valid_token_ppl
        logger.info(metrics)

        if args.local_rank < 1:

            # Should probably do this more often
            checkpoint_name = checkpoint_for(model_base, epoch + 1)
            logger.info("Creating checkpoint: %s", checkpoint_name)
            if args.distributed:
                torch.save(model.module.state_dict(), checkpoint_name)
            else:
                torch.save(model.state_dict(), checkpoint_name)

            rm_old_checkpoints(model_base, epoch + 1)
Beispiel #13
0
class ClassifierTask(Task):

    def __init__(self, mead_settings_config, **kwargs):
        super(ClassifierTask, self).__init__(mead_settings_config, **kwargs)

    @classmethod
    def task_name(cls):
        return 'classify'

    def _create_backend(self, **kwargs):
        backend = Backend(self.config_params.get('backend', 'tf'))
        if backend.name == 'dy':
            import _dynet
            dy_params = _dynet.DynetParams()
            dy_params.from_args()
            dy_params.set_requested_gpus(1)
            if 'autobatchsz' in self.config_params['train']:
                self.config_params['train']['trainer_type'] = 'autobatch'
                dy_params.set_autobatch(True)
                batched = False
            else:
                batched = True
            dy_params.init()
            backend.params = {'pc': _dynet.ParameterCollection(), 'batched': batched}

        backend.load(self.task_name())

        return backend

    def _setup_task(self, **kwargs):
        super(ClassifierTask, self)._setup_task(**kwargs)
        if self.config_params.get('preproc', {}).get('clean', False) is True:
            self.config_params.get('preproc', {})['clean_fn'] = baseline.TSVSeqLabelReader.do_clean
            logger.info('Clean')
        else:
            self.config_params['preproc'] = {}
            self.config_params['preproc']['clean_fn'] = None

    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocab, self.labels = self.reader.build_vocab(
            [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']],
            min_f=Task._get_min_f(self.config_params),
            vocab_file=self.dataset.get('vocab_file'),
            label_file=self.dataset.get('label_file')
        )
        self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocab, self.config_params['features'])
        baseline.save_vocabs(self.get_basedir(), self.feat2index)

    def _create_model(self):
        unif = self.config_params.get('unif', 0.1)
        model = self.config_params['model']
        model['unif'] = model.get('unif', unif)
        lengths_key = model.get('lengths_key', self.primary_key)
        if lengths_key is not None:
            if not lengths_key.endswith('_lengths'):
                lengths_key = '{}_lengths'.format(lengths_key)
            model['lengths_key'] = lengths_key
        if self.backend.params is not None:
            for k, v in self.backend.params.items():
                model[k] = v
        return baseline.model.create_model(self.embeddings, self.labels, **model)

    def _load_dataset(self):
        read = self.config_params['reader'] if 'reader' in self.config_params else self.config_params['loader']
        sort_key = read.get('sort_key')
        bsz, vbsz, tbsz = Task._get_batchsz(self.config_params)
        self.train_data = self.reader.load(
            self.dataset['train_file'],
            self.feat2index,
            bsz,
            shuffle=True,
            sort_key=sort_key,
        )
        self.valid_data = self.reader.load(
            self.dataset['valid_file'],
            self.feat2index,
            vbsz,
        )
        self.test_data = self.reader.load(
            self.dataset['test_file'],
            self.feat2index,
            tbsz,
        )
Beispiel #14
0
class EncoderDecoderTask(Task):

    def __init__(self, mead_settings_config, **kwargs):
        super(EncoderDecoderTask, self).__init__(mead_settings_config, **kwargs)

    @classmethod
    def task_name(cls):
        return 'seq2seq'

    def _create_backend(self, **kwargs):
        backend = Backend(self.config_params.get('backend', 'tf'))
        if 'preproc' not in self.config_params:
            self.config_params['preproc'] = {}
        self.config_params['preproc']['show_ex'] = show_examples
        if backend.name == 'pytorch':
            self.config_params['preproc']['trim'] = True
        elif backend.name == 'dy':
            import _dynet
            dy_params = _dynet.DynetParams()
            dy_params.from_args()
            dy_params.set_requested_gpus(1)
            if 'autobatchsz' in self.config_params['train']:
                self.config_params['train']['trainer_type'] = 'autobatch'
                dy_params.set_autobatch(True)
                batched = False
            else:
                batched = True
            dy_params.init()
            backend.params = {'pc': _dynet.ParameterCollection(), 'batched': batched}
            self.config_params['preproc']['trim'] = True
        else:
            self.config_params['preproc']['trim'] = True
        backend.load(self.task_name())

        return backend

    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocab1, vocab2 = self.reader.build_vocabs(
            [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']],
            min_f=Task._get_min_f(self.config_params),
            vocab_file=self.dataset.get('vocab_file')
        )

        # To keep the config file simple, share a list between source and destination (tgt)
        features_src = []
        features_tgt = None
        for feature in self.config_params['features']:
            if feature['name'] == 'tgt':
                features_tgt = feature
            else:
                features_src += [feature]

        self.src_embeddings, self.feat2src = self._create_embeddings(embeddings_set, vocab1, features_src)
        # For now, dont allow multiple vocabs of output
        baseline.save_vocabs(self.get_basedir(), self.feat2src)
        self.tgt_embeddings, self.feat2tgt = self._create_embeddings(embeddings_set, {'tgt': vocab2}, [features_tgt])
        baseline.save_vocabs(self.get_basedir(), self.feat2tgt)
        self.tgt_embeddings = self.tgt_embeddings['tgt']
        self.feat2tgt = self.feat2tgt['tgt']

    def _load_dataset(self):
        bsz, vbsz, tbsz = Task._get_batchsz(self.config_params)
        self.train_data = self.reader.load(
            self.dataset['train_file'],
            self.feat2src, self.feat2tgt,
            bsz,
            shuffle=True,
            sort_key='{}_lengths'.format(self.primary_key)
        )

        self.valid_data = self.reader.load(
            self.dataset['valid_file'],
            self.feat2src, self.feat2tgt,
            vbsz,
            shuffle=True
        )
        self.test_data = self.reader.load(
            self.dataset['test_file'],
            self.feat2src, self.feat2tgt,
            tbsz,
        )

    def _create_model(self):
        self.config_params['model']["unif"] = self.config_params["unif"]
        model = self.config_params['model']
        unif = self.config_params.get('unif', 0.1)
        model['unif'] = model.get('unif', unif)
        lengths_key = model.get('src_lengths_key', self.primary_key)
        if lengths_key is not None:
            if not lengths_key.endswith('_lengths'):
                lengths_key = '{}_lengths'.format(lengths_key)
            model['src_lengths_key'] = lengths_key
        if self.backend.params is not None:
            for k, v in self.backend.params.items():
                model[k] = v
        return baseline.model.create_seq2seq_model(self.src_embeddings, self.tgt_embeddings, **self.config_params['model'])

    def train(self, checkpoint=None):

        num_ex = self.config_params['num_valid_to_show']

        rlut1 = revlut(self.feat2src[self.primary_key])
        rlut2 = revlut(self.feat2tgt)
        if num_ex > 0:
            logger.info('Showing examples')
            preproc = self.config_params.get('preproc', {})
            show_ex_fn = preproc['show_ex']
            self.config_params['train']['after_train_fn'] = lambda model: show_ex_fn(model,
                                                                                     self.valid_data, rlut1, rlut2,
                                                                                     self.feat2tgt,
                                                                                     preproc['mxlen'], False, 0,
                                                                                     num_ex, reverse=False)
        self.config_params['train']['tgt_rlut'] = rlut2
        return super(EncoderDecoderTask, self).train(checkpoint)
Beispiel #15
0
class TaggerTask(Task):

    def __init__(self, mead_settings_config, **kwargs):
        super(TaggerTask, self).__init__(mead_settings_config, **kwargs)

    @classmethod
    def task_name(cls):
        return 'tagger'

    def _create_backend(self, **kwargs):
        backend = Backend(self.config_params.get('backend', 'tf'))
        if 'preproc' not in self.config_params:
            self.config_params['preproc'] = {}
        if backend.name == 'pytorch':
            self.config_params['preproc']['trim'] = True
        elif backend.name == 'dy':
            import _dynet
            dy_params = _dynet.DynetParams()
            dy_params.from_args()
            dy_params.set_requested_gpus(1)
            if 'autobatchsz' in self.config_params['train']:
                dy_params.set_autobatch(True)
            else:
                raise Exception('Tagger currently only supports autobatching.'
                                'Change "batchsz" to 1 and under "train", set "autobatchsz" to your desired batchsz')
            dy_params.init()
            backend.params = {'pc': _dynet.ParameterCollection(), 'batched': False}
            self.config_params['preproc']['trim'] = True
        else:
            self.config_params['preproc']['trim'] = False

        backend.load(self.task_name())

        return backend

    def initialize(self, embeddings):
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']]
        # TODO: make this optional
        if 'test_file' in self.dataset:
            vocab_sources.append(self.dataset['test_file'])

        vocabs = self.reader.build_vocab(vocab_sources, min_f=Task._get_min_f(self.config_params),
                                         vocab_file
                                         =self.dataset.get('vocab_file'))
        self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features'])
        baseline.save_vocabs(self.get_basedir(), self.feat2index)

    def _create_model(self):
        labels = self.reader.label2index
        span_type = self.config_params['train'].get('span_type')
        constrain = bool(self.config_params['model'].get('constrain_decode', False))
        if span_type is None and constrain:
            logger.warning("Constrained Decoding was set but no span type could be found so no Constraints will be applied.")
        self.config_params['model']['span_type'] = span_type
        if span_type is not None and constrain:
            self.config_params['model']['constraint'] = self.backend.transition_mask(
                labels, span_type, Offsets.GO, Offsets.EOS, Offsets.PAD
            )

        model = self.config_params['model']
        unif = self.config_params.get('unif', 0.1)
        model['unif'] = model.get('unif', unif)

        lengths_key = model.get('lengths_key', self.primary_key)
        if lengths_key is not None:
            if not lengths_key.endswith('_lengths'):
                lengths_key = '{}_lengths'.format(lengths_key)
            model['lengths_key'] = lengths_key

        if self.backend.params is not None:
            for k, v in self.backend.params.items():
                model[k] = v
        return baseline.model.create_tagger_model(self.embeddings, labels, **self.config_params['model'])

    def _load_dataset(self):
        # TODO: get rid of sort_key=self.primary_key in favor of something explicit?
        bsz, vbsz, tbsz = Task._get_batchsz(self.config_params)
        self.train_data, _ = self.reader.load(
            self.dataset['train_file'],
            self.feat2index,
            bsz,
            shuffle=True,
            sort_key='{}_lengths'.format(self.primary_key)
        )
        self.valid_data, _ = self.reader.load(
            self.dataset['valid_file'],
            self.feat2index,
            vbsz,
            sort_key=None
        )
        self.test_data = None
        self.txts = None
        if 'test_file' in self.dataset:
            self.test_data, self.txts = self.reader.load(
                self.dataset['test_file'],
                self.feat2index,
                tbsz,
                shuffle=False,
                sort_key=None
            )


    def train(self, checkpoint=None):
        self._load_dataset()
        baseline.save_vectorizers(self.get_basedir(), self.vectorizers)
        model = self._create_model()
        conll_output = self.config_params.get("conll_output", None)
        train_params = self.config_params['train']
        train_params['checkpoint'] = checkpoint
        metrics = baseline.train.fit(model, self.train_data, self.valid_data, self.test_data,
                           conll_output=conll_output,
                           txts=self.txts, **train_params)
        baseline.zip_files(self.get_basedir())
        self._close_reporting_hooks()
        return model, metrics
Beispiel #16
0
class ClassifierTask(Task):
    def __init__(self, logging_config, mead_settings_config, **kwargs):
        super(ClassifierTask, self).__init__(logging_config,
                                             mead_settings_config, **kwargs)

    @classmethod
    def task_name(cls):
        return 'classify'

    def _create_backend(self, **kwargs):
        backend = Backend(self.config_params.get('backend', 'tf'))
        if backend.name == 'dy':
            import _dynet
            dy_params = _dynet.DynetParams()
            dy_params.from_args()
            dy_params.set_requested_gpus(1)
            if 'autobatchsz' in self.config_params['train']:
                self.config_params['train']['trainer_type'] = 'autobatch'
                dy_params.set_autobatch(True)
                batched = False
            else:
                batched = True
            dy_params.init()
            backend.params = {
                'pc': _dynet.ParameterCollection(),
                'batched': batched
            }
        elif backend.name == 'tf':
            # FIXME this should be registered as well!
            exporter_type = kwargs.get('exporter_type', 'default')
            if exporter_type == 'default':
                from mead.tf.exporters import ClassifyTensorFlowExporter
                backend.exporter = ClassifyTensorFlowExporter
            elif exporter_type == 'preproc':
                from mead.tf.preproc_exporters import ClassifyTensorFlowPreProcExporter
                import mead.tf.preprocessors
                backend.exporter = ClassifyTensorFlowPreProcExporter

        backend.load(self.task_name())

        return backend

    def _setup_task(self, **kwargs):
        super(ClassifierTask, self)._setup_task(**kwargs)
        if self.config_params.get('preproc', {}).get('clean', False) is True:
            self.config_params.get(
                'preproc',
                {})['clean_fn'] = baseline.TSVSeqLabelReader.do_clean
            print('Clean')
        else:
            self.config_params['preproc'] = {}
            self.config_params['preproc']['clean_fn'] = None

    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset,
                                      self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocab, self.labels = self.reader.build_vocab(
            [
                self.dataset['train_file'], self.dataset['valid_file'],
                self.dataset['test_file']
            ],
            min_f=Task._get_min_f(self.config_params),
            vocab_file=self.dataset.get('vocab_file'),
            label_file=self.dataset.get('label_file'))
        self.embeddings, self.feat2index = self._create_embeddings(
            embeddings_set, vocab, self.config_params['features'])
        baseline.save_vocabs(self.get_basedir(), self.feat2index)

    def _create_model(self):
        unif = self.config_params.get('unif', 0.1)
        model = self.config_params['model']
        model['unif'] = model.get('unif', unif)
        lengths_key = model.get('lengths_key', self.primary_key)
        if lengths_key is not None:
            if not lengths_key.endswith('_lengths'):
                lengths_key = '{}_lengths'.format(lengths_key)
            model['lengths_key'] = lengths_key
        if self.backend.params is not None:
            for k, v in self.backend.params.items():
                model[k] = v
        return baseline.model.create_model(self.embeddings, self.labels,
                                           **model)

    def _load_dataset(self):
        self.train_data = self.reader.load(
            self.dataset['train_file'],
            self.feat2index,
            self.config_params['batchsz'],
            shuffle=True,
            sort_key=self.config_params['loader'].get('sort_key'))
        self.valid_data = self.reader.load(
            self.dataset['valid_file'], self.feat2index,
            self.config_params.get('valid_batchsz',
                                   self.config_params['batchsz']))
        self.test_data = self.reader.load(
            self.dataset['test_file'], self.feat2index,
            self.config_params.get('test_batchsz', 1))
Beispiel #17
0
class LanguageModelingTask(Task):

    def __init__(self, mead_settings_config, **kwargs):
        super(LanguageModelingTask, self).__init__(mead_settings_config, **kwargs)

    @classmethod
    def task_name(cls):
        return 'lm'

    def _create_task_specific_reader(self):
        self._create_vectorizers()

        reader_params = self.config_params['reader'] if 'reader' in self.config_params else self.config_params['loader']
        reader_params['nctx'] = reader_params.get('nctx', self.config_params.get('nctx', self.config_params.get('nbptt', 35)))
        reader_params['clean_fn'] = reader_params.get('clean_fn', self.config_params.get('preproc', {}).get('clean_fn'))
        if reader_params['clean_fn'] is not None and self.config_params['dataset'] != 'SST2':
            logger.warning('Warning: A reader preprocessing function (%s) is active, it is recommended that all data preprocessing is done outside of baseline to insure data at inference time matches data at training time.', reader_params['clean_fn'])
        reader_params['mxlen'] = self.vectorizers[self.primary_key].mxlen
        if self.config_params['model'].get('gpus', 1) > 1:
            reader_params['truncate'] = True
        return baseline.reader.create_reader(self.task_name(), self.vectorizers, self.config_params['preproc'].get('trim', False), **reader_params)

    def _create_backend(self, **kwargs):
        backend = Backend(self.config_params.get('backend', 'tf'))

        if backend.name == 'pytorch':
            self.config_params.get('preproc', {})['trim'] = True

        elif backend.name == 'dy':
            self.config_params.get('preproc', {})['trim'] = True
            import _dynet
            dy_params = _dynet.DynetParams()
            dy_params.from_args()
            dy_params.set_requested_gpus(1)
            if 'autobatchsz' in self.config_params['train']:
                dy_params.set_autobatch(True)
                batched = False
            else:
                batched = True
            dy_params.init()
            backend.params = {'pc': _dynet.ParameterCollection(), 'batched': batched}

        backend.load(self.task_name())
        return backend

    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocabs = self.reader.build_vocab(
            [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']],
            min_f=Task._get_min_f(self.config_params),
            vocab_file=self.dataset.get('vocab_file')
        )
        self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features'])
        baseline.save_vocabs(self.get_basedir(), self.feat2index)

    def _load_dataset(self):
        read = self.config_params['reader'] if 'reader' in self.config_params else self.config_params['loader']
        tgt_key = read.get('tgt_key', self.primary_key)
        bsz, vbsz, tbsz = Task._get_batchsz(self.config_params)
        self.train_data = self.reader.load(
            self.dataset['train_file'],
            self.feat2index,
            bsz,
            tgt_key=tgt_key
        )
        self.valid_data = self.reader.load(
            self.dataset['valid_file'],
            self.feat2index,
            vbsz,
            tgt_key=tgt_key
        )
        self.test_data = self.reader.load(
            self.dataset['test_file'],
            self.feat2index,
            1,
            tgt_key=tgt_key
        )

    def _create_model(self):

        model = self.config_params['model']
        unif = self.config_params.get('unif', 0.1)
        model['unif'] = model.get('unif', unif)
        model['batchsz'] = self.config_params['batchsz']
        model['tgt_key'] = self.config_params.get('reader', self.config_params.get('loader', {})).get('tgt_key', self.primary_key)
        if self.backend.params is not None:
            for k, v in self.backend.params.items():
                model[k] = v
        return baseline.model.create_lang_model(self.embeddings, **model)

    def train(self, checkpoint=None):
        self._load_dataset()
        if self.config_params['train'].get('lr_scheduler_type', None) == 'zaremba':
            first_range = int(self.config_params['train']['start_decay_epoch'] * self.train_data.steps)
            self.config_params['train']['bounds'] = [first_range] + list(
                np.arange(
                    self.config_params['train']['start_decay_epoch'] + 1,
                    self.config_params['train']['epochs'] + 1,
                    dtype=np.int32
                ) * self.train_data.steps
            )
        baseline.save_vectorizers(self.get_basedir(), self.vectorizers)
        model = self._create_model()
        train_params = self.config_params['train']
        train_params['checkpoint'] = checkpoint
        metrics = baseline.train.fit(model, self.train_data, self.valid_data, self.test_data, **train_params)
        baseline.zip_files(self.get_basedir())
        self._close_reporting_hooks()
        return model, metrics

    @staticmethod
    def _num_steps_per_epoch(num_examples, nctx, batchsz):
        rest = num_examples // batchsz
        return rest // nctx
Beispiel #18
0
class EncoderDecoderTask(Task):

    def __init__(self, mead_settings_config, **kwargs):
        super(EncoderDecoderTask, self).__init__(mead_settings_config, **kwargs)

    @classmethod
    def task_name(cls):
        return 'seq2seq'

    def _create_backend(self, **kwargs):
        backend = Backend(self.config_params.get('backend', 'tf'))
        if 'preproc' not in self.config_params:
            self.config_params['preproc'] = {}
        self.config_params['preproc']['show_ex'] = show_examples
        if backend.name == 'pytorch':
            self.config_params['preproc']['trim'] = True
        elif backend.name == 'dy':
            import _dynet
            dy_params = _dynet.DynetParams()
            dy_params.from_args()
            dy_params.set_requested_gpus(1)
            if 'autobatchsz' in self.config_params['train']:
                self.config_params['train']['trainer_type'] = 'autobatch'
                dy_params.set_autobatch(True)
                batched = False
            else:
                batched = True
            dy_params.init()
            backend.params = {'pc': _dynet.ParameterCollection(), 'batched': batched}
            self.config_params['preproc']['trim'] = True
        else:
            self.config_params['preproc']['trim'] = True
        backend.load(self.task_name())

        return backend

    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']]
        # TODO: make this optional
        if 'test_file' in self.dataset:
            vocab_sources.append(self.dataset['test_file'])
        vocab1, vocab2 = self.reader.build_vocabs(vocab_sources,
                                                  min_f=Task._get_min_f(self.config_params),
                                                  vocab_file=self.dataset.get('vocab_file'))

        # To keep the config file simple, share a list between source and destination (tgt)
        features_src = []
        features_tgt = None
        for feature in self.config_params['features']:
            if feature['name'] == 'tgt':
                features_tgt = feature
            else:
                features_src += [feature]

        self.src_embeddings, self.feat2src = self._create_embeddings(embeddings_set, vocab1, features_src)
        # For now, dont allow multiple vocabs of output
        baseline.save_vocabs(self.get_basedir(), self.feat2src)
        self.tgt_embeddings, self.feat2tgt = self._create_embeddings(embeddings_set, {'tgt': vocab2}, [features_tgt])
        baseline.save_vocabs(self.get_basedir(), self.feat2tgt)
        self.tgt_embeddings = self.tgt_embeddings['tgt']
        self.feat2tgt = self.feat2tgt['tgt']

    def _load_dataset(self):
        bsz, vbsz, tbsz = Task._get_batchsz(self.config_params)
        self.train_data = self.reader.load(
            self.dataset['train_file'],
            self.feat2src, self.feat2tgt,
            bsz,
            shuffle=True,
            sort_key='{}_lengths'.format(self.primary_key)
        )

        self.valid_data = self.reader.load(
            self.dataset['valid_file'],
            self.feat2src, self.feat2tgt,
            vbsz,
            shuffle=True
        )
        self.test_data = None
        if 'test_file' in self.dataset:
            self.test_data = self.reader.load(
                self.dataset['test_file'],
                self.feat2src, self.feat2tgt,
                tbsz,
            )


    def _create_model(self):
        self.config_params['model']["unif"] = self.config_params["unif"]
        model = self.config_params['model']
        unif = self.config_params.get('unif', 0.1)
        model['unif'] = model.get('unif', unif)
        lengths_key = model.get('src_lengths_key', self.primary_key)
        if lengths_key is not None:
            if not lengths_key.endswith('_lengths'):
                lengths_key = '{}_lengths'.format(lengths_key)
            model['src_lengths_key'] = lengths_key
        if self.backend.params is not None:
            for k, v in self.backend.params.items():
                model[k] = v
        return baseline.model.create_seq2seq_model(self.src_embeddings, self.tgt_embeddings, **self.config_params['model'])

    def train(self, checkpoint=None):

        num_ex = self.config_params['num_valid_to_show']

        rlut1 = revlut(self.feat2src[self.primary_key])
        rlut2 = revlut(self.feat2tgt)
        if num_ex > 0:
            logger.info('Showing examples')
            preproc = self.config_params.get('preproc', {})
            show_ex_fn = preproc['show_ex']
            self.config_params['train']['after_train_fn'] = lambda model: show_ex_fn(model,
                                                                                     self.valid_data, rlut1, rlut2,
                                                                                     self.feat2tgt,
                                                                                     preproc['mxlen'], False, 0,
                                                                                     num_ex, reverse=False)
        self.config_params['train']['tgt_rlut'] = rlut2
        return super(EncoderDecoderTask, self).train(checkpoint)
Beispiel #19
0
parser.add_argument('--datasets',
                    help='json library of dataset labels',
                    default='config/datasets.json',
                    type=convert_path)
parser.add_argument('--embeddings',
                    help='json library of embeddings',
                    default='config/embeddings.json',
                    type=convert_path)
args = parser.parse_args()

datasets = read_json(args.datasets)
datasets = index_by_label(datasets)

for name, d in datasets.items():
    print(name)
    try:
        DataDownloader(d, args.cache).download()
    except Exception as e:
        print(e)

emb = read_json(args.embeddings)
emb = index_by_label(emb)

for name, e in emb.items():
    print(name)
    try:
        EmbeddingDownloader(e['file'], e['dsz'], e.get('sha1'),
                            args.cache).download()
    except Exception as e:
        print(e)
Beispiel #20
0
class LanguageModelingTask(Task):

    def __init__(self, mead_settings_config, **kwargs):
        super(LanguageModelingTask, self).__init__(mead_settings_config, **kwargs)

    @classmethod
    def task_name(cls):
        return 'lm'

    def _create_task_specific_reader(self):
        self._create_vectorizers()

        reader_params = self.config_params['reader'] if 'reader' in self.config_params else self.config_params['loader']
        reader_params['nctx'] = reader_params.get('nctx', self.config_params.get('nctx', self.config_params.get('nbptt', 35)))
        reader_params['clean_fn'] = reader_params.get('clean_fn', self.config_params.get('preproc', {}).get('clean_fn'))
        if reader_params['clean_fn'] is not None and self.config_params['dataset'] != 'SST2':
            logger.warning('Warning: A reader preprocessing function (%s) is active, it is recommended that all data preprocessing is done outside of baseline to insure data at inference time matches data at training time.', reader_params['clean_fn'])
        reader_params['mxlen'] = self.vectorizers[self.primary_key].mxlen
        if self.config_params['model'].get('gpus', 1) > 1:
            reader_params['truncate'] = True
        return baseline.reader.create_reader(self.task_name(), self.vectorizers, self.config_params['preproc'].get('trim', False), **reader_params)

    def _create_backend(self, **kwargs):
        backend = Backend(self.config_params.get('backend', 'tf'))

        if backend.name == 'pytorch':
            self.config_params.get('preproc', {})['trim'] = True

        elif backend.name == 'dy':
            self.config_params.get('preproc', {})['trim'] = True
            import _dynet
            dy_params = _dynet.DynetParams()
            dy_params.from_args()
            dy_params.set_requested_gpus(1)
            if 'autobatchsz' in self.config_params['train']:
                dy_params.set_autobatch(True)
                batched = False
            else:
                batched = True
            dy_params.init()
            backend.params = {'pc': _dynet.ParameterCollection(), 'batched': batched}

        backend.load(self.task_name())
        return backend

    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']]
        # TODO: make this optional
        if 'test_file' in self.dataset:
            vocab_sources.append(self.dataset['test_file'])
        vocabs = self.reader.build_vocab(vocab_sources,
                                         min_f=Task._get_min_f(self.config_params),
                                         vocab_file=self.dataset.get('vocab_file'))
        self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features'])
        baseline.save_vocabs(self.get_basedir(), self.feat2index)

    def _load_dataset(self):
        read = self.config_params['reader'] if 'reader' in self.config_params else self.config_params['loader']
        tgt_key = read.get('tgt_key', self.primary_key)
        bsz, vbsz, tbsz = Task._get_batchsz(self.config_params)
        self.train_data = self.reader.load(
            self.dataset['train_file'],
            self.feat2index,
            bsz,
            tgt_key=tgt_key
        )
        self.valid_data = self.reader.load(
            self.dataset['valid_file'],
            self.feat2index,
            vbsz,
            tgt_key=tgt_key
        )
        self.test_data = None
        if 'test_file' in self.dataset:
            self.test_data = self.reader.load(
                self.dataset['test_file'],
                self.feat2index,
                1,
                tgt_key=tgt_key
            )

    def _create_model(self):

        model = self.config_params['model']
        unif = self.config_params.get('unif', 0.1)
        model['unif'] = model.get('unif', unif)
        model['batchsz'] = self.config_params['batchsz']
        model['tgt_key'] = self.config_params.get('reader',
                                                  self.config_params.get('loader', {})).get('tgt_key', self.primary_key)
        model['src_keys'] = listify(self.config_params.get('reader', list(self.config_params.get('loader', {}).get('src_keys', self.embeddings.keys()))))
        if self.backend.params is not None:
            for k, v in self.backend.params.items():
                model[k] = v
        return baseline.model.create_lang_model(self.embeddings, **model)

    def train(self, checkpoint=None):
        self._load_dataset()
        if self.config_params['train'].get('lr_scheduler_type', None) == 'zaremba':
            first_range = int(self.config_params['train']['start_decay_epoch'] * self.train_data.steps)
            self.config_params['train']['bounds'] = [first_range] + list(
                np.arange(
                    self.config_params['train']['start_decay_epoch'] + 1,
                    self.config_params['train']['epochs'] + 1,
                    dtype=np.int32
                ) * self.train_data.steps
            )
        baseline.save_vectorizers(self.get_basedir(), self.vectorizers)
        model = self._create_model()
        train_params = self.config_params['train']
        train_params['checkpoint'] = checkpoint
        metrics = baseline.train.fit(model, self.train_data, self.valid_data, self.test_data, **train_params)
        baseline.zip_files(self.get_basedir())
        self._close_reporting_hooks()
        return model, metrics

    @staticmethod
    def _num_steps_per_epoch(num_examples, nctx, batchsz):
        rest = num_examples // batchsz
        return rest // nctx
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--dataset_key",
                        type=str,
                        default='wikitext-2',
                        help="key from DATASETS global")
    parser.add_argument("--train_file",
                        type=str,
                        help='Optional file path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        help='Optional file path to use for valid file')
    parser.add_argument("--dataset_cache",
                        type=str,
                        default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--cache_features", type=str2bool, default=True)
    parser.add_argument("--d_model",
                        type=int,
                        default=410,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2100, help="FFN dimension")
    parser.add_argument("--num_heads",
                        type=int,
                        default=10,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch Size")
    parser.add_argument("--tokens",
                        choices=["words", "chars", "subwords"],
                        default="subwords",
                        help="What tokens to use")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=0.25,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.0,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=20,
                        help="Num training epochs")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=1000,
                        help="Num warmup steps")
    parser.add_argument("--eval_every",
                        type=int,
                        default=-1,
                        help="Evaluate every X steps (-1 => end of epoch)")

    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )
    parser.add_argument("--chars_per_word",
                        type=int,
                        default=40,
                        help="How many max characters per word")
    parser.add_argument(
        "--accum_grad_steps",
        type=int,
        default=1,
        help="Create effective batch size by accumulating grads without updates"
    )
    args = parser.parse_args()

    if args.train_file and not args.valid_file:
        logger.error(
            "If you provide a train_file, you must provide a valid_file")
        return

    if not args.train_file and args.valid_file:
        logger.error(
            "If you provide a valid_file, you must also provide a train_file")
        return

    if args.basedir is None:
        args.basedir = 'transformer-{}-{}-{}'.format(args.dataset_key,
                                                     args.tokens, os.getpid())
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.info("Cache directory [%s]", args.dataset_cache)

    args.distributed = args.distributed or int(os.environ.get("WORLD_SIZE",
                                                              1)) > 1

    if args.distributed:
        if args.local_rank == -1:
            # https://github.com/kubeflow/pytorch-operator/issues/128
            # https://github.com/pytorch/examples/blob/master/imagenet/main.py
            logger.info("Setting local rank to RANK env variable")
            args.local_rank = int(os.environ['RANK'])
        logger.warning("Local rank (%d)", args.local_rank)
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    if args.train_file:
        dataset = {
            'train_file': args.train_file,
            'valid_file': args.valid_file
        }
    else:
        dataset = DataDownloader(DATASETS[args.dataset_key],
                                 args.dataset_cache).download()
    reader = create_reader(args.tokens, args.nctx, args.chars_per_word)

    preproc_data = load_embed_and_vocab(args.tokens, reader, dataset,
                                        args.dataset_key, args.d_model,
                                        args.cache_features)

    vocabs = preproc_data['vocabs']
    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs['x'], os.path.join(args.basedir, 'vocabs.json'))
    embeddings = preproc_data['embeddings']
    valid_num_words = preproc_data['valid_num_words']
    tgt_key = preproc_data['tgt_key']
    logger.info("Loaded embeddings")

    train_set = load_data(args.tokens, reader, dataset, 'train_file', vocabs,
                          args.cache_features)
    valid_set = load_data(args.tokens, reader, dataset, 'valid_file', vocabs,
                          args.cache_features)
    logger.info("valid. tokens [%s], valid. words [%s]",
                valid_set.tensors[-1].numel(), valid_num_words)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set) if args.distributed else None
    train_loader = DataLoader(train_set,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              shuffle=(not args.distributed))

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_set) if args.distributed else None
    valid_loader = DataLoader(valid_set,
                              sampler=valid_sampler,
                              batch_size=args.batch_size,
                              shuffle=False)

    logger.info("Loaded datasets")

    model = TransformerLanguageModel.create(
        embeddings,
        hsz=args.d_model,
        d_ff=args.d_ff,
        tie_weights=(args.tokens != 'chars'),
        dropout=args.dropout,
        gpu=False,
        num_heads=args.num_heads,
        layers=args.num_layers,
        src_keys=['x'],
        tgt_key=tgt_key)
    model.to(args.device)
    train_loss = model.create_loss()
    train_loss.to(args.device)

    logger.info("Loaded model and loss")

    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     weight_decay=args.weight_decay)
    logger.info("Model has %s parameters",
                sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Prepare model for distributed training if needed
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        logger.info("Model located on %d", args.local_rank)

    def update(engine, batch):
        model.train()
        x, y = batch
        inputs = {'x': x.to(args.device)}
        labels = y.to(args.device).transpose(0, 1).contiguous()
        logits = model(inputs, None)[0].transpose(0, 1).contiguous()
        shift_logits = logits[:-1]
        shift_labels = labels[1:]
        loss = train_loss(shift_logits, shift_labels)
        loss = loss / args.accum_grad_steps
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        if engine.state.iteration % args.accum_grad_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    def inference(_, batch):
        model.eval()
        with torch.no_grad():
            x, y = batch
            inputs = {'x': x.to(args.device)}
            labels = y.to(args.device).transpose(0, 1).contiguous()
            logits = model(inputs, None)[0].transpose(0, 1).contiguous()
            shift_logits = logits[:-1]
            shift_labels = labels[1:]
            return shift_logits.view(-1,
                                     logits.size(-1)), shift_labels.view(-1)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate at the end of each epoch and every 'eval_every' iterations if needed
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(valid_loader))
    if args.eval_every > 0:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            lambda engine: evaluator.run(valid_loader)
            if engine.state.iteration % args.eval_every == 0 else None)
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0,
                                             len(train_loader) * args.epochs)
    scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr,
                                                args.warmup_steps)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))}
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })

    if args.tokens == 'subwords':
        # If we compute subwords, need to renormalize for num words
        metrics["average_subword_ppl"] = MetricsLambda(math.exp,
                                                       metrics["average_nll"])
        metrics["average_word_ppl"] = MetricsLambda(
            lambda x: math.exp(x * valid_set.tensors[-1].numel() /
                               valid_num_words), metrics["average_nll"])
    else:
        metrics["average_word_ppl"] = MetricsLambda(math.exp,
                                                    metrics["average_nll"])

    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    if args.local_rank < 1:
        RunningAverage(output_transform=lambda x: x).attach(
            trainer, "valid_loss")
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, lambda _: print(
                "Epoch[{}] Training Loss: {:.2f}, Perplexity {:.2f}".format(
                    trainer.state.epoch, trainer.state.output,
                    np.exp(trainer.state.output))))
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: print("Validation: %s" % pformat(
                evaluator.state.metrics)))
        checkpoint_handler = ModelCheckpoint(args.basedir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3,
                                             create_dir=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                  {'mymodel': getattr(model, 'module', model)})
    trainer.run(train_loader, max_epochs=args.epochs)