Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--original_data",
        default=None,
        type=str,
        required=True,
        help="The input data file path."
        " Should be the .tsv file (or other data file) for the task.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the processed data will be written.")
    parser.add_argument(
        "--temp_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the intermediate processed data will be written."
    )
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task.")
    parser.add_argument("--log_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The log file path.")
    parser.add_argument(
        "--id_num_neg",
        default=None,
        type=int,
        required=True,
        help=
        "The number of admission ids that we want to use for negative category."
    )
    parser.add_argument(
        "--id_num_pos",
        default=None,
        type=int,
        required=True,
        help=
        "The number of admission ids that we want to use for positive category."
    )
    parser.add_argument("--random_seed",
                        default=1,
                        type=int,
                        required=True,
                        help="The random_seed for train/val/test split.")
    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )

    ## Other parameters
    parser.add_argument(
        "--Kfold",
        default=None,
        type=int,
        required=False,
        help="The number of folds that we want ot use for cross validation. "
        "Default is not doing cross validation")

    args = parser.parse_args()
    RANDOM_SEED = args.random_seed
    LOG_PATH = args.log_path
    TEMP_DIR = args.temp_dir

    if os.path.exists(TEMP_DIR) and os.listdir(TEMP_DIR):
        raise ValueError(
            "Temp Output directory ({}) already exists and is not empty.".
            format(TEMP_DIR))
    os.makedirs(TEMP_DIR, exist_ok=True)

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    original_df = pd.read_csv(args.original_data, header=None)
    original_df.rename(columns={
        0: "Adm_ID",
        1: "Note_ID",
        2: "chartdate",
        3: "charttime",
        4: "TEXT",
        5: "Label"
    },
                       inplace=True)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=True)

    write_log(("New Pre-processing Job Start! \n"
               "original_data: {}, output_dir: {}, temp_dir: {} \n"
               "task_name: {}, log_path: {}\n"
               "id_num_neg: {}, id_num_pos: {}\n"
               "random_seed: {}, bert_model: {}").format(
                   args.original_data, args.output_dir, args.temp_dir,
                   args.task_name, args.log_path, args.id_num_neg,
                   args.id_num_pos, args.random_seed, args.bert_model),
              LOG_PATH)

    for i in range(int(np.ceil(len(original_df) / 10000))):
        write_log("chunk {} tokenize start!".format(i), LOG_PATH)
        df_chunk = original_df.iloc[i * 10000:(i + 1) * 10000].copy()
        df_processed_chunk = preprocessing(df_chunk, tokenizer)
        df_processed_chunk = df_processed_chunk.astype({
            'Adm_ID': 'int64',
            'Note_ID': 'int64',
            'Label': 'int64'
        })
        temp_file_dir = os.path.join(TEMP_DIR, 'Processed_{}.csv'.format(i))
        df_processed_chunk.to_csv(temp_file_dir, index=False)

    df = pd.DataFrame({
        'Adm_ID': [],
        'Note_ID': [],
        'TEXT': [],
        'Input_ID': [],
        'Label': [],
        'chartdate': [],
        'charttime': []
    })
    for i in range(int(np.ceil(len(original_df) / 10000))):
        temp_file_dir = os.path.join(TEMP_DIR, 'Processed_{}.csv'.format(i))
        df_chunk = pd.read_csv(temp_file_dir, header=0)
        write_log("chunk {} has {} notes".format(i, len(df_chunk)), LOG_PATH)
        df = df.append(df_chunk, ignore_index=True)

    result = df.Label.value_counts()
    write_log(
        "In the full dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}"
        .format(result[1], result[0]), LOG_PATH)

    dead_ID = pd.Series(df[df.Label == 1].Adm_ID.unique())
    not_dead_ID = pd.Series(df[df.Label == 0].Adm_ID.unique())
    write_log(
        "Total Positive Patients' ids: {}, Total Negative Patients' ids: {}".
        format(len(dead_ID), len(not_dead_ID)), LOG_PATH)

    not_dead_ID_use = not_dead_ID.sample(n=args.id_num_neg,
                                         random_state=RANDOM_SEED)
    dead_ID_use = dead_ID.sample(n=args.id_num_pos, random_state=RANDOM_SEED)

    if args.Kfold is None:
        id_val_test_t = dead_ID_use.sample(frac=0.2, random_state=RANDOM_SEED)
        id_val_test_f = not_dead_ID_use.sample(frac=0.2,
                                               random_state=RANDOM_SEED)

        id_train_t = dead_ID_use.drop(id_val_test_t.index)
        id_train_f = not_dead_ID_use.drop(id_val_test_f.index)

        id_val_t = id_val_test_t.sample(frac=0.5, random_state=RANDOM_SEED)
        id_test_t = id_val_test_t.drop(id_val_t.index)
        id_val_f = id_val_test_f.sample(frac=0.5, random_state=RANDOM_SEED)
        id_test_f = id_val_test_f.drop(id_val_f.index)

        id_test = pd.concat([id_test_t, id_test_f])
        test_id_label = pd.DataFrame(data=list(
            zip(id_test, [1] * len(id_test_t) + [0] * len(id_test_f))),
                                     columns=['id', 'label'])

        id_val = pd.concat([id_val_t, id_val_f])
        val_id_label = pd.DataFrame(data=list(
            zip(id_val, [1] * len(id_val_t) + [0] * len(id_val_f))),
                                    columns=['id', 'label'])

        id_train = pd.concat([id_train_t, id_train_f])
        train_id_label = pd.DataFrame(data=list(
            zip(id_train, [1] * len(id_train_t) + [0] * len(id_train_f))),
                                      columns=['id', 'label'])

        mortality_train = df[df.Adm_ID.isin(train_id_label.id)]
        mortality_val = df[df.Adm_ID.isin(val_id_label.id)]
        mortality_test = df[df.Adm_ID.isin(test_id_label.id)]
        mortality_not_use = df[(~df.Adm_ID.isin(train_id_label.id))
                               & (~df.Adm_ID.isin(val_id_label.id)
                                  & (~df.Adm_ID.isin(test_id_label.id)))]

        train_result = mortality_train.Label.value_counts()

        val_result = mortality_val.Label.value_counts()

        test_result = mortality_test.Label.value_counts()

        no_result = mortality_not_use.Label.value_counts()

        mortality_train.to_csv(os.path.join(args.output_dir, 'train.csv'),
                               index=False)
        mortality_val.to_csv(os.path.join(args.output_dir, 'val.csv'),
                             index=False)
        mortality_test.to_csv(os.path.join(args.output_dir, 'test.csv'),
                              index=False)
        mortality_not_use.to_csv(os.path.join(args.output_dir, 'not_use.csv'),
                                 index=False)
        df.to_csv(os.path.join(args.output_dir, 'full.csv'), index=False)

        if len(no_result) == 2:
            write_log((
                "In the train dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                "In the validation dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                "In the test dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                "In the not use dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}"
            ).format(train_result[1], train_result[0], val_result[1],
                     val_result[0], test_result[1], test_result[0],
                     no_result[1], no_result[0]), LOG_PATH)
        else:
            try:
                write_log((
                    "In the train dataset Positive Patients' Notes: {}, Negative  Patients' Notes: {}\n"
                    "In the validation dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                    "In the test dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                    "In the not use dataset Negative Patients' Notes: {}"
                ).format(train_result[1], train_result[0], val_result[1],
                         val_result[0], test_result[1], test_result[0],
                         no_result[0]), LOG_PATH)
            except KeyError:
                write_log((
                    "In the train dataset Positive Patients' Notes: {}, Negative  Patients' Notes: {}\n"
                    "In the validation dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                    "In the test dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                    "In the not use dataset Positive Patients' Notes: {}"
                ).format(train_result[1], train_result[0], val_result[1],
                         val_result[0], test_result[1], test_result[0],
                         no_result[1]), LOG_PATH)

        write_log("Data saved in the {}".format(args.output_dir), LOG_PATH)
    else:
        folds_t = KFold(args.Kfold, False, RANDOM_SEED)
        folds_f = KFold(args.Kfold, False, RANDOM_SEED)
        dead_ID_use.reset_index(inplace=True, drop=True)
        not_dead_ID_use.reset_index(inplace=True, drop=True)
        for num, ((train_t, test_t), (train_f, test_f)) in enumerate(
                zip(folds_t.split(dead_ID_use),
                    folds_f.split(not_dead_ID_use))):
            id_train_t = dead_ID_use[train_t]
            id_val_test_t = dead_ID_use[test_t]
            id_train_f = not_dead_ID_use[train_f]
            id_val_test_f = not_dead_ID_use[test_f]
            id_val_t = id_val_test_t.sample(frac=0.5, random_state=RANDOM_SEED)
            id_test_t = id_val_test_t.drop(id_val_t.index)
            id_val_f = id_val_test_f.sample(frac=0.5, random_state=RANDOM_SEED)
            id_test_f = id_val_test_f.drop(id_val_f.index)

            id_test = pd.concat([id_test_t, id_test_f])
            test_id_label = pd.DataFrame(data=list(
                zip(id_test, [1] * len(id_test_t) + [0] * len(id_test_f))),
                                         columns=['id', 'label'])

            id_val = pd.concat([id_val_t, id_val_f])
            val_id_label = pd.DataFrame(data=list(
                zip(id_val, [1] * len(id_val_t) + [0] * len(id_val_f))),
                                        columns=['id', 'label'])

            id_train = pd.concat([id_train_t, id_train_f])
            train_id_label = pd.DataFrame(data=list(
                zip(id_train, [1] * len(id_train_t) + [0] * len(id_train_f))),
                                          columns=['id', 'label'])

            mortality_train = df[df.Adm_ID.isin(train_id_label.id)]
            mortality_val = df[df.Adm_ID.isin(val_id_label.id)]
            mortality_test = df[df.Adm_ID.isin(test_id_label.id)]
            mortality_not_use = df[(~df.Adm_ID.isin(train_id_label.id))
                                   & (~df.Adm_ID.isin(val_id_label.id)
                                      & (~df.Adm_ID.isin(test_id_label.id)))]

            train_result = mortality_train.Label.value_counts()

            val_result = mortality_val.Label.value_counts()

            test_result = mortality_test.Label.value_counts()

            no_result = mortality_not_use.Label.value_counts()

            os.makedirs(os.path.join(args.output_dir, str(num)))
            mortality_train.to_csv(os.path.join(args.output_dir, str(num),
                                                'train.csv'),
                                   index=False)
            mortality_val.to_csv(os.path.join(args.output_dir, str(num),
                                              'val.csv'),
                                 index=False)
            mortality_test.to_csv(os.path.join(args.output_dir, str(num),
                                               'test.csv'),
                                  index=False)
            mortality_not_use.to_csv(os.path.join(args.output_dir, str(num),
                                                  'not_use.csv'),
                                     index=False)
            df.to_csv(os.path.join(args.output_dir, str(num), 'full.csv'),
                      index=False)

            if len(no_result) == 2:
                write_log((
                    "In the {}th split of {} folds\n"
                    "In the train dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                    "In the validation dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                    "In the test dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                    "In the not use dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}"
                ).format(num, args.Kfold, train_result[1], train_result[0],
                         val_result[1], val_result[0], test_result[1],
                         test_result[0], no_result[1], no_result[0]), LOG_PATH)
            else:
                try:
                    write_log((
                        "In the {}th split of {} folds\n"
                        "In the train dataset Positive Patients' Notes: {}, Negative  Patients' Notes: {}\n"
                        "In the validation dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                        "In the test dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                        "In the not use dataset Negative Patients' Notes: {}"
                    ).format(num, args.Kfold, train_result[1], train_result[0],
                             val_result[1], val_result[0], test_result[1],
                             test_result[0], no_result[0]), LOG_PATH)
                except KeyError:
                    write_log((
                        "In the {}th split of {} folds\n"
                        "In the train dataset Positive Patients' Notes: {}, Negative  Patients' Notes: {}\n"
                        "In the validation dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                        "In the test dataset Positive Patients' Notes: {}, Negative Patients' Notes: {}\n"
                        "In the not use dataset Positive Patients' Notes: {}"
                    ).format(num, args.Kfold, train_result[1], train_result[0],
                             val_result[1], val_result[0], test_result[1],
                             test_result[0], no_result[1]), LOG_PATH)

            write_log(
                "Data saved in the {}".format(
                    os.path.join(args.output_dir, str(num))), LOG_PATH)
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")

    parser.add_argument("--train_data",
                        default=None,
                        type=str,
                        required=True,
                        help="The input training data file name."
                             " Should be the .tsv file (or other data file) for the task.")

    parser.add_argument("--val_data",
                        default=None,
                        type=str,
                        required=True,
                        help="The input validation data file name."
                             " Should be the .tsv file (or other data file) for the task.")

    parser.add_argument("--test_data",
                        default=None,
                        type=str,
                        required=True,
                        help="The input test data file name."
                             " Should be the .tsv file (or other data file) for the task.")

    parser.add_argument("--log_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The log file path.")

    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model checkpoints will be written.")

    parser.add_argument("--save_model",
                        default=False,
                        action='store_true',
                        help="Whether to save the model.")

    parser.add_argument("--bert_model",
                        default="bert-base-uncased",
                        type=str,
                        required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")

    parser.add_argument("--embed_mode",
                        default=None,
                        type=str,
                        required=True,
                        help="The embedding type selected in the list: all, note, chunk, no.")

    parser.add_argument("--task_name",
                        default="ClBERT_mortality_sm",
                        type=str,
                        required=True,
                        help="The name of the task.")

    ## Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--max_chunk_num",
                        default=64,
                        type=int,
                        help="The maximum total input chunk numbers after WordPiece tokenization.")
    parser.add_argument("--train_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--warmup_proportion",
                        default=0.0,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--num_train_epochs",
                        default=3,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumualte before performing a backward/update pass.")

    args = parser.parse_args()

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.save_model:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    LOG_PATH = args.log_path
    MAX_LEN = args.max_seq_length

    config = DotMap()
    config.hidden_dropout_prob = 0.1
    config.layer_norm_eps = 1e-12
    config.initializer_range = 0.02
    config.max_note_position_embedding = 1000
    config.max_chunk_position_embedding = 1000
    config.embed_mode = args.embed_mode
    config.layer_norm_eps = 1e-12
    config.hidden_size = 768

    config.task_name = args.task_name

    write_log(("New Job Start! \n"
               "Data directory: {}, Directory Code: {}, Save Model: {}\n"
               "Output_dir: {}, Task Name: {}, embed_mode: {}\n"
               "max_seq_length: {},  max_chunk_num: {}\n"
               "train_batch_size: {}, eval_batch_size: {}\n"
               "learning_rate: {}, warmup_proportion: {}\n"
               "num_train_epochs: {}, seed: {}, gradient_accumulation_steps: {}").format(args.data_dir,
                                                       args.data_dir.split('_')[-1],
                                                       args.save_model,
                                                       args.output_dir,
                                                       config.task_name,
                                                       config.embed_mode,
                                                       args.max_seq_length,
                                                       args.max_chunk_num,
                                                       args.train_batch_size,
                                                       args.eval_batch_size,
                                                       args.learning_rate,
                                                       args.warmup_proportion,
                                                       args.num_train_epochs,
                                                       args.seed,
                                                       args.gradient_accumulation_steps),
              LOG_PATH)

    content = "config setting: \n"
    for k, v in config.items():
        content += "{}: {} \n".format(k, v)
    write_log(content, LOG_PATH)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    write_log("Number of GPU is {}".format(n_gpu), LOG_PATH)
    for i in range(n_gpu):
        write_log(("Device Name: {},"
                   "Device Capability: {}").format(torch.cuda.get_device_name(i),
                                                   torch.cuda.get_device_capability(i)), LOG_PATH)

    train_file_path = os.path.join(args.data_dir, args.train_data)
    val_file_path = os.path.join(args.data_dir, args.val_data)
    test_file_path = os.path.join(args.data_dir, args.test_data)
    train_df = pd.read_csv(train_file_path)
    val_df = pd.read_csv(val_file_path)
    test_df = pd.read_csv(test_file_path)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)

    write_log("Tokenize Start!", LOG_PATH)
    train_labels, train_inputs, train_masks, train_note_ids = Tokenize_with_note_id(train_df, MAX_LEN, tokenizer)
    validation_labels, validation_inputs, validation_masks, validation_note_ids = Tokenize_with_note_id(val_df, MAX_LEN,
                                                                                                        tokenizer)
    test_labels, test_inputs, test_masks, test_note_ids = Tokenize_with_note_id(test_df, MAX_LEN, tokenizer)
    write_log("Tokenize Finished!", LOG_PATH)
    train_inputs = torch.tensor(train_inputs)
    validation_inputs = torch.tensor(validation_inputs)
    test_inputs = torch.tensor(test_inputs)
    train_labels = torch.tensor(train_labels)
    validation_labels = torch.tensor(validation_labels)
    test_labels = torch.tensor(test_labels)
    train_masks = torch.tensor(train_masks)
    validation_masks = torch.tensor(validation_masks)
    test_masks = torch.tensor(test_masks)
    write_log(("train dataset size is %d,\n"
               "validation dataset size is %d,\n"
               "test dataset size is %d") % (len(train_inputs), len(validation_inputs), len(test_inputs)), LOG_PATH)

    (train_labels, train_inputs,
     train_masks, train_ids,
     train_note_ids, train_chunk_ids) = concat_by_id_list_with_note_chunk_id(train_df, train_labels,
                                                                             train_inputs, train_masks,
                                                                             train_note_ids, MAX_LEN)
    (validation_labels, validation_inputs,
     validation_masks, validation_ids,
     validation_note_ids, validation_chunk_ids) = concat_by_id_list_with_note_chunk_id(val_df, validation_labels,
                                                                                       validation_inputs,
                                                                                       validation_masks,
                                                                                       validation_note_ids, MAX_LEN)
    (test_labels, test_inputs,
     test_masks, test_ids,
     test_note_ids, test_chunk_ids) = concat_by_id_list_with_note_chunk_id(test_df, test_labels,
                                                                           test_inputs, test_masks,
                                                                           test_note_ids, MAX_LEN)

    model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=1)
    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
    num_train_steps = int(
        len(train_df) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_steps)

    m = torch.nn.Sigmoid()

    start = time.time()
    # Store our loss and accuracy for plotting
    train_loss_set = []

    # Number of training epochs (authors recommend between 2 and 4)
    epochs = args.num_train_epochs

    train_batch_generator = mask_batch_generator(args.max_chunk_num, train_inputs, train_labels, train_masks)
    validation_batch_generator = mask_batch_generator(args.max_chunk_num, validation_inputs, validation_labels,
                                                      validation_masks)
    write_log("Training start!", LOG_PATH)
    # trange is a tqdm wrapper around the normal python range
    with torch.autograd.set_detect_anomaly(True):
        for epoch in trange(epochs, desc="Epoch"):
            # Training

            # Set our model to training mode (as opposed to evaluation mode)
            model.train()

            # Tracking variables
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0

            # Train the data for one epoch
            tr_ids_num = len(train_ids)
            tr_batch_loss = []
            for step in range(tr_ids_num):
                b_input_ids, b_labels, b_input_mask = next(train_batch_generator)
                b_input_ids = b_input_ids.to(device)
                b_input_mask = b_input_mask.to(device)
                b_labels = b_labels.repeat(b_input_ids.shape[0]).to(device)
                # Forward pass
                outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
                loss, logits = outputs[:2]

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                train_loss_set.append(loss.item())
                # Backward pass
                loss.backward()
                # Update parameters and take a step using the computed gradient
                if (step + 1) % args.train_batch_size == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    train_loss_set.append(np.mean(tr_batch_loss))
                    tr_batch_loss = []

                # Update tracking variables
                tr_loss += loss.item()
                nb_tr_examples += b_input_ids.size(0)
                nb_tr_steps += 1

                del outputs, b_input_ids, b_input_mask, b_labels
                torch.cuda.empty_cache()

            write_log("Train loss: {}".format(tr_loss / nb_tr_steps), LOG_PATH)

            # Validation

            # Put model in evaluation mode to evaluate loss on the validation set
            model.eval()

            # Tracking variables
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            # Evaluate data for one epoch
            ev_ids_num = len(validation_ids)
            for step in range(ev_ids_num):
                with torch.no_grad():
                    b_input_ids, b_labels, b_input_mask = next(validation_batch_generator)
                    b_input_ids = b_input_ids.to(device)
                    b_input_mask = b_input_mask.to(device)
                    b_labels = b_labels.repeat(b_input_ids.shape[0])
                    outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
                    # Move logits and labels to CPU
                    logits = torch.squeeze(m(outputs)).detach().cpu().numpy()
                    label_ids = b_labels.numpy()
                    tmp_eval_accuracy = flat_accuracy(logits, label_ids)
                    eval_accuracy += tmp_eval_accuracy
                    nb_eval_steps += 1

            write_log("Validation Accuracy: {}".format(eval_accuracy / nb_eval_steps), LOG_PATH)
            output_checkpoints_path = os.path.join(args.output_dir,
                                                   "bert_fine_tuned_with_note_checkpoint_%d.pt" % epoch)
            if args.save_model:
                if n_gpu > 1:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.module.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss,
                    },
                        output_checkpoints_path)

                else:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss,
                    },
                        output_checkpoints_path)
    end = time.time()

    write_log("total training time is: {}s".format(end - start), LOG_PATH)

    fig1 = plt.figure(figsize=(15, 8))
    plt.title("Training loss")
    plt.xlabel("Chunk Batch")
    plt.ylabel("Loss")
    plt.plot(train_loss_set)
    if args.save_model:
        output_fig_path = os.path.join(args.output_dir, "bert_fine_tuned_with_note_training_loss.png")
        plt.savefig(output_fig_path, dpi=fig1.dpi)
        output_model_state_dict_path = os.path.join(args.output_dir,
                                                    "bert_fine_tuned_with_note_state_dict.pt")
        if n_gpu > 1:
            torch.save(model.module.state_dict(), output_model_state_dict_path)
        else:
            torch.save(model.state_dict(), output_model_state_dict_path)
        write_log("Model saved!", LOG_PATH)
    else:
        output_fig_path = os.path.join(args.output_dir,
                                       "bert_fine_tuned_with_note_training_loss_{}_{}.png".format(
                                           args.seed,
                                           args.data_dir.split(
                                               '_')[-1]))
        plt.savefig(output_fig_path, dpi=fig1.dpi)
        write_log("Model not saved as required", LOG_PATH)

    # Prediction on test set

    # Put model in evaluation mode
    model.eval()

    # Tracking variables
    predictions, true_labels = [], []

    # Predict
    te_ids_num = len(test_ids)
    for step in range(te_ids_num):
        b_input_ids = test_inputs[step][-args.max_chunk_num:, :].to(device)
        b_input_mask = test_masks[step][-args.max_chunk_num:, :].to(device)
        b_labels = test_labels[step].repeat(b_input_ids.shape[0])
        # Telling the model not to compute or store gradients, saving memory and speeding up prediction
        with torch.no_grad():
            # Forward pass, calculate logit predictions
            outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)

        # Move logits and labels to CPU
        # outputs' shape: [batch size, 1]
        logits = torch.squeeze(m(outputs)).detach().cpu().numpy().mean()
        label_ids = b_labels.numpy().max()

        # Store predictions and true labels
        predictions.append(logits)
        true_labels.append(label_ids)

    flat_logits = predictions
    flat_predictions = (np.array(flat_logits) >= 0.5).astype(np.int)
    flat_true_labels = true_labels

    output_df = pd.DataFrame({'logits': flat_logits,
                              'pred_label': flat_predictions,
                              'label': flat_true_labels,
                              'Adm_ID': test_ids})
    if args.save_model:
        output_df.to_csv(os.path.join(args.output_dir, 'test_predictions.csv'), index=False)
    else:
        output_df.to_csv(os.path.join(args.output_dir,
                                      'test_predictions_{}_{}.csv'.format(args.seed,
                                                                          args.data_dir.split('_')[-1])),
                         index=False)
    write_performance(flat_true_labels, flat_predictions, flat_logits, config, args)
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )

    parser.add_argument(
        "--train_data",
        default=None,
        type=str,
        required=True,
        help="The input training data file name."
        " Should be the .tsv file (or other data file) for the task.")

    parser.add_argument(
        "--val_data",
        default=None,
        type=str,
        required=True,
        help="The input validation data file name."
        " Should be the .tsv file (or other data file) for the task.")

    parser.add_argument(
        "--test_data",
        default=None,
        type=str,
        required=True,
        help="The input test data file name."
        " Should be the .tsv file (or other data file) for the task.")

    parser.add_argument("--log_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The log file path.")

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    args = parser.parse_args()
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    LOG_PATH = args.log_path
    MAX_LEN = args.max_seq_length

    write_log(
        ("New Split Job Start! \n"
         "data_dir: {}, train_data: {}, val_data: {}, test_data: {} \n"
         "log_path: {}, output_dir: {}, max_seq_length: {}").format(
             args.data_dir, args.train_data, args.val_data, args.test_data,
             args.log_path, args.output_dir, args.max_seq_length), LOG_PATH)

    train_file_path = os.path.join(args.data_dir, args.train_data)
    val_file_path = os.path.join(args.data_dir, args.val_data)
    test_file_path = os.path.join(args.data_dir, args.test_data)
    train_df = pd.read_csv(train_file_path)
    val_df = pd.read_csv(val_file_path)
    test_df = pd.read_csv(test_file_path)

    new_train_df = split_into_chunks(train_df, MAX_LEN)
    new_val_df = split_into_chunks(val_df, MAX_LEN)
    new_test_df = split_into_chunks(test_df, MAX_LEN)

    train_result = new_train_df.Label.value_counts()
    val_result = new_val_df.Label.value_counts()
    test_result = new_test_df.Label.value_counts()

    write_log((
        "In the train dataset Positive Patients' Chunks: {}, Negative Patients' Chunks: {}\n"
        "In the validation dataset Positive Patients' Chunks: {}, Negative Patients' Chunks: {}\n"
        "In the test dataset Positive Patients' Chunks: {}, Negative Patients' Chunks: {}"
    ).format(train_result[1], train_result[0], val_result[1], val_result[0],
             test_result[1], test_result[0]), LOG_PATH)

    new_train_df.to_csv(os.path.join(args.output_dir, args.train_data),
                        index=False)
    new_val_df.to_csv(os.path.join(args.output_dir, args.val_data),
                      index=False)
    new_test_df.to_csv(os.path.join(args.output_dir, args.test_data),
                       index=False)

    write_log("Split finished", LOG_PATH)