コード例 #1
0
ファイル: model.py プロジェクト: bityangke/causality
class COPAModel(nn.Module):
    def __init__(self, model='bert', model_size='base',
                 pool_method='avg', emb_size=10, **kwargs):
        super(COPAModel, self).__init__()

        self.pool_method = pool_method
        self.encoder = Encoder(model=model, model_size=model_size, fine_tune=False,
                               cased=False)
        self.rel_type_emb = nn.Embedding(2, emb_size)
        sent_repr_size = get_repr_size(self.encoder.hidden_size, method=pool_method)
        self.label_net = nn.Sequential(
            nn.Linear(3 * sent_repr_size + emb_size, 1),
            nn.Sigmoid()
        )

        self.training_criterion = nn.BCELoss()

    def get_other_params(self):
        core_encoder_param_names = set()
        for name, param in self.encoder.model.named_parameters():
            if param.requires_grad:
                core_encoder_param_names.add(name)

        other_params = []
        print("\nParams outside core transformer params:\n")
        for name, param in self.named_parameters():
            if param.requires_grad and name not in core_encoder_param_names:
                print(name, param.data.size())
                other_params.append(param)
        print("\n")
        return other_params

    def get_core_params(self):
        return self.encoder.model.parameters()

    def forward(self, batch_data):
        event, event_lens = batch_data.event
        event_emb = self.encoder.get_sentence_repr(
            self.encoder(event.cuda()), event_lens.cuda(), method=self.pool_method)

        hyp_event, hyp_event_lens = batch_data.hyp_event
        hyp_event_emb = self.encoder.get_sentence_repr(
            self.encoder(hyp_event.cuda()), hyp_event_lens.cuda(), method=self.pool_method)

        rel_type_emb = self.rel_type_emb(batch_data.type_event.cuda())
        pred_label = self.label_net(torch.cat([event_emb, hyp_event_emb, event_emb * hyp_event_emb,
                                               rel_type_emb], dim=-1))
        pred_label = torch.squeeze(pred_label, dim=-1)
        loss = self.training_criterion(pred_label, batch_data.label.cuda().float())
        if self.training:
            return loss
        else:
            return loss, pred_label
コード例 #2
0
ファイル: model.py プロジェクト: bityangke/causality
    def __init__(self, model='bert', model_size='base',
                 pool_method='avg', emb_size=10, **kwargs):
        super(COPAModel, self).__init__()

        self.pool_method = pool_method
        self.encoder = Encoder(model=model, model_size=model_size, fine_tune=False,
                               cased=False)
        self.rel_type_emb = nn.Embedding(2, emb_size)
        sent_repr_size = get_repr_size(self.encoder.hidden_size, method=pool_method)
        self.label_net = nn.Sequential(
            nn.Linear(3 * sent_repr_size + emb_size, 1),
            nn.Sigmoid()
        )

        self.training_criterion = nn.BCELoss()
コード例 #3
0
ファイル: model.py プロジェクト: shtoshni92/span-rep
    def __init__(self,
                 model='bert',
                 model_size='base',
                 just_last_layer=False,
                 span_dim=256,
                 pool_method='avg',
                 fine_tune=False,
                 num_spans=1,
                 **kwargs):
        super(CorefModel, self).__init__()

        self.pool_method = pool_method
        self.num_spans = num_spans
        self.just_last_layer = just_last_layer
        self.encoder = Encoder(model=model,
                               model_size=model_size,
                               fine_tune=True,
                               cased=True)
        self.span_net = nn.ModuleDict()
        self.span_net['0'] = get_span_module(
            method=pool_method,
            input_dim=self.encoder.hidden_size,
            use_proj=True,
            proj_dim=span_dim)

        self.pooled_dim = self.span_net['0'].get_output_dim()

        self.label_net = nn.Sequential(
            nn.Linear(2 * self.pooled_dim, span_dim), nn.Tanh(),
            nn.LayerNorm(span_dim), nn.Dropout(0.2), nn.Linear(span_dim, 1),
            nn.Sigmoid())

        self.training_criterion = nn.BCELoss()
コード例 #4
0
ファイル: data.py プロジェクト: shtoshni92/span-rep
        val_iter, test_iter = data.BucketIterator.splits(
            (val, test),
            batch_size=eval_batch_size,
            sort_within_batch=True,
            shuffle=False,
            repeat=False)

        label_field.build_vocab(train)
        num_labels = len(label_field.vocab.itos)
        return (train_iter, val_iter, test_iter, num_labels)


if __name__ == '__main__':
    from encoders.pretrained_transformers import Encoder
    encoder = Encoder(cased=True)
    path = "/home/shtoshni/Research/hackathon_2019/tasks/srl/data"
    train_iter, val_iter, test_iter, num_labels = SRLDataset.iters(
        path, encoder, train_frac=1.0)

    print("Train size:", len(train_iter.data()))
    print("Val size:", len(val_iter.data()))
    print("Test size:", len(test_iter.data()))

    for batch_data in train_iter:
        print(batch_data.text[0].shape)
        print(batch_data.span1.shape)
        print(batch_data.span2.shape)
        print(batch_data.label.shape)

        text, text_len = batch_data.text
コード例 #5
0
ファイル: train.py プロジェクト: shtoshni92/span-rep
def main():
    hp = parse_args()

    # Setup model directories
    model_name = get_model_name(hp)
    model_path = path.join(hp.model_dir, model_name)
    best_model_path = path.join(model_path, 'best_models')
    if not path.exists(model_path):
        os.makedirs(model_path)
    if not path.exists(best_model_path):
        os.makedirs(best_model_path)

    # Set random seed
    torch.manual_seed(hp.seed)

    # Hacky way of assigning the number of labels.
    encoder = Encoder(model=hp.model,
                      model_size=hp.model_size,
                      fine_tune=hp.fine_tune,
                      cased=True)
    # Load data
    logging.info("Loading data")
    train_iter, val_iter, test_iter, num_labels = SRLDataset.iters(
        hp.data_dir,
        encoder,
        batch_size=hp.batch_size,
        eval_batch_size=hp.eval_batch_size,
        train_frac=hp.train_frac)
    logging.info("Data loaded")

    # Initialize the model
    model = SRLModel(encoder, num_labels=num_labels, **vars(hp)).cuda()
    sys.stdout.flush()

    if not hp.fine_tune:
        optimizer = torch.optim.Adam(model.get_other_params(), lr=hp.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=hp.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           patience=5,
                                                           factor=0.5,
                                                           verbose=True)
    steps_done = 0
    max_f1 = 0
    init_num_stuck_evals = 0
    num_steps = (hp.n_epochs * len(train_iter.data())) // hp.real_batch_size
    # Quantize the number of training steps to eval steps
    num_steps = (num_steps // hp.eval_steps) * hp.eval_steps
    logging.info("Total training steps: %d" % num_steps)

    location = path.join(model_path, "model.pt")
    if path.exists(location):
        logging.info("Loading previous checkpoint")
        checkpoint = torch.load(location)
        model.encoder.weighing_params = checkpoint['weighing_params']
        model.span_net.load_state_dict(checkpoint['span_net'])
        model.label_net.load_state_dict(checkpoint['label_net'])
        if hp.fine_tune:
            model.encoder.model.load_state_dict(checkpoint['encoder'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        steps_done = checkpoint['steps_done']
        init_num_stuck_evals = checkpoint['num_stuck_evals']
        max_f1 = checkpoint['max_f1']
        torch.set_rng_state(checkpoint['rng_state'])
        logging.info("Steps done: %d, Max F1: %.3f" % (steps_done, max_f1))

    if not hp.eval:
        train(hp,
              model,
              train_iter,
              val_iter,
              optimizer,
              scheduler,
              model_path,
              best_model_path,
              init_steps=steps_done,
              max_f1=max_f1,
              eval_steps=hp.eval_steps,
              num_steps=num_steps,
              init_num_stuck_evals=init_num_stuck_evals)

    val_f1, test_f1 = final_eval(hp, model, best_model_path, val_iter,
                                 test_iter)
    perf_dir = path.join(hp.model_dir, "perf")
    if not path.exists(perf_dir):
        os.makedirs(perf_dir)

    if hp.slurm_id:
        perf_file = path.join(perf_dir, hp.slurm_id + ".txt")
    else:
        perf_file = path.join(model_path, "perf.txt")
    with open(perf_file, "w") as f:
        f.write("%s\n" % (model_path))
        f.write("%s\t%.4f\n" % ("Valid", val_f1))
        f.write("%s\t%.4f\n" % ("Test", test_f1))
コード例 #6
0
ファイル: ablation_both.py プロジェクト: shtoshni92/span-rep
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    handler = logging.FileHandler(
        os.path.join(args.model_path, args.model_name + '.log'), 'a')
    handler.setLevel(logging.INFO)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    console.setFormatter(formatter)
    logger.addHandler(console)
    logger.propagate = False

    # create data sets, tokenizers, and data loaders
    encoder = Encoder(args.model_type,
                      args.model_size,
                      args.cased,
                      use_proj=args.use_proj,
                      proj_dim=args.proj_dim)
    data_loader_path = os.path.join(args.model_path,
                                    args.model_name + '.loader.pt')
    if os.path.exists(data_loader_path):
        logger.info('Loading datasets.')
        data_info = torch.load(data_loader_path)
        data_loader = data_info['data_loader']
        ConstituentDataset.label_dict = data_info['label_dict']
        ConstituentDataset.encoder = encoder
    else:
        logger.info('Creating datasets.')
        data_set = dict()
        data_loader = dict()
        for split in ['train', 'development', 'test']:
コード例 #7
0
                                         shuffle=True,
                                         repeat=False)

        val_iter, test_iter = data.BucketIterator.splits(
            (val, test),
            batch_size=eval_batch_size,
            sort_within_batch=True,
            shuffle=False,
            repeat=False)

        return (train_iter, val_iter, test_iter)


if __name__ == '__main__':
    from encoders.pretrained_transformers import Encoder
    encoder = Encoder(cased=False)
    path = "/home/shtoshni/Research/hackathon_2019/tasks/coref/data"
    train_iter, val_iter, test_iter = CorefDataset.iters(path,
                                                         encoder,
                                                         train_frac=1.0)

    print("Train size:", len(train_iter.data()))
    print("Val size:", len(val_iter.data()))
    print("Test size:", len(test_iter.data()))

    for batch_data in train_iter:
        print(batch_data.text[0].shape)
        print(batch_data.span1.shape)
        print(batch_data.span2.shape)
        print(batch_data.label.shape)
コード例 #8
0
ファイル: main.py プロジェクト: shtoshni92/span-rep
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    handler = logging.FileHandler(
        os.path.join(args.model_path, args.model_name + '.log'), 'a')
    handler.setLevel(logging.INFO)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    console.setFormatter(formatter)
    logger.addHandler(console)
    logger.propagate = False

    # create data sets, tokenizers, and data loaders
    encoder = Encoder(args.model_type,
                      args.model_size,
                      args.cased,
                      use_proj=False,
                      fine_tune=args.fine_tune)
    data_loader_path = os.path.join(args.model_path,
                                    args.model_name + '.loader.pt')
    if os.path.exists(data_loader_path):
        logger.info('Loading datasets.')
        data_info = torch.load(data_loader_path)
        data_loader = data_info['data_loader']
        ### To be removed
        for split in ['train', 'development', 'test']:
            data_loader[split] = DataLoader(data_loader[split].dataset,
                                            args.batch_size,
                                            collate_fn=collate_fn,
                                            shuffle=(split == 'train'))
        ### End to be removed
コード例 #9
0
ファイル: data.py プロジェクト: shtoshni92/span-rep

def collate_fn(data):
    sents, spans, labels = list(zip(*data))
    max_length = max(item.shape[1] for item in sents)
    pad_id = ConstituentDataset.encoder.tokenizer.pad_token_id
    batch_size = len(sents)
    padded_sents = pad_id * torch.ones(batch_size, max_length).long()
    for i, sent in enumerate(sents):
        padded_sents[i, :sent.shape[1]] = sent[0, :]
    spans = torch.cat(spans, dim=0)
    labels = torch.tensor(labels)
    return padded_sents, spans, labels


# unit test
if __name__ == '__main__':
    from torch.utils.data import DataLoader
    from encoders.pretrained_transformers import Encoder
    encoder = Encoder('bert', 'base', True)
    for split in ['train', 'development', 'test']:
        dataset = ConstituentDataset(
            f'tasks/constclass/data/debug/{split}.json', encoder)
        data_loader = DataLoader(dataset, 64, collate_fn=collate_fn)
        for sents, spans, labels in data_loader:
            pass
        print(f'Split "{split}" has passed the unit test '
              f'with {len(dataset)} instances.')
    from IPython import embed
    embed(using=False)
コード例 #10
0
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    handler = logging.FileHandler(
        os.path.join(args.model_path, args.model_name + '.log'), 'a')
    handler.setLevel(logging.INFO)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    console.setFormatter(formatter)
    logger.addHandler(console)
    logger.propagate = False

    # create data sets, tokenizers, and data loaders
    encoder = Encoder(args.model_type,
                      args.model_size,
                      args.cased,
                      use_proj=args.use_proj,
                      proj_dim=args.proj_dim)
    data_loader_path = os.path.join(args.model_path,
                                    args.model_name + '.loader.pt')

    if os.path.exists(data_loader_path):
        logger.info('Loading datasets.')
        data_info = torch.load(data_loader_path)
        data_loader = data_info['data_loader']
        ConstituentDataset.label_dict = data_info['label_dict']
        ConstituentDataset.encoder = encoder
    else:
        logger.info('Creating datasets.')
        data_set = dict()
        data_loader = dict()