Exemplo n.º 1
0
def _train_labeler(args):
  if args.data_setup == 'joint':
    train_gen_list, val_gen_list, crowd_dev_gen, elmo, bert, vocab = get_joint_datasets(args)
  else:
    train_fname = args.train_data
    dev_fname = args.dev_data
    print(train_fname, dev_fname)
    data_gens, elmo = get_datasets([(train_fname, 'train', args.goal),
                              (dev_fname, 'dev', args.goal)], args)
    train_gen_list = [(args.goal, data_gens[0])]
    val_gen_list = [(args.goal, data_gens[1])]
  train_log = SummaryWriter(os.path.join(constant.EXP_ROOT, args.model_id, "log", "train"))
  validation_log = SummaryWriter(os.path.join(constant.EXP_ROOT, args.model_id, "log", "validation"))
  tensorboard = TensorboardWriter(train_log, validation_log)

  if args.model_type == 'labeler':
    print('==> Labeler')
    model = denoising_models.Labeler(args, constant.ANSWER_NUM_DICT[args.goal])
  elif args.model_type == 'filter':
    print('==> Filter')
    model = denoising_models.Filter(args, constant.ANSWER_NUM_DICT[args.goal])
  else:
    print('Invalid model type: -model_type ' + args.model_type)
    raise NotImplementedError

  model.cuda()
  total_loss = 0
  batch_num = 0
  best_macro_f1 = 0.
  start_time = time.time()
  init_time = time.time()

  if args.bert:
    if args.bert_param_path:
      print('==> Loading BERT from ' + args.bert_param_path)
      model.bert.load_state_dict(torch.load(args.bert_param_path, map_location='cpu'))
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_parameters = [
        {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
        {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
        ]
    optimizer = BERTAdam(optimizer_parameters,
                         lr=args.bert_learning_rate,
                         warmup=args.bert_warmup_proportion,
                         t_total=-1) # TODO: 
  else:
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
  #optimizer = optim.SGD(model.parameters(), lr=1., momentum=0.)

  if args.load:
    load_model(args.reload_model_name, constant.EXP_ROOT, args.model_id, model, optimizer)

  for idx, m in enumerate(model.modules()):
    logging.info(str(idx) + '->' + str(m))

  while True:
    batch_num += 1  # single batch composed of all train signal passed by.
    for (type_name, data_gen) in train_gen_list:
      try:
        batch = next(data_gen)
        batch, _ = to_torch(batch)
      except StopIteration:
        logging.info(type_name + " finished at " + str(batch_num))
        print('Done!')
        torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},
                   '{0:s}/{1:s}.pt'.format(constant.EXP_ROOT, args.model_id))
        return
      optimizer.zero_grad()
      loss, output_logits, cls_logits = model(batch, type_name)
      loss.backward()
      total_loss += loss.item()
      optimizer.step()

      if batch_num % args.log_period == 0 and batch_num > 0:
        gc.collect()
        cur_loss = float(1.0 * loss.clone().item())
        elapsed = time.time() - start_time
        train_loss_str = ('|loss {0:3f} | at {1:d}step | @ {2:.2f} ms/batch'.format(cur_loss, batch_num,
                                                                                    elapsed * 1000 / args.log_period))
        start_time = time.time()
        print(train_loss_str)
        logging.info(train_loss_str)
        tensorboard.add_train_scalar('train_loss_' + type_name, cur_loss, batch_num)

      if batch_num % args.eval_period == 0 and batch_num > 0:
        output_index = get_output_index(output_logits, threshold=args.threshold)
        gold_pred_train = get_gold_pred_str(output_index, batch['y'].data.cpu().clone(), args.goal)
        print(gold_pred_train[:10])
        accuracy = sum([set(y) == set(yp) for y, yp in gold_pred_train]) * 1.0 / len(gold_pred_train)

        train_acc_str = '{1:s} Train accuracy: {0:.1f}%'.format(accuracy * 100, type_name)
        if cls_logits is not None:
          cls_accuracy =  sum([(1. if pred > 0. else 0.) == gold for pred, gold in zip(cls_logits, batch['y_cls'].data.cpu().numpy())])  / float(cls_logits.size()[0])
          cls_tp = sum([(1. if pred > 0. else 0.) == 1. and gold == 1. for pred, gold in zip(cls_logits, batch['y_cls'].data.cpu().numpy())])
          cls_precision = cls_tp  / float(sum([1. if pred > 0. else 0. for pred in cls_logits])) 
          cls_recall = cls_tp  / float(sum(batch['y_cls'].data.cpu().numpy()))
          cls_f1 = f1(cls_precision, cls_recall)
          train_cls_acc_str = '{1:s} Train cls accuracy: {0:.2f}%  P: {2:.3f}  R: {3:.3f}  F1: {4:.3f}'.format(cls_accuracy * 100, type_name, cls_precision, cls_recall, cls_f1)
        print(train_acc_str)
        if cls_logits is not None:
          print(train_cls_acc_str)
        logging.info(train_acc_str)
        tensorboard.add_train_scalar('train_acc_' + type_name, accuracy, batch_num)
        if args.goal != 'onto':
          for (val_type_name, val_data_gen) in val_gen_list:
            if val_type_name == type_name:
              eval_batch, _ = to_torch(next(val_data_gen))
              evaluate_batch(batch_num, eval_batch, model, tensorboard, val_type_name, args, args.goal)

    if batch_num % args.eval_period == 0 and batch_num > 0 and args.data_setup == 'joint':
      # Evaluate Loss on the Turk Dev dataset.
      print('---- eval at step {0:d} ---'.format(batch_num))
      crowd_eval_loss, macro_f1 = evaluate_data(batch_num, 'crowd/dev_tree.json', model,
                                                tensorboard, "open", args, elmo, bert, vocab=vocab)

      if best_macro_f1 < macro_f1:
        best_macro_f1 = macro_f1
        save_fname = '{0:s}/{1:s}_best.pt'.format(constant.EXP_ROOT, args.model_id)
        torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, save_fname)
        print(
          'Total {0:.2f} minutes have passed, saving at {1:s} '.format((time.time() - init_time) / 60, save_fname))

    if batch_num % args.eval_period == 0 and batch_num > 0 and args.goal == 'onto':
      # Evaluate Loss on the Turk Dev dataset.
      print('---- OntoNotes: eval at step {0:d} ---'.format(batch_num))
      crowd_eval_loss, macro_f1 = evaluate_data(batch_num, args.dev_data, model, tensorboard,
                                                args.goal, args, elmo)

    if batch_num % args.save_period == 0 and batch_num > 0:
      save_fname = '{0:s}/{1:s}_{2:d}.pt'.format(constant.EXP_ROOT, args.model_id, batch_num)
      torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, save_fname)
      print(
        'Total {0:.2f} minutes have passed, saving at {1:s} '.format((time.time() - init_time) / 60, save_fname))
  # Training finished! 
  torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},
             '{0:s}/{1:s}.pt'.format(constant.EXP_ROOT, args.model_id))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .csv files (or other data files) for the task."
    )
    parser.add_argument(
        "--vocab_file",
        default=None,
        type=str,
        required=True,
        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help=
        "The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )
    parser.add_argument(
        "--init_checkpoint",
        default=None,
        type=str,
        required=True,
        help="Initial checkpoint (usually from a pre-trained BERT model).")

    ## Other parameters
    parser.add_argument("--eval_test",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the test set.")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10% of training.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--accumulate_gradients",
        type=int,
        default=1,
        help=
        "Number of steps to accumulate gradient on (divide the batch_size and accumulate)"
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )

    parser.add_argument('--save_model',
                        type=str,
                        default=None,
                        help="Path to save the trained model to")

    parser.add_argument('--load_model',
                        type=str,
                        default=None,
                        help="Path to load the trained model from")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    if args.accumulate_gradients < 1:
        raise ValueError(
            "Invalid accumulate_gradients parameter: {}, should be >= 1".
            format(args.accumulate_gradients))

    args.train_batch_size = int(args.train_batch_size /
                                args.accumulate_gradients)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    bert_config = BertConfig.from_json_file(args.bert_config_file)

    if args.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}"
            .format(args.max_seq_length, bert_config.max_position_embeddings))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    # prepare dataloaders

    processor = processor_class()
    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file,
                                           do_lower_case=args.do_lower_case)

    # training set
    train_examples = None
    num_train_steps = None
    train_examples = processor.get_train_examples(args.data_dir)
    num_train_steps = int(
        len(train_examples) / args.train_batch_size * args.num_train_epochs)

    train_features = convert_examples_to_features(train_examples, label_list,
                                                  args.max_seq_length,
                                                  tokenizer)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)

    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids)
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    # test set
    if args.eval_test:
        test_examples = processor.get_test_examples(args.data_dir)
        test_features = convert_examples_to_features(test_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in test_features],
                                     dtype=torch.long)

        test_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        test_dataloader = DataLoader(test_data,
                                     batch_size=args.eval_batch_size,
                                     shuffle=False)

        del all_input_ids, all_input_mask, all_label_ids, all_segment_ids
    # model and optimizer

    model = BertForSequenceClassification(bert_config, len(label_list))

    if args.init_checkpoint is not None:
        model.bert.load_state_dict(
            torch.load(args.init_checkpoint, map_location='cpu'))
    model.to(device)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    no_decay = ['bias', 'gamma', 'beta']
    optimizer_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay_rate':
        0.01
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay_rate':
        0.0
    }]

    optimizer = BERTAdam(optimizer_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_steps)

    # train
    output_log_file = os.path.join(args.output_dir, "log.txt")
    print("output_log_file=", output_log_file)
    with open(output_log_file, "w") as writer:
        if args.eval_test:
            writer.write(
                "epoch\tglobal_step\tloss\ttest_loss\ttest_accuracy\n")
        else:
            writer.write("epoch\tglobal_step\tloss\n")

    if args.load_model:
        checkpoint = torch.load(args.load_model)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        print("===Loaded previous checkpoint===")
        for key in checkpoint:
            print(key, "=", checkpoint[key])

    else:
        global_step = 0
        epoch = 0

    for _ in trange(int(args.num_train_epochs), desc="Epoch"):
        epoch += 1
        model.train()
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()  # We have accumulated enought gradients
                model.zero_grad()
                global_step += 1

        # eval_test
        if args.eval_test:
            model.eval()
            test_loss, test_accuracy = 0, 0
            nb_test_steps, nb_test_examples = 0, 0
            with open(
                    os.path.join(args.output_dir,
                                 "test_ep_" + str(epoch) + ".txt"),
                    "w") as f_test:
                for input_ids, input_mask, segment_ids, label_ids in test_dataloader:
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)

                    with torch.no_grad():
                        tmp_test_loss, logits = model(input_ids, segment_ids,
                                                      input_mask, label_ids)

                    logits = F.softmax(logits, dim=-1)
                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    outputs = np.argmax(logits, axis=1)
                    for output_i in range(len(outputs)):
                        f_test.write(str(outputs[output_i]))
                        for ou in logits[output_i]:
                            f_test.write(" " + str(ou))
                        f_test.write("\n")
                    tmp_test_accuracy = np.sum(outputs == label_ids)

                    test_loss += tmp_test_loss.mean().item()
                    test_accuracy += tmp_test_accuracy

                    nb_test_examples += input_ids.size(0)
                    nb_test_steps += 1

            test_loss = test_loss / nb_test_steps
            test_accuracy = test_accuracy / nb_test_examples

        result = collections.OrderedDict()
        if args.eval_test:
            result = {
                'epoch': epoch,
                'global_step': global_step,
                'loss': tr_loss / nb_tr_steps,
                'test_loss': test_loss,
                'test_accuracy': test_accuracy
            }
        else:
            result = {
                'epoch': epoch,
                'global_step': global_step,
                'loss': tr_loss / nb_tr_steps
            }

        logger.info("***** Eval results *****")
        with open(output_log_file, "a+") as writer:
            for key in result.keys():
                logger.info("  %s = %s\n", key, str(result[key]))
                writer.write("%s\t" % (str(result[key])))
            writer.write("\n")

    if args.save_model:
        try:
            torch.save(
                {
                    'model_state_dict': model.state_dict(),
                    'epoch': epoch,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': tr_loss / nb_tr_steps,
                    'global_step': global_step,
                },
                f=args.save_model)

        except FileNotFoundError:
            f = open(args.save_model, "w+")
            torch.save(
                {
                    'model_state_dict': model.state_dict(),
                    'epoch': epoch,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': tr_loss / nb_tr_steps,
                    'global_step': global_step,
                },
                f=f)