Пример #1
0
def _test(args, device):
  assert args.load
  test_fname = args.eval_data
  data_gens, _, _ = get_datasets([(test_fname, 'test')], args)
  if args.model_type == 'ETModel':
    print('==> ETModel')
    model = models.ETModel(args, constant.ANSWER_NUM_DICT[args.goal])
  else:
    print('Invalid model type: -model_type ' + args.model_type)
    raise NotImplementedError
  model.to(device)
  model.eval()
  load_model(args.reload_model_name, constant.EXP_ROOT, args.model_id, model)
  if args.multi_gpu:
    model = torch.nn.DataParallel(model)
    print("==> use", torch.cuda.device_count(), "GPUs.")
  for name, dataset in [(test_fname, data_gens[0])]:
    print('Processing... ' + name)
    total_gold_pred = []
    total_annot_ids = []
    total_probs = []
    total_ys = []
    for batch_num, batch in enumerate(dataset):
      if batch_num % 10 == 0:
        print(batch_num)
      if not isinstance(batch, dict):
        print('==> batch: ', batch) 
      eval_batch, annot_ids = to_torch(batch, device)
      if args.multi_gpu:
        output_logits = model(eval_batch)
      else:
        _, output_logits, _ = model(eval_batch)
      output_index = get_output_index(output_logits, threshold=args.threshold)
      output_prob = model.sigmoid_fn(output_logits).data.cpu().clone().numpy()
      y = eval_batch['y'].data.cpu().clone().numpy()
      gold_pred = get_gold_pred_str(output_index, y, args.goal)
      total_probs.extend(output_prob)
      total_ys.extend(y)
      total_gold_pred.extend(gold_pred)
      total_annot_ids.extend(annot_ids)
    pickle.dump({'gold_id_array': total_ys, 'pred_dist': total_probs},
                open(constant.OUT_ROOT + '{0:s}.pkl'.format(args.model_id), "wb"))
    print(len(total_annot_ids), len(total_gold_pred))
    with open(constant.OUT_ROOT + '{0:s}.json'.format(args.model_id), 'w') as f_out:
      output_dict = {}
      counter = 0
      for a_id, (gold, pred) in zip(total_annot_ids, total_gold_pred):
        output_dict[a_id] = {"gold": gold, "pred": pred}
        counter += 1
      json.dump(output_dict, f_out)
    logging.info('processing: ' + name)
  print('Done!')
Пример #2
0
def _test(args):
  assert args.load
  test_fname = args.eval_data
  data_gens, _ = get_datasets([(test_fname, 'test', args.goal)], args)
  if args.model_type == 'et_model':
    print('==> Entity Typing Model')
    model = models.ETModel(args, constant.ANSWER_NUM_DICT[args.goal])
  elif args.model_type == 'bert_uncase_small':
    print('==> Bert Uncased Small')
    model = models.Bert(args, constant.ANSWER_NUM_DICT[args.goal])
  else:
    print('Invalid model type: -model_type ' + args.model_type)
    raise NotImplementedError
  model.cuda()
  model.eval()
  load_model(args.reload_model_name, constant.EXP_ROOT, args.model_id, model)

  for name, dataset in [(test_fname, data_gens[0])]:
    print('Processing... ' + name)
    total_gold_pred = []
    total_annot_ids = []
    total_probs = []
    total_ys = []
    batch_attn = []
    for batch_num, batch in enumerate(dataset):
      print(batch_num)
      if not isinstance(batch, dict):
        print('==> batch: ', batch)
      eval_batch, annot_ids = to_torch(batch)
      loss, output_logits, attn_score = model(eval_batch, args.goal)
      #batch_attn.append((batch, attn_score.data))
      output_index = get_output_index(output_logits, threshold=args.threshold)
      #output_prob = model.sigmoid_fn(output_logits).data.cpu().clone().numpy()
      y = eval_batch['y'].data.cpu().clone().numpy()
      gold_pred = get_gold_pred_str(output_index, y, args.goal)
      #total_probs.extend(output_prob)
      #total_ys.extend(y)
      total_gold_pred.extend(gold_pred)
      total_annot_ids.extend(annot_ids)
    #mrr_val = mrr(total_probs, total_ys)
    #print('mrr_value: ', mrr_val)
    #pickle.dump({'gold_id_array': total_ys, 'pred_dist': total_probs},
    #            open('./{0:s}.p'.format(args.reload_model_name), "wb"))
    with open('./{0:s}.json'.format(args.reload_model_name), 'w') as f_out:
      output_dict = {}
      counter = 0
      for a_id, (gold, pred) in zip(total_annot_ids, total_gold_pred):
        #attn = batch_attn[0][1].squeeze(2)[counter]
        #attn = attn.cpu().numpy().tolist()
        #print(attn, int(batch_attn[0][0]['mention_span_length'][counter]), sum(attn))
        #print(mntn_emb[counter])
        #print()
        #print(int(batch_attn[0][0]['mention_span_length'][counter]), batch_attn[0][0]['mention_embed'][counter].shape)
        #attn = attn[:int(batch_attn[0][0]['mention_span_length'][counter])]
        output_dict[a_id] = {"gold": gold, "pred": pred} #, "attn": attn, "mntn_len": int(batch_attn[0][0]['mention_span_length'][counter])}
        counter += 1
      json.dump(output_dict, f_out)
    eval_str = get_eval_string(total_gold_pred)
    print(eval_str)
    logging.info('processing: ' + name)
    logging.info(eval_str)
Пример #3
0
def _train(args):
  if args.data_setup == 'joint':
    train_gen_list, val_gen_list, crowd_dev_gen, elmo, bert, vocab = get_joint_datasets(args)
  else:
    train_fname = args.train_data
    dev_fname = args.dev_data
    print(train_fname, dev_fname)
    data_gens, elmo = get_datasets([(train_fname, 'train', args.goal),
                              (dev_fname, 'dev', args.goal)], args)
    train_gen_list = [(args.goal, data_gens[0])]
    val_gen_list = [(args.goal, data_gens[1])]
  train_log = SummaryWriter(os.path.join(constant.EXP_ROOT, args.model_id, "log", "train"))
  validation_log = SummaryWriter(os.path.join(constant.EXP_ROOT, args.model_id, "log", "validation"))
  tensorboard = TensorboardWriter(train_log, validation_log)

  if args.model_type == 'et_model':
    print('==> Entity Typing Model')
    model = models.ETModel(args, constant.ANSWER_NUM_DICT[args.goal])
  elif args.model_type == 'bert_uncase_small':
    print('==> Bert Uncased Small')
    model = models.Bert(args, constant.ANSWER_NUM_DICT[args.goal])
  else:
    print('Invalid model type: -model_type ' + args.model_type)
    raise NotImplementedError

  model.cuda()
  total_loss = 0
  batch_num = 0
  best_macro_f1 = 0.
  start_time = time.time()
  init_time = time.time()
  if args.bert:
    if args.bert_param_path:
      print('==> Loading BERT from ' + args.bert_param_path)
      model.bert.load_state_dict(torch.load(args.bert_param_path, map_location='cpu'))
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_parameters = [
        {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
        {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
        ]
    optimizer = BERTAdam(optimizer_parameters,
                         lr=args.bert_learning_rate,
                         warmup=args.bert_warmup_proportion,
                         t_total=-1) # TODO: 
  else:
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
  #optimizer = optim.SGD(model.parameters(), lr=1., momentum=0.)

  if args.load:
    load_model(args.reload_model_name, constant.EXP_ROOT, args.model_id, model, optimizer)

  for idx, m in enumerate(model.modules()):
    logging.info(str(idx) + '->' + str(m))

  while True:
    batch_num += 1  # single batch composed of all train signal passed by.
    for (type_name, data_gen) in train_gen_list: 
      try:
        batch = next(data_gen)
        batch, _ = to_torch(batch)
      except StopIteration:
        logging.info(type_name + " finished at " + str(batch_num))
        print('Done!')
        torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},
                   '{0:s}/{1:s}.pt'.format(constant.EXP_ROOT, args.model_id))
        return
      optimizer.zero_grad()
      loss, output_logits, _ = model(batch, type_name)
      loss.backward()
      total_loss += loss.item()
      optimizer.step()

      if batch_num % args.log_period == 0 and batch_num > 0:
        gc.collect()
        cur_loss = float(1.0 * loss.clone().item())
        elapsed = time.time() - start_time
        train_loss_str = ('|loss {0:3f} | at {1:d}step | @ {2:.2f} ms/batch'.format(cur_loss, batch_num,
                                                                                    elapsed * 1000 / args.log_period))
        start_time = time.time()
        print(train_loss_str)
        logging.info(train_loss_str)
        tensorboard.add_train_scalar('train_loss_' + type_name, cur_loss, batch_num)

      if batch_num % args.eval_period == 0 and batch_num > 0:
        output_index = get_output_index(output_logits, threshold=args.threshold)
        gold_pred_train = get_gold_pred_str(output_index, batch['y'].data.cpu().clone(), args.goal)
        print(gold_pred_train[:10]) 
        accuracy = sum([set(y) == set(yp) for y, yp in gold_pred_train]) * 1.0 / len(gold_pred_train)

        train_acc_str = '{1:s} Train accuracy: {0:.1f}%'.format(accuracy * 100, type_name)
        print(train_acc_str)
        logging.info(train_acc_str)
        tensorboard.add_train_scalar('train_acc_' + type_name, accuracy, batch_num)
        if args.goal != 'onto':
          for (val_type_name, val_data_gen) in val_gen_list:
            if val_type_name == type_name:
              eval_batch, _ = to_torch(next(val_data_gen))
              evaluate_batch(batch_num, eval_batch, model, tensorboard, val_type_name, args, args.goal)

    if batch_num % args.eval_period == 0 and batch_num > 0 and args.data_setup == 'joint':
      # Evaluate Loss on the Turk Dev dataset.
      print('---- eval at step {0:d} ---'.format(batch_num))
      ###############
      #feed_dict = next(crowd_dev_gen)
      #eval_batch, _ = to_torch(feed_dict)
      #crowd_eval_loss = evaluate_batch(batch_num, eval_batch, model, tensorboard, "open", args.goal, single_type=args.single_type)
      ###############
      crowd_eval_loss, macro_f1 = evaluate_data(batch_num, 'crowd/dev_tree.json', model,
                                                tensorboard, "open", args, elmo, bert)

      if best_macro_f1 < macro_f1:
        best_macro_f1 = macro_f1
        save_fname = '{0:s}/{1:s}_best.pt'.format(constant.EXP_ROOT, args.model_id)
        torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, save_fname)
        print(
          'Total {0:.2f} minutes have passed, saving at {1:s} '.format((time.time() - init_time) / 60, save_fname))

    if batch_num % args.eval_period == 0 and batch_num > 0 and args.goal == 'onto':
      # Evaluate Loss on the Turk Dev dataset.
      print('---- OntoNotes: eval at step {0:d} ---'.format(batch_num))
      crowd_eval_loss, macro_f1 = evaluate_data(batch_num, args.dev_data, model, tensorboard,
                                                args.goal, args, elmo, bert)

    if batch_num % args.save_period == 0 and batch_num > 30000:
      save_fname = '{0:s}/{1:s}_{2:d}.pt'.format(constant.EXP_ROOT, args.model_id, batch_num)
      torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, save_fname)
      print(
        'Total {0:.2f} minutes have passed, saving at {1:s} '.format((time.time() - init_time) / 60, save_fname))
  # Training finished! 
  torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},
             '{0:s}/{1:s}.pt'.format(constant.EXP_ROOT, args.model_id))
Пример #4
0
def _train(args, device):
  print('==> Loading data generator... ')
  train_gen_list, elmo, char_vocab = get_all_datasets(args)

  if args.model_type == 'ETModel':
    print('==> ETModel')
    model = models.ETModel(args, constant.ANSWER_NUM_DICT[args.goal])
  else:
    print('ERROR: Invalid model type: -model_type ' + args.model_type)
    raise NotImplementedError

  model.to(device)
  total_loss = 0
  batch_num = 0
  best_macro_f1 = 0.
  start_time = time.time()
  init_time = time.time()

  optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

  if args.load:
    load_model(args.reload_model_name, constant.EXP_ROOT, args.model_id, model, optimizer)

  for idx, m in enumerate(model.modules()):
    logging.info(str(idx) + '->' + str(m))

  while True:
    batch_num += 1
    for data_gen in train_gen_list:
      try:
        batch = next(data_gen)
        batch, _ = to_torch(batch, device)
      except StopIteration:
        logging.info('finished at ' + str(batch_num))
        print('Done!')
        torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},
                   '{0:s}/{1:s}.pt'.format(constant.EXP_ROOT, args.model_id))
        return
      optimizer.zero_grad()
      loss, output_logits, _ = model(batch)
      loss.backward()
      total_loss += loss.item()
      optimizer.step()

      if batch_num % args.log_period == 0 and batch_num > 0:
        gc.collect()
        cur_loss = float(1.0 * loss.clone().item())
        elapsed = time.time() - start_time
        train_loss_str = ('|loss {0:3f} | at {1:d}step | @ {2:.2f} ms/batch'.format(cur_loss, batch_num,
                                                                                    elapsed * 1000 / args.log_period))
        start_time = time.time()
        print(train_loss_str)
        logging.info(train_loss_str)

      if batch_num % args.eval_period == 0 and batch_num > 0:
        output_index = get_output_index(output_logits, threshold=args.threshold)
        gold_pred_train = get_gold_pred_str(output_index, batch['y'].data.cpu().clone(), args.goal)
        #print(gold_pred_train[:10])
        accuracy = sum([set(y) == set(yp) for y, yp in gold_pred_train]) * 1.0 / len(gold_pred_train)
        train_acc_str = '==> Train accuracy: {0:.1f}%'.format(accuracy * 100)
        print(train_acc_str)
        logging.info(train_acc_str)

    if batch_num % args.eval_period == 0 and batch_num > args.eval_after:
      print('---- eval at step {0:d} ---'.format(batch_num))
      _, macro_f1 = evaluate_data(
        batch_num, args.dev_data, model, args, elmo, device, char_vocab, dev_type='original'
      )

      if best_macro_f1 < macro_f1:
        best_macro_f1 = macro_f1
        save_fname = '{0:s}/{1:s}_best.pt'.format(constant.EXP_ROOT, args.model_id)
        torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, save_fname)
        print(
          'Total {0:.2f} minutes have passed, saving at {1:s} '.format((time.time() - init_time) / 60, save_fname))

    if batch_num % args.save_period == 0 and batch_num > args.save_after:
      save_fname = '{0:s}/{1:s}_{2:d}.pt'.format(constant.EXP_ROOT, args.model_id, batch_num)
      torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, save_fname)
      print(
        'Total {0:.2f} minutes have passed, saving at {1:s} '.format((time.time() - init_time) / 60, save_fname))