def train_dataset(): """Dummy dataset from file""" return TabularDataset.from_path('tests/data/dummy_tabular/train.csv', sep=',')
def dir_dataset(): """Dummy dataset from directory""" return TabularDataset.from_path('tests/data/dummy_tabular', sep=',')
def full_dataset(): """Dummy dataset from file""" return TabularDataset.from_path(train_path='tests/data/dummy_tabular/train.csv', val_path='tests/data/dummy_tabular/val.csv', sep=',')
def train_dataset_reversed(): """Dummy dataset from file""" return TabularDataset.from_path('tests/data/dummy_tabular/train.csv', sep=',', columns=['label', 'text'])
def train_dataset_no_header(): """Dummy dataset from file""" return TabularDataset.from_path('tests/data/no_header_dataset.csv', sep=',', header=None)
def train(args): """Run Training """ global_step = 0 best_metric = None best_model: Dict[str, torch.Tensor] = dict() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) writer = SummaryWriter(log_dir=args.output_dir) # We use flambe to do the data preprocessing # More info at https://flambe.ai print("Performing preprocessing (possibly download embeddings).") embeddings = args.embeddings if args.use_pretrained_embeddings else None text_field = TextField(lower=args.lowercase, embeddings=embeddings, embeddings_format='gensim') label_field = LabelField() transforms = {'text': text_field, 'label': label_field} dataset = TabularDataset.from_path( args.train_path, args.val_path, sep=',' if args.file_type == 'csv' else '\t', transform=transforms) # Create samplers train_sampler = EpisodicSampler(dataset.train, n_support=args.n_support, n_query=args.n_query, n_episodes=args.n_episodes, n_classes=args.n_classes) # The train_eval_sampler is used to computer prototypes over the full dataset train_eval_sampler = BaseSampler(dataset.train, batch_size=args.eval_batch_size) val_sampler = BaseSampler(dataset.val, batch_size=args.eval_batch_size) if args.device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' else: device = args.device # Build model, criterion and optimizers model = PrototypicalTextClassifier( vocab_size=dataset.text.vocab_size, distance=args.distance, embedding_dim=args.embedding_dim, pretrained_embeddings=dataset.text.embedding_matrix, rnn_type='sru', n_layers=args.n_layers, hidden_dim=args.hidden_dim, freeze_pretrained_embeddings=True) loss_fn = nn.CrossEntropyLoss() parameters = (p for p in model.parameters() if p.requires_grad) optimizer = torch.optim.Adam(parameters, lr=args.learning_rate) print("Beginning training.") for epoch in range(args.num_epochs): ###################### # TRAIN # ###################### print(f'Epoch: {epoch}') model.train() with torch.enable_grad(): for batch in train_sampler: # Zero the gradients and clear the accumulated loss optimizer.zero_grad() # Move to device batch = tuple(t.to(device) for t in batch) query, query_label, support, support_label = batch # Compute loss pred = model(query, support, support_label) loss = loss_fn(pred, query_label) loss.backward() # Clip gradients if necessary if args.max_grad_norm is not None: clip_grad_norm_(model.parameters(), args.max_grad_norm) writer.add_scalar('Training/Loss', loss.item(), global_step) # Optimize optimizer.step() global_step += 1 # Zero the gradients when exiting a train step optimizer.zero_grad() ######################### # EVALUATE # ######################### model.eval() with torch.no_grad(): # First compute prototypes over the training data encodings, labels = [], [] for text, label in train_eval_sampler: padding_mask = (text != model.padding_idx).byte() text_embeddings = model.embedding_dropout( model.embedding(text)) text_encoding = model.encoder(text_embeddings, padding_mask=padding_mask) labels.append(label.cpu()) encodings.append(text_encoding.cpu()) # Compute prototypes encodings = torch.cat(encodings, dim=0) labels = torch.cat(labels, dim=0) prototypes = model.compute_prototypes(encodings, labels).to(device) _preds, _targets = [], [] for batch in val_sampler: # Move to device source, target = tuple(t.to(device) for t in batch) pred = model(source, prototypes=prototypes) _preds.append(pred.cpu()) _targets.append(target.cpu()) preds = torch.cat(_preds, dim=0) targets = torch.cat(_targets, dim=0) val_loss = loss_fn(preds, targets).item() val_metric = (pred.argmax(dim=1) == target).float().mean().item() # Update best model if best_metric is None or val_metric > best_metric: best_metric = val_metric best_model_state = model.state_dict() for k, t in best_model_state.items(): best_model_state[k] = t.cpu().detach() best_model = best_model_state # Log metrics print(f'Validation loss: {val_loss}') print(f'Validation accuracy: {val_metric}') writer.add_scalar('Validation/Loss', val_loss, epoch) writer.add_scalar('Validation/Accuracy', val_metric, epoch) # Save the best model print("Finisehd training.") torch.save(best_model, os.path.join(args.output_dir, 'model.pt'))