Beispiel #1
0
def get_model():
    # EnsembleModel INFO
    # 1. Model_A : Normal CNN Model (Filter Size=3)
    # 2. Model_B : Normal CNN Model (Filter Size=5)
    # 3. Model_C : Custom ResNet Model (Filter Size=3)
    # 4. Model_D : Custom ResNet Model (Filter Size=5)
    # 5. Model_E : Custom DenseNet Model (Filter Size=3)
    Model_A = ModelA(image_size=128,
                     in_channel=3,
                     hidden_channels=[32, 64, 128],
                     output_channel=6)
    Model_B = ModelB(image_size=128,
                     in_channel=3,
                     hidden_channels=[32, 64, 128],
                     output_channel=6)
    Model_C = ModelC(image_size=128,
                     in_channel=3,
                     block_channels=[64, 128, 256],
                     output_channel=6)
    Model_D = ModelD(image_size=128,
                     in_channel=3,
                     block_channels=[64, 128, 256],
                     output_channel=6)
    Model_E = ModelE(image_size=128,
                     in_channel=3,
                     hidden_channels=[32, 64, 128],
                     out_channel=6)
    Model = EnsembleModel(Model_A, Model_B, Model_C, Model_D, Model_E)
    return Model
Beispiel #2
0
def get_Model():
    Model_A = LinearModel(in_channel=28 * 28, linear_channel=[256, 64, 32, 10])
    Model_B = CnnModel(image_size=[28, 28],
                       in_channel=1,
                       hidden_channel=[32, 64, 32, 10])
    Model_C = CustomModel(image_size=[28, 28],
                          in_channel=1,
                          block1_channel=[64, 32, 64],
                          block2_channel=[64, 32, 64],
                          out_channel=10)
    Model = EnsembleModel(Model_A, Model_B, Model_C)
    return Model
def test(args):
    if args.load_var:
        test_utterances, test_labels, word_dict = read_data(
            load_var=args.load_var, input_=None, mode='test')
    else:
        test_utterances, test_labels, word_dict = read_data(load_var=args.load_var, \
                input_=os.path.join(constant.data_path, "entangled_{}.json".format(args.mode)), mode='test')

    if args.save_input:
        utils.save_or_read_input(os.path.join(constant.save_input_path, "{}_utterances.pk".format(args.mode)), \
                                    rw='w', input_obj=test_utterances)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "{}_labels.pk".format(args.mode)), \
                                    rw='w', input_obj=test_labels)

    current_time = re.findall('.*model_(.+?)/.*', args.model_path)[0]
    step_cnt = re.findall('.step_(.+?)\.pkl', args.model_path)[0]

    test_dataloader = TrainDataLoader(test_utterances,
                                      test_labels,
                                      word_dict,
                                      name='test',
                                      batch_size=4)

    ensemble_model = EnsembleModel(word_dict,
                                   word_emb=None,
                                   bidirectional=False)
    if torch.cuda.is_available():
        ensemble_model.cuda()

    supervised_trainer = SupervisedTrainer(args,
                                           ensemble_model,
                                           current_time=current_time)

    supervised_trainer.test(test_dataloader,
                            args.model_path,
                            step_cnt=step_cnt)
    # normalized and there are no leakage.
    fc3 = DailyAggQuarterFeatures(columns=DAILY_AGG_COLUMNS,
                                  agg_day_counts=AGG_DAY_COUNTS,
                                  max_back_quarter=MAX_BACK_QUARTER)

    feature = FeatureMerger(fc1, fc2, on='ticker')
    feature = FeatureMerger(feature, fc3, on=['ticker', 'date'])

    target = QuarterlyTarget(col='marketcap', quarter_shift=0)

    base_models = [
        LogExpModel(lgbm.sklearn.LGBMRegressor()),
        LogExpModel(ctb.CatBoostRegressor(verbose=False))
    ]

    ensemble = EnsembleModel(base_models=base_models,
                             bagging_fraction=BAGGING_FRACTION,
                             model_cnt=MODEL_CNT)

    model = GroupedOOFModel(ensemble, group_column='ticker', fold_cnt=FOLD_CNT)

    pipeline = BasePipeline(feature=feature,
                            target=target,
                            model=model,
                            metric=median_absolute_relative_error,
                            out_name=OUT_NAME)

    result = pipeline.fit(data_loader, ticker_list)
    print(result)
    pipeline.export_core(SAVE_PATH)
Beispiel #5
0
def test(args):
    """Run testing with the given args.

    The function consists of the following steps:
        1. Get model for evaluation.
        2. Get task sequence and class weights.
        3. Get data eval loaders and evaluator.
        4. Evaluate and save model performance (metrics and curves).
    """

    model_args = args.model_args
    logger_args = args.logger_args
    data_args = args.data_args
    transform_args = args.transform_args

    # Get model
    if args.use_multi_model:
        model = load_multi_model(args.multi, model_args, data_args,
                                 args.gpu_ids)
    elif model_args.use_csv_probs:
        model = CSVReaderModel(model_args.ckpt_path,
                               TASK_SEQUENCES[data_args.task_sequence])
        ckpt_info = {'epoch': 0}
    elif args.config_path is not None:
        task2models, aggregation_fn = get_config(args.config_path)
        model = EnsembleModel(task2models, aggregation_fn, args.gpu_ids,
                              model_args, data_args)
        ckpt_info = {'epoch': 0}
    elif model_args.ckpt_paths:
        model, ckpt_info = ModelSaver.load_ensemble(model_args.ckpt_paths,
                                                    args.gpu_ids, model_args,
                                                    data_args)
    else:
        model_args.pretrained = False
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path,
                                                 args.gpu_ids, model_args,
                                                 data_args)

    model = model.to(args.device)
    model.eval()

    # Get the task sequence that the model outputs.
    # Newer models have an attribute called 'task_sequence'.
    # For older models we need to specify what
    # task sequence was used.
    # if hasattr(model.module, 'task_sequence'):
    #     task_sequence = model.module.task_sequence
    # else:
    #     task_sequence = TASK_SEQUENCES[data_args.task_sequence]
    #     print(f'WARNING: assuming that the models task sequence is \n {task_sequence}')
    task_sequence = TASK_SEQUENCES[data_args.task_sequence]

    cxr_frac = {
        'pocus': data_args.eval_pocus,
        'hocus': data_args.eval_hocus,
        'pulm': data_args.eval_pulm
    }
    # Get train loader in order to get the class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path
    loader = get_loader(data_args,
                        args.transform_args,
                        train_csv_name,
                        task_sequence,
                        su_frac=1 if data_args.eval_su else 0,
                        nih_frac=1 if data_args.eval_nih else 0,
                        cxr_frac=cxr_frac,
                        tcga_frac=1 if data_args.eval_tcga else 0,
                        batch_size=args.batch_size,
                        covar_list=model_args.covar_list,
                        fold_num=data_args.fold_num)
    class_weights = loader.dataset.class_weights

    # Get eval loaders and radiologist performance
    eval_loader = get_eval_loaders(
        data_args,
        transform_args,
        task_sequence,
        args.batch_size,
        frontal_lateral=model_args.frontal_lateral,
        return_info_dict=model_args.use_csv_probs or logger_args.save_cams,
        covar_list=model_args.covar_list,
        fold_num=data_args.fold_num)[-1]  # Evaluate only on valid

    rad_perf = pd.read_csv(
        data_args.su_rad_perf_path
    ) if data_args.su_rad_perf_path is not None else None

    if data_args.split != 'valid':
        eval_loader = get_loader(data_args,
                                 args.transform_args,
                                 data_args.split,
                                 task_sequence,
                                 su_frac=1 if data_args.eval_su else 0,
                                 nih_frac=1 if data_args.eval_nih else 0,
                                 cxr_frac=cxr_frac,
                                 tcga_frac=1 if data_args.eval_tcga else 0,
                                 batch_size=args.batch_size,
                                 covar_list=model_args.covar_list,
                                 fold_num=data_args.fold_num)

    results_dir = os.path.join(logger_args.results_dir, data_args.split)
    os.makedirs(results_dir, exist_ok=True)
    write_model_paths(results_dir, model_args.ckpt_path, model_args.ckpt_paths)

    visuals_dir = Path(results_dir) / 'visuals'
    if args.config_path is None and not model_args.ckpt_paths:
        # Get evaluator

        eval_args = {}
        eval_args['num_visuals'] = None
        eval_args['iters_per_eval'] = None
        eval_args['has_missing_tasks'] = args.has_tasks_missing
        eval_args['model_uncertainty'] = model_args.model_uncertainty
        eval_args['class_weights'] = class_weights
        eval_args['max_eval'] = None
        eval_args['device'] = args.device
        eval_args['optimizer'] = None
        evaluator = get_evaluator('classification', [eval_loader], None,
                                  eval_args)

        metrics, curves = evaluator.evaluate(model,
                                             args.device,
                                             results_dir=results_dir,
                                             report_probabilities=True)
        # TODO: Generalize the plot function. Remove hard-coded values.
        # plot(curves, metrics, visuals_dir, rad_perf)

        eval_metrics = [
            'AUPRC', 'AUROC', 'log_loss', 'rads_below_ROC', 'rads_below_PR',
            'accuracy'
        ]

        if logger_args.write_results:
            results_path = os.path.join(results_dir, f'scores.csv')
            evaluate_task_sequence = 'competition' if data_args.dataset_name == 'stanford' else data_args.task_sequence
            write_results(data_args.dataset_name, data_args.split,
                          eval_metrics, metrics, results_path,
                          logger_args.name, ckpt_info, evaluate_task_sequence)

    # Save visuals
    if logger_args.save_cams:
        cams_dir = visuals_dir / 'cams'
        save_grad_cams(args,
                       eval_loader,
                       model,
                       cams_dir,
                       only_competition=logger_args.only_competition_cams,
                       only_top_task=False,
                       probabilities_csv=logger_args.probabilities_csv)
Beispiel #6
0
def test(args):
    """Run testing with the given args.

    The function consists of the following steps:
        1. Get model for evaluation.
        2. Get task sequence and class weights.
        3. Get data eval loaders and evaluator.
        4. Evaluate and save model performance (metrics and curves).
    """

    model_args = args.model_args
    logger_args = args.logger_args
    data_args = args.data_args
    transform_args = args.transform_args

    # Get model
    if args.use_multi_model:
        model = load_multi_model(args.multi, model_args, data_args,
                                 args.gpu_ids)
    elif model_args.use_csv_probs:
        model = CSVReaderModel(model_args.ckpt_path,
                               TASK_SEQUENCES[data_args.task_sequence])
        ckpt_info = {'epoch': 0}
    elif args.config_path is not None:
        task2models, aggregation_fn = get_config(args.config_path)
        model = EnsembleModel(task2models, aggregation_fn, args.gpu_ids,
                              model_args, data_args)
        ckpt_info = {'epoch': 0}
    else:
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path,
                                                 args.gpu_ids, model_args,
                                                 data_args)
    model = model.to(args.device)
    model.eval()

    # Get the task sequence that the model outputs.
    # Newer models have an attribute called 'task_sequence'.
    # For older models we need to specify what
    # task sequence was used.
    if hasattr(model.module, 'task_sequence'):
        task_sequence = model.module.task_sequence
    else:
        task_sequence = TASK_SEQUENCES[data_args.task_sequence]
        print(
            f'WARNING: assuming that the models task sequence is \n {task_sequence}'
        )

    # Get train loader in order to get the class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path
    '''
    loader = get_loader(data_args,
                        args.transform_args,
                        train_csv_name,
                        task_sequence,
                        su_frac = 1 if data_args.eval_su else 0,
                        nih_frac = 1 if data_args.eval_nih else 0,
                        pocus_frac = 1 if data_args.eval_pocus else 0,
                        tcga_frac = 1 if data_args.eval_tcga else 0,
                        batch_size=args.batch_size)
    '''
    class_weights = [[0.5], [0.5]]  #loader.dataset.class_weights
    '''
    # Get eval loaders and radiologist performance
    eval_loader = get_eval_loaders(data_args,
                                   transform_args,
                                   task_sequence,
                                   args.batch_size,
                                   frontal_lateral=model_args.frontal_lateral,
                                   return_info_dict=model_args.use_csv_probs or logger_args.save_cams)[-1] # Evaluate only on valid
    '''
    rad_perf = pd.read_csv(
        data_args.su_rad_perf_path
    ) if data_args.su_rad_perf_path is not None else None

    eval_loader = get_loader(
        data_args,
        transform_args,
        data_args.split,
        task_sequence,
        su_frac=1 if data_args.eval_su else 0,
        nih_frac=1 if data_args.eval_nih else 0,
        pocus_frac=1 if data_args.eval_pocus else 0,
        tcga_frac=1 if data_args.eval_tcga else 0,
        tcga_google_frac=1 if data_args.eval_tcga_google else 0,
        tcga_stanford_frac=1 if data_args.eval_tcga_stanford else 0,
        batch_size=args.batch_size,
        normalize=model_args.normalize,
        study_level=not model_args.frontal_lateral,
        frontal_lateral=model_args.frontal_lateral,
        return_info_dict=model_args.use_csv_probs or logger_args.save_cams)

    results_dir = os.path.join(logger_args.results_dir, data_args.split)
    visuals_dir = Path(results_dir) / 'visuals'
    if args.config_path is None:
        # Get evaluator

        eval_args = {}
        eval_args['num_visuals'] = None
        eval_args['iters_per_eval'] = None
        eval_args['has_missing_tasks'] = args.has_tasks_missing
        eval_args['model_uncertainty'] = model_args.model_uncertainty
        eval_args['class_weights'] = class_weights
        eval_args['max_eval'] = None
        eval_args['device'] = args.device
        evaluator = get_evaluator('classification', [eval_loader], None,
                                  eval_args)

        metrics, curves = evaluator.evaluate(model, args.device)
        print(metrics)

        # TODO: Generalize the plot function. Remove hard-coded values.
        # Plot Results
        # print(f"Plotting to {visuals_dir}")
        # TODO: uncomment once we have curves to plot
        # plot(curves, metrics, visuals_dir, rad_perf)

        eval_metrics = [
            'AUPRC', 'AUROC', 'log_loss', 'rads_below_ROC', 'rads_below_PR'
        ]

        if logger_args.write_results:
            results_path = os.path.join(results_dir, f'scores.csv')
            write_results(data_args.dataset_name, data_args.split,
                          eval_metrics, metrics, results_path,
                          logger_args.name, ckpt_info)

    # Save visuals
    if logger_args.save_cams:
        cams_dir = visuals_dir / 'cams'
        save_grad_cams(args,
                       eval_loader,
                       model,
                       cams_dir,
                       only_competition=logger_args.only_competition_cams,
                       only_top_task=False)
Beispiel #7
0
def main():
    logging.basicConfig(level=logging.INFO)
    task_id = str(int(time.time()))
    tmp_model_path = os.path.join('/tmp', '%s.h5' % task_id)

    if True:
        task = {
            'task_id':
            task_id,
            'score_metric':
            'val_rmse',
            'dataset_path':
            'showdown_full',
            'final':
            True,
            'model_config':
            TransferLstmModel.create_cnn(
                tmp_model_path,
                transform_model_config={
                    'model_uri':
                    '/models/snapshots/regression/1480182349/31.h5',
                    'scale': 16,
                    'type': 'regression'
                },
                timesteps=50,
                W_l2=0.001,
                scale=16.,
                input_shape=(120, 320, 3)),
            'training_args': {
                'batch_size': 32,
                'epochs': 100,
            },
        }

    if False:
        task = {
            'task_id':
            task_id,
            'score_metric':
            'loss',
            'dataset_path':
            'shinale_full',
            'final':
            False,
            'model_config':
            RegressionModel.create_resnet_inception_v2(tmp_model_path,
                                                       learning_rate=0.001,
                                                       input_shape=(120, 320,
                                                                    3)),
            'training_args': {
                'batch_size': 16,
                'epochs': 100,
                'pctl_sampling': 'uniform',
                'pctl_thresholds': showdown_percentiles(),
            },
        }

    if False:
        task = {
            'task_id': task_id,
            'score_metric': 'loss',
            'dataset_path': 'showdown_full',
            'final': True,
            'model_config': {
                'model_uri': '/models/output/1480004259.h5',
                'scale': 16,
                'type': 'regression'
            },
            'training_args': {
                'batch_size': 32,
                'epochs': 40,
            },
        }

    if False:
        # sharp left vs center vs sharp right
        task = {
            'task_id':
            task_id,
            'dataset_path':
            'finale_full',
            'score_metric':
            'val_categorical_accuracy',
            'model_config':
            CategoricalModel.create(tmp_model_path,
                                    use_adadelta=True,
                                    W_l2=0.001,
                                    thresholds=[-0.061, 0.061]),
            'training_args': {
                'batch_size': 32,
                'epochs': 30,
                'pctl_sampling': 'uniform',
            },
        }

    if False:
        # half degree model
        task = {
            'task_id':
            task_id,
            'dataset_path':
            'finale_center',
            'model_config':
            CategoricalModel.create(tmp_model_path,
                                    use_adadelta=True,
                                    learning_rate=0.001,
                                    thresholds=np.linspace(-0.061, 0.061,
                                                           14)[1:-1],
                                    input_shape=(120, 320, 3)),
            'training_args': {
                'pctl_sampling': 'uniform',
                'batch_size': 32,
                'epochs': 20,
            },
        }

    if False:
        input_model_config = {
            'model_uri': 's3://sdc-matt/simple/1477715388/model.h5',
            'type': 'simple',
            'cat_classes': 5
        }

        ensemble_model_config = EnsembleModel.create(tmp_model_path,
                                                     input_model_config,
                                                     timesteps=3,
                                                     timestep_noise=0.1,
                                                     timestep_dropout=0.5)

        task = {
            'task_id': task_id,
            'dataset_path': 'final_training',
            'model_config': ensemble_model_config,
            'training_args': {
                'batch_size': 64,
                'epochs': 3
            },
        }

    if False:
        lstm_model_config = LstmModel.create(tmp_model_path, (10, 120, 320, 3),
                                             timesteps=10,
                                             W_l2=0.0001,
                                             scale=60.0)

        task = {
            'task_id': task_id,
            'dataset_path': 'showdown_full',
            'final': True,
            'model_config': lstm_model_config,
            'training_args': {
                'pctl_sampling': 'uniform',
                'batch_size': 32,
                'epochs': 10,
            },
        }

    handle_task(task)
    linguistic_model = AttentionModel(linguistic_cfg)
    linguistic_model.float().to(device)

    try:
        linguistic_model.load_state_dict(torch.load(args.linguistic_model))
    except:
        print(
            "Failed to load model from {} without device mapping. Trying to load with mapping to {}"
            .format(args.linguistic_model, device))
        linguistic_model.load_state_dict(
            torch.load(args.linguistic_model, map_location=device))
    """Defining loss and optimizer"""
    criterion = torch.nn.CrossEntropyLoss().to(device)

    model = EnsembleModel(acoustic_model, linguistic_model)

    model_run_path = MODEL_PATH + "/" + strftime("%Y-%m-%d_%H:%M:%S", gmtime())
    model_weights_path = "{}/{}".format(model_run_path, "ensemble_model.torch")
    result_path = "{}/result.txt".format(model_run_path)
    os.makedirs(model_run_path, exist_ok=True)
    """Choosing hardware"""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == "cuda":
        print(
            "Using GPU. Setting default tensor type to torch.cuda.FloatTensor")
        torch.set_default_tensor_type("torch.cuda.FloatTensor")
    else:
        print("Using CPU. Setting default tensor type to torch.FloatTensor")
        torch.set_default_tensor_type("torch.FloatTensor")
    """Converting model to specified hardware and format"""
    def test_fit_predict(self):
        X, y = gen_data(1000)
        base_model = LinearRegression()
        base_model.fit(X[:600], y['y'][:600])
        pred = base_model.predict(X[600:])
        base_score = mean_squared_error(y['y'][600:], pred)

        model = EnsembleModel([LinearRegression(), 
                               lgbm.sklearn.LGBMRegressor()], 
                               bagging_fraction=0.8,
                               model_cnt=20)
                               
        model.fit(X[:600], y['y'][:600])
        pred = model.predict(X[600:])
        ans_score = mean_squared_error(y[600:], pred)
        assert len(model.models) == 20
        assert len(pred) == len(X[600:])
        assert ans_score < base_score


        model = EnsembleModel([ConstModel(-1), 
                               ConstModel(1)], 
                               bagging_fraction=0.8,
                               model_cnt=5000)
        model.fit(X[:600], y['y'][:600])
        pred = model.predict(X[600:])
        assert len(set(pred)) == 1
        assert np.abs(pred[0]) < 0.1

        model = EnsembleModel([ConstModel(1), 
                               ConstModel(1)], 
                               bagging_fraction=0.8,
                               model_cnt=5000)
        model.fit(X[:600], y['y'][:600])
        pred = model.predict(X[600:])
        assert len(set(pred)) == 1
        assert pred[0] == 1

        model = EnsembleModel([lgbm.sklearn.LGBMClassifier(max_depth=3), 
                               lgbm.sklearn.LGBMClassifier()], 
                               bagging_fraction=0.8,
                               model_cnt=20)
        model.fit(X[:600], np.log(y['y'])[:600] > 0)
        pred = model.predict(X[600:])
        assert (pred >= 0).min()
        assert (pred <= 1).min()
def train(args):
    utils.make_all_dirs(current_time)
    if args.load_var:
        all_utterances, labels, word_dict = read_data(load_var=args.load_var,
                                                      input_=None,
                                                      mode='train')
        dev_utterances, dev_labels, _ = read_data(load_var=args.load_var,
                                                  input_=None,
                                                  mode='dev')
    else:
        all_utterances, labels, word_dict = read_data(load_var=args.load_var, \
                input_=os.path.join(constant.data_path, "entangled_train.json"), mode='train')
        dev_utterances, dev_labels, _ = read_data(load_var=args.load_var, \
                input_=os.path.join(constant.data_path, "entangled_dev.json"), mode='dev')

    word_emb = build_embedding_matrix(word_dict, glove_loc=args.glove_loc, \
                    emb_loc=os.path.join(constant.save_input_path, "word_emb.pk"), load_emb=False)

    if args.save_input:
        utils.save_or_read_input(os.path.join(constant.save_input_path, "train_utterances.pk"), \
                                    rw='w', input_obj=all_utterances)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "train_labels.pk"), \
                                    rw='w', input_obj=labels)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "word_dict.pk"), \
                                    rw='w', input_obj=word_dict)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "word_emb.pk"), \
                                    rw='w', input_obj=word_emb)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "dev_utterances.pk"), \
                                    rw='w', input_obj=dev_utterances)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "dev_labels.pk"), \
                                    rw='w', input_obj=dev_labels)

    train_dataloader = TrainDataLoader(all_utterances, labels, word_dict)
    if args.add_noise:
        noise_train_dataloader = TrainDataLoader(all_utterances,
                                                 labels,
                                                 word_dict,
                                                 add_noise=True)
    else:
        noise_train_dataloader = None
    dev_dataloader = TrainDataLoader(dev_utterances,
                                     dev_labels,
                                     word_dict,
                                     name='dev')

    logger_name = os.path.join(constant.log_path,
                               "{}.txt".format(current_time))
    LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    logging.basicConfig(format=LOG_FORMAT,
                        level=logging.INFO,
                        filename=logger_name,
                        filemode='w')
    logger = logging.getLogger()
    global log_head
    log_head = log_head + "Training Model: {}; ".format(args.model)
    if args.add_noise:
        log_head += "Add Noise: True; "
    logger.info(log_head)

    if args.model == 'T':
        ensemble_model_bidirectional = EnsembleModel(word_dict,
                                                     word_emb=word_emb,
                                                     bidirectional=True)
    elif args.model == 'TS':
        ensemble_model_bidirectional = EnsembleModel(word_dict,
                                                     word_emb=None,
                                                     bidirectional=True)
    else:
        ensemble_model_bidirectional = None
    if args.model == 'TS':
        ensemble_model_bidirectional.load_state_dict(
            torch.load(args.model_path))
    ensemble_model = EnsembleModel(word_dict,
                                   word_emb=word_emb,
                                   bidirectional=False)

    if torch.cuda.is_available():
        ensemble_model.cuda()
        if args.model == 'T' or args.model == 'TS':
            ensemble_model_bidirectional.cuda()

    supervised_trainer = SupervisedTrainer(args, ensemble_model, teacher_model=ensemble_model_bidirectional, \
                                                logger=logger, current_time=current_time)

    supervised_trainer.train(train_dataloader, noise_train_dataloader,
                             dev_dataloader)