def __init__(self, args, vocab, emb_matrix=None): super().__init__() self.vocab = vocab self.args = args self.unsaved_modules = [] def add_unsaved_module(name, module): self.unsaved_modules += [name] setattr(self, name, module) # input layers input_size = 0 if self.args['word_emb_dim'] > 0: self.word_emb = nn.Embedding(len(self.vocab['word']), self.args['word_emb_dim'], PAD_ID) # load pretrained embeddings if specified if emb_matrix is not None: self.init_emb(emb_matrix) if not self.args.get('emb_finetune', True): self.word_emb.weight.detach_() input_size += self.args['word_emb_dim'] if self.args['char'] and self.args['char_emb_dim'] > 0: if self.args['charlm']: if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']): raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file'])) if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']): raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file'])) add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False)) add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False)) input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim() else: self.charmodel = CharacterModel(args, vocab, bidirectional=True, attention=False) input_size += self.args['char_hidden_dim'] * 2 # optionally add a input transformation layer if self.args.get('input_transform', False): self.input_transform = nn.Linear(input_size, input_size) else: self.input_transform = None # recurrent layers self.taggerlstm = PackedLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, \ bidirectional=True, dropout=0 if self.args['num_layers'] == 1 else self.args['dropout']) # self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size)) self.drop_replacement = None self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False) self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False) # tag classifier num_tag = len(self.vocab['tag']) self.tag_clf = nn.Linear(self.args['hidden_dim']*2, num_tag) self.tag_clf.bias.data.zero_() # criterion self.crit = CRFLoss(num_tag) self.drop = nn.Dropout(args['dropout']) self.worddrop = WordDropout(args['word_dropout']) self.lockeddrop = LockedDropout(args['locked_dropout'])
def train(args): model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \ else '{}/{}_{}_charlm.pt'.format(args['save_dir'], args['shorthand'], args['direction']) vocab_file = args['save_dir'] + '/' + args['vocab_save_name'] if args['vocab_save_name'] is not None \ else '{}/{}_vocab.pt'.format(args['save_dir'], args['shorthand']) if os.path.exists(vocab_file): logger.info('Loading existing vocab file') vocab = { 'char': CharVocab.load_state_dict( torch.load(vocab_file, lambda storage, loc: storage)) } else: logger.info('Building and saving vocab') vocab = { 'char': build_vocab(args['train_file'] if args['train_dir'] is None else args['train_dir'], cutoff=args['cutoff']) } torch.save(vocab['char'].state_dict(), vocab_file) logger.info("Training model with vocab size: {}".format(len( vocab['char']))) model = CharacterLanguageModel( args, vocab, is_forward_lm=True if args['direction'] == 'forward' else False) if args['cuda']: model = model.cuda() params = [param for param in model.parameters() if param.requires_grad] optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay']) criterion = torch.nn.CrossEntropyLoss() scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, verbose=True, factor=args['anneal'], patience=args['patience']) writer = None if args['summary']: from torch.utils.tensorboard import SummaryWriter summary_dir = '{}/{}_summary'.format(args['save_dir'], args['save_name']) if args['save_name'] is not None \ else '{}/{}_{}_charlm_summary'.format(args['save_dir'], args['shorthand'], args['direction']) writer = SummaryWriter(log_dir=summary_dir) # evaluate model within epoch if eval_interval is set eval_within_epoch = False if args['eval_steps'] > 0: eval_within_epoch = True best_loss = None global_step = 0 for epoch in range(1, args['epochs'] + 1): # load train data from train_dir if not empty, otherwise load from file if args['train_dir'] is not None: train_path = args['train_dir'] else: train_path = args['train_file'] train_data = load_data(train_path, vocab, args['direction']) dev_data = load_file(args['eval_file'], vocab, args['direction']) # dev must be a single file # run over entire training set for data_chunk in train_data: batches = batchify(data_chunk, args['batch_size']) hidden = None total_loss = 0.0 total_batches = math.ceil( (batches.size(1) - 1) / args['bptt_size']) iteration, i = 0, 0 # over the data chunk while i < batches.size(1) - 1 - 1: model.train() global_step += 1 start_time = time.time() bptt = args['bptt_size'] if np.random.random( ) < 0.95 else args['bptt_size'] / 2. # prevent excessively small or negative sequence lengths seq_len = max(5, int(np.random.normal(bptt, 5))) # prevent very large sequence length, must be <= 1.2 x bptt seq_len = min(seq_len, int(args['bptt_size'] * 1.2)) data, target = get_batch(batches, i, seq_len) lens = [data.size(1) for i in range(data.size(0))] if args['cuda']: data = data.cuda() target = target.cuda() optimizer.zero_grad() output, hidden, decoded = model.forward(data, lens, hidden) loss = criterion(decoded.view(-1, len(vocab['char'])), target) total_loss += loss.data.item() loss.backward() torch.nn.utils.clip_grad_norm_(params, args['max_grad_norm']) optimizer.step() hidden = repackage_hidden(hidden) if (iteration + 1) % args['report_steps'] == 0: cur_loss = total_loss / args['report_steps'] elapsed = time.time() - start_time logger.info( "| epoch {:5d} | {:5d}/{:5d} batches | sec/batch {:.6f} | loss {:5.2f} | ppl {:8.2f}" .format( epoch, iteration + 1, total_batches, elapsed / args['report_steps'], cur_loss, math.exp(cur_loss), )) total_loss = 0.0 iteration += 1 i += seq_len # evaluate if necessary if eval_within_epoch and global_step % args['eval_steps'] == 0: _, _, best_loss = evaluate_and_save(args, vocab, dev_data, model, criterion, scheduler, best_loss, \ global_step, model_file, writer) # if eval_interval isn't provided, run evaluation after each epoch if not eval_within_epoch: _, _, best_loss = evaluate_and_save(args, vocab, dev_data, model, criterion, scheduler, best_loss, \ epoch, model_file, writer) # use epoch in place of global_step for logging if writer: writer.close() return
def main(): args = parse_args() seed = utils.set_random_seed(args.seed, args.cuda) logger.info("Using random seed: %d" % seed) utils.ensure_dir(args.save_dir) # TODO: maybe the dataset needs to be in a torch data loader in order to # make cuda operations faster if args.train: train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len) logger.info("Using training set: %s" % args.train_file) logger.info("Training set has %d labels" % len(dataset_labels(train_set))) elif not args.load_name: raise ValueError( "No model provided and not asked to train a model. This makes no sense" ) else: train_set = None pretrain = load_pretrain(args) if args.charlm: if args.charlm_shorthand is None: raise ValueError( "CharLM Shorthand is required for loading pretrained CharLM model..." ) logger.info('Using pretrained contextualized char embedding') charlm_forward_file = '{}/{}_forward_charlm.pt'.format( args.charlm_save_dir, args.charlm_shorthand) charlm_backward_file = '{}/{}_backward_charlm.pt'.format( args.charlm_save_dir, args.charlm_shorthand) charmodel_forward = CharacterLanguageModel.load(charlm_forward_file, finetune=False) charmodel_backward = CharacterLanguageModel.load(charlm_backward_file, finetune=False) else: charmodel_forward = None charmodel_backward = None if args.load_name: model = cnn_classifier.load(args.load_name, pretrain, charmodel_forward, charmodel_backward) else: assert train_set is not None labels = dataset_labels(train_set) extra_vocab = dataset_vocab(train_set) model = cnn_classifier.CNNClassifier( pretrain=pretrain, extra_vocab=extra_vocab, labels=labels, charmodel_forward=charmodel_forward, charmodel_backward=charmodel_backward, args=args) if args.cuda: model.cuda() logger.info("Filter sizes: %s" % str(model.config.filter_sizes)) logger.info("Filter channels: %s" % str(model.config.filter_channels)) logger.info("Intermediate layers: %s" % str(model.config.fc_shapes)) save_name = args.save_name if not (save_name): save_name = args.base_name + "_" + args.shorthand + "_" save_name = save_name + "FS_%s_" % "_".join( [str(x) for x in model.config.filter_sizes]) save_name = save_name + "C_%d_" % model.config.filter_channels if model.config.fc_shapes: save_name = save_name + "FC_%s_" % "_".join( [str(x) for x in model.config.fc_shapes]) save_name = save_name + "classifier.pt" model_file = os.path.join(args.save_dir, save_name) if args.train: print_args(args) dev_set = data.read_dataset(args.dev_file, args.wordvec_type, min_len=None) logger.info("Using dev set: %s" % args.dev_file) check_labels(model.labels, dev_set) train_model(model, model_file, args, train_set, dev_set, model.labels) test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None) logger.info("Using test set: %s" % args.test_file) check_labels(model.labels, test_set) if args.test_remap_labels is None: confusion = confusion_dataset(model, test_set) logger.info("Confusion matrix:\n{}".format( format_confusion(confusion, model.labels))) correct, total = confusion_to_accuracy(confusion) logger.info("Macro f1: {}".format(confusion_to_macro_f1(confusion))) else: correct = score_dataset( model, test_set, remap_labels=args.test_remap_labels, forgive_unmapped_labels=args.forgive_unmapped_labels) total = len(test_set) logger.info("Test set: %d correct of %d examples. Accuracy: %f" % (correct, total, correct / total))