Example #1
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    GlobalNames.USE_GPU = FLAGS.use_gpu

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    if "seed" in training_configs:
        # Set random seed
        GlobalNames.SEED = training_configs['seed']

    if 'buffer_size' not in training_configs:
        training_configs['buffer_size'] = 100 * training_configs['batch_size']

    saveto_collections = '%s.pkl' % os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_CHECKPOINIS_PREFIX)
    saveto_best_model = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)
    saveto_best_optim_params = os.path.join(FLAGS.saveto,
                                            FLAGS.model_name + GlobalNames.MY_BEST_OPTIMIZER_PARAMS_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocab(dict_path=data_configs['dictionaries'][0], max_n_words=data_configs['n_words'][0])
    vocab_tgt = Vocab(dict_path=data_configs['dictionaries'][1], max_n_words=data_configs['n_words'][1])

    train_bitext_dataset = ZipDatasets(
        TextDataset(data_path=data_configs['train_data'][0],
                    vocab=vocab_src,
                    bpe_codes=data_configs['bpe_codes'][0],
                    max_len=data_configs['max_len'][0]
                    ),
        TextDataset(data_path=data_configs['train_data'][1],
                    vocab=vocab_tgt,
                    bpe_codes=data_configs['bpe_codes'][1],
                    max_len=data_configs['max_len'][1]
                    ),
        shuffle=training_configs['shuffle']
    )

    valid_bitext_dataset = ZipDatasets(
        TextDataset(data_path=data_configs['valid_data'][0],
                    vocab=vocab_src,
                    bpe_codes=data_configs['bpe_codes'][0]),
        TextDataset(data_path=data_configs['valid_data'][1],
                    vocab=vocab_tgt,
                    bpe_codes=data_configs['bpe_codes'][1])
    )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=training_configs['batch_size'],
                                     sort_buffer=training_configs['use_bucket'],
                                     buffer_size=training_configs['buffer_size'],
                                     sort_fn=lambda line: len(line[-1]))

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  sort_buffer=False)

    bleu_scorer = ExternalScriptBLEUScorer(reference_path=data_configs['bleu_valid_reference'],
                                           lang_pair=data_configs['lang_pair'])

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    model_collections = Collections()

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================================================================== #
    # Build Model & Sampler & Validation
    INFO('Building model...')
    timer.tic()

    model_cls = model_configs.get("model")
    if model_cls not in src.models.__all__:
        raise ValueError(
            "Invalid model class \'{}\' provided. Only {} are supported now.".format(
                model_cls, src.models.__all__))

    nmt_model = eval(model_cls)(n_src_vocab=vocab_src.max_n_words,
                                n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(nmt_model)

    critic = NMTCritierion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params']
                      )

    # Initialize training indicators
    uidx = 0
    bad_count = 0

    # Whether Reloading model
    if FLAGS.reload is True and os.path.exists(saveto_best_model):
        timer.tic()
        INFO("Reloading model...")
        params = torch.load(saveto_best_model)
        nmt_model.load_state_dict(params)

        model_archives = Collections.unpickle(path=saveto_collections)
        model_collections.load(archives=model_archives)

        uidx = model_archives['uidx']
        bad_count = model_archives['bad_count']

        INFO("Done. Model reloaded.")

        if os.path.exists(saveto_best_optim_params):
            INFO("Reloading optimizer params...")
            optimizer_params = torch.load(saveto_best_optim_params)
            optim.optim.load_state_dict(optimizer_params)

            INFO("Done. Optimizer params reloaded.")
        elif uidx > 0:
            INFO("Failed to reload optimizer params: {} does not exist".format(
                saveto_best_optim_params))

        INFO('Done. Elapsed time {0}'.format(timer.toc()))
    # New training. Check if pretraining needed
    else:
        # pretrain
        load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None)

    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    if training_configs['decay_method'] == "loss":

        scheduler = LossScheduler(optimizer=optim,
                                  max_patience=training_configs['lrate_decay_patience'],
                                  min_lr=training_configs['min_lrate'],
                                  decay_scale=0.5,
                                  warmup_steps=training_configs['decay_warmup_steps']
                                  )

    elif training_configs['decay_method'] == "noam":
        optim.init_lr = optimizer_configs['learning_rate'] * model_configs['d_model'] ** (-0.5)

        scheduler = NoamScheduler(optimizer=optim,
                                  warmup_steps=training_configs['decay_warmup_steps'])
    else:
        scheduler = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================================================================== #
    # Prepare training

    params_best_loss = None

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    valid_loss = 1.0 * 1e12  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    for eidx in range(training_configs['max_epochs']):
        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                     total=len(training_iterator),
                                     unit="sents"
                                     )
        for batch in training_iter:

            uidx += 1

            # ================================================================================== #
            # Learning rate annealing
            if np.mod(uidx, training_configs['decay_freq']) == 0 or FLAGS.debug:

                if scheduler.step(global_step=uidx, loss=valid_loss):

                    if training_configs['decay_method'] == "loss":
                        nmt_model.load_state_dict(params_best_loss)

                new_lr = list(optim.get_lrate())[0]
                summary_writer.add_scalar("lrate", new_lr, global_step=uidx)

            seqs_x, seqs_y = batch

            batch_size_t = len(seqs_x)
            cum_samples += batch_size_t
            cum_words += sum(len(s) for s in seqs_y)

            training_progress_bar.update(batch_size_t)

            # Prepare data
            x, y = prepare_data(seqs_x, seqs_y, cuda=GlobalNames.USE_GPU)

            # optim.zero_grad()
            nmt_model.zero_grad()
            loss = compute_forward(model=nmt_model,
                                   critic=critic,
                                   seqs_x=x,
                                   seqs_y=y,
                                   eval=False,
                                   normalization=batch_size_t,
                                   shard_size=training_configs['shard_size'])
            optim.step()

            # ================================================================================== #
            # Display some information
            if np.mod(uidx, training_configs['disp_freq']) == 0:
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                words_per_sen = cum_samples / (timer.toc(return_seconds=True))
                summary_writer.add_scalar("Speed (words/sec)", scalar_value=words_per_sec, global_step=uidx)
                summary_writer.add_scalar("Speed (words/sen)", scalar_value=words_per_sen, global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if np.mod(uidx, training_configs['save_freq']) == 0 or FLAGS.debug:

                if not os.path.exists(FLAGS.saveto):
                    os.mkdir(FLAGS.saveto)

                INFO('Saving the model at iteration {}...'.format(uidx))

                if not os.path.exists(FLAGS.saveto):
                    os.mkdir(FLAGS.saveto)

                saveto_uidx = os.path.join(FLAGS.saveto, FLAGS.model_name + '.iter%d.tpz' % uidx)
                torch.save(nmt_model.state_dict(), saveto_uidx)

                Collections.pickle(path=saveto_collections,
                                   uidx=uidx,
                                   bad_count=bad_count,
                                   **model_collections.export())

                saving_files.append(saveto_uidx)

                INFO('Done')

                if len(saving_files) > 5:
                    for f in saving_files[:-1]:
                        os.remove(f)

                    saving_files = [saving_files[-1]]

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if np.mod(uidx, training_configs['loss_valid_freq']) == 0 or FLAGS.debug:

                valid_loss, valid_n_correct = loss_validation(model=nmt_model,
                                                              critic=critic,
                                                              valid_iterator=valid_iterator,
                                                              )

                model_collections.add_to_collection("history_losses", valid_loss)

                min_history_loss = np.array(model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx)
                summary_writer.add_scalar("n_correct", valid_n_correct, global_step=uidx)

                # If no bess loss model saved, save it.
                if len(model_collections.get_collection("history_losses")) == 0 or params_best_loss is None:
                    params_best_loss = nmt_model.state_dict()

                if valid_loss <= min_history_loss:
                    params_best_loss = nmt_model.state_dict()  # Export best variables

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if (np.mod(uidx, training_configs['bleu_valid_freq']) == 0 and uidx > training_configs['bleu_valid_warmup']) \
                    or FLAGS.debug:

                valid_bleu = bleu_validation(uidx=uidx,
                                             valid_iterator=valid_iterator,
                                             batch_size=training_configs['bleu_valid_batch_size'],
                                             model=nmt_model,
                                             bleu_scorer=bleu_scorer,
                                             eval_at_char_level=data_configs['eval_at_char_level'],
                                             vocab_tgt=vocab_tgt
                                             )

                model_collections.add_to_collection(key="history_bleus", value=valid_bleu)

                best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max())

                summary_writer.add_scalar("bleu", valid_bleu, uidx)
                summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        INFO('Saving best model...')

                        # save model
                        best_params = nmt_model.state_dict()
                        torch.save(best_params, saveto_best_model)

                        # save optim params
                        INFO('Saving best optimizer params...')
                        best_optim_params = optim.optim.state_dict()
                        torch.save(best_optim_params, saveto_best_optim_params)

                        INFO('Done.')

                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs['early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                with open("./valid.txt", 'a') as f:
                    f.write("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}\n".format(uidx, valid_loss,
                                                                                                   valid_bleu, lrate,
                                                                                                   bad_count))

        training_progress_bar.close()
Example #2
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        cofig_path: str
        pretrain_path: str, defalut=""
        model_name: str
        log_path: str
    """
    config_path = os.path.abspath(FLAGS.config_path)

    with open(config_path.strip()) as f:
        configs = json.load(f)

    data_configs = configs['data_configs']
    # data_configs = set_default_configs(data_configs, default_configs['data_configs'])

    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']

    training_configs = configs['training_configs']
    # training_configs = set_default_configs(training_configs, default_configs['training_configs'])

    saveto_collections = '%s.pkl' % os.path.join(
        FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_CHECKPOINIS_PREFIX)
    saveto_best_model = os.path.join(
        FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary

    vocab_src = Vocab(dict_path=data_configs['dictionaries'][0],
                      max_n_words=data_configs['n_words_src'])
    vocab_tgt = Vocab(dict_path=data_configs['dictionaries'][1],
                      max_n_words=data_configs['n_words_tgt'])

    train_bitext_dataset = ZipDatasets(
        TextDataset(data_path=data_configs['train_data'][0],
                    vocab=vocab_src,
                    bpe_codes=data_configs['src_bpe_codes'],
                    max_len=data_configs['max_len'][0]),
        TextDataset(data_path=data_configs['train_data'][1],
                    vocab=vocab_tgt,
                    bpe_codes=data_configs['tgt_bpe_codes'],
                    max_len=data_configs['max_len'][1]),
        shuffle=training_configs['shuffle'])

    valid_bitext_dataset = ZipDatasets(
        TextDataset(data_path=data_configs['valid_data'][0],
                    vocab=vocab_src,
                    bpe_codes=data_configs['src_bpe_codes']),
        TextDataset(data_path=data_configs['valid_data'][1],
                    vocab=vocab_tgt,
                    bpe_codes=data_configs['tgt_bpe_codes']))

    train_reader = DataIterator(dataset=train_bitext_dataset,
                                batch_size=training_configs['batch_size'],
                                sort_buffer=True,
                                sort_fn=lambda line: len(line[1]))

    valid_reader = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        sort_buffer=False)

    # bleu_scorer = BLEUScorer(reference_path=['{0}{1}'.format(data_configs['valid_data'][1], i) for i in range(data_configs['n_refs'])],
    #                          use_char=training_configs["eval_at_char_level"])

    bleu_scorer = ExternalScriptBLEUScorer(
        reference_path=data_configs['bleu_valid_reference'],
        lang_pair=data_configs['lang_pair'])

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================================================================== #
    # Reload theano shared variables

    parameters = Parameters()
    model_collections = Collections()

    if len(FLAGS.pretrain_path) > 0 and (FLAGS.reload is False
                                         or not os.path.exists(FLAGS.saveto)):
        if os.path.exists(FLAGS.pretrain_path):
            INFO('Reloading model parameters...')
            timer.tic()
            params_pretrain = np.load(FLAGS.pretrain_path)
            parameters.load(params=params_pretrain)
            INFO('Done. Elapsed time {0}'.format(timer.toc()))
        else:
            WARN("Pre-trained model not found at {0}!".format(
                FLAGS.pretrain_path))

    if FLAGS.reload is True and os.path.exists(saveto_best_model):
        INFO('Reloading model...')
        timer.tic()
        params = np.load(saveto_best_model)
        parameters.load(params)

        model_archives = Collections.unpickle(path=saveto_collections)
        model_collections.load(archives=model_archives)

        uidx = model_archives['uidx']
        bad_count = model_archives['bad_count']

        INFO('Done. Elapsed time {0}'.format(timer.toc()))

    else:
        uidx = 0
        bad_count = 0

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================================================================== #
    # Build Model & Sampler & Validation
    INFO('Building model...')
    timer.tic()

    nmt_model = NMTModel(parameters=parameters,
                         prefix=FLAGS.model_name,
                         n_words_src=vocab_src.max_n_words,
                         n_words_tgt=vocab_tgt.max_n_words,
                         **model_configs)

    INFO('Building training&loss_eval function...')
    f_cost, f_update, f_loss_eval = \
        build_training_loss_eval(parameters=parameters,
                                 model=nmt_model,
                                 optimizer_configs=optimizer_configs
                                 )

    INFO('Building decoding function...')

    f_decoding = build_decoding_func(parameters, nmt_model)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================================================================== #
    # Prepare training

    params_best_loss = None

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    # if lrate_decay_patience=20, loss_valid_freq=100, lrate_scheduler start at 20 * 100 = 2000 steps
    lrate_scheduler = LearningRateDecay(
        max_patience=training_configs['lrate_decay_patience'],
        start_steps=training_configs['loss_valid_freq'] *
        training_configs['lrate_decay_patience'])

    cum_samples = 0
    cum_words = 0
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    for eidx in range(training_configs['max_epochs']):
        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        n_samples = 0

        train_iter = train_reader.build_generator()

        for seqs_x, seqs_y in train_iter:

            b_samples = len(seqs_y)
            b_words = sum([len(s) for s in seqs_y])

            n_samples += b_samples
            uidx += 1

            x, x_mask, y, y_mask = prepare_data(seqs_x, seqs_y)

            cum_samples += b_samples
            cum_words += b_words

            GlobalNames.USE_NOISE.set_value(1.0)

            ud_start = time.time()

            loss_ = f_cost(x, x_mask, y, y_mask)
            f_update(lrate)

            ud = time.time() - ud_start

            # check for bad numbers, usually we remove non-finite elements
            # and continue training - but not done here
            if np.isnan(loss_) or np.isinf(loss_):
                WARN('NaN detected')
                time.sleep(300)  # if NaN is detected, sleep 5min

            # ================================================================================== #
            # Verbose
            if np.mod(uidx, training_configs['disp_freq']) == 0 or FLAGS.debug:

                INFO(
                    "Epoch: {0} Update: {1} Loss: {2:.2f} UD: {3:.2f} {4:.2f} words/sec {5:.2f} samples/sec"
                    .format(
                        eidx, uidx, float(loss_), ud,
                        cum_words / timer_for_speed.toc(return_seconds=True),
                        cum_samples /
                        timer_for_speed.toc(return_seconds=True)))
                cum_words = 0
                cum_samples = 0
                timer_for_speed.tic()

            # ================================================================================== #
            # Saving checkpoints

            if np.mod(uidx, training_configs['save_freq']) == 0 or FLAGS.debug:

                INFO('Saving the model at iteration {}...'.format(uidx))

                params_uidx = parameters.export()

                saveto_uidx = os.path.join(FLAGS.saveto,
                                           FLAGS.model_name + '.%d.npz' % uidx)

                np.savez(saveto_uidx, **params_uidx)

                Collections.pickle(path=saveto_collections,
                                   uidx=uidx,
                                   bad_count=bad_count,
                                   **model_collections.export())

                saving_files.append(saveto_uidx)

                INFO('Done')

                # ================================================================================== #
                # Remove models

                if len(saving_files) > 5:
                    for f in saving_files[:-1]:
                        os.remove(f)

                    saving_files = [saving_files[-1]]

            # ================================================================================== #
            # Loss Validation & Learning rate annealing

            if np.mod(uidx,
                      training_configs['loss_valid_freq']) == 0 or FLAGS.debug:

                GlobalNames.USE_NOISE.set_value(0.0)

                valid_loss = loss_validation(f_loss_eval=f_loss_eval,
                                             valid_data_reader=valid_reader,
                                             prepare_data=prepare_data)

                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss",
                                          min_history_loss,
                                          global_step=uidx)

                # If no bess loss model saved, save it.
                if len(model_collections.get_collection(
                        "history_losses")) == 0 or params_best_loss is None:
                    params_best_loss = parameters.export()

                if valid_loss <= min_history_loss:

                    params_best_loss = parameters.export(
                    )  # Export best variables

                if training_configs['lrate_decay'] is True or FLAGS.debug:

                    new_lrate = lrate_scheduler.decay(uidx, valid_loss, lrate)

                    summary_writer.add_scalar("lrate_half_patience",
                                              lrate_scheduler._bad_counts,
                                              uidx)

                    # If learning rate decay happened,
                    # reload from the best loss model.
                    if new_lrate < lrate:
                        parameters.reload_value(params_best_loss, exclude=None)

                    lrate = new_lrate

                    summary_writer.add_scalar("lrate", lrate, uidx)

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if np.mod(uidx,
                      training_configs['bleu_valid_freq']) == 0 or FLAGS.debug:

                GlobalNames.USE_NOISE.set_value(0.0)

                valid_bleu = bleu_validation(
                    func_beam_search=f_decoding,
                    valid_data_reader=valid_reader,
                    prepare_data=prepare_data,
                    bleu_scorer=bleu_scorer,
                    vocab_tgt=vocab_tgt,
                    uidx=uidx,
                    debug=FLAGS.debug,
                    tgt_bpe=data_configs['tgt_bpe_codes'],
                    batch_size=training_configs['bleu_valid_batch_size'],
                    eval_at_char_level=training_configs["eval_at_char_level"])

                model_collections.add_to_collection(key="history_bleus",
                                                    value=valid_bleu)

                best_valid_bleu = float(
                    np.array(model_collections.get_collection(
                        "history_bleus")).max())

                summary_writer.add_scalar("bleu", valid_bleu, uidx)
                summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        INFO('Saving best model...')

                        best_params = parameters.export()
                        np.savez(saveto_best_model, **best_params)

                        INFO('Done.')

                else:
                    bad_count += 1

                    if bad_count >= training_configs['early_stop_patience']:
                        is_early_stop = True
                        WARN("Early Stop!")

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                with open("./valid.txt", 'a') as f:
                    f.write(
                        "{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}\n"
                        .format(uidx, valid_loss, valid_bleu, lrate,
                                bad_count))