Exemple #1
0
def __train():
    log_file = os.path.join(config.LOG_DIR, '{}-{}-{}-{}.log'.format(os.path.splitext(
        os.path.basename(__file__))[0], args.idx, str_today, config.MACHINE_NAME))
    # log_file = None
    init_universal_logging(log_file, mode='a', to_stdout=True)
    logging.info('logging to {}'.format(log_file))

    train_config = srlfetexp.TrainConfig(loss_name='mm', neg_scale=0.1, n_steps=70000)

    lstm_dim = 250
    mlp_hidden_dim = 500
    type_embed_dim = 500
    word_vecs_file = config.WIKI_FETEL_WORDVEC_FILE

    # dataset = 'figer'
    dataset = 'bbn'
    datafiles = config.FIGER_FILES if dataset == 'figer' else config.BBN_FILES

    data_prefix = datafiles['srl-train-data-prefix']
    dev_data_pkl = data_prefix + '-dev.pkl'
    train_data_pkl = data_prefix + '-train.pkl'

    test_file_tup = (datafiles['test-mentions'], datafiles['test-sents'],
                     datafiles['test-sents-dep'], datafiles['test-srl'])
    single_type_path = False if dataset == 'figer' else True

    # output_model_file = None
    save_model_file_prefix = os.path.join(config.DATA_DIR, 'models/srl-{}'.format(dataset))

    gres = expdata.ResData(datafiles['type-vocab'], word_vecs_file)
    logging.info('dataset={} {}'.format(dataset, data_prefix))
    srlfetexp.train_srlfet(device, gres, train_data_pkl, dev_data_pkl, None, test_file_tup, lstm_dim, mlp_hidden_dim,
                           type_embed_dim, train_config, single_type_path,
                           save_model_file_prefix=save_model_file_prefix)
Exemple #2
0
def __train3():
    log_file = os.path.join(config.LOG_DIR, '{}-{}-{}-{}.log'.format(os.path.splitext(
        os.path.basename(__file__))[0], args.idx, str_today, config.MACHINE_NAME))
    # log_file = None
    init_universal_logging(log_file, mode='a', to_stdout=True)
    logging.info('logging to {}'.format(log_file))

    margin = 1.0
    train_config = srlfetexp.TrainConfig(
        pos_margin=margin, neg_margin=margin, neg_scale=1.0, batch_size=128, schedule_lr=True, n_steps=70000)

    lstm_dim = 250
    mlp_hidden_dim = 500
    type_embed_dim = 500
    word_vecs_file = config.WIKI_FETEL_WORDVEC_FILE

    dataset = 'figer'
    # dataset = 'bbn'
    datafiles = config.FIGER_FILES if dataset == 'figer' else config.BBN_FILES

    data_prefix = datafiles['srl-train-data-prefix']
    dev_data_pkl = data_prefix + '-dev.pkl'
    train_data_pkl = data_prefix + '-train.pkl'

    test_file_tup = (datafiles['test-mentions'], datafiles['test-sents'],
                     datafiles['test-sents-dep'], datafiles['test-srl'])
    single_type_path = False if dataset == 'figer' else True

    # output_model_file = None
    save_model_file_prefix = os.path.join(config.DATA_DIR, 'models/srl3-{}'.format(dataset))

    val_mentions_file = os.path.join(config.DATA_DIR, 'figer/wiki-valcands-figer-mentions.json')
    val_sents_file = os.path.join(config.DATA_DIR, 'figer/wiki-valcands-figer-sents.json')
    val_srl_file = os.path.join(config.DATA_DIR, 'figer/wiki-valcands-figer-srl.txt')
    val_dep_file = os.path.join(config.DATA_DIR, 'figer/wiki-valcands-figer-tok-dep.txt')
    val_manual_label_file = os.path.join(config.DATA_DIR, 'figer/figer-dev-man-labeled.txt')
    # manual_val_file_tup = (val_mentions_file, val_sents_file, val_dep_file, val_srl_file, val_manual_label_file)
    manual_val_file_tup = None

    gres = expdata.ResData(datafiles['type-vocab'], word_vecs_file)
    logging.info('dataset={} {}'.format(dataset, data_prefix))
    srlfetexp.train_srlfet(
        device, gres, train_data_pkl, dev_data_pkl, manual_val_file_tup, test_file_tup, lstm_dim, mlp_hidden_dim,
        type_embed_dim, train_config, single_type_path, save_model_file_prefix=save_model_file_prefix)
Exemple #3
0
def train_model():
    if config.dataset == 'cufe':
        data_prefix = config.CUFE_FILES['training_data_prefix']
        if config.train_on_crowd:
            data_prefix = config.CUFE_FILES['crowd_training_data_prefix']

    elif config.dataset == 'ufet':
        data_prefix = config.UFET_FILES['training_data_prefix']

    data_prefix += f'-{config.seq_tokenizer_name}'
    if config.dir_suffix:
        data_prefix += '-' + config.dir_suffix
    if not os.path.isdir(data_prefix): os.mkdir(data_prefix)

    str_today = datetime.datetime.now().strftime('%m_%d_%H%M')
    if not os.path.isdir(join(data_prefix, 'log')):
        os.mkdir(join(data_prefix, 'log'))
    if not config.test:
        log_file = os.path.join(
            join(data_prefix, 'log'),
            '{}-{}.log'.format(str_today, config.model_name))
        print(log_file)
    else:
        log_file = os.path.join(
            config.LOG_DIR, '{}-{}-test.log'.format(str_today,
                                                    config.model_name))
    # if not os.path.isdir(log_file) and not config.test: os.mkdir(log_file)
    init_universal_logging(log_file, mode='a', to_stdout=True)

    save_model_dir = join(data_prefix, 'models')
    if not os.path.isdir(save_model_dir): os.mkdir(save_model_dir)

    gres = exp_utils.GlobalRes(config)

    run_name = f'{config.dataset}-{config.model_name}-{config.seq_tokenizer_name}-{str_today}'
    logging.info(f'run_name: {run_name}')
    logging.info(
        f'/data/cleeag/word_embeddings/{config.mention_tokenizer_name}/{config.mention_tokenizer_name}_tokenizer&vecs.pkl -- loaded'
    )
    logging.info(f'training on {config.dataset}')
    logging.info(
        f'total type count: {len(gres.type2type_id_dict)}, '
        f'general type count: {0 if config.without_general_types else len(gres.general_type_set)}'
    )

    if config.dataset == 'ufet':
        crowd_training_samples = f'{config.CROWD_TRAIN_DATA_PREFIX}-{config.seq_tokenizer_name}.pkl'
        if config.test:

            train_data_pkl = join(data_prefix, 'dev.pkl')
            training_samples = datautils.load_pickle_data(train_data_pkl)
            crowd_training_samples = datautils.load_pickle_data(
                crowd_training_samples)

        else:
            train_data_pkl = join(data_prefix, 'train.pkl')
            print('loading training data {} ...'.format(train_data_pkl),
                  end=' ',
                  flush=True)
            training_samples = datautils.load_pickle_data(train_data_pkl)
            print('done', flush=True)
            logging.info('training data {} -- loaded'.format(train_data_pkl))

            crowd_training_samples = datautils.load_pickle_data(
                crowd_training_samples)

            if config.fine_tune and config.use_bert:
                # training_samples = random.choices(len(training_samples) // 10, training_samples)
                random.shuffle(training_samples)
                training_samples = training_samples[:len(training_samples) //
                                                    10]
                logging.info(
                    f'fining tuning with {len(training_samples)} samples')

        # dev_data_pkl = join(data_prefix, 'dev.pkl')
        # dev_samples = datautils.load_pickle_data(dev_data_pkl)
        dev_json_path = join(data_prefix, 'dev.json')
        dev_samples = exp_utils.model_samples_from_json(dev_json_path)
        dev_true_labels_dict = {
            s['mention_id']: [
                gres.type2type_id_dict.get(x) for x in s['types']
                if x in gres.type2type_id_dict
            ]
            for s in dev_samples
        }

        test_data_pkl = join(data_prefix, 'test.pkl')
        test_samples = datautils.load_pickle_data(test_data_pkl)
        test_true_labels_dict = {
            s['mention_id']: [
                gres.type2type_id_dict.get(x) for x in s['types']
                if x in gres.type2type_id_dict
            ]
            for s in test_samples
        }

    else:
        dev_data_pkl = join(data_prefix, 'dev.pkl')
        test_data_pkl = join(data_prefix, 'test.pkl')
        if config.test:
            train_data_pkl = join(data_prefix, 'dev.pkl')
        else:
            train_data_pkl = join(data_prefix, 'train.pkl')
        # test_data_pkl = config.CUFE_FILES['test_data_file_prefix'] + f'-{config.seq_tokenizer_name}/test.pkl'

        print('loading training data {} ...'.format(train_data_pkl),
              end=' ',
              flush=True)
        training_samples = datautils.load_pickle_data(train_data_pkl)
        print('done', flush=True)
        logging.info('training data {} -- loaded'.format(train_data_pkl))

        if not config.train_on_crowd:
            # crowd_training_samples = f'{config.CROWD_TRAIN_DATA_PREFIX}-{config.seq_tokenizer_name}/train.pkl'
            crowd_training_samples = datautils.load_pickle_data(
                join(data_prefix, 'crowd-train.pkl'))

        print('loading dev data {} ...'.format(dev_data_pkl),
              end=' ',
              flush=True)
        dev_samples = datautils.load_pickle_data(dev_data_pkl)
        print('done', flush=True)
        dev_true_labels_dict = {
            s['mention_id']: [
                gres.type2type_id_dict.get(x)
                for x in exp_utils.general_mapping(s['types'], gres)
            ]
            for s in dev_samples
        }

        test_samples = datautils.load_pickle_data(test_data_pkl)
        test_true_labels_dict = {
            s['mention_id']: [
                gres.type2type_id_dict.get(x)
                for x in exp_utils.general_mapping(s['types'], gres)
            ]
            for s in test_samples
        }

    logging.info(
        f'total training samples: {len(training_samples)}, '
        f'dev samples: {len(dev_samples)}, testing samples: {len(test_samples)}'
    )

    if not config.test:
        result_dir = join(data_prefix, f'{str_today}-results')
        if config.dataset == 'cufe':
            type_scope = 'general_types' if config.only_general_types else 'all_types'
        else:
            type_scope = config.dataset_type
        dev_results_file = join(
            result_dir,
            f'dev-{config.model_name}-{type_scope}-results-{config.inference_threshhold}.txt'
        )
        dev_incorrect_results_file = join(
            result_dir,
            f'dev-{config.model_name}-{type_scope}-incorrect_results-{config.inference_threshhold}.txt'
        )
        test_results_file = join(
            result_dir,
            f'test-{config.model_name}-{type_scope}-results-{config.inference_threshhold}.txt'
        )
        test_incorrect_results_file = join(
            result_dir,
            f'test-{config.model_name}-{type_scope}-incorrect_results-{config.inference_threshhold}.txt'
        )
    else:
        result_dir = join(data_prefix, f'test-results')
        dev_results_file = join(result_dir, f'dev-results.txt')
        dev_incorrect_results_file = join(result_dir,
                                          f'dev-incorrect_results.txt')
        test_results_file = join(result_dir, f'test-results.txt')
        test_incorrect_results_file = join(result_dir,
                                           f'test-incorrect_results.txt')

    if not os.path.isdir(result_dir): os.mkdir(result_dir)

    logging.info(
        'use_bert = {}, use_lstm = {}, use_mlp={}, bert_param_frozen={}, bert_fine_tune={}'
        .format(config.use_bert, config.use_lstm, config.use_mlp,
                config.freeze_bert, config.fine_tune))
    logging.info(
        'type_embed_dim={} contextt_lstm_hidden_dim={} pmlp_hdim={}'.format(
            config.type_embed_dim, config.lstm_hidden_dim,
            config.pred_mlp_hdim))

    # setup training
    device = torch.device(
        f'cuda:{config.gpu_ids[0]}') if torch.cuda.device_count() > 0 else None
    device_name = torch.cuda.get_device_name(config.gpu_ids[0])

    logging.info(f'running on device: {device_name}')
    logging.info('building model...')

    model = fet_model(config, device, gres)
    logging.info(f'transfer={config.transfer}')

    if config.continue_train:
        model_path = config.CONTINUE_TRAINING_PATH[config.continue_train]
        logging.info(f'loading checkpoint from {model_path}')
        trained_weights = torch.load(model_path)
        trained_weights = {
            '.'.join(k.split('.')[1:]): v
            for k, v in trained_weights.items()
        }
        cur_model_dict = model.state_dict()
        cur_model_dict.update(trained_weights)
        model.load_state_dict(cur_model_dict)

    if config.transfer and config.use_lstm:
        logging.info(f'loading checkpoint from {config.TRANSFER_MODEL_PATH}')
        cur_model_dict = model.state_dict()
        trained_weights = torch.load(config.TRANSFER_MODEL_PATH)
        trained_weights_bilstm = {
            '.'.join(k.split('.')[1:]): v
            for k, v in trained_weights.items() if 'bi_lstm' in k
        }
        cur_model_dict.update(trained_weights_bilstm)
        model.load_state_dict(cur_model_dict)

    model.to(device)
    model = torch.nn.DataParallel(model, device_ids=config.gpu_ids)

    batch_size = 32 if config.dataset == 'cufe' and config.train_on_crowd else config.batch_size
    n_iter = 150 if config.dataset == 'cufe' and config.train_on_crowd else config.n_iter
    n_batches = (len(training_samples) + batch_size - 1) // batch_size
    n_steps = n_iter * n_batches
    eval_cycle = config.eval_cycle

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    losses = list()
    best_dev_acc = -1
    best_maf1_v = -1
    step = 0
    steps_since_last_best = 0

    # start training
    logging.info('{}'.format(model.__class__.__name__))
    logging.info('training batch size: {}'.format(batch_size))
    logging.info(
        '{} epochs, {} steps, {} steps per iter, learning rate={}, lr_decay={}, start training ...'
        .format(n_iter, n_steps, n_batches, config.learning_rate,
                config.lr_gamma))

    while True:
        if step == n_steps:
            break

        batch_idx = step % n_batches
        batch_beg, batch_end = batch_idx * batch_size, min(
            (batch_idx + 1) * batch_size, len(training_samples))
        if config.dataset == 'ufet':
            batch_samples = training_samples[batch_beg:batch_end - batch_size * 2 // 3] \
                            + random.choices(crowd_training_samples, k=batch_size * 1 // 3)
        elif config.dataset == 'cufe':
            if not config.train_on_crowd:
                batch_samples = training_samples[batch_beg:batch_end - batch_size * 2 // 3] \
                                + random.choices(crowd_training_samples, k=batch_size * 1 // 3)
            else:
                batch_samples = training_samples[batch_beg:batch_end]

        try:
            input_dataset, type_vecs = exp_utils.samples_to_tensor(
                config, gres, batch_samples)

            input_dataset = tuple(x.to(device) for x in input_dataset)
            type_vecs = type_vecs.to(device)
            model.module.train()
            logits = model(input_dataset, gres)
        except:
            step += 1
            continue

        if config.dataset == 'ufet':
            loss = model.module.define_loss(logits, type_vecs,
                                            config.dataset_type)
        elif config.GENERAL_TYPES_MAPPING and not config.only_general_types:
            loss = model.module.get_uw_loss(logits, type_vecs, gres)
        else:
            loss = model.module.get_loss(logits, type_vecs)
        optimizer.zero_grad()

        loss.backward()
        if config.use_lstm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0,
                                           float('inf'))
        optimizer.step()
        losses.append(loss.data.cpu().numpy())
        step += 1
        if step % eval_cycle == 0 and step > 0:
            print('\nevaluating...')
            l_v, acc_v, pacc_v, maf1_v, ma_p_v, ma_r_v, mif1_v, dev_results, incorrect_dev_result_objs = \
                model_utils.eval_fetel(config, gres, model, dev_samples, dev_true_labels_dict)

            l_t, acc_t, pacc_t, maf1_t, ma_p_t, ma_r_t, mif1_t, test_results, incorrect_test_result_objs = \
                model_utils.eval_fetel(config, gres, model, test_samples, test_true_labels_dict)

            if maf1_v > best_maf1_v:
                steps_since_last_best = 0
            best_tag = '*' if maf1_v > best_maf1_v else ''
            logging.info(
                'run name={}, step={}/{}, learning rate={}, losses={:.4f}, steps_since_last_best={}'
                .format(run_name,
                        step, n_steps, optimizer.param_groups[0]['lr'],
                        sum(losses), steps_since_last_best))
            logging.info(
                'dev evaluation result: '
                'l_v={:.4f} acc_v={:.4f} pacc_v={:.4f} macro_f1_v={:.4f} micro_f1_v={:.4f}{}'
                .format(l_v, acc_v, pacc_v, maf1_v, mif1_v, best_tag))
            logging.info(
                f'dev evaluation result: macro_p={ma_p_v:.4f}, macro_r={ma_r_v:.4f}'
            )

            logging.info(
                'test evaluation result: '
                'l_v={:.4f} acc_t={:.4f} pacc_t={:.4f} macro_f1_t={:.4f} micro_f1_t={:.4f}{}'
                .format(l_t, acc_t, pacc_t, maf1_t, mif1_t, best_tag))
            logging.info(
                f'test evaluation result: macro_p={ma_p_t:.4f}, macro_r={ma_r_t:.4f}'
            )

            if maf1_v > best_maf1_v:
                if save_model_dir and not config.test:
                    save_model_file = join(
                        save_model_dir, f'{config.model_name}-{str_today}.pt')
                    torch.save(model.state_dict(), save_model_file)
                    logging.info('model saved to {}'.format(save_model_file))

                logging.info(
                    'prediction result saved to {}'.format(result_dir))
                datautils.save_json_objs(dev_results, dev_results_file)
                datautils.save_json_objs(incorrect_dev_result_objs,
                                         dev_incorrect_results_file)
                datautils.save_json_objs(test_results, test_results_file)
                datautils.save_json_objs(incorrect_test_result_objs,
                                         test_incorrect_results_file)
                # best_dev_acc = acc_v
                best_maf1_v = maf1_v

            losses = list()

        steps_since_last_best += 1
Exemple #4
0
                                         test_incorrect_results_file)
                # best_dev_acc = acc_v
                best_maf1_v = maf1_v

            losses = list()

        steps_since_last_best += 1


if __name__ == '__main__':
    torch.random.manual_seed(config.RANDOM_SEED)
    np.random.seed(config.NP_RANDOM_SEED)
    random.seed(config.PY_RANDOM_SEED)

    str_today = datetime.datetime.now().strftime('%m_%d_%H%M')
    model_used = 'use_bert' if config.use_bert else 'use_lstm'
    if not os.path.isdir(config.LOG_DIR): os.mkdir(config.LOG_DIR)
    if not config.test:
        log_file = os.path.join(
            config.LOG_DIR, '{}-{}_{}.log'.format(
                os.path.splitext(os.path.basename(__file__))[0], str_today,
                model_used))
    else:
        log_file = os.path.join(
            config.LOG_DIR, '{}-{}_{}_test.log'.format(
                os.path.splitext(os.path.basename(__file__))[0], str_today,
                model_used))
    init_universal_logging(log_file, mode='a', to_stdout=True)
    train_model()
    # model_utils.check_breakdown_performance()