예제 #1
0
파일: train_lstm.py 프로젝트: marcwww/rlt
def build_iters(args):
    TXT = data.Field(lower=args.lower, include_lengths=True, batch_first=True)
    LBL = data.Field(sequential=False, unk_token=None)
    TREE = data.Field(sequential=True, use_vocab=False, pad_token=0)

    ftrain = 'data/sst/sst/trees/train.txt'
    fvalid = 'data/sst/sst/trees/dev.txt'
    ftest = 'data/sst/sst/trees/test.txt'

    examples_train, len_ave = load_examples(ftrain, subtrees=True)
    examples_valid, _ = load_examples(fvalid, subtrees=False)
    examples_test, _ = load_examples(ftest, subtrees=False)
    train = Dataset(examples_train,
                    fields=[('txt', TXT), ('tree', TREE), ('lbl', LBL)])
    TXT.build_vocab(train, vectors=args.pretrained)
    LBL.build_vocab(train)
    valid = Dataset(examples_valid,
                    fields=[('txt', TXT), ('tree', TREE), ('lbl', LBL)])
    test = Dataset(examples_test,
                   fields=[('txt', TXT), ('tree', TREE), ('lbl', LBL)])

    def batch_size_fn(new_example, current_count, ebsz):
        return ebsz + (len(new_example.txt) / len_ave)**0.3

    device = torch.device(args.gpu if args.gpu != -1 else 'cpu')
    train_iter = basic.BucketIterator(train,
                                      batch_size=args.batch_size,
                                      sort=True,
                                      shuffle=True,
                                      repeat=False,
                                      sort_key=lambda x: len(x.txt),
                                      batch_size_fn=batch_size_fn,
                                      device=device)

    valid_iter = basic.BucketIterator(valid,
                                      batch_size=args.batch_size,
                                      sort=True,
                                      shuffle=True,
                                      repeat=False,
                                      sort_key=lambda x: len(x.txt),
                                      batch_size_fn=batch_size_fn,
                                      device=device)

    test_iter = basic.BucketIterator(test,
                                     batch_size=args.batch_size,
                                     sort=True,
                                     shuffle=True,
                                     repeat=False,
                                     sort_key=lambda x: len(x.txt),
                                     batch_size_fn=batch_size_fn,
                                     device=device)

    return train_iter, valid_iter, test_iter, (TXT, TREE, LBL)
예제 #2
0
def build_iters(args):
    EXPR = torchtext.data.Field(sequential=True,
                                use_vocab=True,
                                batch_first=True,
                                include_lengths=True,
                                pad_token=PAD,
                                eos_token=None)
    VAL = torchtext.data.Field(sequential=False)
    ftrain = 'data/train_d20s.tsv'
    fvalid = 'data/test_d20s.tsv'
    # ftest = 'data/test_d20s.tsv'

    examples_train, len_ave = load_examples(ftrain)
    examples_valid, _ = load_examples(fvalid)
    train = Dataset(examples_train, fields=[('expr', EXPR),
                                            ('val', VAL)])
    EXPR.build_vocab(train)
    VAL.build_vocab(train)
    valid = Dataset(examples_valid, fields=[('expr', EXPR),
                                            ('val', VAL)])

    device = torch.device(args.gpu if args.gpu != -1 else 'cpu')
    def batch_size_fn(new_example, current_count, ebsz):
        return ebsz + (len(new_example.expr) / len_ave) ** 0.3

    splits, split_avels = split_examples(examples_train)
    train_iters = {srange:None for srange in splits}
    for srange, split in splits.items():
        train_split = Dataset(split, fields=[('expr', EXPR),
                               ('val', VAL)])
        data_iter = basic.BucketIterator(train_split,
                                         batch_size=args.bsz,
                                         sort=True,
                                         shuffle=True,
                                         repeat=False,
                                         sort_key=lambda x: len(x.expr),
                                         batch_size_fn=batch_size_fn,
                                         device=device)
        train_iters[srange] = data_iter

    valid_iter = basic.BucketIterator(valid,
                                      batch_size=args.bsz,
                                      sort=True,
                                      shuffle=True,
                                      repeat=False,
                                      sort_key=lambda x: len(x.expr),
                                      batch_size_fn=batch_size_fn,
                                      device=device)

    return train_iters, valid_iter, EXPR, VAL
예제 #3
0
파일: train_bucket.py 프로젝트: marcwww/rlt
def train(args):
    text_field = data.Field(lower=args.lower,
                            include_lengths=True,
                            batch_first=True)
    label_field = data.Field(sequential=False)

    filter_pred = None
    if not args.fine_grained:
        filter_pred = lambda ex: ex.label != 'neutral'
    dataset_splits = datasets.SST.splits(root='./data/sst',
                                         text_field=text_field,
                                         label_field=label_field,
                                         fine_grained=args.fine_grained,
                                         train_subtrees=True,
                                         filter_pred=filter_pred)

    lens = []
    for e in dataset_splits[0].examples:
        lens.append(len(e.text))
    len_ave = np.mean(lens)

    text_field.build_vocab(*dataset_splits, vectors=args.pretrained)
    label_field.build_vocab(*dataset_splits)

    logging.info(f'Initialize with pretrained vectors: {args.pretrained}')
    logging.info(f'Number of classes: {len(label_field.vocab)}')

    def batch_size_fn(new_example, current_count, ebsz):
        return ebsz + (len(new_example.text) / len_ave)**0.3

    device = torch.device(args.gpu if args.gpu != -1 else 'cpu')
    train_loader = basic.BucketIterator(dataset_splits[0],
                                        batch_size=args.batch_size,
                                        sort=True,
                                        shuffle=True,
                                        repeat=False,
                                        sort_key=lambda x: len(x.text),
                                        batch_size_fn=batch_size_fn,
                                        device=device)
    valid_loader = basic.BucketIterator(dataset_splits[1],
                                        batch_size=args.batch_size,
                                        sort=True,
                                        shuffle=True,
                                        repeat=False,
                                        sort_key=lambda x: len(x.text),
                                        batch_size_fn=batch_size_fn,
                                        device=device)

    num_classes = len(label_field.vocab)
    model = SSTModel(num_classes=num_classes,
                     num_words=len(text_field.vocab),
                     word_dim=args.word_dim,
                     hidden_dim=args.hidden_dim,
                     clf_hidden_dim=args.clf_hidden_dim,
                     clf_num_layers=args.clf_num_layers,
                     use_leaf_rnn=args.leaf_rnn,
                     bidirectional=args.bidirectional,
                     intra_attention=args.intra_attention,
                     use_batchnorm=args.batchnorm,
                     dropout_prob=args.dropout)
    if args.pretrained:
        model.word_embedding.weight.data.set_(text_field.vocab.vectors)
    if args.fix_word_embedding:
        logging.info('Will not update word embeddings')
        model.word_embedding.weight.requires_grad = False
    logging.info(f'Using device {args.gpu}')
    model.to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'adam':
        optimizer_class = optim.Adam
    elif args.optimizer == 'adagrad':
        optimizer_class = optim.Adagrad
    elif args.optimizer == 'adadelta':
        optimizer_class = optim.Adadelta
    else:
        raise NotImplementedError
    optimizer = optimizer_class(params=params, weight_decay=args.l2reg)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                               mode='max',
                                               factor=0.5,
                                               patience=20 *
                                               args.halve_lr_every,
                                               verbose=True)
    criterion = nn.CrossEntropyLoss()

    train_summary_writer = SummaryWriter(
        log_dir=os.path.join(args.save_dir, 'log', 'train'))
    valid_summary_writer = SummaryWriter(
        log_dir=os.path.join(args.save_dir, 'log', 'valid'))

    def run_iter(batch, is_training):
        model.train(is_training)
        words, length = batch.text
        label = batch.label
        logits = model(words=words, length=length)
        label_pred = logits.max(1)[1]
        accuracy = torch.eq(label, label_pred).float().mean()
        loss = criterion(input=logits, target=label)
        if is_training:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(parameters=params, max_norm=5)
            optimizer.step()
        return loss, accuracy

    def add_scalar_summary(summary_writer, name, value, step):
        if torch.is_tensor(value):
            value = value.item()
        summary_writer.add_scalar(tag=name,
                                  scalar_value=value,
                                  global_step=step)

    num_train_batches = len(train_loader)
    validate_every = num_train_batches // 20
    best_vaild_accuacy = 0
    iter_count = 0
    for epoch in range(args.max_epoch):
        for batch_iter, train_batch in enumerate(tqdm.tqdm(train_loader)):
            train_loss, train_accuracy = run_iter(batch=train_batch,
                                                  is_training=True)
            iter_count += 1
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='loss',
                               value=train_loss,
                               step=iter_count)
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='accuracy',
                               value=train_accuracy,
                               step=iter_count)

            if (batch_iter + 1) % validate_every == 0:
                valid_loss_sum = valid_accuracy_sum = 0
                num_valid_batches = len(valid_loader)
                for valid_batch in valid_loader:
                    valid_loss, valid_accuracy = run_iter(batch=valid_batch,
                                                          is_training=False)
                    valid_loss_sum += valid_loss.item()
                    valid_accuracy_sum += valid_accuracy.item()
                valid_loss = valid_loss_sum / num_valid_batches
                valid_accuracy = valid_accuracy_sum / num_valid_batches
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='loss',
                                   value=valid_loss,
                                   step=iter_count)
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='accuracy',
                                   value=valid_accuracy,
                                   step=iter_count)
                scheduler.step(valid_accuracy)
                progress = train_loader.iterations / len(train_loader)
                logging.info(f'Epoch {epoch}: '
                             f'valid loss = {valid_loss:.4f}, '
                             f'valid accuracy = {valid_accuracy:.4f}')
                if valid_accuracy > best_vaild_accuacy:
                    best_vaild_accuacy = valid_accuracy
                    model_filename = (f'nets-{progress:.2f}'
                                      f'-{valid_loss:.4f}'
                                      f'-{valid_accuracy:.4f}.pkl')
                    model_path = os.path.join(args.save_dir, model_filename)
                    torch.save(model.state_dict(), model_path)
                    print(f'Saved the new best nets to {model_path}')
                if progress > args.max_epoch:
                    break