def build_iters(args): TXT = data.Field(lower=args.lower, include_lengths=True, batch_first=True) LBL = data.Field(sequential=False, unk_token=None) TREE = data.Field(sequential=True, use_vocab=False, pad_token=0) ftrain = 'data/sst/sst/trees/train.txt' fvalid = 'data/sst/sst/trees/dev.txt' ftest = 'data/sst/sst/trees/test.txt' examples_train, len_ave = load_examples(ftrain, subtrees=True) examples_valid, _ = load_examples(fvalid, subtrees=False) examples_test, _ = load_examples(ftest, subtrees=False) train = Dataset(examples_train, fields=[('txt', TXT), ('tree', TREE), ('lbl', LBL)]) TXT.build_vocab(train, vectors=args.pretrained) LBL.build_vocab(train) valid = Dataset(examples_valid, fields=[('txt', TXT), ('tree', TREE), ('lbl', LBL)]) test = Dataset(examples_test, fields=[('txt', TXT), ('tree', TREE), ('lbl', LBL)]) def batch_size_fn(new_example, current_count, ebsz): return ebsz + (len(new_example.txt) / len_ave)**0.3 device = torch.device(args.gpu if args.gpu != -1 else 'cpu') train_iter = basic.BucketIterator(train, batch_size=args.batch_size, sort=True, shuffle=True, repeat=False, sort_key=lambda x: len(x.txt), batch_size_fn=batch_size_fn, device=device) valid_iter = basic.BucketIterator(valid, batch_size=args.batch_size, sort=True, shuffle=True, repeat=False, sort_key=lambda x: len(x.txt), batch_size_fn=batch_size_fn, device=device) test_iter = basic.BucketIterator(test, batch_size=args.batch_size, sort=True, shuffle=True, repeat=False, sort_key=lambda x: len(x.txt), batch_size_fn=batch_size_fn, device=device) return train_iter, valid_iter, test_iter, (TXT, TREE, LBL)
def build_iters(args): EXPR = torchtext.data.Field(sequential=True, use_vocab=True, batch_first=True, include_lengths=True, pad_token=PAD, eos_token=None) VAL = torchtext.data.Field(sequential=False) ftrain = 'data/train_d20s.tsv' fvalid = 'data/test_d20s.tsv' # ftest = 'data/test_d20s.tsv' examples_train, len_ave = load_examples(ftrain) examples_valid, _ = load_examples(fvalid) train = Dataset(examples_train, fields=[('expr', EXPR), ('val', VAL)]) EXPR.build_vocab(train) VAL.build_vocab(train) valid = Dataset(examples_valid, fields=[('expr', EXPR), ('val', VAL)]) device = torch.device(args.gpu if args.gpu != -1 else 'cpu') def batch_size_fn(new_example, current_count, ebsz): return ebsz + (len(new_example.expr) / len_ave) ** 0.3 splits, split_avels = split_examples(examples_train) train_iters = {srange:None for srange in splits} for srange, split in splits.items(): train_split = Dataset(split, fields=[('expr', EXPR), ('val', VAL)]) data_iter = basic.BucketIterator(train_split, batch_size=args.bsz, sort=True, shuffle=True, repeat=False, sort_key=lambda x: len(x.expr), batch_size_fn=batch_size_fn, device=device) train_iters[srange] = data_iter valid_iter = basic.BucketIterator(valid, batch_size=args.bsz, sort=True, shuffle=True, repeat=False, sort_key=lambda x: len(x.expr), batch_size_fn=batch_size_fn, device=device) return train_iters, valid_iter, EXPR, VAL
def train(args): text_field = data.Field(lower=args.lower, include_lengths=True, batch_first=True) label_field = data.Field(sequential=False) filter_pred = None if not args.fine_grained: filter_pred = lambda ex: ex.label != 'neutral' dataset_splits = datasets.SST.splits(root='./data/sst', text_field=text_field, label_field=label_field, fine_grained=args.fine_grained, train_subtrees=True, filter_pred=filter_pred) lens = [] for e in dataset_splits[0].examples: lens.append(len(e.text)) len_ave = np.mean(lens) text_field.build_vocab(*dataset_splits, vectors=args.pretrained) label_field.build_vocab(*dataset_splits) logging.info(f'Initialize with pretrained vectors: {args.pretrained}') logging.info(f'Number of classes: {len(label_field.vocab)}') def batch_size_fn(new_example, current_count, ebsz): return ebsz + (len(new_example.text) / len_ave)**0.3 device = torch.device(args.gpu if args.gpu != -1 else 'cpu') train_loader = basic.BucketIterator(dataset_splits[0], batch_size=args.batch_size, sort=True, shuffle=True, repeat=False, sort_key=lambda x: len(x.text), batch_size_fn=batch_size_fn, device=device) valid_loader = basic.BucketIterator(dataset_splits[1], batch_size=args.batch_size, sort=True, shuffle=True, repeat=False, sort_key=lambda x: len(x.text), batch_size_fn=batch_size_fn, device=device) num_classes = len(label_field.vocab) model = SSTModel(num_classes=num_classes, num_words=len(text_field.vocab), word_dim=args.word_dim, hidden_dim=args.hidden_dim, clf_hidden_dim=args.clf_hidden_dim, clf_num_layers=args.clf_num_layers, use_leaf_rnn=args.leaf_rnn, bidirectional=args.bidirectional, intra_attention=args.intra_attention, use_batchnorm=args.batchnorm, dropout_prob=args.dropout) if args.pretrained: model.word_embedding.weight.data.set_(text_field.vocab.vectors) if args.fix_word_embedding: logging.info('Will not update word embeddings') model.word_embedding.weight.requires_grad = False logging.info(f'Using device {args.gpu}') model.to(device) params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'adam': optimizer_class = optim.Adam elif args.optimizer == 'adagrad': optimizer_class = optim.Adagrad elif args.optimizer == 'adadelta': optimizer_class = optim.Adadelta else: raise NotImplementedError optimizer = optimizer_class(params=params, weight_decay=args.l2reg) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='max', factor=0.5, patience=20 * args.halve_lr_every, verbose=True) criterion = nn.CrossEntropyLoss() train_summary_writer = SummaryWriter( log_dir=os.path.join(args.save_dir, 'log', 'train')) valid_summary_writer = SummaryWriter( log_dir=os.path.join(args.save_dir, 'log', 'valid')) def run_iter(batch, is_training): model.train(is_training) words, length = batch.text label = batch.label logits = model(words=words, length=length) label_pred = logits.max(1)[1] accuracy = torch.eq(label, label_pred).float().mean() loss = criterion(input=logits, target=label) if is_training: optimizer.zero_grad() loss.backward() clip_grad_norm_(parameters=params, max_norm=5) optimizer.step() return loss, accuracy def add_scalar_summary(summary_writer, name, value, step): if torch.is_tensor(value): value = value.item() summary_writer.add_scalar(tag=name, scalar_value=value, global_step=step) num_train_batches = len(train_loader) validate_every = num_train_batches // 20 best_vaild_accuacy = 0 iter_count = 0 for epoch in range(args.max_epoch): for batch_iter, train_batch in enumerate(tqdm.tqdm(train_loader)): train_loss, train_accuracy = run_iter(batch=train_batch, is_training=True) iter_count += 1 add_scalar_summary(summary_writer=train_summary_writer, name='loss', value=train_loss, step=iter_count) add_scalar_summary(summary_writer=train_summary_writer, name='accuracy', value=train_accuracy, step=iter_count) if (batch_iter + 1) % validate_every == 0: valid_loss_sum = valid_accuracy_sum = 0 num_valid_batches = len(valid_loader) for valid_batch in valid_loader: valid_loss, valid_accuracy = run_iter(batch=valid_batch, is_training=False) valid_loss_sum += valid_loss.item() valid_accuracy_sum += valid_accuracy.item() valid_loss = valid_loss_sum / num_valid_batches valid_accuracy = valid_accuracy_sum / num_valid_batches add_scalar_summary(summary_writer=valid_summary_writer, name='loss', value=valid_loss, step=iter_count) add_scalar_summary(summary_writer=valid_summary_writer, name='accuracy', value=valid_accuracy, step=iter_count) scheduler.step(valid_accuracy) progress = train_loader.iterations / len(train_loader) logging.info(f'Epoch {epoch}: ' f'valid loss = {valid_loss:.4f}, ' f'valid accuracy = {valid_accuracy:.4f}') if valid_accuracy > best_vaild_accuacy: best_vaild_accuacy = valid_accuracy model_filename = (f'nets-{progress:.2f}' f'-{valid_loss:.4f}' f'-{valid_accuracy:.4f}.pkl') model_path = os.path.join(args.save_dir, model_filename) torch.save(model.state_dict(), model_path) print(f'Saved the new best nets to {model_path}') if progress > args.max_epoch: break