Пример #1
0
def train(config):
    train_data = pickle.load(open(os.path.join(config.data_path, config.train_name), "rb"))
    dev_data = pickle.load(open(os.path.join(config.data_path, config.dev_name), "rb"))
    test_data = pickle.load(open(os.path.join(config.data_path, config.test_name), "rb"))
    vocabulary = pickle.load(open(os.path.join(config.data_path, config.vocabulary_name), "rb"))
    # load w2v data
    weight = pickle.load(open(os.path.join(config.data_path, config.weight_name), "rb"))

    if config.task_name == "lstm":
        text_model = LSTM(vocab_size=len(vocabulary), embed_dim=config.embed_dim,
                         output_dim=config.class_num, hidden_dim=config.hidden_dim,
                         num_layers=config.num_layers, dropout=config.dropout)
    elif config.task_name == "lstm_maxpool":
        text_model = LSTM_maxpool(vocab_size=len(vocabulary), embed_dim=config.embed_dim,
                         		  output_dim=config.class_num, hidden_dim=config.hidden_dim,
                         		  num_layers=config.num_layers, dropout=config.dropout)
    elif config.task_name == "rnn":
        text_model = RNN(vocab_size=len(vocabulary), embed_dim=config.embed_dim,
                         output_dim=config.class_num, hidden_dim=config.hidden_dim,
                         num_layers=config.num_layers, dropout=config.dropout)
    elif config.task_name == "cnn":
        text_model = CNN(vocab_size=len(vocabulary), embed_dim=config.embed_dim,
                         class_num=config.class_num, kernel_num=config.kernel_num,
                         kernel_sizes=config.kernel_sizes, dropout=config.dropout,
                         static=config.static, in_channels=config.in_channels)
    elif config.task_name == "cnn_w2v":
        text_model = CNN_w2v(vocab_size=len(vocabulary), embed_dim=config.embed_dim,
                             class_num=config.class_num, kernel_num=config.kernel_num,
                             kernel_sizes=config.kernel_sizes, dropout=config.dropout,
                             static=config.static, in_channels=config.in_channels,
                             weight=weight)
    elif config.task_name == "rcnn":
        text_model = RCNN(vocab_size=len(vocabulary), embed_dim=config.embed_dim, 
                          output_dim=config.class_num, hidden_dim=config.hidden_dim, 
                          num_layers=config.num_layers, dropout=config.dropout)
    optimizer = Adam(lr=config.lr, weight_decay=config.weight_decay)
    timing = TimingCallback()
    early_stop = EarlyStopCallback(config.patience)
    accuracy = AccuracyMetric(pred='output', target='target')

    trainer = Trainer(train_data=train_data, model=text_model, loss=CrossEntropyLoss(),
                      batch_size=config.batch_size, check_code_level=0,
                      metrics=accuracy, n_epochs=config.epoch,
                      dev_data=dev_data, save_path=config.save_path,
                      print_every=config.print_every, validate_every=config.validate_every,
                      optimizer=optimizer, use_tqdm=False,
                      device=config.device, callbacks=[timing, early_stop])
    trainer.train()

    # test result
    tester = Tester(test_data, text_model, metrics=accuracy)
    tester.test()
Пример #2
0
def train():
    config = Config()

    train_data, dev_data, vocabulary = get_dataset(config.data_path)

    poetry_model = PoetryModel(vocabulary_size=len(vocabulary),
                               embedding_size=config.embedding_size,
                               hidden_size=config.hidden_size)
    loss = Loss(pred='output', target='target')
    perplexity = Perplexity(pred='output', target='target')

    print("optimizer:", config.optimizer)
    print("momentum:", config.momentum)
    if config.optimizer == 'adam':
        optimizer = Adam(lr=config.lr, weight_decay=config.weight_decay)
    elif config.optimizer == 'sgd':
        optimizer = SGD(lr=config.lr, momentum=config.momentum)
    elif config.optimizer == 'adagrad':
        optimizer = Adagrad(lr=config.lr, weight_decay=config.weight_decay)
    elif config.optimizer == 'adadelta':
        optimizer = Adadelta(lr=config.lr,
                             rho=config.rho,
                             eps=config.eps,
                             weight_decay=config.weight_decay)

    timing = TimingCallback()
    early_stop = EarlyStopCallback(config.patience)

    trainer = Trainer(train_data=train_data,
                      model=poetry_model,
                      loss=loss,
                      metrics=perplexity,
                      n_epochs=config.epoch,
                      batch_size=config.batch_size,
                      print_every=config.print_every,
                      validate_every=config.validate_every,
                      dev_data=dev_data,
                      save_path=config.save_path,
                      optimizer=optimizer,
                      check_code_level=config.check_code_level,
                      metric_key="-PPL",
                      sampler=RandomSampler(),
                      prefetch=False,
                      use_tqdm=True,
                      device=config.device,
                      callbacks=[timing, early_stop])
    trainer.train()
def train(config, task_name):
    train_data = pickle.load(
        open(os.path.join(config.bert_data_path, config.train_name), "rb"))
    print(train_data[0])
    # debug
    if config.debug:
        train_data = train_data[0:30]
    dev_data = pickle.load(
        open(os.path.join(config.bert_data_path, config.dev_name), "rb"))
    print(dev_data[0])
    # test_data = pickle.load(open(os.path.join(config.bert_data_path, config.test_name), "rb"))

    schemas = get_schemas(config.source_path)
    state_dict = torch.load(config.bert_path)
    # print(state_dict)
    text_model = BertForMultiLabelSequenceClassification.from_pretrained(
        config.bert_folder, state_dict=state_dict, num_labels=len(schemas))

    # optimizer
    param_optimizer = list(text_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    num_train_optimization_steps = int(
        len(train_data) / config.batch_size /
        config.update_every) * config.epoch
    if config.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    optimizer = BertAdam(
        lr=config.lr,
        warmup=config.warmup_proportion,
        t_total=num_train_optimization_steps).construct_from_pytorch(
            optimizer_grouped_parameters)

    timing = TimingCallback()
    early_stop = EarlyStopCallback(config.patience)
    logs = FitlogCallback(dev_data)
    f1 = F1_score(pred='output', target='label_id')

    trainer = Trainer(train_data=train_data,
                      model=text_model,
                      loss=BCEWithLogitsLoss(),
                      batch_size=config.batch_size,
                      check_code_level=-1,
                      metrics=f1,
                      metric_key='f1',
                      n_epochs=int(config.epoch),
                      dev_data=dev_data,
                      save_path=config.save_path,
                      print_every=config.print_every,
                      validate_every=config.validate_every,
                      update_every=config.update_every,
                      optimizer=optimizer,
                      use_tqdm=False,
                      device=config.device,
                      callbacks=[timing, early_stop, logs])
    trainer.train()

    # test result
    tester = Tester(
        dev_data,
        text_model,
        metrics=f1,
        device=config.device,
        batch_size=config.batch_size,
    )
    tester.test()
Пример #4
0
def train(config):
    train_data = pickle.load(
        open(os.path.join(config.data_path, config.train_name), "rb"))
    # debug
    train_data = train_data[0:100]
    dev_data = pickle.load(
        open(os.path.join(config.data_path, config.dev_name), "rb"))
    print(len(train_data), len(dev_data))
    # test_data = pickle.load(open(os.path.join(config.data_path, config.test_name), "rb"))
    # load w2v data
    # weight = pickle.load(open(os.path.join(config.data_path, config.weight_name), "rb"))

    word_vocab = pickle.load(
        open(os.path.join(config.data_path, config.word_vocab_name), "rb"))
    char_vocab = pickle.load(
        open(os.path.join(config.data_path, config.char_vocab_name), "rb"))
    pos_vocab = pickle.load(
        open(os.path.join(config.data_path, config.pos_vocab_name), "rb"))
    spo_vocab = pickle.load(
        open(os.path.join(config.data_path, config.spo_vocab_name), "rb"))
    tag_vocab = pickle.load(
        open(os.path.join(config.data_path, config.tag_vocab_name), "rb"))
    print('word vocab', len(word_vocab))
    print('char vocab', len(char_vocab))
    print('pos vocab', len(pos_vocab))
    print('spo vocab', len(spo_vocab))
    print('tag vocab', len(tag_vocab))

    model = BiLSTM_CRF(config.batch_size,
                       len(word_vocab),
                       len(char_vocab),
                       len(pos_vocab),
                       len(spo_vocab),
                       config.embed_dim,
                       config.hidden_dim,
                       tag_vocab.idx2word,
                       dropout=0.5)

    optimizer = SGD(lr=config.lr, momentum=config.momentum)
    timing = TimingCallback()
    early_stop = EarlyStopCallback(config.patience)
    loss = NLLLoss()
    metrics = SpanFPreRecMetric(tag_vocab)
    # accuracy = AccuracyMetric(pred='output', target='target')

    trainer = Trainer(train_data=train_data,
                      model=model,
                      loss=loss,
                      metrics=metrics,
                      batch_size=config.batch_size,
                      n_epochs=config.epoch,
                      dev_data=dev_data,
                      save_path=config.save_path,
                      check_code_level=-1,
                      print_every=100,
                      validate_every=0,
                      optimizer=optimizer,
                      use_tqdm=False,
                      device=config.device,
                      callbacks=[timing, early_stop])
    trainer.train()
Пример #5
0
print('Create Tensor Data...')
x_train, x_test, y_train, y_test = create_tensor_data_6channel(
    train_df, input_shape, NUMBER_OF_CLASSES)

##################################################
################ Train Model #####################
# import wandb
# from wandb.keras import WandbCallback
# wandb.init(name=RUN_NAME, project="6_class", notes=COMMENTS)

checkpoint = ModelCheckpoint(SAVE_BEST_WEIGHTS,
                             monitor='val_accuracy',
                             verbose=1,
                             save_best_only=True,
                             mode='max')
cb = TimingCallback()
# callbacks = [WandbCallback(), checkpoint, cb]
callbacks = [checkpoint, cb]
model = create_model_6channel(input_shape, NUMBER_OF_CLASSES)
model.summary()

history = model.fit(x_train,
                    y_train,
                    batch_size=BATCH_SIZE,
                    epochs=EPOCHS,
                    callbacks=callbacks,
                    verbose=1,
                    validation_data=(x_test, y_test),
                    shuffle=True)

# model.save(os.path.join(wandb.run.dir, "model.h5"))
Пример #6
0
def train(config, task_name):
    train_data = pickle.load(
        open(os.path.join(config.data_path, config.train_name), "rb"))
    # debug
    if config.debug:
        train_data = train_data[0:100]
    dev_data = pickle.load(
        open(os.path.join(config.data_path, config.dev_name), "rb"))
    print(len(train_data), len(dev_data))
    # test_data = pickle.load(open(os.path.join(config.data_path, config.test_name), "rb"))
    # load w2v data
    # weight = pickle.load(open(os.path.join(config.data_path, config.weight_name), "rb"))

    word_vocab = pickle.load(
        open(os.path.join(config.data_path, config.word_vocab_name), "rb"))
    char_vocab = pickle.load(
        open(os.path.join(config.data_path, config.char_vocab_name), "rb"))
    pos_vocab = pickle.load(
        open(os.path.join(config.data_path, config.pos_vocab_name), "rb"))
    # spo_vocab = pickle.load(open(os.path.join(config.data_path, config.spo_vocab_name), "rb"))
    tag_vocab = pickle.load(
        open(os.path.join(config.data_path, config.tag_vocab_name), "rb"))
    print('word vocab', len(word_vocab))
    print('char vocab', len(char_vocab))
    print('pos vocab', len(pos_vocab))
    # print('spo vocab', len(spo_vocab))
    print('tag vocab', len(tag_vocab))

    schema = get_schemas(config.source_path)

    if task_name == 'bilstm_crf':
        model = AdvSeqLabel(
            char_init_embed=(len(char_vocab), config.char_embed_dim),
            word_init_embed=(len(word_vocab), config.word_embed_dim),
            pos_init_embed=(len(pos_vocab), config.pos_embed_dim),
            spo_embed_dim=len(schema),
            sentence_length=config.sentence_length,
            hidden_size=config.hidden_dim,
            num_classes=len(tag_vocab),
            dropout=config.dropout,
            id2words=tag_vocab.idx2word,
            encoding_type=config.encoding_type)
    elif task_name == 'trans_crf':
        model = TransformerSeqLabel(
            char_init_embed=(len(char_vocab), config.char_embed_dim),
            word_init_embed=(len(word_vocab), config.word_embed_dim),
            pos_init_embed=(len(pos_vocab), config.pos_embed_dim),
            spo_embed_dim=len(schema),
            num_classes=len(tag_vocab),
            id2words=tag_vocab.idx2word,
            encoding_type=config.encoding_type,
            num_layers=config.num_layers,
            inner_size=config.inner_size,
            key_size=config.key_size,
            value_size=config.value_size,
            num_head=config.num_head,
            dropout=config.dropout)

    optimizer = Adam(lr=config.lr, weight_decay=config.weight_decay)
    timing = TimingCallback()
    early_stop = EarlyStopCallback(config.patience)
    # loss = NLLLoss()
    logs = FitlogCallback(dev_data)
    metrics = SpanFPreRecMetric(tag_vocab,
                                pred='pred',
                                seq_len='seq_len',
                                target='tag')

    train_data.set_input('tag')
    dev_data.set_input('tag')
    dev_data.set_target('seq_len')
    #print(train_data.get_field_names())
    trainer = Trainer(
        train_data=train_data,
        model=model,
        # loss=loss,
        metrics=metrics,
        metric_key='f',
        batch_size=config.batch_size,
        n_epochs=config.epoch,
        dev_data=dev_data,
        save_path=config.save_path,
        check_code_level=-1,
        print_every=config.print_every,
        validate_every=config.validate_every,
        optimizer=optimizer,
        use_tqdm=False,
        device=config.device,
        callbacks=[timing, early_stop, logs])
    trainer.train()

    # test result
    tester = Tester(dev_data,
                    model,
                    metrics=metrics,
                    device=config.device,
                    batch_size=config.batch_size)
    tester.test()
Пример #7
0
def train(config, task_name):
    train_data = pickle.load(
        open(os.path.join(config.data_path, config.train_name), "rb"))
    # debug
    if config.debug:
        train_data = train_data[0:30]
    dev_data = pickle.load(
        open(os.path.join(config.data_path, config.dev_name), "rb"))
    # test_data = pickle.load(open(os.path.join(config.data_path, config.test_name), "rb"))
    vocabulary = pickle.load(
        open(os.path.join(config.data_path, config.vocabulary_name), "rb"))

    # load w2v data
    # weight = pickle.load(open(os.path.join(config.data_path, config.weight_name), "rb"))

    if task_name == "lstm":
        text_model = LSTM(vocab_size=len(vocabulary),
                          embed_dim=config.embed_dim,
                          output_dim=config.class_num,
                          hidden_dim=config.hidden_dim,
                          num_layers=config.num_layers,
                          dropout=config.dropout)
    elif task_name == "lstm_maxpool":
        text_model = LSTM_maxpool(vocab_size=len(vocabulary),
                                  embed_dim=config.embed_dim,
                                  output_dim=config.class_num,
                                  hidden_dim=config.hidden_dim,
                                  num_layers=config.num_layers,
                                  dropout=config.dropout)
    elif task_name == "cnn":
        text_model = CNN(vocab_size=len(vocabulary),
                         embed_dim=config.embed_dim,
                         class_num=config.class_num,
                         kernel_num=config.kernel_num,
                         kernel_sizes=config.kernel_sizes,
                         dropout=config.dropout,
                         static=config.static,
                         in_channels=config.in_channels)
    elif task_name == "rnn":
        text_model = RNN(vocab_size=len(vocabulary),
                         embed_dim=config.embed_dim,
                         output_dim=config.class_num,
                         hidden_dim=config.hidden_dim,
                         num_layers=config.num_layers,
                         dropout=config.dropout)
    # elif task_name == "cnn_w2v":
    #     text_model = CNN_w2v(vocab_size=len(vocabulary), embed_dim=config.embed_dim,
    #                          class_num=config.class_num, kernel_num=config.kernel_num,
    #                          kernel_sizes=config.kernel_sizes, dropout=config.dropout,
    #                          static=config.static, in_channels=config.in_channels,
    #                          weight=weight)
    elif task_name == "rcnn":
        text_model = RCNN(vocab_size=len(vocabulary),
                          embed_dim=config.embed_dim,
                          output_dim=config.class_num,
                          hidden_dim=config.hidden_dim,
                          num_layers=config.num_layers,
                          dropout=config.dropout)
    #elif task_name == "bert":
    #    text_model = BertModel.from_pretrained(config.bert_path)

    optimizer = Adam(lr=config.lr, weight_decay=config.weight_decay)
    timing = TimingCallback()
    early_stop = EarlyStopCallback(config.patience)
    logs = FitlogCallback(dev_data)
    f1 = F1_score(pred='output', target='target')

    trainer = Trainer(train_data=train_data,
                      model=text_model,
                      loss=BCEWithLogitsLoss(),
                      batch_size=config.batch_size,
                      check_code_level=-1,
                      metrics=f1,
                      metric_key='f1',
                      n_epochs=config.epoch,
                      dev_data=dev_data,
                      save_path=config.save_path,
                      print_every=config.print_every,
                      validate_every=config.validate_every,
                      optimizer=optimizer,
                      use_tqdm=False,
                      device=config.device,
                      callbacks=[timing, early_stop, logs])
    trainer.train()

    # test result
    tester = Tester(
        dev_data,
        text_model,
        metrics=f1,
        device=config.device,
        batch_size=config.batch_size,
    )
    tester.test()