Example #1
0
def run_train(args, hparams):
    if args.numpy_seed is not None:
        print("Setting numpy random seed to {}...".format(args.numpy_seed))
        np.random.seed(args.numpy_seed)

    # Make sure that pytorch is actually being initialized randomly.
    # On my cluster I was getting highly correlated results from multiple
    # runs, but calling reset_parameters() changed that. A brief look at the
    # pytorch source code revealed that pytorch initializes its RNG by
    # calling std::random_device, which according to the C++ spec is allowed
    # to be deterministic.
    seed_from_numpy = np.random.randint(2147483648)
    print("Manual seed for pytorch:", seed_from_numpy)
    torch.manual_seed(seed_from_numpy)

    hparams.set_from_args(args)
    print("Hyperparameters:")
    hparams.print()

    print("Loading training trees from {}...".format(args.train_path))
    if hparams.predict_tags and args.train_path.endswith('10way.clean'):
        print("WARNING: The data distributed with this repository contains "
              "predicted part-of-speech tags only (not gold tags!) We do not "
              "recommend enabling predict_tags in this configuration.")
    train_treebank = trees.load_trees(args.train_path)
    if hparams.max_len_train > 0:
        train_treebank = [tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train]
    print("Loaded {:,} training examples.".format(len(train_treebank)))

    print("Loading development trees from {}...".format(args.dev_path))
    dev_treebank = trees.load_trees(args.dev_path)
    if hparams.max_len_dev > 0:
        dev_treebank = [tree for tree in dev_treebank if len(list(tree.leaves())) <= hparams.max_len_dev]
    print("Loaded {:,} development examples.".format(len(dev_treebank)))

    print("Processing trees for training...")
    train_parse = [tree.convert() for tree in train_treebank]

    print("Constructing vocabularies...")

    tag_vocab = vocabulary.Vocabulary()
    tag_vocab.index(tokens.START)
    tag_vocab.index(tokens.STOP)
    tag_vocab.index(tokens.TAG_UNK)

    word_vocab = vocabulary.Vocabulary()
    word_vocab.index(tokens.START)
    word_vocab.index(tokens.STOP)
    word_vocab.index(tokens.UNK)

    label_vocab = vocabulary.Vocabulary()
    label_vocab.index(())

    char_set = set()

    for tree in train_parse:
        nodes = [tree]
        while nodes:
            node = nodes.pop()
            if isinstance(node, trees.InternalParseNode):
                label_vocab.index(node.label)
                nodes.extend(reversed(node.children))
            else:
                tag_vocab.index(node.tag)
                word_vocab.index(node.word)
                char_set |= set(node.word)

    char_vocab = vocabulary.Vocabulary()

    # If codepoints are small (e.g. Latin alphabet), index by codepoint directly
    highest_codepoint = max(ord(char) for char in char_set)
    if highest_codepoint < 512:
        if highest_codepoint < 256:
            highest_codepoint = 256
        else:
            highest_codepoint = 512

        # This also takes care of constants like tokens.CHAR_PAD
        for codepoint in range(highest_codepoint):
            char_index = char_vocab.index(chr(codepoint))
            assert char_index == codepoint
    else:
        char_vocab.index(tokens.CHAR_UNK)
        char_vocab.index(tokens.CHAR_START_SENTENCE)
        char_vocab.index(tokens.CHAR_START_WORD)
        char_vocab.index(tokens.CHAR_STOP_WORD)
        char_vocab.index(tokens.CHAR_STOP_SENTENCE)
        for char in sorted(char_set):
            char_vocab.index(char)

    tag_vocab.freeze()
    word_vocab.freeze()
    label_vocab.freeze()
    char_vocab.freeze()

    def print_vocabulary(name, vocab):
        special = {tokens.START, tokens.STOP, tokens.UNK}
        print("{} ({:,}): {}".format(
            name, vocab.size,
            sorted(value for value in vocab.values if value in special) +
            sorted(value for value in vocab.values if value not in special)))

    if args.print_vocabs:
        print_vocabulary("Tag", tag_vocab)
        print_vocabulary("Word", word_vocab)
        print_vocabulary("Label", label_vocab)

    print("Initializing model...")

    load_path = args.load_path
    if load_path is not None:
        print(f"Loading parameters from {load_path}")
        info = torch_load(load_path)
        parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict'])
    else:
        parser = parse_nk.NKChartParser(
            tag_vocab,
            word_vocab,
            label_vocab,
            char_vocab,
            hparams,
        )

    print("Initializing optimizer...")
    trainable_parameters = [param for param in parser.parameters() if param.requires_grad]
    trainer = torch.optim.Adam(trainable_parameters, lr=1., betas=(0.9, 0.98), eps=1e-9)
    if load_path is not None:
        trainer.load_state_dict(info['trainer'])

    def set_lr(new_lr):
        for param_group in trainer.param_groups:
            param_group['lr'] = new_lr

    assert hparams.step_decay, "Only step_decay schedule is supported"

    warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        trainer, 'max',
        factor=hparams.step_decay_factor,
        patience=hparams.step_decay_patience,
        verbose=True,
    )
    def schedule_lr(iteration):
        iteration = iteration + 1
        if iteration <= hparams.learning_rate_warmup_steps:
            set_lr(iteration * warmup_coeff)

    clippable_parameters = trainable_parameters
    grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm

    print("Training...")
    total_processed = 0
    current_processed = 0
    check_every = len(train_parse) / args.checks_per_epoch
    best_dev_fscore = -np.inf
    best_dev_model_path = None
    best_dev_processed = 0

    start_time = time.time()

    def check_dev():
        nonlocal best_dev_fscore
        nonlocal best_dev_model_path
        nonlocal best_dev_processed

        dev_start_time = time.time()

        dev_predicted = []
        for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size):
            subbatch_trees = dev_treebank[dev_start_index:dev_start_index+args.eval_batch_size]
            subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees]
            predicted, _ = parser.parse_batch(subbatch_sentences)
            del _
            dev_predicted.extend([p.convert() for p in predicted])

        dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted)

        print(
            "dev-fscore {} "
            "dev-elapsed {} "
            "total-elapsed {}".format(
                dev_fscore,
                format_elapsed(dev_start_time),
                format_elapsed(start_time),
            )
        )

        if dev_fscore.fscore > best_dev_fscore:
            if best_dev_model_path is not None:
                extensions = [".pt"]
                for ext in extensions:
                    path = best_dev_model_path + ext
                    if os.path.exists(path):
                        print("Removing previous model file {}...".format(path))
                        os.remove(path)

            best_dev_fscore = dev_fscore.fscore
            best_dev_model_path = "{}_dev={:.2f}".format(
                args.model_path_base, dev_fscore.fscore)
            best_dev_processed = total_processed
            print("Saving new best model to {}...".format(best_dev_model_path))
            torch.save({
                'spec': parser.spec,
                'state_dict': parser.state_dict(),
                'trainer' : trainer.state_dict(),
                }, best_dev_model_path + ".pt")

    for epoch in itertools.count(start=1):
        if args.epochs is not None and epoch > args.epochs:
            break

        np.random.shuffle(train_parse)
        epoch_start_time = time.time()

        for start_index in tqdm(range(0, len(train_parse), args.batch_size)):
            trainer.zero_grad()
            schedule_lr(total_processed // args.batch_size)

            batch_loss_value = 0.0
            batch_trees = train_parse[start_index:start_index + args.batch_size]
            batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees]
            batch_num_tokens = sum(len(sentence) for sentence in batch_sentences)

            for subbatch_sentences, subbatch_trees in parser.split_batch(batch_sentences, batch_trees, args.subbatch_max_tokens):
                _, loss = parser.parse_batch(subbatch_sentences, subbatch_trees)

                if hparams.predict_tags:
                    loss = loss[0] / len(batch_trees) + loss[1] / batch_num_tokens
                else:
                    loss = loss / len(batch_trees)
                loss_value = float(loss.data.cpu().numpy())
                batch_loss_value += loss_value
                if loss_value > 0:
                    loss.backward()
                del loss
                total_processed += len(subbatch_trees)
                current_processed += len(subbatch_trees)

            grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold)

            trainer.step()

            # if start_index // args.batch_size + 1 == int(np.ceil(len(train_parse) / args.batch_size)):
            #     print(
            #         "epoch {:,} "
            #         "batch {:,}/{:,} "
            #         "processed {:,} "
            #         "batch-loss {:.4f} "
            #         "grad-norm {:.4f} "
            #         "epoch-elapsed {} "
            #         "total-elapsed {}".format(
            #             epoch,
            #             start_index // args.batch_size + 1,
            #             int(np.ceil(len(train_parse) / args.batch_size)),
            #             total_processed,
            #             batch_loss_value,
            #             grad_norm,
            #             format_elapsed(epoch_start_time),
            #             format_elapsed(start_time),
            #         )
            #     )

            if current_processed >= check_every:
                current_processed -= check_every
                # print('\nEpoch {}, weights {}'.format(epoch, parser.weighted_layer.weight.data))
                check_dev()

        # adjust learning rate at the end of an epoch
        if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps:
            scheduler.step(best_dev_fscore)
            if (total_processed - best_dev_processed) > ((hparams.step_decay_patience + 1) * hparams.max_consecutive_decays * len(train_parse)):
                print("Terminating due to lack of improvement in dev fscore.")
                print("The layer weights are: ", parser.weighted_layer.weight)
                break
Example #2
0
def run_train(args, hparams):
    if args.numpy_seed is not None:
        print("Setting numpy random seed to {}...".format(args.numpy_seed),
              flush=True)
        np.random.seed(args.numpy_seed)

    # Make sure that pytorch is actually being initialized randomly.
    # On my cluster I was getting highly correlated results from multiple
    # runs, but calling reset_parameters() changed that. A brief look at the
    # pytorch source code revealed that pytorch initializes its RNG by
    # calling std::random_device, which according to the C++ spec is allowed
    # to be deterministic.
    seed_from_numpy = np.random.randint(2147483648)
    print("Manual seed for pytorch:", seed_from_numpy, flush=True)
    torch.manual_seed(seed_from_numpy)

    hparams.set_from_args(args)
    print("Hyperparameters:", flush=True)
    hparams.print()

    print("Loading training trees from {}...".format(args.train_path),
          flush=True)
    train_treebank = trees.load_trees(args.train_path)
    if hparams.max_len_train > 0:
        train_treebank = [
            tree for tree in train_treebank
            if len(list(tree.leaves())) <= hparams.max_len_train
        ]
    print("Loaded {:,} training examples.".format(len(train_treebank)),
          flush=True)

    print("Loading development trees from {}...".format(args.dev_path),
          flush=True)
    dev_treebank = trees.load_trees(args.dev_path)
    if hparams.max_len_dev > 0:
        dev_treebank = [
            tree for tree in dev_treebank
            if len(list(tree.leaves())) <= hparams.max_len_dev
        ]
    print("Loaded {:,} development examples.".format(len(dev_treebank)),
          flush=True)

    print("Processing trees for training...", flush=True)
    train_parse = [tree.convert() for tree in train_treebank]

    print("Constructing vocabularies...", flush=True)

    tag_vocab = vocabulary.Vocabulary()
    tag_vocab.index(tokens.START)
    tag_vocab.index(tokens.STOP)
    tag_vocab.index(tokens.TAG_UNK)

    word_vocab = vocabulary.Vocabulary()
    word_vocab.index(tokens.START)
    word_vocab.index(tokens.STOP)
    word_vocab.index(tokens.UNK)

    label_vocab = vocabulary.Vocabulary()
    label_vocab.index(())

    char_set = set()

    for tree in train_parse:
        nodes = [tree]
        while nodes:
            node = nodes.pop()
            if isinstance(node, trees.InternalParseNode):
                label_vocab.index(node.label)
                nodes.extend(reversed(node.children))
            else:
                tag_vocab.index(node.tag)
                word_vocab.index(node.word)
                char_set |= set(node.word)

    char_vocab = vocabulary.Vocabulary()

    # If codepoints are small (e.g. Latin alphabet), index by codepoint directly
    highest_codepoint = max(ord(char) for char in char_set)
    if highest_codepoint < 512:
        if highest_codepoint < 256:
            highest_codepoint = 256
        else:
            highest_codepoint = 512

        # This also takes care of constants like tokens.CHAR_PAD
        for codepoint in range(highest_codepoint):
            char_index = char_vocab.index(chr(codepoint))
            assert char_index == codepoint
    else:
        char_vocab.index(tokens.CHAR_UNK)
        char_vocab.index(tokens.CHAR_START_SENTENCE)
        char_vocab.index(tokens.CHAR_START_WORD)
        char_vocab.index(tokens.CHAR_STOP_WORD)
        char_vocab.index(tokens.CHAR_STOP_SENTENCE)
        for char in sorted(char_set):
            char_vocab.index(char)

    tag_vocab.freeze()
    word_vocab.freeze()
    label_vocab.freeze()
    char_vocab.freeze()

    def print_vocabulary(name, vocab):
        special = {tokens.START, tokens.STOP, tokens.UNK}
        print("{} ({:,}): {}".format(
            name, vocab.size,
            sorted(value for value in vocab.values if value in special) +
            sorted(value for value in vocab.values if value not in special)),
              flush=True)

    if args.print_vocabs:
        print_vocabulary("Tag", tag_vocab)
        print_vocabulary("Word", word_vocab)
        print_vocabulary("Label", label_vocab)

    print("Initializing model...", flush=True)

    load_path = None
    if load_path is not None:
        print(f"Loading parameters from {load_path}", flush=True)
        info = torch_load(load_path)
        parser = parse_nk.NKChartParser.from_spec(info['spec'],
                                                  info['state_dict'])
    else:
        parser = parse_nk.NKChartParser(
            tag_vocab,
            word_vocab,
            label_vocab,
            char_vocab,
            hparams,
        )

    print("Initializing optimizer...", flush=True)
    trainable_parameters = [
        param for param in parser.parameters() if param.requires_grad
    ]
    trainer = torch.optim.Adam(trainable_parameters,
                               lr=1.,
                               betas=(0.9, 0.98),
                               eps=1e-9)
    if load_path is not None:
        trainer.load_state_dict(info['trainer'])

    def set_lr(new_lr):
        for param_group in trainer.param_groups:
            param_group['lr'] = new_lr

    assert hparams.step_decay, "Only step_decay schedule is supported"

    warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        trainer,
        'max',
        factor=hparams.step_decay_factor,
        patience=hparams.step_decay_patience,
        verbose=True,
    )

    def schedule_lr(iteration):
        iteration = iteration + 1
        if iteration <= hparams.learning_rate_warmup_steps:
            set_lr(iteration * warmup_coeff)

    clippable_parameters = trainable_parameters
    grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm

    print("Training...", flush=True)
    total_processed = 0
    current_processed = 0
    check_every = len(train_parse) / args.checks_per_epoch
    best_dev_fscore = -np.inf
    best_dev_model_path = None
    dev_efscore = None

    start_time = time.time()

    def check_dev():
        nonlocal best_dev_fscore
        nonlocal best_dev_model_path
        nonlocal dev_efscore

        dev_start_time = time.time()

        dev_predicted = []
        for dev_start_index in range(0, len(dev_treebank),
                                     args.eval_batch_size):
            subbatch_trees = dev_treebank[dev_start_index:dev_start_index +
                                          args.eval_batch_size]
            subbatch_sentences = [[(leaf.tag, leaf.word)
                                   for leaf in tree.leaves()]
                                  for tree in subbatch_trees]
            predicted, _ = parser.parse_batch(subbatch_sentences)
            del _
            dev_predicted.extend([p.convert() for p in predicted])

        dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank,
                                    dev_predicted)

        print(" dev-fscore {} "
              "dev-elapsed {} "
              "total-elapsed {}".format(
                  dev_fscore,
                  format_elapsed(dev_start_time),
                  format_elapsed(start_time),
              ),
              flush=True)

        dev_efscore = evaluate_EDITED.Evaluate(dev_treebank, dev_predicted)

        print(" dev-Efscore: {}".format(dev_efscore), flush=True)

        # MJ - keep model with best efscore
        if dev_efscore.efscore > best_dev_fscore:
            best_dev_fscore = dev_efscore.efscore
            if best_dev_model_path is not None:
                extensions = [".pt"]
                for ext in extensions:
                    path = best_dev_model_path + ext
                    if os.path.exists(path):
                        print(
                            " Removing previous model file {}...".format(path),
                            flush=True)
                        os.remove(path)

                best_dev_model_path = "{}_Edev={:.4}".format(
                    args.model_path_base, best_dev_fscore)
                print(" Saving new best model to {}...".format(
                    best_dev_model_path),
                      flush=True)
                torch.save(
                    {
                        'spec': parser.spec,
                        'state_dict': parser.state_dict(),
                        'trainer': trainer.state_dict(),
                    }, best_dev_model_path + ".pt")

    def check_hurdle(epoch, hurdle):
        if best_dev_fscore < hurdle:
            message = ("FAILURE: Epoch {} hurdle failed, stopping now!\n"
                       "best_dev_fscore = {} < epoch{}_hurdle = {}".format(
                           epoch, best_dev_fscore, epoch, hurdle))
            print(message, flush=True)
            if args.results_path:
                print(message, file=open(args.results_path, 'w'), flush=True)
            sys.exit(message)

    for epoch in itertools.count(start=1):
        if args.epochs is not None and epoch > args.epochs:
            break

        np.random.shuffle(train_parse)
        epoch_start_time = time.time()
        sum_grad_norm = 0
        sum_batch_loss_value = 0
        batch_processed = 0

        for start_index in range(0, len(train_parse), args.batch_size):
            trainer.zero_grad()
            schedule_lr(total_processed // args.batch_size)

            batch_loss_value = 0.0
            batch_trees = train_parse[start_index:start_index +
                                      args.batch_size]
            batch_sentences = [[(leaf.tag, leaf.word)
                                for leaf in tree.leaves()]
                               for tree in batch_trees]

            for subbatch_sentences, subbatch_trees in parser.split_batch(
                    batch_sentences, batch_trees, args.subbatch_max_tokens):
                _, loss = parser.parse_batch(subbatch_sentences,
                                             subbatch_trees)

                loss = loss / len(batch_trees)
                loss_value = float(loss.data.cpu().numpy())
                batch_loss_value += loss_value
                if loss_value > 0:
                    loss.backward()
                del loss

                len_subbatch_trees = len(subbatch_trees)
                total_processed += len_subbatch_trees
                current_processed += len_subbatch_trees
                batch_processed += len_subbatch_trees

            grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters,
                                                       grad_clip_threshold)
            sum_batch_loss_value += batch_loss_value
            sum_grad_norm += grad_norm

            trainer.step()

            if args.print_minibatches:
                print(" epoch {:,} "
                      "batch {:,}/{:,} "
                      "processed {:,} "
                      "batch-loss {:.4f} "
                      "grad-norm {:.4f} "
                      "epoch-elapsed {} "
                      "total-elapsed {}".format(
                          epoch, start_index // args.batch_size + 1,
                          int(np.ceil(len(train_parse) / args.batch_size)),
                          total_processed, batch_loss_value, grad_norm,
                          format_elapsed(epoch_start_time),
                          format_elapsed(start_time)),
                      flush=True)

            if current_processed >= check_every:
                current_processed -= check_every
                check_dev()

        # adjust learning rate at the end of an epoch
        if hparams.step_decay:
            if (total_processed // args.batch_size +
                    1) > hparams.learning_rate_warmup_steps:
                scheduler.step(best_dev_fscore)

        print("Epoch {:,} "
              "total-processed {:,} "
              "average-batch-loss {:.4} "
              "average-grad-norm {:.4} "
              "total-elapsed {}".format(epoch, total_processed,
                                        sum_batch_loss_value / batch_processed,
                                        sum_grad_norm / batch_processed,
                                        format_elapsed(start_time)),
              flush=True)

        if epoch == 1:
            check_hurdle(epoch, args.epoch1_hurdle)
        elif epoch == 10:
            check_hurdle(epoch, args.epoch10_hurdle)

    if args.results_path:
        outf = open(args.results_path, 'w')
        print(dev_efscore.table(), file=outf, flush=True)
    else:
        print(dev_efscore.table(), flush=True)