Ejemplo n.º 1
0
def test(path, gpus):
    # load
    print("Start")
    # file = open('data/nela-covid-2020/combined/headlines_contentmorals_cnn_bart_split.pkl', 'rb')
    file = open('data/nela-covid-2020/combined/headlines_cnn_bart_split.pkl',
                'rb')
    data = pickle.load(file)
    file.close()
    print("Data Loaded")

    test_dataset = NewsDataset(data['test'])
    # test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=128, num_workers=4)

    # model = OneHotMoralClassifier.load_from_checkpoint(path)
    model = OneHotMoralClassifier({}, use_mask=False)
    model.load_state_dict(torch.load(path))
    trainer = Trainer(gpus=gpus, distributed_backend='dp')

    trainer.test(model, test_dataloaders=test_loader)
from models import MoralClassifier
from models.custom_transformer_classifier import OneHotMoralClassifier
from data import NewsDataset
import torch
from models import MoralTransformer

# load
print("Loading data...")
file = open('data/nela-covid-2020/combined/headlines_cnn_bart_split.pkl', 'rb')
# file = open('headlines_cnn_bart_split.pkl', 'rb')
data = pickle.load(file)
file.close()
print("Data loaded")

dataset = NewsDataset(data['test'],
                      moral_mode='random',
                      include_moral_tokens=True)
dataloader = DataLoader(dataset, batch_size=64, num_workers=4)

discriminator = OneHotMoralClassifier({}, use_mask=False)
print('Loading discriminator...')
discriminator.load_state_dict(
    torch.load('final_models/discriminator_titlemorals_state.pkl'))
print('Discriminator loaded')

model = MoralTransformer(discriminator=discriminator,
                         feed_moral_tokens_to='decoder',
                         contextual_injection=False)
print('Loading generator state...')
model.load_state_dict(
    torch.load('final_models/special_finetuned.ckpt')['state_dict'])
Ejemplo n.º 3
0
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from config import args
from model import Text_cnn
from data import NewsDataset
from utils import make_embedding_matrix
from data import generate_batches
from utils import make_train_state
from utils import update_train_state
from utils import compute_accuracy
from utils import plot_performance

if args.reload_from_files:
    # training from a checkpoint
    dataset = NewsDataset.load_dataset_and_load_vectorizer(
        args.news_csv, args.vectorizer_file)
else:
    # create dataset and vectorizer
    dataset = NewsDataset.load_dataset_and_make_vectorizer(args.news_csv)
    dataset.save_vectorizer(args.vectorizer_file)
vectorizer = dataset.get_vectorizer()

# Use GloVe or randomly initialized embeddings
if args.use_glove:
    words = vectorizer.title_vocab._token_to_idx.keys()
    embeddings = make_embedding_matrix(glove_filepath=args.glove_filepath,
                                       words=words)
    print("Using pre-trained embeddings")
else:
    print("Not using pre-trained embeddings")
    embeddings = None
Ejemplo n.º 4
0
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from models import MoralClassifier
from models.custom_transformer_classifier import OneHotMoralClassifier
from data import NewsDataset
import torch
from models import MoralTransformer

# load
print("Start")
file = open('data/nela-covid-2020/combined/headlines_cnn_bart_split.pkl', 'rb')
# file = open('headlines_cnn_bart_split.pkl', 'rb')
data = pickle.load(file)
file.close()
print("Data Loaded")

test_dataset = NewsDataset(data['test'])

discriminator = OneHotMoralClassifier({}, use_mask=False)
discriminator.load_state_dict(
    torch.load('saved_models/discriminator_titlemorals_state.pkl'))

model = MoralTransformer(discriminator=discriminator,
                         feed_moral_tokens_to='decoder',
                         contextual_injection=False)
# model.load_state_dict(torch.load('experiments/exp1/checkpoints/epoch=6-step=69999.ckpt')['state_dict'])
# model.load_state_dict(torch.load('experiments/decoder_1e-06_id+random_normalized_pairwise_False/checkpoints/epoch=9-step=26589.ckpt')['state_dict'])
# model.load_state_dict(torch.load('experiments/decoder_1e-06_identity_normalized_pairwise_False/checkpoints/epoch=17-step=95723.ckpt')['state_dict'])
# model.load_state_dict(torch.load('experiments/encoder_1e-06_identity_normalized_pairwise_False/checkpoints/epoch=14-step=79769.ckpt')['state_dict'])
# model.load_state_dict(torch.load('experiments/decoder_1e-06_random_normalized_pairwise_True/checkpoints/epoch=22-step=122313.ckpt')['state_dict'])
# model.load_state_dict(torch.load('experiments/encoder_1e-06_random_normalized_pairwise_True/checkpoints/epoch=23-step=127631.ckpt')['state_dict'])
model.load_state_dict(
Ejemplo n.º 5
0
def train(gpus):
    print("Loading data...")
    # file = open('headlines_cnn_bart_split.pkl', 'rb')
    file = open('data/nela-covid-2020/combined/headlines_cnn_bart_split.pkl',
                'rb')
    data = pickle.load(file)
    file.close()
    print("Data loaded")

    exp = experiments[exp_idx]

    feed_moral_tokens_to = exp[0]
    lr = exp[1]
    moral_mode = exp[2]
    use_content_loss = bool(exp[3])
    content_loss_type = exp[3]
    use_moral_loss = exp[4]

    exp_name = '_'.join([
        feed_moral_tokens_to,
        str(lr), moral_mode,
        str(content_loss_type),
        str(use_moral_loss)
    ])
    exp_name = "RESUME " + exp_name

    # exp_name='TMP'

    # exp_name += '_content_weighted_10x'
    # exp_name += '_moral_weighted_10x'
    # exp_name += ''

    print(exp_name)

    # stuff to keep
    freeze_encoder = True
    freeze_decoder = False
    include_moral_tokens = True

    if feed_moral_tokens_to == 'injection':
        freeze_encoder = False
        include_moral_tokens = False

    data['train'] = data['train']
    train_dataset = NewsDataset(data['train'],
                                moral_mode=moral_mode,
                                include_moral_tokens=include_moral_tokens)
    val_dataset = NewsDataset(data['val'],
                              moral_mode=moral_mode,
                              include_moral_tokens=include_moral_tokens)
    test_dataset = NewsDataset(data['test'],
                               moral_mode=moral_mode,
                               include_moral_tokens=include_moral_tokens)

    train_loader = DataLoader(train_dataset,
                              batch_size=8,
                              num_workers=4,
                              shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, num_workers=4)

    # ------------
    # training
    # ------------
    print('Loading discriminator...')
    discriminator = OneHotMoralClassifier({}, use_mask=False)
    discriminator.load_state_dict(
        torch.load('saved_models/discriminator_titlemorals_state.pkl'))
    print('Discriminator loaded')

    model = MoralTransformerSpecial(
        lr=lr,
        discriminator=discriminator,
        use_content_loss=use_content_loss,
        contextual_injection=(not include_moral_tokens),
        freeze_encoder=freeze_encoder,
        freeze_decoder=freeze_decoder,
        feed_moral_tokens_to=feed_moral_tokens_to,
        content_loss_type=content_loss_type,
        use_moral_loss=use_moral_loss,

        # content_loss_weighting=10,
        # moral_loss_weighting=10,
    )

    # model.load_state_dict(torch.load('experiments/decoder_1e-06_id+random_normalized_pairwise_False/checkpoints/epoch=9-step=26589.ckpt')['state_dict'])

    checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(
        "./experiments", exp_name, "checkpoints"),
                                          save_top_k=1,
                                          save_last=True,
                                          monitor='train_loss',
                                          mode='min')
    trainer = Trainer(
        gpus=gpus,
        # auto_lr_find=False, # use to explore LRs
        # distributed_backend='dp',
        resume_from_checkpoint='saved_models/special_finetuned_30.ckpt',
        # max_epochs=30,
        max_epochs=50,
        callbacks=[checkpoint_callback],
    )

    # LR Exploration
    # lr_finder = trainer.tuner.lr_find(model, train_loader, val_loader)
    # new_lr = lr_finder.suggestion()
    # print(new_lr)

    trainer.fit(model, train_loader, val_loader)

    with open(os.path.join("./experiments", exp_name, 'loss_history.pkl'),
              'wb') as f:
        pickle.dump(model.loss_history, f)

    print("Training Done")
Ejemplo n.º 6
0
def train(exp_name, gpus):
    print("Start")
    file = open('data/nela-covid-2020/combined/headlines_cnn_bart_split.pkl',
                'rb')
    # file = open('data/nela-covid-2020/combined/headlines_contentmorals_cnn_bart_split.pkl', 'rb')
    data = pickle.load(file)
    file.close()
    print("Data Loaded")

    # create datasets
    # train_dataset = NewsDataset(data['train'][0:1])
    train_dataset = NewsDataset(data['train'])
    val_dataset = NewsDataset(data['val'])
    test_dataset = NewsDataset(data['test'])

    embedding_dataset = EmbeddingDataset()

    train_loader = DataLoader(train_dataset, batch_size=32, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, num_workers=4)

    # train_loader = DataLoader(train_dataset, batch_size=16, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=16, num_workers=4)

    # train_loader = DataLoader(embedding_dataset, batch_size=32, num_workers=4)
    # train_loader = DataLoader(embedding_dataset, batch_size=512, num_workers=4)
    # val_loader = DataLoader(embedding_dataset, batch_size=64, num_workers=4)

    # ------------
    # training
    # ------------
    LEARNING_RATE = 1e-5
    hparams = {'lr': LEARNING_RATE}
    model = OneHotMoralClassifier(hparams, use_mask=False)
    # model = CustomMoralClassifier(hparams)
    # model = MoralClassifier(hparams)
    # model = PseudoEmbedding(hparams)
    early_stop_callback = EarlyStopping(monitor='val_loss',
                                        min_delta=0.00,
                                        patience=3,
                                        verbose=True,
                                        mode='auto')
    checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(
        "./experiments", exp_name, "checkpoints"),
                                          save_top_k=1,
                                          monitor='train_loss',
                                          mode='min')
    trainer = Trainer(
        gpus=gpus,
        # auto_lr_find=False, # use to explore LRs
        # distributed_backend='dp',
        max_epochs=20,
        callbacks=[early_stop_callback, checkpoint_callback],
    )

    # LR Exploration
    # lr_finder = trainer.tuner.lr_find(model, train_loader, val_loader)
    # print(lr_finder.results)
    # fig = lr_finder.plot(suggest=True)
    # # fig.show()
    # # fig.savefig('lr.png')
    # new_lr = lr_finder.suggestion()
    # print(new_lr)

    trainer.fit(model, train_loader, val_loader)
    print("Training Done")