def main(args, logger): random.seed(args.seed) torch.manual_seed(args.seed) device = torch.device('cuda' if args.cuda else 'cpu') dat = TaggingDataset(args.data, args.batch_size, device) dat.log(logger) logger.log(str(args)) model = BiLSTMTagger(len(dat.word2x), len(dat.tag2y), len(dat.char2c), args.wdim, args.cdim, args.hdim, args.dropout, args.layers, args.nochar, args.loss, args.init).to(device) model.apply(get_init_weights(args.init)) optim = torch.optim.Adam(model.parameters(), lr=args.lr) best_model = copy.deepcopy(model) best_perf = float('-inf') bad_epochs = 0 try: for ep in range(1, args.epochs+1): random.shuffle(dat.batches_train) output = model.do_epoch(ep, dat.batches_train, args.clip, optim, logger=logger, check_interval=args.check_interval) if math.isnan(output['loss']): break with torch.no_grad(): eval_result = model.evaluate(dat.batches_val, dat.tag2y) perf = eval_result['acc'] if not 'O' in dat.tag2y else \ eval_result['f1_<all>'] logger.log('Epoch {:3d} | '.format(ep) + ' '.join(['{:s} {:8.3f} | '.format(key, output[key]) for key in output]) + ' val perf {:8.3f}'.format(perf), newline=False) if perf > best_perf: best_perf = perf bad_epochs = 0 logger.log('\t*Updating best model*') best_model.load_state_dict(model.state_dict()) else: bad_epochs += 1 logger.log('\tBad epoch %d' % bad_epochs) if bad_epochs >= args.max_bad_epochs: break except KeyboardInterrupt: logger.log('-'*89) logger.log('Exiting from training early') return best_model, best_perf