Exemplo n.º 1
0
def train(args):
    """Maximum Likelihood Estimation"""

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar)
    train_set = Dataset.from_bin_file(args.train_file)

    if args.dev_file:
        dev_set = Dataset.from_bin_file(args.dev_file)
    else:
        dev_set = Dataset(examples=[])

    vocab = pickle.load(open(args.vocab, 'rb'))

    if args.lang == 'wikisql':
        # import additional packages for wikisql dataset
        from model.wikisql.dataset import WikiSqlExample, WikiSqlTable, TableColumn

    parser_cls = get_parser_class(args.lang)
    model = parser_cls(args, vocab, transition_system)
    model.train()
    if args.cuda: model.cuda()

    optimizer_cls = eval('torch.optim.%s' %
                         args.optimizer)  # FIXME: this is evil!
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)

    if args.uniform_init:
        print('uniformly initialize parameters [-%f, +%f]' %
              (args.uniform_init, args.uniform_init),
              file=sys.stderr)
        nn_utils.uniform_init(-args.uniform_init, args.uniform_init,
                              model.parameters())
    elif args.glorot_init:
        print('use glorot initialization', file=sys.stderr)
        nn_utils.glorot_init(model.parameters())

    # load pre-trained word embedding (optional)
    if args.glove_embed_path:
        print('load glove embedding from: %s' % args.glove_embed_path,
              file=sys.stderr)
        glove_embedding = GloveHelper(args.glove_embed_path)
        glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]

            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f' % (
                    train_iter, report_loss / report_examples)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (
                        report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        if args.dev_file:
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(
                    dev_set.examples,
                    model,
                    args,
                    verbose=True,
                    eval_top_pred_only=args.eval_top_pred_only)
                dev_acc = eval_results['accuracy']
                print('[Epoch %d] code generation accuracy=%.5f took %ds' %
                      (epoch, dev_acc, time.time() - eval_start),
                      file=sys.stderr)
                is_better = history_dev_scores == [] or dev_acc > max(
                    history_dev_scores)
                history_dev_scores.append(dev_acc)
        else:
            is_better = True

            if epoch > args.lr_decay_after_epoch:
                lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                print('decay learning rate to %f' % lr, file=sys.stderr)

                # set new lr
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Exemplo n.º 2
0
def train(cfg: argparse.Namespace):
    cli_logger.info("=== Training ===")

    # initial setup
    summary_writer = prologue(cfg)

    # load train/dev set
    train_set = Dataset.from_bin_file(cfg.train_file)

    if cfg.dev_file:
        dev_set = Dataset.from_bin_file(cfg.dev_file)
    else:
        dev_set = Dataset(examples=[])

    vocab = pickle.load(open(cfg.vocab, 'rb'))

    grammar = ASDLGrammar.from_text(open(cfg.asdl_file).read())
    transition_system = Registrable.by_name(cfg.transition_system)(grammar)

    parser_cls = Registrable.by_name(cfg.parser)
    model = parser_cls(cfg, vocab, transition_system)
    model.train()

    evaluator = Registrable.by_name(cfg.evaluator)(transition_system, args=cfg)
    if cfg.cuda:
        model.cuda()

    optimizer_cls = eval(f'torch.optim.{cfg.optimizer}')
    optimizer = optimizer_cls(model.parameters(), lr=cfg.lr)

    if cfg.uniform_init:
        print('uniformly initialize parameters [-%f, +%f]' %
              (cfg.uniform_init, cfg.uniform_init),
              file=sys.stderr)
        nn_utils.uniform_init(-cfg.uniform_init, cfg.uniform_init,
                              model.parameters())
    elif cfg.xavier_init:
        print('use xavier initialization', file=sys.stderr)
        nn_utils.xavier_init(model.parameters())

    # load pre-trained word embedding (optional)
    if cfg.glove_embed_path:
        print('load glove embedding from: %s' % cfg.glove_embed_path,
              file=sys.stderr)
        glove_embedding = GloveHelper(cfg.glove_embed_path)
        glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=cfg.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= cfg.decode_max_time_step
            ]
            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if cfg.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if cfg.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), cfg.clip_grad)

            optimizer.step()

            if train_iter % cfg.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f' % (
                    train_iter, report_loss / report_examples)
                if cfg.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (
                        report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)

        if cfg.save_all_models:
            model_file = cfg.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        if cfg.dev_file:
            if epoch % cfg.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(
                    dev_set.examples,
                    model,
                    evaluator,
                    cfg,
                    verbose=True,
                    eval_top_pred_only=cfg.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]

                print(
                    '[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)'
                    % (epoch, eval_results, evaluator.default_metric,
                       dev_score, time.time() - eval_start),
                    file=sys.stderr)

                is_better = history_dev_scores == [] or dev_score > max(
                    history_dev_scores)
                history_dev_scores.append(dev_score)
        else:
            is_better = True

        if cfg.decay_lr_every_epoch and epoch > cfg.lr_decay_after_epoch:
            lr = optimizer.param_groups[0]['lr'] * cfg.lr_decay
            print('decay learning rate to %f' % lr, file=sys.stderr)

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = cfg.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), cfg.save_to + '.optim.bin')
        elif patience < cfg.patience and epoch >= cfg.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == cfg.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= cfg.patience and epoch >= cfg.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == cfg.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * cfg.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(cfg.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if cfg.cuda:
                model = model.cuda()

            # load optimizers
            if cfg.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(cfg.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0

    # final setup
    epilogue(cfg, summary_writer)
Exemplo n.º 3
0
def train(args):
    """Maximum Likelihood Estimation"""

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar)
    train_set = Dataset.from_bin_file(args.train_file)

    if args.dev_file:
        dev_set = Dataset.from_bin_file(args.dev_file)
    else: dev_set = Dataset(examples=[])

    vocab = pickle.load(open(args.vocab, 'rb'))
    
    if args.lang == 'wikisql':
        # import additional packages for wikisql dataset
        from model.wikisql.dataset import WikiSqlExample, WikiSqlTable, TableColumn

    parser_cls = get_parser_class(args.lang)
    model = parser_cls(args, vocab, transition_system)
    model.train()
    if args.cuda: model.cuda()

    optimizer_cls = eval('torch.optim.%s' % args.optimizer)  # FIXME: this is evil!
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)

    if args.uniform_init:
        print('uniformly initialize parameters [-%f, +%f]' % (args.uniform_init, args.uniform_init), file=sys.stderr)
        nn_utils.uniform_init(-args.uniform_init, args.uniform_init, model.parameters())
    elif args.glorot_init:
        print('use glorot initialization', file=sys.stderr)
        nn_utils.glorot_init(model.parameters())

    # load pre-trained word embedding (optional)
    if args.glove_embed_path:
        print('load glove embedding from: %s' % args.glove_embed_path, file=sys.stderr)
        glove_embedding = GloveHelper(args.glove_embed_path)
        glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True):
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f' % (train_iter, report_loss / report_examples)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        if args.dev_file:
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(dev_set.examples, model, args,
                                                   verbose=True, eval_top_pred_only=args.eval_top_pred_only)
                dev_acc = eval_results['accuracy']
                print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr)
                is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores)
                history_dev_scores.append(dev_acc)
        else:
            is_better = True

            if epoch > args.lr_decay_after_epoch:
                lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                print('decay learning rate to %f' % lr, file=sys.stderr)

                # set new lr
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Exemplo n.º 4
0
def train(args):
    """Maximum Likelihood Estimation"""

    # load in train/dev set
    train_set = Dataset.from_bin_file(args.train_file)

    if args.dev_file:
        dev_set = Dataset.from_bin_file(args.dev_file)
    else:
        dev_set = Dataset(examples=[])

    vocab = pickle.load(open(args.vocab, 'rb'))

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = Registrable.by_name(args.transition_system)(grammar)

    parser_cls = Registrable.by_name(args.parser)  # TODO: add arg
    if args.pretrain:
        print('Finetune with: ', args.pretrain, file=sys.stderr)
        model = parser_cls.load(model_path=args.pretrain, cuda=args.cuda)
    else:
        model = parser_cls(args, vocab, transition_system)

    model.train()
    evaluator = Registrable.by_name(args.evaluator)(transition_system,
                                                    args=args)
    if args.cuda: model.cuda()

    trainable_parameters = [
        p for n, p in model.named_parameters()
        if 'automodel' not in n and p.requires_grad
    ]
    bert_parameters = [
        p for n, p in model.named_parameters()
        if 'automodel' in n and p.requires_grad
    ]

    optimizer_cls = eval('torch.optim.%s' %
                         args.optimizer)  # FIXME: this is evil!
    if args.finetune_bert:
        optimizer = optimizer_cls(trainable_parameters, lr=args.lr)
    else:
        optimizer = optimizer_cls(trainable_parameters + bert_parameters,
                                  lr=args.lr)

    if not args.pretrain:
        if args.uniform_init:
            print('uniformly initialize parameters [-%f, +%f]' %
                  (args.uniform_init, args.uniform_init),
                  file=sys.stderr)
            nn_utils.uniform_init(-args.uniform_init, args.uniform_init,
                                  trainable_parameters)
        elif args.glorot_init:
            print('use glorot initialization', file=sys.stderr)
            nn_utils.glorot_init(trainable_parameters)

        # load pre-trained word embedding (optional)
        if args.glove_embed_path:
            print('load glove embedding from: %s' % args.glove_embed_path,
                  file=sys.stderr)
            glove_embedding = GloveHelper(args.glove_embed_path)
            glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0

    if args.warmup_step > 0 and args.annealing_step > args.warmup_step:
        lr_scheduler = get_linear_schedule_with_warmup(optimizer,
                                                       args.warmup_step,
                                                       args.annealing_step)

    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]
            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data.item()
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                if args.finetune_bert:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        trainable_parameters + bert_parameters, args.clip_grad)
                else:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        trainable_parameters, args.clip_grad)

            optimizer.step()

            # warmup and annealing
            if args.warmup_step > 0 and args.annealing_step > args.warmup_step:
                lr_scheduler.step()

            if train_iter % args.log_every == 0:
                lr = optimizer.param_groups[0]['lr']
                log_str = '[Iter %d] encoder loss=%.5f, lr=%.6f' % (
                    train_iter, report_loss / report_examples, lr)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (
                        report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        is_better = False
        if args.dev_file:
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(
                    dev_set.examples,
                    model,
                    evaluator,
                    args,
                    verbose=False,
                    eval_top_pred_only=args.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]

                print(
                    '[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)'
                    % (epoch, eval_results, evaluator.default_metric,
                       dev_score, time.time() - eval_start),
                    file=sys.stderr)

                is_better = history_dev_scores == [] or dev_score > max(
                    history_dev_scores)
                history_dev_scores.append(dev_score)
        else:
            is_better = True

        if args.decay_lr_every_epoch and epoch > args.lr_decay_after_epoch:
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('decay learning rate to %f' % lr, file=sys.stderr)

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial or (
                    args.warmup_step > 0
                    and args.annealing_step > args.warmup_step):
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                if args.finetune_bert:
                    optimizer = torch.optim.Adam(trainable_parameters +
                                                 bert_parameters,
                                                 lr=lr)
                else:
                    optimizer = torch.optim.Adam(trainable_parameters, lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Exemplo n.º 5
0
def train(args):
    """Maximum Likelihood Estimation"""

    # load in train/dev set
    print("loading files")
    print(f"Loading Train at {args.train_file}")
    train_set = Dataset.from_bin_file(args.train_file)
    train_ids = [e.idx.split('-')[-1] for e in train_set.examples]
    print(f"{len(train_set.examples)} total examples")
    print('Checking ids:')
    for idx in ['4170655', '13704860', '4170655', '13704860', '3862010']:
        print(f"\t{idx} is {'not in' if idx not in train_ids else 'in'} train")
    print("")
    print(f"First 5 Examples in Train:")
    for i in range(10):
        print(f'\tExample {i + 1}(idx:{train_set.examples[i].idx}):')
        print(f"\t\tSource:{repr(' '.join(train_set.all_source[i])[:100])}")
        print(f"\t\tTarget:{repr(train_set.all_targets[i][:100])}")
    if args.dev_file:
        print(f"Loading dev at {args.dev_file}")
        dev_set = Dataset.from_bin_file(args.dev_file)
    else:
        dev_set = Dataset(examples=[])
    print("Loading vocab")

    vocab = pickle.load(open(args.vocab, 'rb'))

    print(f"Loading grammar {args.asdl_file}")
    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = Registrable.by_name(args.transition_system)(grammar)

    parser_cls = Registrable.by_name(args.parser)  # TODO: add arg
    if args.pretrain:
        print('Finetune with: ', args.pretrain)
        model = parser_cls.load(model_path=args.pretrain, cuda=args.cuda)
    else:
        model = parser_cls(args, vocab, transition_system)

    model.train()
    evaluator = Registrable.by_name(args.evaluator)(transition_system,
                                                    args=args)
    if args.cuda: model.cuda()

    optimizer_cls = eval('torch.optim.%s' %
                         args.optimizer)  # FIXME: this is evil!
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)

    if not args.pretrain:
        if args.uniform_init:
            print('uniformly initialize parameters [-%f, +%f]' %
                  (args.uniform_init, args.uniform_init))
            nn_utils.uniform_init(-args.uniform_init, args.uniform_init,
                                  model.parameters())
        elif args.glorot_init:
            print('use glorot initialization')
            nn_utils.glorot_init(model.parameters())

        # load pre-trained word embedding (optional)
        if args.glove_embed_path:
            print('load glove embedding from: %s' % args.glove_embed_path)
            glove_embedding = GloveHelper(args.glove_embed_path)
            glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab))

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]
            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data.item()
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f' % (
                    train_iter, report_loss / report_examples)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (
                        report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin))

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file)
            model.save(model_file)

        # perform validation
        is_better = False
        if args.dev_file:
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch)
                eval_start = time.time()
                eval_results = evaluation.evaluate(
                    dev_set.examples,
                    model,
                    evaluator,
                    args,
                    verbose=False,
                    eval_top_pred_only=args.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]

                print(
                    '[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)'
                    % (epoch, eval_results, evaluator.default_metric,
                       dev_score, time.time() - eval_start))

                is_better = history_dev_scores == [] or dev_score > max(
                    history_dev_scores)
                history_dev_scores.append(dev_score)
        else:
            is_better = True

        if args.decay_lr_every_epoch and epoch > args.lr_decay_after_epoch:
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('decay learning rate to %f' % lr)

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..')
            print('save model to [%s]' % model_file)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!')
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial)
            if num_trial == args.max_num_trial:
                print('early stop!')
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer')
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers')
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Exemplo n.º 6
0
def train(args):
    """Maximum Likelihood Estimation"""
    tokenizer=BertTokenizer.from_pretrained('./pretrained_models/bert-base-uncased-vocab.txt')
    bertmodels=BertModel.from_pretrained('./pretrained_models/base-uncased/')
    print(len(tokenizer.vocab))
    
    # load in train/dev set
    tmpvocab={'null':0}
    tmpvocab1={'null':0,'unk':1}
    tmpvocab2={'null':0,'unk':1}
    tmpvocab3={'null':0,'unk':1}
    train_set = Dataset.from_bin_file(args.train_file)
    from dependency import sentencetoadj,sentencetoextra_message
    valid=0
#    for example in tqdm.tqdm(train_set.examples):
##        print(example.src_sent)
##        example.mainnode,example.adj,example.edge,_,isv=sentencetoadj(example.src_sent,tmpvocab)
#        example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,True)
##        valid+=isv
##        print(example.src_sent)
###        print( example.contains,example.pos,example.ner,example.types)
##        a=input('gh')
#    print('bukey',valid)
    if args.dev_file:
        dev_set = Dataset.from_bin_file(args.dev_file)
    else: dev_set = Dataset(examples=[])
#    for example in tqdm.tqdm(dev_set.examples):
##        print(example.src_sent)
##        example.mainnode,example.adj,example.edge,_,_=sentencetoadj(example.src_sent,tmpvocab)
#        example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,False)
    vocab = pickle.load(open(args.vocab, 'rb'))
    print(len(vocab.source))
    vocab.source.copyandmerge(tokenizer)
    print(len(vocab.source))
#    tokenizer.update(vocab.source.word2id)
#    print(len(tokenizer.vocab))
#    print(tokenizer.vocab['metodiev'])
#    bertmodels.resize_token_embeddings(len(vocab.source))
    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = Registrable.by_name(args.transition_system)(grammar)

    parser_cls = Registrable.by_name(args.parser)  # TODO: add arg
    model = parser_cls(args, vocab, transition_system,tmpvocab,tmpvocab1,tmpvocab2,tmpvocab3)

    model.train()
    model.tokenizer=tokenizer
    evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args)
    if args.cuda: model.cuda()



    if args.uniform_init:
        print('uniformly initialize parameters [-%f, +%f]' % (args.uniform_init, args.uniform_init), file=sys.stderr)
        nn_utils.uniform_init(-args.uniform_init, args.uniform_init, model.parameters())
    elif args.glorot_init:
        print('use glorot initialization', file=sys.stderr)
        nn_utils.glorot_init(model.parameters())

    # load pre-trained word embedding (optional)
    if args.glove_embed_path:
        print('load glove embedding from: %s' % args.glove_embed_path, file=sys.stderr)
        glove_embedding = GloveHelper(args.glove_embed_path)
        glove_embedding.load_to(model.src_embed, vocab.source)
    print([name for name,_ in model.named_parameters()])
    model.bert_model=bertmodels
#    print([name for name,_ in model.named_parameters()])
    model.train()
    if args.cuda: model.cuda()
#    return 0
#    a=input('haha')
    optimizer_cls = eval('torch.optim.%s' % args.optimizer)  # FIXME: this is evil!
#    parameters=[p for name,p in model.named_parameters() if 'bert_model' not in name or 'embeddings' in name]
    parameters=[p for name,p in model.named_parameters() if 'bert_model' not in name]
    parameters1=[p for name,p in model.named_parameters() if 'bert_model' in name]
    optimizer = optimizer_cls(parameters, lr=args.lr)
    optimizer1 = optimizer_cls(parameters1, lr=0.00001)
    print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)
    is_better = False
    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    report_loss1=0
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        if epoch>40:break
        epoch += 1
        epoch_begin = time.time()
        model.train()
        for batch_examples in tqdm.tqdm(train_set.batch_iter(batch_size=args.batch_size, shuffle=True)):
            def process(header,src_sent,tokenizer):
                length1=len(header)
                flat_src=[]
                for item in src_sent:
                    flat_src.extend(tokenizer._tokenize(item))
                flat=[token for item in header for token in item.tokens]
                flat_head=[]
                for item in flat:
                    flat_head.extend(tokenizer._tokenize(item))
#                length2=len(flat)+length1+len(src_sent)
                length2=len(flat_head)+length1+len(flat_src)
                print(src_sent)
                print([item.tokens for item in header])
                print(flat_src)
                print(flat)
                a=input('hahaha')
                return length2<130
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step and process(e.table.header,e.src_sent,tokenizer)]
            train_iter += 1
            optimizer.zero_grad()
            optimizer1.zero_grad()
#            params1=model.named_parameters()
#            print([param for param,_ in params1])
#            params=model.rnns.named_parameters()
#            print([param for param,_ in params])
#            print([type(param.grad) for _,param in model.rnns.named_parameters()])
#            a=input('ghh;')
            ret_val,_ = model.score(batch_examples)
            loss = -ret_val[0]
            loss1=ret_val[1]
            # print(loss.data)
            loss_val = torch.sum(loss).data.item()
            report_loss += loss_val
            report_loss1 += 1.0*torch.sum(ret_val[2])
            report_examples += len(batch_examples)
            loss = torch.mean(loss)+0*loss1+0*torch.mean(ret_val[2])

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data.item()
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()
#            print([type(param.grad) for _,param in model.rnns.named_parameters()])
#            
#            print([type(param.grad) for param in model.parameters()])
#            a=input('ghh;')
            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)

            optimizer.step()
            optimizer1.step()
            loss=None
            if train_iter % args.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f,coverage loss=%.5f' % (train_iter, report_loss / report_examples,report_loss1 / report_examples)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.
                report_loss1=0

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        if args.dev_file and epoch>=6:
#            a=input('gh')
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(dev_set.examples, model, evaluator, args,
                                                   verbose=True, eval_top_pred_only=args.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]

                print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % (
                                    epoch, eval_results,
                                    evaluator.default_metric,
                                    dev_score,
                                    time.time() - eval_start), file=sys.stderr)
                is_better = history_dev_scores == [] or dev_score > max(history_dev_scores)
                history_dev_scores.append(dev_score)
                print('[Epoch %d] begin validation2' % epoch, file=sys.stderr)
#                eval_start = time.time()
#                eval_results = evaluation.evaluate(dev_set.examples[:2000], model, evaluator, args,
#                                                   verbose=True, eval_top_pred_only=args.eval_top_pred_only)
#                dev_score = eval_results[evaluator.default_metric]
#
#                print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % (
#                                    epoch, eval_results,
#                                    evaluator.default_metric,
#                                    dev_score,
#                                    time.time() - eval_start), file=sys.stderr)
#                is_better = history_dev_scores == [] or dev_score > max(history_dev_scores)
#                history_dev_scores.append(dev_score)

        else:
            is_better = True

        if args.decay_lr_every_epoch and epoch > args.lr_decay_after_epoch:
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('decay learning rate to %f' % lr, file=sys.stderr)

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                a=input('hj')
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0