Exemplo n.º 1
0
def main(args):
    # preprocessor = Preprocessor(args.embedding_path)
    # train, valid = preprocessor.get_train_valid_dataset(args.data_path)

    with open(args.pickle_path, 'rb') as f:
        data = pickle.load(f)
        preprocessor = data['preprocessor']
        train, valid = data['train'], data['valid']

    if args.arch == 'NSMv1':
        from torch_solver import TorchSolver
        solver = TorchSolver(preprocessor.get_word_dim(),
                             args.dim_hidden,
                             valid=valid,
                             batch_size=args.batch_size,
                             n_epochs=args.n_epochs,
                             learning_rate=args.learning_rate,
                             device=args.device,
                             decoder_use_state=args.decoder_use_state)

    # load model
    if args.load is not None:
        solver.load(args.load)

    if not args.five_fold:
        model_checkpoint = ModelCheckpoint(args.model_path, 'loss', 1, 'all')
        metrics_logger = MetricsLogger(args.log_path)
        solver.fit_dataset(train, [model_checkpoint, metrics_logger])
    else:
        from utils import MWPDataset
        problems = train._problems
        fold_indices = [int(len(problems) * 0.2) * i for i in range(6)]
        for fold in range(5):
            train = []
            for j in range(5):
                if j != fold:
                    start = fold_indices[j]
                    end = fold_indices[j + 1]
                    train += problems[start:end]

            transform = \
                PermuteStackOps(args.revert_prob, args.transpose_prob) \
                if args.permute else None
            train = MWPDataset(train, preprocessor.indices_to_embeddings)
            logging.info('Start training fold {}'.format(fold))
            model_checkpoint = ModelCheckpoint(
                '{}.fold{}'.format(args.model_path, fold), 'loss', 1, 'all')
            metrics_logger = MetricsLogger('{}.fold{}'.format(
                args.log_path, fold))
            solver = TorchSolver(preprocessor.get_word_dim(),
                                 args.dim_hidden,
                                 valid=valid,
                                 batch_size=args.batch_size,
                                 n_epochs=args.n_epochs,
                                 learning_rate=args.learning_rate,
                                 device=args.device)
            solver.fit_dataset(train, [model_checkpoint, metrics_logger])
Exemplo n.º 2
0
def main(args):
    config_path = os.path.join(args.model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    logging.info('loading embedding...')
    with open(config['model']['embedding'], 'rb') as f:
        embedding = pickle.load(f)
        config['model']['embedding'] = embedding.vectors

    logging.info('loading valid data...')
    with open(config['model']['valid'], 'rb') as f:
        config['model']['valid'] = pickle.load(f)

    logging.info('loading train data...')
    with open(config['train'], 'rb') as f:
        train = pickle.load(f)

    predictor = Predictor(arch=config['arch'],
                          device=args.device,
                          metrics=[Recall()],
                          **config['model'])

    if args.load is not None:
        predictor.load(args.load)

    model_checkpoint = ModelCheckpoint(
        os.path.join(args.model_dir, 'model.pkl'), **config['callbacks'])
    metrics_logger = MetricsLogger(os.path.join(args.model_dir, 'log.json'))

    logging.info('start training!')
    predictor.fit_dataset(train, train.collate_fn, model_checkpoint,
                          metrics_logger)
Exemplo n.º 3
0
def main(args):
    config_path = os.path.join(args.model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    logging.info('loading valid data...')
    with open(config['model']['valid'], 'rb') as f:
        config['model']['valid'] = pickle.load(f)

    logging.info('loading character vocabulary...')
    with open(config['charmap'], 'rb') as f:
        charmap = pickle.load(f)
    logging.info('loading word vocabulary...')
    with open(config['wordmap'], 'rb') as f:
        wordmap = pickle.load(f)
    config['model']['num_embeddings'] = len(charmap)
    config['model']['padding_idx'] = charmap['<PAD>']
    config['model']['vocab_size'] = len(wordmap)

    predictor = Predictor(arch=config['arch'],
                          device=args.device,
                          metrics=[Perplexity()],
                          **config['model'])

    if args.load is not None:
        predictor.load(args.load)

    model_checkpoint = ModelCheckpoint(
        os.path.join(args.model_dir, 'model.pkl'), **config['callbacks'])
    metrics_logger = MetricsLogger(os.path.join(args.model_dir, 'log.json'))

    logging.info('start training!')
    predictor.fit_dataset(config['train'], model_checkpoint, metrics_logger)
Exemplo n.º 4
0
def main(args):
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    model_dir, exp_dir = os.path.split(
        args.output_dir[:-1]) if args.output_dir[-1] == '/' else os.path.split(
            args.output_dir)
    config_path = os.path.join(model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)
    logging.info(f'Save config file to {args.output_dir}.')
    with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=4)

    logging.info('Loading training data...')
    with open(config['train_path'], 'rb') as f:
        train = pickle.load(f)
    train.context_padded_len = config['train_context_padded_len']
    train.option_padded_len = config['train_option_padded_len']
    train.n_negative = config['train_n_negative']

    logging.info('Loading validation data...')
    with open(config['valid_path'], 'rb') as f:
        valid = pickle.load(f)
        config['model_parameters']['valid'] = valid
    valid.context_padded_len = config['valid_context_padded_len']
    valid.option_padded_len = config['valid_option_padded_len']

    logging.info('Loading preproprecessed word embedding...')
    with open(config['embedding_path'], 'rb') as f:
        embedding = pickle.load(f)
        config['model_parameters']['embedding'] = embedding

    metric = Recall(at=10)
    predictor = Predictor(training=True,
                          metrics=[metric],
                          device=args.device,
                          **config['model_parameters'])
    model_checkpoint = ModelCheckpoint(filepath=os.path.join(
        args.output_dir, 'model'),
                                       monitor=metric.name,
                                       mode='max',
                                       all_saved=False)
    metrics_logger = MetricsLogger(
        log_dest=os.path.join(args.output_dir, 'log.json'))

    if args.load_dir is not None:
        predictor.load(args.load_dir)

    logging.info('Start training.')
    start = time.time()
    predictor.fit_dataset(data=train,
                          collate_fn=train.collate_fn,
                          callbacks=[model_checkpoint, metrics_logger],
                          output_dir=args.output_dir)
    end = time.time()
    total = end - start
    hrs, mins, secs = int(total // 3600), int(
        (total % 3600) // 60), int(total % 60)
    logging.info('End training.')
    logging.info(f'Total time: {hrs}hrs {mins}mins {secs}secs.')
Exemplo n.º 5
0
def main(args):
    print(args)
    config_path = os.path.join(args.model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    logging.info('loading embedding...')
    with open(config['model_parameters']['embedding'], 'rb') as f:
        embedding = pickle.load(f)
        config['model_parameters']['embedding'] = embedding.vectors

    logging.info('loading valid data...')
    with open(config['model_parameters']['valid'], 'rb') as f:
        config['model_parameters']['valid'] = pickle.load(f)

    logging.info('loading train data...')
    with open(config['train'], 'rb') as f:
        train = pickle.load(f)

    if config['arch'] == 'ExampleNet':
        from example_predictor import ExamplePredictor
        PredictorClass = ExamplePredictor
    predictor = PredictorClass(metrics=[Recall(1), Recall(10)],
                               **config['model_parameters'])

    if args.load is not None:
        predictor.load(args.load)

    model_checkpoint = ModelCheckpoint(
        os.path.join(args.model_dir, 'model.pkl'), 'loss', 1, 'all')
    metrics_logger = MetricsLogger(os.path.join(args.model_dir, 'log.json'))

    logging.info('start training!')
    predictor.fit_dataset(train, train.collate_fn,
                          [model_checkpoint, metrics_logger])
Exemplo n.º 6
0
def main(args):
    config_path = os.path.join(args.model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    logging.info('loading embedding...')
    with open(config['model_parameters']['embedding'], 'rb') as f:
        embedding = pickle.load(f)
        config['model_parameters']['embedding'] = embedding.vectors

    logging.info('loading valid data...')
    with open(config['model_parameters']['valid'], 'rb') as f:
        valid = pickle.load(f)
    logging.info('loading train data...')
    with open(config['train'], 'rb') as f:
        train = pickle.load(f)

    PredictorClass = BestPredictor

    predictor = PredictorClass(metrics=[Recall(at=1),
                                        Recall(at=5)],
                               **config['model_parameters'])

    model_checkpoint = ModelCheckpoint(
        os.path.join(args.model_dir, 'model.pkl'), 'loss', 1, 'all')
    metrics_logger = MetricsLogger(os.path.join(args.model_dir, 'log.json'))
    logging.info('start training!')
    predictor.fit_dataset(train, valid, train.collate_fn,
                          [model_checkpoint, metrics_logger])
Exemplo n.º 7
0
def main(args):

    config_path = os.path.join(args.model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    logging.info(f"Using cuda device: {config['cuda_ids']}")
    os.environ['CUDA_VISIBLE_DEVICES'] = config['cuda_ids']

    logging.info('loading embedding...')
    with open(config['model_parameters']['embedding'], 'rb') as f:
        embedding = pickle.load(f)
        config['model_parameters']['embedding'] = embedding.vectors

    logging.info('loading valid data...')
    with open(config['model_parameters']['valid'], 'rb') as f:
        config['model_parameters']['valid'] = pickle.load(f)

    logging.info('loading train data...')
    with open(config['train'], 'rb') as f:
        train = pickle.load(f)

    if config['arch'] == 'ExampleNet':
        from example_predictor import ExamplePredictor
        PredictorClass = ExamplePredictor
    elif config['arch'] == 'RnnBaselineNet':
        from rnnbaseline_predictor import RnnBaselinePredictor
        PredictorClass = RnnBaselinePredictor
    elif config['arch'] == 'RnnAttentionNet':
        from rnnattention_predictor import RnnAttentionPredictor
        PredictorClass = RnnAttentionPredictor
    elif config['arch'] == 'RnnTransformerNet':
        from rnntransformer_predictor import RnnTransformerPredictor
        PredictorClass = RnnTransformerPredictor

    predictor = PredictorClass(
        metrics=[Recall(), Recall(1), Recall(5)],
        **config['model_parameters']
    )

    if args.load is not None:
        predictor.load(args.load)

    model_checkpoint = ModelCheckpoint(
        os.path.join(args.model_dir, 'model.pkl'),
        'loss', 1, 'all'
    )
    metrics_logger = MetricsLogger(
        os.path.join(args.model_dir, 'log.json')
    )

    logging.info('start training!')
    predictor.fit_dataset(train,
                          train.collate_fn,
                          [model_checkpoint, metrics_logger])
Exemplo n.º 8
0
def main(args):
    config_path = os.path.join(args.model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    logging.info('loading embedding...')
    with open(config['model_parameters']['embedding'], 'rb') as f:
        embedding = pickle.load(f)
        config['model_parameters']['embedding'] = embedding.vectors

    logging.info('loading valid data...')
    with open(config['model_parameters']['valid'], 'rb') as f:
        config['model_parameters']['valid'] = pickle.load(f)

    logging.info('loading train data...')
    with open(config['train'], 'rb') as f:
        train = pickle.load(f)
    #print(train) #dataset.DialogDataset object
    #word2index = embedding.word_dict
    #index2word =  {v: k for k, v in word2index.items()}

    if config['arch'] == 'ExampleNet':
        #from example_predictor import ExamplePredictor
        #from rnn_predictor import RNNPredictor
        #from best_predictor import BestRNNAttPredictor
        from rnnatt_predictor import RNNAttPredictor
        #PredictorClass = ExamplePredictor
        #PredictorClass = RNNPredictor
        PredictorClass = RNNAttPredictor
        #PredictorClass = BestRNNAttPredictor

    #print("config['model_parameters']: ", config['model_parameters']) #it's a dict; {'valid': dataset.DialogDataset object, 'embedding':a big tensor}
    #print("**config['model_parameters']: ", **config['model_parameters'])
    predictor = PredictorClass(metrics=[Recall()],
                               **config['model_parameters'])
    # **dict : https://stackoverflow.com/questions/21809112/what-does-tuple-and-dict-means-in-python
    #input()
    if args.load is not None:
        predictor.load(args.load)

    model_checkpoint = ModelCheckpoint(
        os.path.join(args.model_dir,
                     'model_rnnatt_6_negative_samples_0324.pkl'), 'loss', 1,
        'all')
    metrics_logger = MetricsLogger(
        os.path.join(args.model_dir,
                     'log_rnnatt_6_neagtive_samples_0324.json'))

    logging.info('start training!')
    predictor.fit_dataset(train, train.collate_fn,
                          [model_checkpoint, metrics_logger])
Exemplo n.º 9
0
def main(args, config_path):
    logging.info('Loading configuration file from {}'.format(config_path))
    with open(config_path) as f:
        config = json.load(f)

    embedding_pkl_path = os.path.join(args.model_dir, config["embedding_pkl_path"])
    train_pkl_path = os.path.join(args.model_dir, config["train_pkl_path"])
    val_pkl_path = os.path.join(args.model_dir, config["val_pkl_path"])
    labelEncoder_path = os.path.join(args.model_dir, config["labelEncoder_path"])
    with open(embedding_pkl_path, "rb") as f:
        config["model_parameters"]["embedding"] = pickle.load(f).vectors
        logging.info( "Load embedding from {}".format(embedding_pkl_path))
    with open(train_pkl_path, "rb") as f:
        train = pickle.load(f)
        logging.info( "Load train from {}".format(train_pkl_path))
    with open(val_pkl_path, "rb") as f:
        config["model_parameters"]["valid"] = pickle.load(f)
        logging.info( "Load val from {}".format(val_pkl_path))
    with open(labelEncoder_path, "rb") as f:
        config["model_parameters"]["labelEncoder"] = pickle.load(f)
        logging.info( "Load labelEncoder from {}".format(labelEncoder_path))


    predictor = Predictor(metric=Metric(), **config["model_parameters"])

    if args.load is not None:
        predictor.load(args.load)

    model_checkpoint = ModelCheckpoint(
            os.path.join(args.model_dir, 'model.pkl'),
            'loss', 1, 'all')
    metrics_logger = MetricsLogger(
            os.path.join(args.model_dir, 'log.json'))

    tensorboard = Tensorboard(config["tensorboard"])

    logging.info("start training!")
    predictor.fit_dataset(train, train.collate_fn, [model_checkpoint, metrics_logger, tensorboard])
Exemplo n.º 10
0
def main(args):
    config_path = os.path.join(args.model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    logging.info('loading word dictionary...')
    with open(config['words_dict'], 'rb') as f:
        words_dict = pickle.load(f)

    logging.info('loading train data...')
    with open(config['train'], 'rb') as f:
        train = pickle.load(f)

    logging.info('loading validation data...')
    with open(config['model_parameters']['valid'], 'rb') as f:
        valid = pickle.load(f)
    config['model_parameters']['valid'] = valid

    if args.lr_finder:
        pass
    else:
        if config['arch'] == 'Predictor':
            from predictor import Predictor
            PredictorClass = Predictor

        predictor = PredictorClass(metrics=[Accuracy()],
                                   word_dict=words_dict,
                                   **config['model_parameters'])

        metrics_logger = MetricsLogger(os.path.join(args.model_dir,
                                                    'log.json'))

        if args.load is not None:
            predictor.load(args.load)
            try:
                metrics_logger.load(int(args.load.split('.')[-1]))
            except:
                metrics_logger.load(448)

        model_checkpoint = ModelCheckpoint(
            os.path.join(args.model_dir, 'model.pkl'), 'Accuracy', 1, 'max')

        logging.info('start training!')
        predictor.fit_dataset(train, train.collate_fn,
                              [model_checkpoint, metrics_logger])
Exemplo n.º 11
0
def main(args):
    config_path = os.path.join(args.model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    # logging.info('loading embedding...')
    # with open(config['model_parameters']['embeddings'], 'rb') as f:
    #     embeddings = pickle.load(f)
    #     config['model_parameters']['embeddings'] = embeddings

    logging.info('loading dev data...')
    with open(config['model_parameters']['valid'], 'rb') as f:
        config['model_parameters']['valid'] = pickle.load(f)

    logging.info('loading train data...')
    with open(config['train'], 'rb') as f:
        train = pickle.load(f)

    if 'train_max_len' in config:
        train.data = list(
            filter(lambda s: s['context_len'] < config['train_max_len'],
                   train.data)
        )

    if config['arch'] == 'BiDAF':
        from bidaf_predictor import BiDAFPredictor
        PredictorClass = BiDAFPredictor
    elif config['arch'] == 'QANet':
        from qanet_predictor import QANetPredictor
        PredictorClass = QANetPredictor
    elif config['arch'] == 'BERT':
        from bert_predictor import BERTPredictor
        PredictorClass = BERTPredictor

    if config['arch'] != 'XLNet':
        predictor = PredictorClass(
            metrics=[SimpleEM(), QuACF1()],
            **config['model_parameters']
        )
    else:
        from bert_predictor import BERTPredictor
        predictor = BERTPredictor(
            metrics=[SimpleEM(), QuACF1()],
            ctx_emb='xlnet',
            **config['model_parameters']
        )


    if args.load is not None:
        predictor.load(args.load)

    model_checkpoint = ModelCheckpoint(
        os.path.join(args.model_dir, 'model.pkl'),
        'loss', 1, 'all'
    )
    metrics_logger = MetricsLogger(
        os.path.join(args.model_dir, 'log.json')
    )

    logging.info('start training!')
    predictor.fit_dataset(train,
                          train.collate_fn,
                          [model_checkpoint, metrics_logger])
Exemplo n.º 12
0
    def __init__(self, args):
        logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s',
                            level=logging.INFO,
                            datefmt='%Y-%m-%d %H:%M:%S')

        logging.info('Initiating task: %s' % args.taskname)

        self.config = config(args)
        if not all([os.path.isfile(i) for i in self.config.pickle_files]):
            logging.info('Preprocesing data.....')
            if args.pick == 'neg4':
                build_processed_data(self.config.datadir,
                                     self.config.pickledir,
                                     neg_num=4)
            elif args.pick == 'last':
                build_processed_data(self.config.datadir,
                                     self.config.pickledir,
                                     last=True)
            elif args.pick == 'difemb':
                build_processed_data(self.config.datadir,
                                     self.config.pickledir,
                                     difemb=True)

            else:
                build_processed_data(self.config.datadir,
                                     self.config.pickledir)

        else:
            logging.info('Preprocesing already done.')

        with open(os.path.join(self.config.pickledir, 'embedding.pkl'),
                  'rb') as f:
            embedding = pickle.load(f)
            embedding = embedding.vectors
        self.embedding_dim = embedding.size(1)
        self.embedding = torch.nn.Embedding(embedding.size(0),
                                            embedding.size(1))
        self.embedding.weight = torch.nn.Parameter(embedding)

        self.Modelfunc = {
            'lin': LinearNet,
            'rnn': RnnNet,
            'att': RnnAttentionNet,
            'best': BestNet,
            'gru': GruNet,
            'last': LastNet,
        }

        if os.path.exists(self.config.outputdir):
            if args.resume == False:
                logging.info(
                    'Warning, task already exists, add --resume True, exiting')
                sys.exit(0)
            else:
                logging.info('Resuming....')
                with open(self.config.modeltype_path, 'r') as f:
                    resume_type = f.read()
                self.model = self.Modelfunc[resume_type]
                logging.info('model type is %s, model to be constructed' %
                             resume_type)
        else:
            os.mkdir(self.config.outputdir)
            with open(self.config.modeltype_path, 'w') as f:
                f.write(args.modeltype)
            self.model = self.Modelfunc[args.modeltype](self.embedding_dim)
            logging.info('model type is %s, model created' % args.modeltype)

        model_checkpoint = ModelCheckpoint(self.config.modelpath, 'loss', 1,
                                           'all')
        metrics_logger = MetricsLogger(self.config.logpath)

        if args.resume:
            self.config.start_epoch = metrics_logger.load()
            if args.resume_epoch != -1:
                self.config.resumepath = self.config.modelpath + '.%d' % self.config.resume_epoch
            else:
                self.config.resumepath = self.config.modelpath + '.%d' % (
                    self.config.start_epoch - 1)

            self.model = self.model(self.embedding_dim)
            self.model.load_state_dict(torch.load(self.config.resumepath))
            logging.info('config loaded, model constructed and loaded')

        print(self.model)

        logging.info('loading dataloaders')
        self.trainloader, self.testloader, self.validloader = make_dataloader(
            self.config.pickledir)

        self.metrics = [Recall()]
        self.callbacks = [model_checkpoint, metrics_logger]

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device == 'cuda':
            self.model.to(self.device)
            self.embedding.to(self.device)

        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.config.lr)
Exemplo n.º 13
0
def main(args):
    config_path = os.path.join(args.model_dir, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    logging.info('loading embedding...')
    with open(config['model_parameters']['embedding'], 'rb') as f:
        embedding = pickle.load(f)
        config['model_parameters']['embedding'] = embedding.vectors

    logging.info('loading valid data...')
    with open(config['model_parameters']['valid'], 'rb') as f:
        config['model_parameters']['valid'] = pickle.load(f)
        #valid = pickle.load(f)

    logging.info('loading train data...')
    with open(config['train'], 'rb') as f:
        train = pickle.load(f)

    if config['arch'] == 'ExampleNet':
        #from modules import ExampleNet
        from predictors import ExamplePredictor
        PredictorClass = ExamplePredictor
        predictor = PredictorClass(
            metrics=[Recall(1), Recall(10)],
            batch_size=128,
            max_epochs=1000000,
            dropout_rate=0.2,
            learning_rate=1e-3,
            grad_accumulate_steps=1,
            loss='BCELoss',  #BCELoss, FocalLoss
            margin=0,
            threshold=None,
            similarity='MLP',  #inner_product, Cosine, MLP
            device=args.device,
            **config['model_parameters'])
    elif config['arch'] == 'RnnNet':
        from predictors import RnnPredictor
        PredictorClass = RnnPredictor
        predictor = PredictorClass(
            metrics=[Recall(1), Recall(10)],
            batch_size=512,
            max_epochs=1000000,
            dropout_rate=0.2,
            learning_rate=1e-3,
            grad_accumulate_steps=1,
            loss='FocalLoss',  #BCELoss, FocalLoss
            margin=0,
            threshold=None,
            similarity='Cosine',  #inner_product, Cosine, MLP
            device=args.device,
            **config['model_parameters'])

    elif config['arch'] == 'RnnAttentionNet':
        from predictors import RnnAttentionPredictor
        PredictorClass = RnnAttentionPredictor
        predictor = PredictorClass(
            metrics=[Recall(1), Recall(10)],
            batch_size=32,
            max_epochs=1000000,
            dropout_rate=0.2,
            learning_rate=1e-3,
            grad_accumulate_steps=1,
            loss='BCELoss',  #BCELoss, FocalLoss
            margin=0,
            threshold=None,
            similarity='MLP',  #inner_product, Cosine, MLP
            device=args.device,
            **config['model_parameters'])
    else:
        logging.warning('Unknown config["arch"] {}'.format(config['arch']))

    if args.load is not None:
        predictor.load(args.load)

    #def ModelCheckpoint(filepath, monitor='loss', verbose=0, mode='min')
    model_checkpoint = ModelCheckpoint(os.path.join(args.model_dir,
                                                    'model.pkl'),
                                       monitor='Recall@{}'.format(10),
                                       verbose=1,
                                       mode='all')
    metrics_logger = MetricsLogger(os.path.join(args.model_dir, 'log.json'))

    early_stopping = EarlyStopping(os.path.join(args.model_dir, 'model.pkl'),
                                   monitor='Recall@{}'.format(10),
                                   verbose=1,
                                   mode='max',
                                   patience=30)

    logging.info('start training!')
    #print ('train', train)
    predictor.fit_dataset(train, train.collate_fn,
                          [model_checkpoint, metrics_logger, early_stopping])