예제 #1
0
        tok_num_labels=ARGS.num_tok_labels,
        cache_dir=ARGS.working_dir + '/cache',
        tok2id=tok2id)
if CUDA:
    model = model.cuda()
    print("cuda available")

print('PREPPING RUN...')

# # # # # # # # ## # # # ## # # OPTIMIZER, LOSS # # # # # # # # ## # # # ## # #

optimizer = tagging_utils.build_optimizer(
    model, int((num_train_examples * ARGS.epochs) / ARGS.train_batch_size),
    ARGS.learning_rate)

loss_fn = tagging_utils.build_loss_fn()

# # # # # # # # ## # # # ## # # TRAIN # # # # # # # # ## # # # ## # #

writer = SummaryWriter(ARGS.working_dir)

print('INITIAL EVAL...')
model.eval()
results = tagging_utils.run_inference(model, eval_dataloader, loss_fn,
                                      tokenizer)
writer.add_scalar('eval/tok_loss', np.mean(results['tok_loss']), 0)
writer.add_scalar('eval/tok_acc', np.mean(results['labeling_hits']), 0)

attention_results = []

예제 #2
0
    print("Ding -- Use BertForMultitask")
    model = detection_model.BertForMultitask.from_pretrained(
        ARGS.bert_model,
        cls_num_labels=ARGS.num_categories,
        tok_num_labels=ARGS.num_tok_labels,
        cache_dir=ARGS.working_dir + '/cache',
        tok2id=tok2id)

if CUDA:
    model = model.cuda()

optimizer = utils.build_optimizer(
    model, int((num_train_examples * ARGS.epochs) / ARGS.train_batch_size),
    ARGS.learning_rate)

loss_fn = utils.build_loss_fn()

writer = SummaryWriter(ARGS.working_dir)

print('INITIAL EVAL...')
model.eval()
# prev: see on the eval data
# results = tagging_utils.run_inference(model, eval_dataloader, loss_fn, tokenizer)
results = utils.run_inference(model, eval_dataloader, loss_fn, tokenizer)
writer.add_scalar('eval/tok_loss', np.mean(results['tok_loss']), 0)
writer.add_scalar('eval/tok_acc', np.mean(results['labeling_hits']), 0)

print('TRAINING...')
model.train()
for epoch in range(ARGS.epochs):
    print('STARTING EPOCH ', epoch)
예제 #3
0
if CUDA:
    model = model.cuda()

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('NUM PARAMS: ', params)

# # # # # # # # ## # # # ## # # OPTIMIZER, LOSS # # # # # # # # ## # # # ## # #

num_train_steps = (num_train_examples * 40)
if ARGS.pretrain_data:
    num_train_steps += (num_pretrain_examples * ARGS.pretrain_epochs)

optimizer = utils.build_optimizer(model, num_train_steps)

loss_fn, cross_entropy_loss = utils.build_loss_fn(vocab_size=len(tok2id))

writer = SummaryWriter(ARGS.working_dir)

# # # # # # # # # # # PRETRAINING (optional) # # # # # # # # # # # # # # # #
if ARGS.pretrain_data:
    print('PRETRAINING...')
    for epoch in range(ARGS.pretrain_epochs):
        model.train()
        losses = utils.train_for_epoch(
            model,
            pretrain_dataloader,
            tok2id,
            optimizer,
            cross_entropy_loss,
            ignore_enrich=not ARGS.use_pretrain_enrich)
예제 #4
0
def main(_):
    # Create the dataset.
    tokenizer = utils.init_tokenizer(FLAGS.dataset)
    graph_tokenizer = utils.init_graph_tokenizer()
    dataset_class = utils.get_dataset_class(FLAGS.dataset, FLAGS.model_type)
    has_graph = True if FLAGS.model_type == 'graph2text' else False
    local_devices = jax.local_devices()
    num_gpus = min(FLAGS.num_gpus, len(local_devices))

    if FLAGS.job_mode == 'train':
        train_dataset = dataset_class(tokenizer=tokenizer,
                                      graph_tokenizer=graph_tokenizer,
                                      batch_size=FLAGS.train_batch_size,
                                      subset='train',
                                      timesteps=FLAGS.train_timesteps,
                                      version=FLAGS.graph_data_version,
                                      shuffle_data=True,
                                      repeat=True,
                                      debug=FLAGS.debug)
        train_iter = iter(train_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.train_memory_size)
        optimizer = optax.chain(
            optax.clip_by_global_norm(FLAGS.grad_clip), optax.scale_by_adam(),
            optax.scale_by_schedule(
                functools.partial(utils.schedule,
                                  lr_schedule=FLAGS.lr_schedule,
                                  init_lr=FLAGS.init_lr,
                                  min_lr_ratio=FLAGS.min_lr_ratio,
                                  max_steps=FLAGS.max_steps)), optax.scale(-1))
        optimizer = optax.apply_if_finite(optimizer, max_consecutive_errors=5)
        updater = Updater(loss_fn,
                          optimizer,
                          devices=local_devices[:num_gpus],
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _train(updater, train_iter, num_gpus)
    elif FLAGS.job_mode == 'eval':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=FLAGS.eval_batch_size,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.eval_timesteps,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=False,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.eval_memory_size)
        # only use one device for evaluation
        devices = local_devices[:1]
        updater = Updater(loss_fn,
                          optimizer=None,
                          devices=devices,
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _eval(updater, eval_iter)
    elif FLAGS.job_mode == 'sample':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=1,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.sample_length,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=True,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        _sample(eval_iter, tokenizer, local_devices[:num_gpus])
    elif FLAGS.job_mode == 'retrieve':
        eval_dataset = dataset_class(tokenizer=tokenizer,
                                     graph_tokenizer=graph_tokenizer,
                                     batch_size=1,
                                     subset=FLAGS.eval_subset,
                                     timesteps=FLAGS.eval_timesteps,
                                     version=FLAGS.graph_data_version,
                                     shuffle_data=False,
                                     repeat=False,
                                     graph_retrieval_dataset=True,
                                     debug=FLAGS.debug)
        eval_iter = iter(eval_dataset)
        loss_fn = utils.build_loss_fn(vocab_size=tokenizer.vocab_size,
                                      cache_steps=FLAGS.eval_memory_size)
        # only use one device for evaluation
        devices = local_devices[:1]
        updater = Updater(loss_fn,
                          optimizer=None,
                          devices=devices,
                          has_graph=has_graph)
        updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)
        _retrieve(updater, eval_iter)