def eval(args): paddle.set_device(args.device) if not args.init_from_ckpt: raise ValueError('init_from_ckpt should be set when eval.') vocab = load_vocab(args.vocab_file, args.max_characters_per_token) elmo = ELMo(args.batch_size, args.char_embed_dim, args.projection_dim, vocab.size, dropout=args.dropout, num_layers=args.num_layers, num_highways=args.num_highways, char_vocab_size=vocab.char_size) elmo.eval() elmo_loss = ELMoLoss() # Loads pre-trained parameters. weight_state_dict = paddle.load(args.init_from_ckpt + '.pdparams') elmo.set_state_dict(weight_state_dict) print("Loaded checkpoint from %s" % args.init_from_ckpt) dev_dataset = OneBillionWordDataset(args.dev_data_path, vocab, args.batch_size, args.unroll_steps, mode='test', shuffle=False, seed=args.seed) dev_dataloader = DataLoader(dev_dataset, return_list=True, batch_size=None) total_step = total_loss = 0 total_time = 0.0 batch_start_time = time.time() for step, inputs in enumerate(dev_dataloader, start=1): ids, next_ids, ids_reverse, next_ids_reverse = inputs outputs = elmo([ids, ids_reverse]) loss = elmo_loss(outputs, [next_ids, next_ids_reverse]) ppl = paddle.exp(loss) total_loss += loss.numpy()[0] total_step += 1 total_time += (time.time() - batch_start_time) if step % args.log_freq == 0: print("Eval step %d - loss: %.4f - Perplexity: %.4f - %.3fs/step" % (step, loss.numpy()[0] * args.unroll_steps, ppl.numpy()[0], total_time / args.log_freq)) total_time = 0.0 batch_start_time = time.time() avg_loss = total_loss / total_step avg_ppl = math.exp(avg_loss) print("Eval - average loss: %.4f - average Perplexity: %.4f" % (avg_loss * args.unroll_steps, avg_ppl))
def eval(): paddle.disable_static() n_gpus = dist.get_world_size() rank = dist.get_rank() if n_gpus > 1: dist.init_parallel_env() args = parse_args() if not args.init_from_ckpt: raise ValueError('init_from_ckpt should be set when eval.') vocab = load_vocab(args.vocab_file, args.max_characters_per_token) elmo = ELMo(args.batch_size, args.char_embed_dim, args.projection_dim, vocab.size, dropout=args.dropout, num_layers=args.num_layers, num_highways=args.num_highways, char_vocab_size=vocab.char_size) if n_gpus > 1: elmo = paddle.DataParallel(elmo) elmo.eval() elmo_loss = ELMoLoss() # Loads pre-trained parameters. weight_state_dict = paddle.load(args.init_from_ckpt + '.pdparams') elmo.set_state_dict(weight_state_dict) print("Loaded checkpoint from %s" % args.init_from_ckpt) dev_dataset = OneBillionWordDataset(args.dev_data_path, vocab, args.batch_size, args.unroll_steps, n_gpus, rank, mode='test', shuffle=False, seed=args.random_seed) # FIXME(xiemoyuan): When DataLoader support setting batch_size to None, # setting batch_size to None. dev_dataloader = DataLoader(dev_dataset, return_list=True, batch_size=1) total_step = total_loss = 0 total_time = 0.0 batch_start_time = time.time() for step, inputs in enumerate(dev_dataloader, start=1): # FIXME(xiemoyuan): When DataLoader support setting batch_size to None, # deleting the operation of squeeze. for j in range(len(inputs)): inputs[j] = paddle.squeeze(inputs[j], axis=0) ids, next_ids, ids_reverse, next_ids_reverse = inputs outputs = elmo([ids, ids_reverse]) loss = elmo_loss(outputs, [next_ids, next_ids_reverse]) ppl = paddle.exp(loss) total_loss += loss.numpy()[0] total_step += 1 total_time += (time.time() - batch_start_time) if rank == 0: if step % args.log_freq == 0: print( "Eval step %d - loss: %.4f - Perplexity: %.4f - %.3fs/step" % (step, loss.numpy()[0] * args.unroll_steps, ppl.numpy()[0], total_time / args.log_freq)) total_time = 0.0 batch_start_time = time.time() avg_loss = total_loss / total_step avg_ppl = math.exp(avg_loss) if rank == 0: print("Eval - average loss: %.4f - average Perplexity: %.4f" % (avg_loss * args.unroll_steps, avg_ppl))
def train(args): paddle.set_device(args.device) n_procs = dist.get_world_size() rank = dist.get_rank() if n_procs > 1: dist.init_parallel_env() vocab = load_vocab(args.vocab_file, args.max_characters_per_token) elmo = ELMo(args.batch_size, args.char_embed_dim, args.projection_dim, vocab.size, dropout=args.dropout, num_layers=args.num_layers, num_highways=args.num_highways, char_vocab_size=vocab.char_size) if n_procs > 1: elmo = paddle.DataParallel(elmo) elmo.train() gloabl_norm_clip = nn.ClipGradByGlobalNorm(args.max_grad_norm) optimizer = paddle.optimizer.Adagrad(learning_rate=args.lr, parameters=elmo.parameters(), initial_accumulator_value=1.0, grad_clip=gloabl_norm_clip) elmo_loss = ELMoLoss() # Loads pre-trained parameters. if args.init_from_ckpt: weight_state_dict = paddle.load(args.init_from_ckpt + '.pdparams') opt_state_dict = paddle.load(args.init_from_ckpt + '.pdopt') elmo.set_state_dict(weight_state_dict) optimizer.set_state_dict(opt_state_dict) print("Loaded checkpoint from %s" % args.init_from_ckpt) train_dataset = OneBillionWordDataset(args.train_data_path, vocab, args.batch_size, args.unroll_steps, n_procs=n_procs, rank=rank, mode='train', shuffle=True, seed=args.seed) train_dataloader = DataLoader(train_dataset, return_list=True, batch_size=None) n_tokens_per_batch = args.batch_size * args.unroll_steps * n_procs n_steps_per_epoch = int(train_dataset.number_of_tokens / n_tokens_per_batch) n_steps_total = args.epochs * n_steps_per_epoch print("Training for %s epochs and %s steps" % (args.epochs, n_steps_total)) total_time = 0.0 batch_start_time = time.time() for step, inputs in enumerate(train_dataloader, start=1): ids, next_ids, ids_reverse, next_ids_reverse = inputs outputs = elmo([ids, ids_reverse]) loss = elmo_loss(outputs, [next_ids, next_ids_reverse]) ppl = paddle.exp(loss) loss *= args.unroll_steps loss.backward() optimizer.step() optimizer.clear_grad() total_time += (time.time() - batch_start_time) if step % args.log_freq == 0: print("step %d/%d - loss: %.4f - Perplexity: %.4f - %.3fs/step" % (step, n_steps_total, loss.numpy()[0], ppl.numpy()[0], total_time / args.log_freq)) total_time = 0.0 if rank == 0 and step % args.save_freq == 0: save_params(elmo, optimizer, args.save_dir, step) if step == n_steps_total: # training done if rank == 0: save_params(elmo, optimizer, args.save_dir, 'final') break batch_start_time = time.time()