Exemplo n.º 1
0
    def __init__(self, n_src_vocab, n_tgt_vocab, **config):
        super().__init__(n_src_vocab, n_tgt_vocab, **config)
        self.decoder = Decoder(cell_type=config.setdefault(
            'decoder_cell_type', 'cgru'),
                               n_words=n_tgt_vocab,
                               input_size=config['d_word_vec'],
                               hidden_size=config['d_model'],
                               context_size=config['d_model'],
                               capsule_type=config['capsule_type'],
                               dropout_rate=config['dropout'],
                               bridge_type=config['bridge_type'],
                               num_capsules=config['num_capsules'])

        if self.config["apply_word_prediction_loss"]:
            per_dim = config['d_contextual_capsule'] // 2
            self.wp_past = WordPredictor(generator=self.generator,
                                         input_size=per_dim,
                                         d_word_vec=config['d_word_vec'])
            self.wp_future = WordPredictor(generator=self.generator,
                                           input_size=per_dim,
                                           d_word_vec=config['d_word_vec'])

        self.criterion = MultiCriterion(
            weights=dict(nmt_nll=1., wploss_past=1., wploss_future=1.),
            nmt_nll=NMTCriterion(label_smoothing=config['label_smoothing']),
            wploss_past=MultiTargetNMTCriterion(
                label_smoothing=config['label_smoothing']),
            wploss_future=MultiTargetNMTCriterion(
                label_smoothing=config['label_smoothing']))
Exemplo n.º 2
0
    def __init__(self,
                 n_src_vocab,
                 n_tgt_vocab,
                 n_layers=6,
                 n_head=8,
                 d_word_vec=512,
                 d_model=512,
                 d_inner_hid=1024,
                 dim_per_head=None,
                 dropout=0.1,
                 proj_share_weight=True,
                 tie_embedding=True,
                 **kwargs):

        super(Transformer, self).__init__()

        self.encoder = Encoder(n_src_vocab,
                               n_layers=n_layers,
                               n_head=n_head,
                               d_word_vec=d_word_vec,
                               d_model=d_model,
                               d_inner_hid=d_inner_hid,
                               dropout=dropout,
                               dim_per_head=dim_per_head)

        self.decoder = Decoder(n_tgt_vocab,
                               n_layers=n_layers,
                               n_head=n_head,
                               d_word_vec=d_word_vec,
                               d_model=d_model,
                               d_inner_hid=d_inner_hid,
                               dropout=dropout,
                               dim_per_head=dim_per_head)

        self.dropout = nn.Dropout(dropout)

        assert d_model == d_word_vec, \
            'To facilitate the residual connections, \
             the dimensions of all module output shall be the same.'

        if tie_embedding:
            self.encoder.embeddings.embeddings.weight = self.decoder.embeddings.embeddings.weight
            print('tie embedding')
        if proj_share_weight:
            # self.encoder.embeddings.embeddings.weight = self.decoder.embeddings.embeddings.weight
            self.generator = Generator(
                n_words=n_tgt_vocab,
                hidden_size=d_word_vec,
                shared_weight=self.decoder.embeddings.embeddings.weight,
                padding_idx=PAD)

        else:
            self.generator = Generator(n_words=n_tgt_vocab,
                                       hidden_size=d_word_vec,
                                       padding_idx=PAD)

        if kwargs["criterion"] == "basic":
            self.criterion = NMTCriterion(
                label_smoothing=kwargs['label_smoothing'])
Exemplo n.º 3
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # write log of training to file.
    write_log_to_file(
        os.path.join(FLAGS.log_path,
                     "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    GlobalNames.USE_GPU = FLAGS.use_gpu

    if GlobalNames.USE_GPU:
        CURRENT_DEVICE = "cpu"
    else:
        CURRENT_DEVICE = "cuda:0"

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    best_model_prefix = os.path.join(
        FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][0])

    train_batch_size = training_configs["batch_size"] * max(
        1, training_configs["update_cycle"])
    train_buffer_size = training_configs["buffer_size"] * max(
        1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(TextLineDataset(
        data_path=data_configs['train_data'][0],
        vocabulary=vocab_tgt,
        max_len=data_configs['max_len'][0],
    ),
                                      shuffle=training_configs['shuffle'])

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(
            data_path=data_configs['valid_data'][0],
            vocabulary=vocab_tgt,
        ))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=train_batch_size,
        use_bucket=training_configs['use_bucket'],
        buffer_size=train_buffer_size,
        batching_func=training_configs['batching_key'])

    valid_iterator = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        use_bucket=True,
        buffer_size=100000,
        numbering=True)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(FLAGS.saveto, FLAGS.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])
    best_model_saver = Saver(
        save_prefix=best_model_prefix,
        num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    lm_model = build_model(n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(lm_model)

    params_total = sum([p.numel() for n, p in lm_model.named_parameters()])
    params_with_embedding = sum([
        p.numel() for n, p in lm_model.named_parameters()
        if n.find('embedding') == -1
    ])
    INFO('Total parameters: {}'.format(params_total))
    INFO('Total parameters (excluding word embeddings): {}'.format(
        params_with_embedding))

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        lm_model = lm_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    lm_model.init_parameters(FLAGS.pretrain_path, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=lm_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params'])

    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(
                optimizer=optim, **optimizer_configs["scheduler_configs"])

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim,
                                      **optimizer_configs['scheduler_configs'])
        else:
            WARN(
                "Unknown scheduler name {0}. Do not use lr_scheduling.".format(
                    optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build moving average

    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(
            moving_average_method=training_configs['moving_average_method'],
            named_params=lm_model.named_parameters(),
            alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=lm_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [0])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    valid_loss = best_valid_loss = float('inf')  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:
        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
            eidx, uidx),
                                     total=len(training_iterator),
                                     unit="sents")
        for batch in training_iter:

            uidx += 1

            if optimizer_configs[
                    "schedule_method"] is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                scheduler.step(global_step=uidx)

            seqs_y = batch

            n_samples_t = len(seqs_y)
            n_words_t = sum(len(s) for s in seqs_y)

            cum_samples += n_samples_t
            cum_words += n_words_t

            train_loss = 0.
            optim.zero_grad()
            try:
                # Prepare data
                for (seqs_y_t, ) in split_shard(
                        seqs_y, split_size=training_configs['update_cycle']):
                    y = prepare_data(seqs_y_t, cuda=GlobalNames.USE_GPU)

                    loss = compute_forward(
                        model=lm_model,
                        critic=critic,
                        # seqs_x=x,
                        seqs_y=y,
                        eval=False,
                        normalization=n_samples_t,
                        norm_by_words=training_configs["norm_by_words"])
                    train_loss += loss / y.size(
                        1) if not training_configs["norm_by_words"] else loss
                optim.step()

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                    optim.zero_grad()
                else:
                    raise e

            if ma is not None and eidx >= training_configs[
                    'moving_average_start_epoch']:
                ma.step()

            training_progress_bar.update(n_samples_t)
            training_progress_bar.set_description(
                ' - (Epc {}, Upd {}) '.format(eidx, uidx))
            training_progress_bar.set_postfix_str(
                'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f})'.format(
                    train_loss, valid_loss, best_valid_loss))
            summary_writer.add_scalar("train_loss",
                                      scalar_value=train_loss,
                                      global_step=uidx)

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                summary_writer.add_scalar("Speed(words/sec)",
                                          scalar_value=words_per_sec,
                                          global_step=uidx)
                summary_writer.add_scalar("Speed(sents/sen)",
                                          scalar_value=sents_per_sec,
                                          global_step=uidx)
                summary_writer.add_scalar("lrate",
                                          scalar_value=lrate,
                                          global_step=uidx)
                summary_writer.add_scalar("oom_count",
                                          scalar_value=oom_count,
                                          global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(
                    uidx,
                    eidx,
                    every_n_step=training_configs['save_freq'],
                    debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:
                    checkpoint_saver.save(global_step=uidx,
                                          model=lm_model,
                                          optim=optim,
                                          lr_scheduler=scheduler,
                                          collections=model_collections,
                                          ma=ma)

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['loss_valid_freq'],
                    debug=FLAGS.debug):

                if ma is not None:
                    origin_state_dict = deepcopy(lm_model.state_dict())
                    lm_model.load_state_dict(ma.export_ma_params(),
                                             strict=False)

                valid_loss = loss_validation(
                    model=lm_model,
                    critic=critic,
                    valid_iterator=valid_iterator,
                    norm_by_words=training_configs["norm_by_words"])

                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss",
                                          min_history_loss,
                                          global_step=uidx)

                if ma is not None:
                    lm_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

                if optimizer_configs["schedule_method"] == "loss":
                    scheduler.step(metric=best_valid_loss)

                # If model get new best valid loss
                if valid_loss < best_valid_loss:
                    bad_count = 0

                    if is_early_stop is False:
                        # 1. save the best model
                        torch.save(lm_model.state_dict(),
                                   best_model_prefix + ".final")

                        # 2. record all several best models
                        best_model_saver.save(global_step=uidx, model=lm_model)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs[
                            'early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                best_valid_loss = min_history_loss

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                INFO("{0} Loss: {1:.2f} lrate: {2:6f} patience: {3}".format(
                    uidx, valid_loss, lrate, bad_count))

        training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Exemplo n.º 4
0
def tune(flags):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = flags.use_gpu

    if flags.multi_gpu:
        dist.distributed_init(flags.shared_dir)
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        local_rank = dist.get_local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    # else write log of training to file.
    if rank == 0:
        write_log_to_file(
            os.path.join(flags.log_path,
                         "%s.log" % time.strftime("%Y%m%d-%H%M%S")))
    else:
        close_logging()

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    INFO(pretty_configs(configs))

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data
    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos
    # bt tag dataset
    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        is_train_dataset=True),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        is_train_dataset=True))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=training_configs["batch_size"],
        use_bucket=training_configs['use_bucket'],
        buffer_size=training_configs['buffer_size'],
        batching_func=training_configs['batching_key'],
        world_size=world_size,
        rank=rank)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial

    lrate = optimizer_configs['learning_rate']
    model_collections = Collections()

    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])
    best_model_prefix = os.path.join(
        flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX)
    best_model_saver = Saver(
        save_prefix=best_model_prefix,
        num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words,
                            padding_idx=vocab_src.pad,
                            vocab_src=vocab_src,
                            vocab_tgt=vocab_tgt,
                            **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'],
                          padding_idx=vocab_tgt.pad)

    INFO(critic)

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model,
                          flags.pretrain_path,
                          exclude_prefix=flags.pretrain_exclude_prefix,
                          device=Constants.CURRENT_DEVICE)
    # froze_parameters
    froze_params(nmt_model, flags.froze_config)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 4. Build optimizer
    INFO('Building Optimizer...')

    if not flags.multi_gpu:
        optim = Optimizer(name=optimizer_configs['optimizer'],
                          model=nmt_model,
                          lr=lrate,
                          grad_clip=optimizer_configs['grad_clip'],
                          optim_args=optimizer_configs['optimizer_params'],
                          update_cycle=training_configs['update_cycle'])
    else:
        optim = dist.DistributedOptimizer(
            name=optimizer_configs['optimizer'],
            model=nmt_model,
            lr=lrate,
            grad_clip=optimizer_configs['grad_clip'],
            optim_args=optimizer_configs['optimizer_params'],
            device_id=local_rank)

    # 5. Build scheduler for optimizer if needed
    scheduler = build_scheduler(
        schedule_method=optimizer_configs['schedule_method'],
        optimizer=optim,
        scheduler_configs=optimizer_configs['scheduler_configs'])

    # 6. build moving average
    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(
            moving_average_method=training_configs['moving_average_method'],
            named_params=nmt_model.named_parameters(),
            alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if flags.reload:
        checkpoint_saver.load_latest(model=nmt_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma,
                                     device=Constants.CURRENT_DEVICE)

    # broadcast parameters and optimizer states
    if world_size > 1:
        INFO("Broadcasting model parameters...")
        dist.broadcast_parameters(params=nmt_model.state_dict())
        INFO("Broadcasting optimizer states...")
        dist.broadcast_optimizer_state(optimizer=optim.optim)
        INFO('Done.')

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    is_early_stop = model_collections.get_collection("is_early_stop", [
        False,
    ])[-1]

    train_loss_meter = AverageMeter()
    sent_per_sec_meter = TimeMeter()
    tok_per_sec_meter = TimeMeter()

    update_cycle = training_configs['update_cycle']
    grad_denom = 0
    train_loss = 0.0
    cum_n_words = 0
    valid_loss = best_valid_loss = float('inf')

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=flags.log_path)
    else:
        summary_writer = None

    sent_per_sec_meter.start()
    tok_per_sec_meter.start()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
                eidx, uidx),
                                         total=len(training_iterator),
                                         unit="sents")
        else:
            training_progress_bar = None
        # INFO(Constants.USE_BT)
        for batch in training_iter:
            # bt attrib data
            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU)

                loss = compute_forward(
                    model=nmt_model,
                    critic=critic,
                    seqs_x=x,
                    seqs_y=y,
                    eval=False,
                    normalization=1.0,
                    norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size
                train_loss += loss

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom, update uidx
            # - learning rate scheduling
            # - update moving average

            if update_cycle == 0:

                # 0. reduce variables
                if world_size > 1:
                    grad_denom = dist.all_reduce_py(grad_denom)
                    train_loss = dist.all_reduce_py(train_loss)
                    cum_n_words = dist.all_reduce_py(cum_n_words)

                # 1. update parameters
                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)
                    training_progress_bar.set_description(
                        ' - (Epc {}, Upd {}) '.format(eidx, uidx))

                    postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format(
                        train_loss, valid_loss, best_valid_loss)
                    training_progress_bar.set_postfix_str(postfix_str)

                # 2. learning rate scheduling
                if scheduler is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                    scheduler.step(global_step=uidx)

                # 3. update moving average
                if ma is not None and eidx >= training_configs[
                        'moving_average_start_epoch']:
                    ma.step()

                # 4. update meters
                train_loss_meter.update(train_loss, grad_denom)
                sent_per_sec_meter.update(grad_denom)
                tok_per_sec_meter.update(cum_n_words)

                # 5. reset accumulated variables, update uidx
                update_cycle = training_configs['update_cycle']
                grad_denom = 0
                uidx += 1
                cum_n_words = 0.0
                train_loss = 0.0

            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):

                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar(
                        "Speed(sents/sec)",
                        scalar_value=sent_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "Speed(words/sec)",
                        scalar_value=tok_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "train_loss",
                        scalar_value=train_loss_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar("lrate",
                                              scalar_value=lrate,
                                              global_step=uidx)
                    summary_writer.add_scalar("oom_count",
                                              scalar_value=oom_count,
                                              global_step=uidx)

                # Reset Meters
                sent_per_sec_meter.reset()
                tok_per_sec_meter.reset()
                train_loss_meter.reset()

            # ================================================================================== #
            # Saving checkpoints
            # if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=flags.debug):
            #     model_collections.add_to_collection("uidx", uidx)
            #     model_collections.add_to_collection("eidx", eidx)
            #     model_collections.add_to_collection("bad_count", bad_count)
            #
            #     if not is_early_stop:
            #         if rank == 0:
            #             checkpoint_saver.save(global_step=uidx,
            #                                   model=nmt_model,
            #                                   optim=optim,
            #                                   lr_scheduler=scheduler,
            #                                   collections=model_collections,
            #                                   ma=ma)

        torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Exemplo n.º 5
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    GlobalNames.USE_GPU = FLAGS.use_gpu

    if FLAGS.multi_gpu:

        if hvd is None or distributed is None:
            ERROR("Distributed training is disable. Please check the installation of Horovod.")

        hvd.init()
        world_size = hvd.size()
        rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if GlobalNames.USE_GPU:
        torch.cuda.set_device(local_rank)
        CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    if rank != 0:
        close_logging()

    # write log of training to file.
    if rank == 0:
        write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    # ================================================================================== #
    # Parsing configuration files

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_baseline_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    actual_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        ),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        )
    )

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['valid_data'][0],
                        vocabulary=vocab_src,
                        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        )
    )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=training_configs["batch_size"],
                                     use_bucket=training_configs['use_bucket'],
                                     buffer_size=actual_buffer_size,
                                     batching_func=training_configs['batching_key'],
                                     world_size=world_size,
                                     rank=rank)

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  use_bucket=True, buffer_size=100000, numbering=True,
                                  world_size=world_size, rank=rank)

    bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"],
                                  num_refs=data_configs["num_refs"],
                                  lang_pair=data_configs["lang_pair"],
                                  sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'],
                                  postprocess=training_configs["bleu_valid_configs"]['postprocess']
                                  )

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)),
                             num_max_keeping=training_configs['num_kept_checkpoints']
                             )
    best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model'])

    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params'],
                      distributed=True if world_size > 1 else False,
                      update_cycle=training_configs['update_cycle']
                      )
    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(optimizer=optim,
                                                 **optimizer_configs["scheduler_configs"]
                                                 )

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs'])
        else:
            WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build moving average

    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(moving_average_method=training_configs['moving_average_method'],
                           named_params=nmt_model.named_parameters(),
                           alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                     collections=model_collections, ma=ma)

    # broadcast parameters and optimizer states
    if world_size > 1:
        hvd.broadcast_parameters(params=nmt_model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer=optim.optim, root_rank=0)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    cum_n_samples = 0
    cum_n_words = 0
    best_valid_loss = 1.0 * 1e10  # Max Float
    update_cycle = training_configs['update_cycle']
    grad_denom = 0

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=FLAGS.log_path)
    else:
        summary_writer = None

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                         total=len(training_iterator),
                                         unit="sents"
                                         )
        else:
            training_progress_bar = None

        for batch in training_iter:

            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)

            cum_n_samples += batch_size
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=GlobalNames.USE_GPU)

                loss = compute_forward(model=nmt_model,
                                       critic=critic,
                                       seqs_x=x,
                                       seqs_y=y,
                                       eval=False,
                                       normalization=1.0,
                                       norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom
            # - update uidx
            # - update moving average

            if update_cycle == 0:
                if world_size > 1:
                    grad_denom = distributed.all_reduce(grad_denom)

                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)

                update_cycle = training_configs['update_cycle']
                grad_denom = 0

                uidx += 1

                if scheduler is None:
                    pass
                elif optimizer_configs["schedule_method"] == "loss":
                    scheduler.step(metric=best_valid_loss)
                else:
                    scheduler.step(global_step=uidx)

                if ma is not None and eidx >= training_configs['moving_average_start_epoch']:
                    ma.step()
            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']):

                if world_size > 1:
                    cum_n_words = sum(distributed.all_gather(cum_n_words))
                    cum_n_samples = sum(distributed.all_gather(cum_n_samples))

                # words per second and sents per second
                words_per_sec = cum_n_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_n_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx)
                    summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx)
                    summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx)
                    summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx)

                # Reset timer
                timer.tic()
                cum_n_words = 0
                cum_n_samples = 0

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'],
                                       debug=FLAGS.debug):

                valid_loss = loss_validation(model=nmt_model,
                                             critic=critic,
                                             valid_iterator=valid_iterator,
                                             rank=rank,
                                             world_size=world_size
                                             )

                model_collections.add_to_collection("history_losses", valid_loss)

                min_history_loss = np.array(model_collections.get_collection("history_losses")).min()

                best_valid_loss = min_history_loss

                if summary_writer is not None:
                    summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                    summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx)

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx,
                                       every_n_step=training_configs['bleu_valid_freq'],
                                       min_step=training_configs['bleu_valid_warmup'],
                                       debug=FLAGS.debug):

                valid_bleu = bleu_validation(uidx=uidx,
                                             valid_iterator=valid_iterator,
                                             batch_size=training_configs["bleu_valid_batch_size"],
                                             model=nmt_model,
                                             bleu_scorer=bleu_scorer,
                                             vocab_tgt=vocab_tgt,
                                             valid_dir=FLAGS.valid_path,
                                             max_steps=training_configs["bleu_valid_configs"]["max_steps"],
                                             beam_size=training_configs["bleu_valid_configs"]["beam_size"],
                                             alpha=training_configs["bleu_valid_configs"]["alpha"],
                                             world_size=world_size,
                                             rank=rank,
                                             )

                model_collections.add_to_collection(key="history_bleus", value=valid_bleu)

                best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max())

                if summary_writer is not None:
                    summary_writer.add_scalar("bleu", valid_bleu, uidx)
                    summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        if rank == 0:
                            # 1. save the best model
                            torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

                            # 2. record all several best models
                            best_model_saver.save(global_step=uidx, model=nmt_model, ma=ma)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs['early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                if summary_writer is not None:
                    summary_writer.add_scalar("bad_count", bad_count, uidx)

                INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count
                ))

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:
                    if rank == 0:
                        checkpoint_saver.save(global_step=uidx,
                                              model=nmt_model,
                                              optim=optim,
                                              lr_scheduler=scheduler,
                                              collections=model_collections,
                                              ma=ma)

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Exemplo n.º 6
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # write log of training to file.
    write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    GlobalNames.USE_GPU = FLAGS.use_gpu

    if GlobalNames.USE_GPU:
        CURRENT_DEVICE = "cpu"
    else:
        CURRENT_DEVICE = "cuda:0"

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    train_batch_size = training_configs["batch_size"] * max(1, training_configs["update_cycle"])
    train_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        ),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        ),
        shuffle=training_configs['shuffle']
    )

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['valid_data'][0],
                        vocabulary=vocab_src,
                        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        )
    )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=train_batch_size,
                                     use_bucket=training_configs['use_bucket'],
                                     buffer_size=train_buffer_size,
                                     batching_func=training_configs['batching_key'])

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  use_bucket=True, buffer_size=100000, numbering=True)

    bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"],
                                  num_refs=data_configs["num_refs"],
                                  lang_pair=data_configs["lang_pair"],
                                  sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'],
                                  postprocess=training_configs["bleu_valid_configs"]['postprocess']
                                  )

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)),
                             num_max_keeping=training_configs['num_kept_checkpoints']
                             )
    best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params']
                      )
    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(optimizer=optim,
                                                 **optimizer_configs["scheduler_configs"]
                                                 )

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs'])
        else:
            WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build EMA
    if training_configs['ema_decay'] > 0.0:
        ema = ExponentialMovingAverage(named_params=nmt_model.named_parameters(), decay=training_configs['ema_decay'])
    else:
        ema = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                     collections=model_collections)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [0])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    best_valid_loss = 1.0 * 1e10  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:

        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                     total=len(training_iterator),
                                     unit="sents"
                                     )
        for batch in training_iter:

            uidx += 1

            if scheduler is None:
                pass
            elif optimizer_configs["schedule_method"] == "loss":
                scheduler.step(metric=best_valid_loss)
            else:
                scheduler.step(global_step=uidx)

            seqs_x, seqs_y = batch

            n_samples_t = len(seqs_x)
            n_words_t = sum(len(s) for s in seqs_y)

            cum_samples += n_samples_t
            cum_words += n_words_t

            training_progress_bar.update(n_samples_t)

            optim.zero_grad()

            # Prepare data
            for seqs_x_t, seqs_y_t in split_shard(seqs_x, seqs_y, split_size=training_configs['update_cycle']):
                x, y = prepare_data(seqs_x_t, seqs_y_t, cuda=GlobalNames.USE_GPU)

                loss = compute_forward(model=nmt_model,
                                       critic=critic,
                                       seqs_x=x,
                                       seqs_y=y,
                                       eval=False,
                                       normalization=n_samples_t,
                                       norm_by_words=training_configs["norm_by_words"])
            optim.step()

            if ema is not None:
                ema.step()

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']):
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx)
                summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx)
                summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:

                    checkpoint_saver.save(global_step=uidx, model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                          collections=model_collections, ema=ema)

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'],
                                       debug=FLAGS.debug):

                if ema is not None:
                    origin_state_dict = deepcopy(nmt_model.state_dict())
                    nmt_model.load_state_dict(ema.state_dict(), strict=False)

                valid_loss = loss_validation(model=nmt_model,
                                             critic=critic,
                                             valid_iterator=valid_iterator,
                                             )

                model_collections.add_to_collection("history_losses", valid_loss)

                min_history_loss = np.array(model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx)

                best_valid_loss = min_history_loss

                if ema is not None:
                    nmt_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx,
                                       every_n_step=training_configs['bleu_valid_freq'],
                                       min_step=training_configs['bleu_valid_warmup'],
                                       debug=FLAGS.debug):

                if ema is not None:
                    origin_state_dict = deepcopy(nmt_model.state_dict())
                    nmt_model.load_state_dict(ema.state_dict(), strict=False)

                valid_bleu = bleu_validation(uidx=uidx,
                                             valid_iterator=valid_iterator,
                                             batch_size=training_configs["bleu_valid_batch_size"],
                                             model=nmt_model,
                                             bleu_scorer=bleu_scorer,
                                             vocab_tgt=vocab_tgt,
                                             valid_dir=FLAGS.valid_path,
                                             max_steps=training_configs["bleu_valid_configs"]["max_steps"],
                                             beam_size=training_configs["bleu_valid_configs"]["beam_size"],
                                             alpha=training_configs["bleu_valid_configs"]["alpha"]
                                             )

                model_collections.add_to_collection(key="history_bleus", value=valid_bleu)

                best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max())

                summary_writer.add_scalar("bleu", valid_bleu, uidx)
                summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        # 1. save the best model
                        torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

                        # 2. record all several best models
                        best_model_saver.save(global_step=uidx, model=nmt_model)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs['early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                if ema is not None:
                    nmt_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

                INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count
                ))

        training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Exemplo n.º 7
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # write log of training to file.
    write_log_to_file(
        os.path.join(FLAGS.log_path,
                     "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    GlobalNames.USE_GPU = FLAGS.use_gpu

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    if "seed" in training_configs:
        # Set random seed
        GlobalNames.SEED = training_configs['seed']

    if 'buffer_size' not in training_configs:
        training_configs['buffer_size'] = 100 * training_configs['batch_size']

    saveto_collections = '%s.pkl' % os.path.join(
        FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_CHECKPOINIS_PREFIX)
    saveto_best_model = os.path.join(
        FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)
    saveto_best_optim_params = os.path.join(
        FLAGS.saveto,
        FLAGS.model_name + GlobalNames.MY_BEST_OPTIMIZER_PARAMS_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocab(dict_path=data_configs['dictionaries'][0],
                      max_n_words=data_configs['n_words'][0])
    vocab_tgt = Vocab(dict_path=data_configs['dictionaries'][1],
                      max_n_words=data_configs['n_words'][1])

    train_bitext_dataset = ZipDatasets(
        TextDataset(data_path=data_configs['train_data'][0],
                    vocab=vocab_src,
                    bpe_codes=data_configs['bpe_codes'][0],
                    max_len=data_configs['max_len'][0],
                    use_char=data_configs['use_char'][0]),
        TextDataset(data_path=data_configs['train_data'][1],
                    vocab=vocab_tgt,
                    bpe_codes=data_configs['bpe_codes'][1],
                    max_len=data_configs['max_len'][1],
                    use_char=data_configs['use_char'][1]),
        shuffle=training_configs['shuffle'])

    valid_bitext_dataset = ZipDatasets(
        TextDataset(data_path=data_configs['valid_data'][0],
                    vocab=vocab_src,
                    bpe_codes=data_configs['bpe_codes'][0],
                    use_char=data_configs['use_char'][0]),
        TextDataset(data_path=data_configs['valid_data'][1],
                    vocab=vocab_tgt,
                    bpe_codes=data_configs['bpe_codes'][1],
                    use_char=data_configs['use_char'][1]))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=training_configs['batch_size'],
        sort_buffer=training_configs['use_bucket'],
        buffer_size=training_configs['buffer_size'],
        sort_fn=lambda line: len(line[-1]))

    valid_iterator = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        sort_buffer=False)

    bleu_scorer = ExternalScriptBLEUScorer(
        reference_path=data_configs['bleu_valid_reference'],
        lang=data_configs['lang_pair'].split('-')[1],
        bleu_script=training_configs['bleu_valid_configs']['bleu_script'],
        digits_only=True,
        lc=training_configs['bleu_valid_configs']['lowercase'],
        postprocess=training_configs['bleu_valid_configs']['postprocess'])

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    model_collections = Collections()

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================================================================== #
    # Build Model & Sampler & Validation
    INFO('Building model...')
    timer.tic()

    model_cls = model_configs.get("model")
    if model_cls not in src.models.__all__:
        raise ValueError(
            "Invalid model class \'{}\' provided. Only {} are supported now.".
            format(model_cls, src.models.__all__))

    nmt_model = eval(model_cls)(n_src_vocab=vocab_src.max_n_words,
                                n_tgt_vocab=vocab_tgt.max_n_words,
                                **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params'])

    # Initialize training indicators
    uidx = 0
    bad_count = 0

    # Whether Reloading model
    if FLAGS.reload is True and os.path.exists(saveto_best_model):
        timer.tic()
        INFO("Reloading model...")
        params = torch.load(saveto_best_model)
        nmt_model.load_state_dict(params)

        model_archives = Collections.unpickle(path=saveto_collections)
        model_collections.load(archives=model_archives)

        uidx = model_archives['uidx']
        bad_count = model_archives['bad_count']

        INFO("Done. Model reloaded.")

        if os.path.exists(saveto_best_optim_params):
            INFO("Reloading optimizer params...")
            optimizer_params = torch.load(saveto_best_optim_params)
            optim.optim.load_state_dict(optimizer_params)

            INFO("Done. Optimizer params reloaded.")
        elif uidx > 0:
            INFO("Failed to reload optimizer params: {} does not exist".format(
                saveto_best_optim_params))

        INFO('Done. Elapsed time {0}'.format(timer.toc()))
    # New training. Check if pretraining needed
    else:
        # pretrain
        load_pretrained_model(nmt_model,
                              FLAGS.pretrain_path,
                              exclude_prefix=None)

    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # Configure Learning Scheduler
    # Here we have two policies, "loss" and "noam"

    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = LossScheduler(optimizer=optim,
                                      **optimizer_configs['scheduler_configs'])

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim,
                                      **optimizer_configs['scheduler_configs'])
        else:
            WARN(
                "Unknown scheduler name {0}. Do not use lr_scheduling.".format(
                    optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================================================================== #
    # Prepare training

    params_best_loss = None

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    valid_loss = 1.0 * 1e12  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    for eidx in range(training_configs['max_epochs']):
        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                     total=len(training_iterator),
                                     unit="sents")
        for batch in training_iter:

            uidx += 1

            # ================================================================================== #
            # Learning rate annealing

            if scheduler is not None and (np.mod(uidx, scheduler.schedule_freq)
                                          == 0 or FLAGS.debug):

                if scheduler.step(global_step=uidx, loss=valid_loss):

                    if optimizer_configs['schedule_method'] == "loss":
                        nmt_model.load_state_dict(params_best_loss)

                new_lr = list(optim.get_lrate())[0]
                summary_writer.add_scalar("lrate", new_lr, global_step=uidx)

            seqs_x, seqs_y = batch

            batch_size_t = len(seqs_x)
            cum_samples += batch_size_t
            cum_words += sum(len(s) for s in seqs_y)

            training_progress_bar.update(batch_size_t)

            # Prepare data
            x, y = prepare_data(seqs_x, seqs_y, cuda=GlobalNames.USE_GPU)

            # optim.zero_grad()
            nmt_model.zero_grad()
            loss = compute_forward(model=nmt_model,
                                   critic=critic,
                                   seqs_x=x,
                                   seqs_y=y,
                                   eval=False,
                                   normalization=batch_size_t,
                                   shard_size=training_configs['shard_size'])
            optim.step()

            # ================================================================================== #
            # Display some information
            if np.mod(uidx, training_configs['disp_freq']) == 0:
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_samples / (timer.toc(return_seconds=True))
                summary_writer.add_scalar("Speed(words/sec)",
                                          scalar_value=words_per_sec,
                                          global_step=uidx)
                summary_writer.add_scalar("Speed(sents/sen)",
                                          scalar_value=sents_per_sec,
                                          global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if np.mod(uidx, training_configs['save_freq']) == 0 or FLAGS.debug:

                if not os.path.exists(FLAGS.saveto):
                    os.mkdir(FLAGS.saveto)

                INFO('Saving the model at iteration {}...'.format(uidx))

                if not os.path.exists(FLAGS.saveto):
                    os.mkdir(FLAGS.saveto)

                saveto_uidx = os.path.join(
                    FLAGS.saveto, FLAGS.model_name + '.iter%d.tpz' % uidx)
                torch.save(nmt_model.state_dict(), saveto_uidx)

                Collections.pickle(path=saveto_collections,
                                   uidx=uidx,
                                   bad_count=bad_count,
                                   **model_collections.export())

                saving_files.append(saveto_uidx)

                INFO('Done')

                if len(saving_files) > 5:
                    for f in saving_files[:-1]:
                        os.remove(f)

                    saving_files = [saving_files[-1]]

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if np.mod(uidx,
                      training_configs['loss_valid_freq']) == 0 or FLAGS.debug:

                valid_loss, valid_n_correct = loss_validation(
                    model=nmt_model,
                    critic=critic,
                    valid_iterator=valid_iterator,
                )

                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss",
                                          min_history_loss,
                                          global_step=uidx)
                summary_writer.add_scalar("n_correct",
                                          valid_n_correct,
                                          global_step=uidx)

                # If no bess loss model saved, save it.
                if len(model_collections.get_collection(
                        "history_losses")) == 0 or params_best_loss is None:
                    params_best_loss = nmt_model.state_dict()

                if valid_loss <= min_history_loss:
                    params_best_loss = nmt_model.state_dict(
                    )  # Export best variables

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if (np.mod(uidx, training_configs['bleu_valid_freq']) == 0 and uidx > training_configs['bleu_valid_warmup']) \
                    or FLAGS.debug:

                valid_bleu = bleu_validation(
                    uidx=uidx,
                    valid_iterator=valid_iterator,
                    batch_size=training_configs['bleu_valid_batch_size'],
                    model=nmt_model,
                    bleu_scorer=bleu_scorer,
                    eval_at_char_level=data_configs['eval_at_char_level'],
                    vocab_tgt=vocab_tgt,
                    valid_dir=FLAGS.valid_path,
                    max_steps=training_configs["bleu_valid_max_steps"])

                model_collections.add_to_collection(key="history_bleus",
                                                    value=valid_bleu)

                best_valid_bleu = float(
                    np.array(model_collections.get_collection(
                        "history_bleus")).max())

                summary_writer.add_scalar("bleu", valid_bleu, uidx)
                summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        INFO('Saving best model...')

                        # save model
                        best_params = nmt_model.state_dict()
                        torch.save(best_params, saveto_best_model)

                        # save optim params
                        INFO('Saving best optimizer params...')
                        best_optim_params = optim.optim.state_dict()
                        torch.save(best_optim_params, saveto_best_optim_params)

                        INFO('Done.')

                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs[
                            'early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                with open("./valid.txt", 'a') as f:
                    f.write(
                        "{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}\n"
                        .format(uidx, valid_loss, valid_bleu, lrate,
                                bad_count))

        training_progress_bar.close()
Exemplo n.º 8
0
def interactive_FBS(FLAGS):
    patience = FLAGS.try_times
    GlobalNames.USE_GPU = FLAGS.use_gpu
    config_path = os.path.abspath(FLAGS.config_path)

    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']

    timer = Timer()
    #===================================================================================
    #load data
    INFO('loading data...')
    timer.tic()

    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    valid_dataset = TextLineDataset(data_path=FLAGS.source_path,
                                    vocabulary=vocab_src)
    valid_iterator = DataIterator(dataset=valid_dataset,
                                  batch_size=FLAGS.batch_size,
                                  use_bucket=True,
                                  buffer_size=100000,
                                  numbering=True)

    valid_ref = []
    with open(FLAGS.ref_path) as f:
        for sent in f:
            valid_ref.append(vocab_tgt.sent2ids(sent))

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    #===================================================================================
    #build Model & Sampler & Validation
    INFO('Building model...')
    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        critic = critic.cuda()

    timer.tic()
    fw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                               n_tgt_vocab=vocab_tgt.max_n_words,
                               **model_configs)

    #bw_nmt_model = None
    bw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                               n_tgt_vocab=vocab_tgt.max_n_words,
                               **model_configs)
    fw_nmt_model.eval()
    bw_nmt_model.eval()

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Reloading model parameters...')
    timer.tic()

    fw_params = load_model_parameters(FLAGS.fw_model_path, map_location="cpu")
    bw_params = load_model_parameters(FLAGS.bw_model_path, map_location="cpu")

    fw_nmt_model.load_state_dict(fw_params)
    bw_nmt_model.load_state_dict(bw_params)

    if GlobalNames.USE_GPU:
        fw_nmt_model.cuda()
        bw_nmt_model.cuda()

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('begin...')
    timer.tic()

    result_numbers = []
    result = []
    n_words = 0

    imt_numbers = []
    imt_result = []
    imt_n_words = 0
    imt_constrains = [[] for ii in range(FLAGS.imt_step)]

    infer_progress_bar = tqdm(total=len(valid_iterator),
                              desc=' - (Infer)',
                              unit='sents')

    valid_iter = valid_iterator.build_generator()
    for batch in valid_iter:
        batch_result = []
        batch_numbers = []
        numbers, seqs_x = batch
        batch_size_t = len(seqs_x)

        x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU)

        with torch.no_grad():
            word_ids = beam_search(nmt_model=fw_nmt_model,
                                   beam_size=FLAGS.beam_size,
                                   max_steps=FLAGS.max_steps,
                                   src_seqs=x,
                                   alpha=FLAGS.alpha)

        word_ids = word_ids.cpu().numpy().tolist()

        for sent_t in word_ids:
            sent_t = [[wid for wid in line if wid != PAD] for line in sent_t]
            result.append(sent_t)
            batch_result.append(sent_t[0])

            n_words += len(sent_t[0])

        result_numbers += numbers
        imt_numbers += numbers
        batch_numbers += numbers
        batch_ref = [valid_ref[ii] for ii in batch_numbers]

        last_sents = copy.deepcopy(batch_result)
        constrains = [[[] for ii in range(patience)]
                      for jj in range(batch_size_t)]
        positions = [[[] for ii in range(patience)]
                     for jj in range(batch_size_t)]
        for idx in range(FLAGS.imt_step):
            cons, pos = sample_constrains(last_sents, batch_ref, patience)

            for ii in range(batch_size_t):
                for jj in range(patience):
                    constrains[ii][jj].append(cons[ii][jj])
                    positions[ii][jj].append(pos[ii][jj])

            #print(positions)
            imt_constrains[idx].append([vocab_tgt.ids2sent(c) for c in cons])
            bidirection = False
            if FLAGS.bidirection:
                bidirection = True
            with torch.no_grad():
                constrained_word_ids, positions = fixwords_beam_search(
                    fw_nmt_model=fw_nmt_model,
                    bw_nmt_model=bw_nmt_model,
                    beam_size=FLAGS.beam_size,
                    max_steps=FLAGS.max_steps,
                    src_seqs=x,
                    alpha=FLAGS.alpha,
                    constrains=constrains,
                    positions=positions,
                    last_sentences=last_sents,
                    imt_step=idx + 1,
                    bidirection=bidirection)
            constrained_word_ids = constrained_word_ids.cpu().numpy().tolist()
            last_sents = []
            for i, sent_t in enumerate(constrained_word_ids):
                sent_t = [[wid for wid in line if wid != PAD]
                          for line in sent_t]
                if idx == FLAGS.imt_step - 1:
                    imt_result.append(copy.deepcopy(sent_t))
                    imt_n_words += len(sent_t[0])
                samples = []
                for trans in sent_t:
                    sample = []
                    for w in trans:
                        if w == vocab_tgt.EOS:
                            break
                        sample.append(w)
                    samples.append(sample)

                sent_t = []
                for ii in range(len(samples)):
                    if ii % FLAGS.beam_size == 0:
                        sent_t.append(samples[ii])
                BLEU = []
                for sample in sent_t:
                    bleu, _ = bleuScore(sample, batch_ref[i])
                    BLEU.append(bleu)

                # print("BLEU: ", BLEU)
                order = np.argsort(BLEU).tolist()
                order = order[::-1]
                # print("order: ", order)
                sent_t = [sent_t[ii] for ii in order]

                last_sents.append(sent_t[0])

            if FLAGS.online_learning and idx == FLAGS.imt_step - 1:
                seqs_y = []
                for sent in last_sents:
                    sent = [BOS] + sent
                    seqs_y.append(sent)
                compute_forward(fw_nmt_model, critic, x,
                                torch.Tensor(seqs_y).long().cuda())
                seqs_y = [sent[::-1] for sent in seqs_y]
                for ii in range(len(seqs_y)):
                    seqs_y[ii][0] = BOS
                    seqs_y[ii][-1] = EOS
                compute_forward(bw_nmt_model, critic, x,
                                torch.Tensor(seqs_y).long().cuda())

        infer_progress_bar.update(batch_size_t)

    infer_progress_bar.close()
    INFO('Done. Speed: {0:.2f} words/sec'.format(
        n_words / (timer.toc(return_seconds=True))))

    translation = []
    for sent in result:
        samples = []
        for trans in sent:
            sample = []
            for w in trans:
                if w == vocab_tgt.EOS:
                    break
                sample.append(vocab_tgt.id2token(w))
            samples.append(vocab_tgt.tokenizer.detokenize(sample))
        translation.append(samples)

    origin_order = np.argsort(result_numbers).tolist()
    translation = [translation[ii] for ii in origin_order]

    keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min(
        FLAGS.beam_size, FLAGS.keep_n)
    outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)]

    with batch_open(outputs, 'w') as handles:
        for trans in translation:
            for i in range(keep_n):
                if i < len(trans):
                    handles[i].write('%s\n' % trans[i])
                else:
                    handles[i].write('%s\n' % 'eos')

    imt_translation = []
    for sent in imt_result:
        samples = []
        for trans in sent:
            sample = []
            for w in trans:
                if w == vocab_tgt.EOS:
                    break
                sample.append(w)
            samples.append(sample)
        imt_translation.append(samples)

    origin_order = np.argsort(imt_numbers).tolist()
    imt_translation = [imt_translation[ii] for ii in origin_order]
    for idx in range(FLAGS.imt_step):
        imt_constrains[idx] = [
            ' '.join(imt_constrains[idx][ii]) + '\n' for ii in origin_order
        ]

        with open('%s.cons%d' % (FLAGS.saveto, idx), 'w') as f:
            f.writelines(imt_constrains[idx])

    bleu_translation = []
    for idx, sent in enumerate(imt_translation):
        samples = []
        for ii in range(len(sent)):
            if ii % FLAGS.beam_size == 0:
                samples.append(sent[ii])
        BLEU = []
        for sample in samples:
            bleu, _ = bleuScore(sample, valid_ref[idx])
            BLEU.append(bleu)

        #print("BLEU: ", BLEU)
        order = np.argsort(BLEU).tolist()
        order = order[::-1]
        #print("order: ", order)
        samples = [vocab_tgt.ids2sent(samples[ii]) for ii in order]
        bleu_translation.append(samples)

    #keep_n = FLAGS.beam_size*patience if FLAGS.keep_n <= 0 else min(FLAGS.beam_size*patience, FLAGS.keep_n)
    keep_n = patience
    outputs = ['%s.imt%d' % (FLAGS.saveto, i) for i in range(keep_n)]

    with batch_open(outputs, 'w') as handles:
        for trans in bleu_translation:
            for i in range(keep_n):
                if i < len(trans):
                    handles[i].write('%s\n' % trans[i])
                else:
                    handles[i].write('%s\n' % 'eos')
Exemplo n.º 9
0
    def __init__(self, n_src_vocab, n_tgt_vocab, **config):
        super().__init__()

        self.config = config

        self.encoder = Encoder(
            n_src_vocab,
            n_layers=config["n_layers"],
            n_head=config["n_head"],
            d_word_vec=config["d_word_vec"],
            d_model=config["d_model"],
            d_inner_hid=config["d_inner_hid"],
            dropout=config["dropout"],
        )

        self.decoder = Decoder(
            n_tgt_vocab,
            n_layers=config["n_layers"],
            n_head=config["n_head"],
            d_word_vec=config["d_word_vec"],
            d_model=config["d_model"],
            d_inner_hid=config["d_inner_hid"],
            dropout=config["dropout"],
            # capsule configs
            capsule_type=config["capsule_type"],
            comb_type=config.setdefault("comb_type", "ffn"),
            routing_type=config.setdefault("routing_type", "dynamic_routing"),
            dim_capsule=config["d_capsule"],
            num_capsules=config["num_capsules"],
            null_capsule=config["null_capsule"])

        assert config["d_model"] == config["d_word_vec"], \
            'To facilitate the residual connections, \
             the dimensions of all module output shall be the same.'

        self.generator = Generator(n_words=n_tgt_vocab,
                                   hidden_size=config["d_word_vec"],
                                   padding_idx=PAD)
        if config["proj_share_weight"]:
            self.generator.proj.weight = self.decoder.embeddings.embeddings.weight

        if config.setdefault("tie_source_target_embeddings", False):
            self.encoder.embeddings.embeddings.weight = self.decoder.embeddings.embeddings.weight

        self.capsule_per_dim = config['d_capsule'] // 3
        # if self.config["apply_word_prediction_loss"]:
        #     self.wp_past = WordPredictor(generator=self.generator,
        #                                  input_size=self.capsule_per_dim,
        #                                  d_word_vec=config['d_word_vec'])
        #     self.wp_future = WordPredictor(generator=self.generator,
        #                                    input_size=self.capsule_per_dim,
        #                                    d_word_vec=config['d_word_vec'])

        # criterion
        self.criterion = MultiCriterion(
            weights=dict(nmt_nll=1.),
            nmt_nll=NMTCriterion(label_smoothing=config['label_smoothing']))

        if "wp" in config["auxiliary_loss"] or self.config[
                "apply_word_prediction_loss"]:
            self.config["apply_word_prediction_loss"] = True
            self.wp_present = WordPredictor(generator=self.generator,
                                            input_size=self.capsule_per_dim,
                                            d_word_vec=config['d_word_vec'])
            self.criterion.add(
                "wploss_present",
                NMTCriterion(label_smoothing=config['label_smoothing']),
                weight=1.)
            self.wp_past = WordPredictor(generator=self.generator,
                                         input_size=self.capsule_per_dim,
                                         d_word_vec=config['d_word_vec'])
            self.wp_future = WordPredictor(generator=self.generator,
                                           input_size=self.capsule_per_dim,
                                           d_word_vec=config['d_word_vec'])
            self.criterion.add("wploss_past",
                               MultiTargetNMTCriterion(
                                   label_smoothing=config['label_smoothing']),
                               weight=1.)
            self.criterion.add("wploss_future",
                               MultiTargetNMTCriterion(
                                   label_smoothing=config['label_smoothing']),
                               weight=1.)
        if "bca" in config["auxiliary_loss"]:
            self.linear_bca_past = nn.Linear(self.capsule_per_dim,
                                             config["d_model"])
            self.criterion.add("bca_past", MSELoss(), weight=1.)
            self.linear_bca_future = nn.Linear(self.capsule_per_dim,
                                               config["d_model"])
            self.criterion.add("bca_future", MSELoss(), weight=1.)