pw.info('Loading dataset.') dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb')) w_map, test_data, range_idx = dataset['w_map'], dataset['test_data'], dataset['range'] train_loader = LargeDataset(args.dataset_folder, range_idx, args.batch_size, args.sequence_length) test_loader = EvalDataset(test_data, args.batch_size) pw.info('Building models.') rnn_map = {'Basic': BasicRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = args.layer_drop)} rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate) cut_off = args.cut_off + [len(w_map) + 1] if args.label_dim > 0: soft_max = AdaptiveSoftmax(args.label_dim, cut_off) else: soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off) lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu) lm_model.rand_ini() pw.info('Building optimizer.') optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta} if args.lr > 0: optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr) else: optimizer=optim_map[args.update](lm_model.parameters()) if args.restore_checkpoint: if os.path.isfile(args.restore_checkpoint): pw.info("loading checkpoint: '{}'".format(args.restore_checkpoint)) model_file = wrapper.restore_checkpoint(args.restore_checkpoint)['model'] lm_model.load_state_dict(model_file, False) else:
device = torch.device("cuda:" + str(gpu_index) if gpu_index >= 0 else "cpu") if gpu_index >= 0: torch.cuda.set_device(gpu_index) pw.info('Loading data from {}.'.format(args.corpus)) dataset = pickle.load(open(args.corpus, 'rb')) name_list = ['flm_map', 'blm_map', 'gw_map', 'c_map', 'y_map', 'emb_array', 'train_data', 'test_data', 'dev_data'] flm_map, blm_map, gw_map, c_map, y_map, emb_array, train_data, test_data, dev_data = [dataset[tup] for tup in name_list ] pw.info('Building language models and seuqence labeling models.') rnn_map = {'Basic': BasicRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = 0)} flm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate) blm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate) flm_model = LM(flm_rnn_layer, None, len(flm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim) blm_model = LM(blm_rnn_layer, None, len(blm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim) flm_model_seq = SparseSeqLM(flm_model, False, args.lm_droprate, False) blm_model_seq = SparseSeqLM(blm_model, True, args.lm_droprate, False) SL_map = {'vanilla':Vanilla_SeqLabel, 'lm-aug': SeqLabel} seq_model = SL_map[args.seq_model](flm_model_seq, blm_model_seq, len(c_map), args.seq_c_dim, args.seq_c_hid, args.seq_c_layer, len(gw_map), args.seq_w_dim, args.seq_w_hid, args.seq_w_layer, len(y_map), args.seq_droprate, unit=args.seq_rnn_unit) pw.info('Loading pre-trained models from {}.'.format(args.load_seq)) seq_file = wrapper.restore_checkpoint(args.load_seq)['model'] seq_model.load_state_dict(seq_file) seq_model.to(device) crit = CRFLoss(y_map) decoder = CRFDecode(y_map) evaluator = eval_wc(decoder, 'f1')
def main(): global best_ppl print('loading dataset') dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb')) w_map, test_data, range_idx = dataset['w_map'], dataset['test_data'], dataset['range'] cut_off = args.cut_off + [len(w_map) + 1] train_loader = LargeDataset(args.dataset_folder, range_idx, args.batch_size, args.sequence_length) test_loader = EvalDataset(test_data, args.batch_size) print('building model') rnn_map = {'Basic': BasicRNN, 'DDNet': DDRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = args.layer_drop)} rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate) if args.label_dim > 0: soft_max = AdaptiveSoftmax(args.label_dim, cut_off) else: soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off) lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu) lm_model.rand_ini() # lm_model.cuda() # set up optimizers optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta, 'SGD': functools.partial(optim.SGD, momentum=0.9), 'LSRAdam':LSRAdam, 'LSAdam': LSAdam, 'AdamW': AdamW, 'RAdam': RAdam, 'SRAdamW': SRAdamW, 'SRRAdam': SRRAdam} if args.update.lower() == 'lsradam' or args.update.lower == 'lsadam': optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr*((1.+4.*args.sigma)**(0.25)), betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, sigma=args.sigma) elif args.update.lower() == 'radam': optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) elif args.update.lower() == 'adamw': optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, warmup=args.warmup) elif args.update.lower() == 'sradamw': iter_count = 1 optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = args.warmup, restarting_iter=args.restart_schedule[0]) elif args.update.lower() == 'srradam': #NOTE: need to double-check this iter_count = 1 optimizer = optim_map[args.update](lm_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = args.warmup, restarting_iter=args.restart_schedule[0]) else: if args.lr > 0: optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr) else: optimizer=optim_map[args.update](lm_model.parameters()) # Resume title = 'onebillionword-' + args.rnn_layer logger = Logger(os.path.join(args.checkpath, 'log.txt'), title=title) logger.set_names(['Learning Rate', 'Train Loss', 'Train PPL', 'Valid PPL']) if args.load_checkpoint: if os.path.isfile(args.load_checkpoint): print("loading checkpoint: '{}'".format(args.load_checkpoint)) checkpoint_file = torch.load(args.load_checkpoint, map_location=lambda storage, loc: storage) lm_model.load_state_dict(checkpoint_file['lm_model'], False) optimizer.load_state_dict(checkpoint_file['opt'], False) else: print("no checkpoint found at: '{}'".format(args.load_checkpoint)) test_lm = nn.NLLLoss() test_lm.cuda() lm_model.cuda() batch_index = 0 epoch_loss = 0 full_epoch_loss = 0 best_train_ppl = float('inf') cur_lr = args.lr schedule_index = 1 try: for indexs in range(args.epoch): print('#' * 89) print('Start: {}'.format(indexs)) if args.optimizer.lower() == 'sradamw': if indexs in args.schedule: optimizer = SRAdamW(lm_model.parameters(), lr=args.lr * (args.gamma**schedule_index), betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = 0, restarting_iter=args.restart_schedule[schedule_index]) schedule_index += 1 elif args.optimizer.lower() == 'srradam': if indexs in args.schedule: optimizer = SRRAdam(lm_model.parameters(), lr=args.lr * (args.gamma**schedule_index), betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup = 0, restarting_iter=args.restart_schedule[schedule_index]) schedule_index += 1 else: adjust_learning_rate(optimizer, indexs) logger.file.write('\nEpoch: [%d | %d] LR: %f' % (indexs + 1, args.epoch, state['lr'])) iterator = train_loader.get_tqdm() full_epoch_loss = 0 lm_model.train() for word_t, label_t in iterator: if 1 == train_loader.cur_idx: lm_model.init_hidden() label_t = label_t.view(-1) lm_model.zero_grad() loss = lm_model(word_t, label_t) loss.backward() torch.nn.utils.clip_grad_norm(lm_model.parameters(), args.clip) optimizer.step() if args.optimizer.lower() == 'sradamw' or args.optimizer.lower() == 'srradam' iter_count, iter_total = optimizer.update_iter() batch_index += 1 if 0 == batch_index % args.interval: s_loss = utils.to_scalar(loss) writer.add_scalars('loss_tracking/train_loss', {args.model_name:s_loss}, batch_index) epoch_loss += utils.to_scalar(loss) full_epoch_loss += utils.to_scalar(loss) if 0 == batch_index % args.check_interval: epoch_ppl = math.exp(epoch_loss / args.check_interval) writer.add_scalars('loss_tracking/train_ppl', {args.model_name: epoch_ppl}, batch_index) print('epoch_ppl: {} lr: {} @ batch_index: {}'.format(epoch_ppl, cur_lr, batch_index)) logger.file.write('epoch_ppl: {} lr: {} @ batch_index: {}'.format(epoch_ppl, cur_lr, batch_index)) epoch_loss = 0 test_ppl = evaluate(test_loader, lm_model, test_lm, -1) is_best = test_ppl < best_ppl best_ppl = min(test_ppl, best_ppl) writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, indexs) print('test_ppl: {} @ index: {}'.format(test_ppl, indexs)) logger.file.write('test_ppl: {} @ index: {}'.format(test_ppl, indexs)) save_checkpoint({ 'epoch': epoch + 1, 'schedule_index': schedule_index, 'lm_model': lm_model.state_dict(), 'ppl': test_ppl, 'best_ppl': best_ppl, 'opt':optimizer.state_dict(), }, is_best, indexs, checkpoint=args.checkpath) except KeyboardInterrupt: print('Exiting from training early') logger.file.write('Exiting from training early') test_ppl = evaluate(test_loader, lm_model, test_lm, -1) writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, args.epoch) is_best=False save_checkpoint({ 'epoch': epoch + 1, 'schedule_index': schedule_index, 'lm_model': lm_model.state_dict(), 'ppl': test_ppl, 'best_ppl': best_ppl, 'opt':optimizer.state_dict(), }, is_best, indexs, checkpoint=args.checkpath) print('Best PPL:%f'%best_ppl) logger.file.write('Best PPL:%f'%best_ppl) logger.close() with open("./all_results.txt", "a") as f: fcntl.flock(f, fcntl.LOCK_EX) f.write("%s\n"%args.checkpath) f.write("best_ppl %f\n\n"%best_ppl) fcntl.flock(f, fcntl.LOCK_UN)
device = torch.device("cuda:" + str(gpu_index) if gpu_index >= 0 else "cpu") if gpu_index >= 0: torch.cuda.set_device(gpu_index) logger.info('Loading data') dataset = pickle.load(open(args.corpus, 'rb')) name_list = ['flm_map', 'blm_map', 'gw_map', 'c_map', 'y_map', 'emb_array', 'train_data', 'test_data', 'dev_data'] flm_map, blm_map, gw_map, c_map, y_map, emb_array, train_data, test_data, dev_data = [dataset[tup] for tup in name_list ] logger.info('Loading language model') rnn_map = {'Basic': BasicRNN} flm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate) blm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate) flm_model = LM(flm_rnn_layer, None, len(flm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim) blm_model = LM(blm_rnn_layer, None, len(blm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim) flm_file = wrapper.restore_checkpoint(args.forward_lm)['model'] flm_model.load_state_dict(flm_file, False) blm_file = wrapper.restore_checkpoint(args.backward_lm)['model'] blm_model.load_state_dict(blm_file, False) flm_model_seq = ElmoLM(flm_model, False, args.lm_droprate, True) blm_model_seq = ElmoLM(blm_model, True, args.lm_droprate, True) logger.info('Building model') SL_map = {'vanilla':Vanilla_SeqLabel, 'lm-aug': SeqLabel} seq_model = SL_map[args.seq_model](flm_model_seq, blm_model_seq, len(c_map), args.seq_c_dim, args.seq_c_hid, args.seq_c_layer, len(gw_map), args.seq_w_dim, args.seq_w_hid, args.seq_w_layer, len(y_map), args.seq_droprate, unit=args.seq_rnn_unit) seq_model.rand_init() seq_model.load_pretrained_word_embedding(torch.FloatTensor(emb_array)) seq_model.to(device)