def val(model, mode='dev'): temp = 'infre.' if args.infre else '' val_dataset = pickle.load(open('../data/{}.{}pkl'.format(mode, temp), 'rb')) val_dataset = ProbingListMaxDataset(val_dataset) dataLoader = DataLoader(val_dataset, batch_size=args.batch_sz, shuffle=True) criterion = torch.nn.CrossEntropyLoss() val_loss = 0 val_acc = 0 model.eval() with torch.no_grad(): for batch in dataLoader: x = torch.stack(batch['input']) # 5 x bz y = batch['label'] # bz if torch.cuda.is_available(): x = x.cuda() y = y.cuda() output = model(x) loss = criterion(output, y) val_loss += loss.item() val_acc += (output.argmax(1) == y).sum().item() val_loss = val_loss / len(val_dataset) val_acc = val_acc / len(val_dataset) return val_loss, val_acc
# create dir create_dir('../{}/{}/'.format(model_dir, run)) create_dir('../{}/{}/{}'.format(model_dir, run, args.embed)) create_dir(model_save_dir) logging.basicConfig( level=logging.INFO, filename='{}/log.txt'.format(model_save_dir), datefmt='%Y/%m/%d %H:%M:%S', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger('{}'.format(model_save_dir)) # temp = 'infre.' if args.infre else '' temp = 'regre.' val_dataset = pickle.load( open('../data/{}.diff.{}pkl'.format('dev', temp), 'rb')) val_dataset = ProbingListMaxDataset(val_dataset) test_dataset = pickle.load( open('../data/{}.diff.{}pkl'.format('test', temp), 'rb')) test_dataset = ProbingListMaxDataset(test_dataset) train_dataset = pickle.load( open('../data/{}.diff.{}pkl'.format('train', temp), 'rb')) train_dataset = ProbingListMaxDataset(train_dataset) for lr in [0.1, 1, 0.01, 5, 10]: for hidden_dim in [100]: args.lr = lr args.hidden_dim = hidden_dim train(args, logger, model_save_dir, val_dataset, test_dataset, train_dataset)
def train(args, logger, model_save_dir): # set seed torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if args.infre: pretrain_embed = pickle.load( open('../embed_infre/{}'.format(args.embed), 'rb')) train_dataset = pickle.load(open('../data/train.infre.pkl', 'rb')) else: pretrain_embed = pickle.load( open('../embed/{}'.format(args.embed), 'rb')) train_dataset = pickle.load(open('../data/train.pkl', 'rb')) try: pretrain_embed = torch.from_numpy(pretrain_embed).float() except: pretrain_embed = pretrain_embed.float() train_dataset = ProbingListMaxDataset(train_dataset) dataLoader = DataLoader(train_dataset, batch_size=args.batch_sz, shuffle=True) if args.model == 'BiLSTM': model = ListMax(args.hidden_dim, pretrain_embed) elif args.model == 'CNN': model = CNN(pretrained=pretrain_embed) else: model = TransformerModel(pretrained=pretrain_embed, nhead=5, nhid=50, nlayers=2) # model = ListMaxTransformer(args.hidden_dim, pretrain_embed) if torch.cuda.is_available(): model.cuda() criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) best_dev_acc = 0 best_dev_model = None best_dev_test_acc = 0 counter = 0 for epoch in range(1, args.n_epoch + 1): train_loss = 0 train_acc = 0 model.train() iteration = 0 for batch in dataLoader: optimizer.zero_grad() x = torch.stack(batch['input']) # 5 x bz y = batch['label'] # bz if torch.cuda.is_available(): x = x.cuda() y = y.cuda() output = model(x) loss = criterion(output, y) train_loss += loss.item() loss.backward() optimizer.step() train_acc += (output.argmax(1) == y).sum().item() iteration += 1 # if iteration % args.iter_print == 0: # logger.info('{}-{}-{}-{}'.format(epoch, iteration, train_loss, train_acc)) train_loss = train_loss / len(train_dataset) train_acc = train_acc / len(train_dataset) dev_loss, dev_acc = val(model, mode='dev') test_loss, test_acc = val(model, mode='test') if dev_acc > best_dev_acc: best_dev_model = model.state_dict().copy() best_dev_acc = dev_acc best_dev_test_acc = test_acc counter = 0 else: counter += 1 logger.info('TRAIN: epoch:{}-loss:{}-acc:{}'.format( epoch, train_loss, train_acc)) logger.info('DEV: epoch:{}-loss:{}-acc:{}'.format( epoch, dev_loss, dev_acc)) logger.info('TEST: epoch:{}-loss:{}-acc:{}'.format( epoch, test_loss, test_acc)) logger.info('BEST-DEV-ACC: {}, BEST-DEV-TEST-ACC:{}'.format( best_dev_acc, best_dev_test_acc)) # # if counter > 30: # break torch.save( best_dev_model, model_save_dir + '/model-{}-{}.pt'.format(best_dev_test_acc, args.lr))