Example #1
0
def evaluate_valid_loss(data_loader, model, opt):
    model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0

    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            # load one2many dataset
            src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
            num_trgs = [
                len(trg_str_list) for trg_str_list in trg_str_2dlist
            ]  # a list of num of targets in each batch, with len=batch_size

            max_num_oov = max([len(oov) for oov in oov_lists
                               ])  # max number of oov for each batch

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)

            start_time = time.time()
            decoder_dist, attention_dist = model(src, src_lens, trg, src_oov,
                                                 max_num_oov, src_mask,
                                                 num_trgs)
            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()

            loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += sum(trg_lens)

    eval_loss_stat = LossStatistics(evaluation_loss_sum,
                                    total_trg_tokens,
                                    n_batch,
                                    forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat
Example #2
0
def evaluate_loss(data_loader, model, opt):
    model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0

    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            if not opt.one2many:  # load one2one dataset
                src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, title, title_oov, title_lens, title_mask = batch
            else:  # load one2many dataset
                src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _, title, title_oov, title_lens, title_mask = batch
                num_trgs = [
                    len(trg_str_list) for trg_str_list in trg_str_2dlist
                ]  # a list of num of targets in each batch, with len=batch_size

            max_num_oov = max([len(oov) for oov in oov_lists
                               ])  # max number of oov for each batch

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)
            if opt.title_guided:
                title = title.to(opt.device)
                title_mask = title_mask.to(opt.device)
                # title_oov = title_oov.to(opt.device)

            start_time = time.time()
            if not opt.one2many:
                decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ = model(
                    src,
                    src_lens,
                    trg,
                    src_oov,
                    max_num_oov,
                    src_mask,
                    title=title,
                    title_lens=title_lens,
                    title_mask=title_mask)
            else:
                decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ = model(
                    src,
                    src_lens,
                    trg,
                    src_oov,
                    max_num_oov,
                    src_mask,
                    num_trgs,
                    title=title,
                    title_lens=title_lens,
                    title_mask=title_mask)
            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()
            if opt.copy_attention:  # Compute the loss using target with oov words
                loss = masked_cross_entropy(decoder_dist,
                                            trg_oov,
                                            trg_mask,
                                            trg_lens,
                                            opt.coverage_attn,
                                            coverage,
                                            attention_dist,
                                            opt.lambda_coverage,
                                            coverage_loss=False)
            else:  # Compute the loss using target without oov words
                loss = masked_cross_entropy(decoder_dist,
                                            trg,
                                            trg_mask,
                                            trg_lens,
                                            opt.coverage_attn,
                                            coverage,
                                            attention_dist,
                                            opt.lambda_coverage,
                                            coverage_loss=False)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += sum(trg_lens)

    eval_loss_stat = LossStatistics(evaluation_loss_sum,
                                    total_trg_tokens,
                                    n_batch,
                                    forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat
Example #3
0
def train_model(model, optimizer_ml, train_data_loader, valid_data_loader,
                opt):

    logging.info(
        '======================  Start Training  =========================')

    total_batch = -1
    early_stop_flag = False

    total_train_loss_statistics = LossStatistics()
    report_train_loss_statistics = LossStatistics()
    report_train_ppl = []
    report_valid_ppl = []
    report_train_loss = []
    report_valid_loss = []
    best_valid_ppl = float('inf')
    best_valid_loss = float('inf')
    num_stop_dropping = 0

    model.train()

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        if early_stop_flag:
            break

        for batch_i, batch in enumerate(train_data_loader):
            total_batch += 1

            # Training
            batch_loss_stat, decoder_dist = train_one_batch(
                batch, model, optimizer_ml, opt, batch_i)
            report_train_loss_statistics.update(batch_loss_stat)
            total_train_loss_statistics.update(batch_loss_stat)

            # Checkpoint, decay the learning rate if validation loss stop dropping, apply early stopping if stop decreasing for several epochs.
            # Save the model parameters if the validation loss improved.
            if total_batch % 4000 == 0:
                print("Epoch %d; batch: %d; total batch: %d" %
                      (epoch, batch_i, total_batch))
                sys.stdout.flush()

            if epoch >= opt.start_checkpoint_at:
                if (opt.checkpoint_interval == -1 and batch_i == len(train_data_loader) - 1) or \
                        (opt.checkpoint_interval > -1 and total_batch > 1 and total_batch % opt.checkpoint_interval == 0):

                    # test the model on the validation dataset for one epoch
                    valid_loss_stat = evaluate_valid_loss(
                        valid_data_loader, model, opt)
                    model.train()
                    current_valid_loss = valid_loss_stat.xent()
                    current_valid_ppl = valid_loss_stat.ppl()
                    print("Enter check point!")
                    sys.stdout.flush()

                    current_train_ppl = report_train_loss_statistics.ppl()
                    current_train_loss = report_train_loss_statistics.xent()

                    # debug
                    if math.isnan(current_valid_loss) or math.isnan(
                            current_train_loss):
                        logging.info(
                            "NaN valid loss. Epoch: %d; batch_i: %d, total_batch: %d"
                            % (epoch, batch_i, total_batch))
                        exit()
                    # update the best valid loss and save the model parameters
                    if current_valid_loss < best_valid_loss:
                        print("Valid loss drops")
                        sys.stdout.flush()
                        best_valid_loss = current_valid_loss
                        best_valid_ppl = current_valid_ppl
                        num_stop_dropping = 0

                        check_pt_model_path = os.path.join(
                            opt.model_path,
                            '%s.epoch=%d.batch=%d.total_batch=%d' %
                            (opt.exp, epoch, batch_i, total_batch) + '.model')
                        torch.save(  # save model parameters
                            model.state_dict(), open(check_pt_model_path,
                                                     'wb'))
                        logging.info('Saving checkpoint to %s' %
                                     check_pt_model_path)

                    else:
                        print("Valid loss does not drop")
                        sys.stdout.flush()
                        num_stop_dropping += 1
                        # decay the learning rate by a factor
                        for i, param_group in enumerate(
                                optimizer_ml.param_groups):
                            old_lr = float(param_group['lr'])
                            new_lr = old_lr * opt.learning_rate_decay
                            if old_lr - new_lr > EPS:
                                param_group['lr'] = new_lr

                    logging.info(
                        'Epoch: %d; batch idx: %d; total batches: %d' %
                        (epoch, batch_i, total_batch))
                    logging.info(
                        'avg training ppl: %.3f; avg validation ppl: %.3f; best validation ppl: %.3f'
                        %
                        (current_train_ppl, current_valid_ppl, best_valid_ppl))
                    logging.info(
                        'avg training loss: %.3f; avg validation loss: %.3f; best validation loss: %.3f'
                        % (current_train_loss, current_valid_loss,
                           best_valid_loss))

                    report_train_ppl.append(current_train_ppl)
                    report_valid_ppl.append(current_valid_ppl)
                    report_train_loss.append(current_train_loss)
                    report_valid_loss.append(current_valid_loss)

                    if num_stop_dropping >= opt.early_stop_tolerance:
                        logging.info(
                            'Have not increased for %d check points, early stop training'
                            % num_stop_dropping)
                        early_stop_flag = True
                        break
                    report_train_loss_statistics.clear()

    # export the training curve
    train_valid_curve_path = opt.exp_path + '/train_valid_curve'
    export_train_and_valid_loss(report_train_loss, report_valid_loss,
                                report_train_ppl, report_valid_ppl,
                                opt.checkpoint_interval,
                                train_valid_curve_path)
Example #4
0
def train_one_batch(batch, model, optimizer, opt, batch_i):
    # load one2many data
    """
    src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
    src_lens: a list containing the length of src sequences for each batch, with len=batch
    src_mask: a FloatTensor, [batch, src_seq_len]
    src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
    trg: LongTensor [batch, trg_seq_len], each target trg[i] contains the indices of a set of concatenated keyphrases, separated by opt.word2idx[pykp.io.SEP_WORD]
         if opt.delimiter_type = 0, SEP_WORD=<sep>, if opt.delimiter_type = 1, SEP_WORD=<eok>
    trg_lens: a list containing the length of trg sequences for each batch, with len=batch
    trg_mask: a FloatTensor, [batch, trg_seq_len]
    trg_oov: same as trg_oov, but all unk words are replaced with temporary idx, e.g. 50000, 50001 etc.
    """
    src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
    # a list of num of targets in each batch, with len=batch_size
    num_trgs = [len(trg_str_list) for trg_str_list in trg_str_2dlist]

    max_num_oov = max([len(oov) for oov in oov_lists
                       ])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)
    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)

    optimizer.zero_grad()

    start_time = time.time()

    decoder_dist, attention_dist = model(src,
                                         src_lens,
                                         trg,
                                         src_oov,
                                         max_num_oov,
                                         src_mask,
                                         num_trgs=num_trgs)
    forward_time = time_since(start_time)

    start_time = time.time()

    loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask)

    loss_compute_time = time_since(start_time)

    total_trg_tokens = sum(trg_lens)

    if math.isnan(loss.item()):
        print("Batch i: %d" % batch_i)
        print("src")
        print(src)
        print(src_oov)
        print(src_str_list)
        print(src_lens)
        print(src_mask)
        print("trg")
        print(trg)
        print(trg_oov)
        print(trg_str_2dlist)
        print(trg_lens)
        print(trg_mask)
        print("oov list")
        print(oov_lists)
        print("Decoder")
        print(decoder_dist)
        print(attention_dist)
        raise ValueError("Loss is NaN")

    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = src.size(0)
    else:
        raise ValueError('The type of loss normalization is invalid.')

    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    # back propagation on the normalized loss
    loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(
            model.parameters(), opt.max_grad_norm)

    optimizer.step()

    # construct a statistic object for the loss
    stat = LossStatistics(loss.item(),
                          total_trg_tokens,
                          n_batch=1,
                          forward_time=forward_time,
                          loss_compute_time=loss_compute_time,
                          backward_time=backward_time)

    return stat, decoder_dist.detach()
Example #5
0
def evaluate_loss(data_loader, model, ntm_model, opt):
    model.eval()
    ntm_model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0
    print("Evaluate loss for %d batches" % len(data_loader))
    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            if not opt.one2many:  # load one2one dataset
                src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, src_bow = batch
            else:  # load one2many dataset
                src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
                num_trgs = [len(trg_str_list) for trg_str_list in
                            trg_str_2dlist]  # a list of num of targets in each batch, with len=batch_size

            max_num_oov = max([len(oov) for oov in oov_lists])  # max number of oov for each batch

            batch_size = src.size(0)
            n_batch += batch_size

            # move data to GPU if available
            src = src.to(opt.device)
            src_mask = src_mask.to(opt.device)
            trg = trg.to(opt.device)
            trg_mask = trg_mask.to(opt.device)
            src_oov = src_oov.to(opt.device)
            trg_oov = trg_oov.to(opt.device)

            if opt.use_topic_represent:
                src_bow = src_bow.to(opt.device)
                src_bow_norm = F.normalize(src_bow)
                if opt.topic_type == 'z':
                    topic_represent, _, _, _, _ = ntm_model(src_bow_norm)
                else:
                    _, topic_represent, _, _, _ = ntm_model(src_bow_norm)
            else:
                topic_represent = None

            start_time = time.time()

            # one2one setting
            decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ \
                = model(src, src_lens, trg, src_oov, max_num_oov, src_mask, topic_represent)

            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()
            if opt.copy_attention:  # Compute the loss using target with oov words
                loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask, trg_lens,
                                            opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage,
                                            coverage_loss=False)
            else:  # Compute the loss using target without oov words
                loss = masked_cross_entropy(decoder_dist, trg, trg_mask, trg_lens,
                                            opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage,
                                            coverage_loss=False)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += sum(trg_lens)

            if (batch_i + 1) % (len(data_loader) // 5) == 0:
                print("Train: %d/%d batches, current avg loss: %.3f" %
                      ((batch_i + 1), len(data_loader), evaluation_loss_sum / total_trg_tokens))

    eval_loss_stat = LossStatistics(evaluation_loss_sum, total_trg_tokens, n_batch, forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat
Example #6
0
def train_one_batch(batch,
                    model,
                    optimizer,
                    opt,
                    batch_i,
                    source_representation_queue=None):
    if not opt.one2many:  # load one2one data
        src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, title, title_oov, title_lens, title_mask = batch
        """
        src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
        src_lens: a list containing the length of src sequences for each batch, with len=batch
        src_mask: a FloatTensor, [batch, src_seq_len]
        trg: a LongTensor containing the word indices of target sentences, [batch, trg_seq_len]
        trg_lens: a list containing the length of trg sequences for each batch, with len=batch
        trg_mask: a FloatTensor, [batch, trg_seq_len]
        src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
        trg_oov: a LongTensor containing the word indices of target sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
        """
    else:  # load one2many data
        src, src_lens, src_mask, src_oov, oov_lists, src_str_list, trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _, title, title_oov, title_lens, title_mask = batch
        num_trgs = [
            len(trg_str_list) for trg_str_list in trg_str_2dlist
        ]  # a list of num of targets in each batch, with len=batch_size
        """
        trg: LongTensor [batch, trg_seq_len], each target trg[i] contains the indices of a set of concatenated keyphrases, separated by opt.word2idx[pykp.io.SEP_WORD]
             if opt.delimiter_type = 0, SEP_WORD=<sep>, if opt.delimiter_type = 1, SEP_WORD=<eos>
        trg_oov: same as trg_oov, but all unk words are replaced with temporary idx, e.g. 50000, 50001 etc.
        """
    batch_size = src.size(0)
    max_num_oov = max([len(oov) for oov in oov_lists
                       ])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)
    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)
    if opt.title_guided:
        title = title.to(opt.device)
        title_mask = title_mask.to(opt.device)
        #title_oov = title_oov.to(opt.device)
    # title, title_oov, title_lens, title_mask

    optimizer.zero_grad()

    #if opt.one2many_mode == 0 or opt.one2many_mode == 1:
    start_time = time.time()

    if opt.use_target_encoder:  # Sample encoder representations
        if len(source_representation_queue
               ) < opt.source_representation_sample_size:
            source_representation_samples_2dlist = None
            source_representation_target_list = None
        else:
            source_representation_samples_2dlist = []
            source_representation_target_list = []
            for i in range(batch_size):
                # N encoder representation from the queue
                source_representation_samples_list = source_representation_queue.sample(
                    opt.source_representation_sample_size)
                # insert a place-holder for the ground-truth source representation to a random index
                place_holder_idx = np.random.randint(
                    0, opt.source_representation_sample_size + 1)
                source_representation_samples_list.insert(
                    place_holder_idx, None)  # len=N+1
                # insert the sample list of one batch to the 2d list
                source_representation_samples_2dlist.append(
                    source_representation_samples_list)
                # store the idx of place-holder for that batch
                source_representation_target_list.append(place_holder_idx)
    else:
        source_representation_samples_2dlist = None
        source_representation_target_list = None
        """
        if encoder_representation_samples_2dlist[0] is None and batch_i > math.ceil(
                opt.encoder_representation_sample_size / batch_size):
            # a return value of none indicates we don't have sufficient samples
            # it will only occurs in the first few training steps
            raise ValueError("encoder_representation_samples should not be none at this batch!")
        """

    if not opt.one2many:
        decoder_dist, h_t, attention_dist, encoder_final_state, coverage, delimiter_decoder_states, delimiter_decoder_states_lens, source_classification_dist = model(
            src,
            src_lens,
            trg,
            src_oov,
            max_num_oov,
            src_mask,
            sampled_source_representation_2dlist=
            source_representation_samples_2dlist,
            source_representation_target_list=source_representation_target_list,
            title=title,
            title_lens=title_lens,
            title_mask=title_mask)
    else:
        decoder_dist, h_t, attention_dist, encoder_final_state, coverage, delimiter_decoder_states, delimiter_decoder_states_lens, source_classification_dist = model(
            src,
            src_lens,
            trg,
            src_oov,
            max_num_oov,
            src_mask,
            num_trgs=num_trgs,
            sampled_source_representation_2dlist=
            source_representation_samples_2dlist,
            source_representation_target_list=source_representation_target_list,
            title=title,
            title_lens=title_lens,
            title_mask=title_mask)
    forward_time = time_since(start_time)

    if opt.use_target_encoder:  # Put all the encoder final states to the queue. Need to call detach() first
        # encoder_final_state: [batch, memory_bank_size]
        [
            source_representation_queue.put(encoder_final_state[i, :].detach())
            for i in range(batch_size)
        ]

    start_time = time.time()
    if opt.copy_attention:  # Compute the loss using target with oov words
        loss = masked_cross_entropy(
            decoder_dist, trg_oov, trg_mask, trg_lens, opt.coverage_attn,
            coverage, attention_dist, opt.lambda_coverage, opt.coverage_loss,
            delimiter_decoder_states, opt.orthogonal_loss,
            opt.lambda_orthogonal, delimiter_decoder_states_lens)
    else:  # Compute the loss using target without oov words
        loss = masked_cross_entropy(
            decoder_dist, trg, trg_mask, trg_lens, opt.coverage_attn, coverage,
            attention_dist, opt.lambda_coverage, opt.coverage_loss,
            delimiter_decoder_states, opt.orthogonal_loss,
            opt.lambda_orthogonal, delimiter_decoder_states_lens)

    loss_compute_time = time_since(start_time)

    #else:  # opt.one2many_mode == 2
    #    forward_time = 0
    #    loss_compute_time = 0
    #    # TODO: a for loop to accumulate loss for each keyphrase
    #    # TODO: meanwhile, accumulate the forward time and loss_compute time
    #    pass

    total_trg_tokens = sum(trg_lens)

    if math.isnan(loss.item()):
        print("Batch i: %d" % batch_i)
        print("src")
        print(src)
        print(src_oov)
        print(src_str_list)
        print(src_lens)
        print(src_mask)
        print("trg")
        print(trg)
        print(trg_oov)
        print(trg_str_2dlist)
        print(trg_lens)
        print(trg_mask)
        print("oov list")
        print(oov_lists)
        print("Decoder")
        print(decoder_dist)
        print(h_t)
        print(attention_dist)
        raise ValueError("Loss is NaN")

    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = src.size(0)
    else:
        raise ValueError('The type of loss normalization is invalid.')

    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    # back propagation on the normalized loss
    loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(
            model.parameters(), opt.max_grad_norm)
        # grad_norm_after_clipping = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2)
        # logging.info('clip grad (%f -> %f)' % (grad_norm_before_clipping, grad_norm_after_clipping))

    optimizer.step()

    # Compute target encoder loss
    if opt.use_target_encoder and source_classification_dist is not None:
        start_time = time.time()
        optimizer.zero_grad()
        # convert source_representation_target_list to a LongTensor with size=[batch_size, max_num_delimiters]
        max_num_delimiters = delimiter_decoder_states.size(2)
        source_representation_target = torch.LongTensor(
            source_representation_target_list).to(trg.device)  # [batch_size]
        # expand along the second dimension, since for the target for each delimiter states in the same batch are the same
        source_representation_target = source_representation_target.view(
            -1,
            1).repeat(1,
                      max_num_delimiters)  # [batch_size, max_num_delimiters]
        # mask for source representation classification
        source_representation_target_mask = torch.zeros(
            batch_size, max_num_delimiters).to(trg.device)
        for i in range(batch_size):
            source_representation_target_mask[
                i, :delimiter_decoder_states_lens[i]].fill_(1)
        # compute the masked loss
        loss_te = masked_cross_entropy(source_classification_dist,
                                       source_representation_target,
                                       source_representation_target_mask)
        loss_compute_time += time_since(start_time)
        # back propagation on the normalized loss
        start_time = time.time()
        loss_te.div(normalization).backward()
        backward_time += time_since(start_time)

        if opt.max_grad_norm > 0:
            grad_norm_before_clipping = nn.utils.clip_grad_norm_(
                model.parameters(), opt.max_grad_norm)

        optimizer.step()

    # construct a statistic object for the loss
    stat = LossStatistics(loss.item(),
                          total_trg_tokens,
                          n_batch=1,
                          forward_time=forward_time,
                          loss_compute_time=loss_compute_time,
                          backward_time=backward_time)

    return stat, decoder_dist.detach()
Example #7
0
def train_model(model, optimizer_ml, optimizer_rl, criterion,
                train_data_loader, valid_data_loader, opt):
    '''
    generator = SequenceGenerator(model,
                                  eos_idx=opt.word2idx[pykp.io.EOS_WORD],
                                  beam_size=opt.beam_size,
                                  max_sequence_length=opt.max_sent_length
                                  )
    '''
    logging.info(
        '======================  Start Training  =========================')

    total_batch = -1
    early_stop_flag = False

    total_train_loss_statistics = LossStatistics()
    report_train_loss_statistics = LossStatistics()
    report_train_ppl = []
    report_valid_ppl = []
    report_train_loss = []
    report_valid_loss = []
    best_valid_ppl = float('inf')
    best_valid_loss = float('inf')
    num_stop_dropping = 0

    if opt.use_target_encoder:
        source_representation_queue = SourceRepresentationQueue(
            opt.source_representation_queue_size)
    else:
        source_representation_queue = None

    if opt.train_from:  # opt.train_from:
        #TODO: load the training state
        raise ValueError(
            "Not implemented the function of load from trained model")
        pass

    model.train()

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        if early_stop_flag:
            break

        # TODO: progress bar
        #progbar = Progbar(logger=logging, title='Training', target=len(train_data_loader), batch_size=train_data_loader.batch_size,total_examples=len(train_data_loader.dataset.examples))

        for batch_i, batch in enumerate(train_data_loader):
            total_batch += 1

            # Training
            if opt.train_ml:
                batch_loss_stat, decoder_dist = train_one_batch(
                    batch, model, optimizer_ml, opt, batch_i,
                    source_representation_queue)
                report_train_loss_statistics.update(batch_loss_stat)
                total_train_loss_statistics.update(batch_loss_stat)
                #logging.info("one_batch")
                #report_loss.append(('train_ml_loss', loss_ml))
                #report_loss.append(('PPL', loss_ml))

                # Brief report
                '''
                if batch_i % opt.report_every == 0:
                    brief_report(epoch, batch_i, one2one_batch, loss_ml, decoder_log_probs, opt)
                '''

            #progbar.update(epoch, batch_i, report_loss)

            # Checkpoint, decay the learning rate if validation loss stop dropping, apply early stopping if stop decreasing for several epochs.
            # Save the model parameters if the validation loss improved.
            if total_batch % 4000 == 0:
                print("Epoch %d; batch: %d; total batch: %d" %
                      (epoch, batch_i, total_batch))
                sys.stdout.flush()

            if epoch >= opt.start_checkpoint_at:
                if (opt.checkpoint_interval == -1 and batch_i == len(train_data_loader) - 1) or \
                        (opt.checkpoint_interval > -1 and total_batch > 1 and total_batch % opt.checkpoint_interval == 0):
                    if opt.train_ml:
                        # test the model on the validation dataset for one epoch
                        valid_loss_stat = evaluate_loss(
                            valid_data_loader, model, opt)
                        model.train()
                        current_valid_loss = valid_loss_stat.xent()
                        current_valid_ppl = valid_loss_stat.ppl()
                        print("Enter check point!")
                        sys.stdout.flush()

                        current_train_ppl = report_train_loss_statistics.ppl()
                        current_train_loss = report_train_loss_statistics.xent(
                        )

                        # debug
                        if math.isnan(current_valid_loss) or math.isnan(
                                current_train_loss):
                            logging.info(
                                "NaN valid loss. Epoch: %d; batch_i: %d, total_batch: %d"
                                % (epoch, batch_i, total_batch))
                            exit()

                        if current_valid_loss < best_valid_loss:  # update the best valid loss and save the model parameters
                            print("Valid loss drops")
                            sys.stdout.flush()
                            best_valid_loss = current_valid_loss
                            best_valid_ppl = current_valid_ppl
                            num_stop_dropping = 0

                            check_pt_model_path = os.path.join(
                                opt.model_path,
                                '%s.epoch=%d.batch=%d.total_batch=%d' %
                                (opt.exp, epoch, batch_i, total_batch) +
                                '.model')
                            torch.save(  # save model parameters
                                model.state_dict(),
                                open(check_pt_model_path, 'wb'))
                            logging.info('Saving checkpoint to %s' %
                                         check_pt_model_path)

                        else:
                            print("Valid loss does not drop")
                            sys.stdout.flush()
                            num_stop_dropping += 1
                            # decay the learning rate by a factor
                            for i, param_group in enumerate(
                                    optimizer_ml.param_groups):
                                old_lr = float(param_group['lr'])
                                new_lr = old_lr * opt.learning_rate_decay
                                if old_lr - new_lr > EPS:
                                    param_group['lr'] = new_lr

                        # log loss, ppl, and time
                        #print("check point!")
                        #sys.stdout.flush()
                        logging.info(
                            'Epoch: %d; batch idx: %d; total batches: %d' %
                            (epoch, batch_i, total_batch))
                        logging.info(
                            'avg training ppl: %.3f; avg validation ppl: %.3f; best validation ppl: %.3f'
                            % (current_train_ppl, current_valid_ppl,
                               best_valid_ppl))
                        logging.info(
                            'avg training loss: %.3f; avg validation loss: %.3f; best validation loss: %.3f'
                            % (current_train_loss, current_valid_loss,
                               best_valid_loss))

                        report_train_ppl.append(current_train_ppl)
                        report_valid_ppl.append(current_valid_ppl)
                        report_train_loss.append(current_train_loss)
                        report_valid_loss.append(current_valid_loss)

                        if num_stop_dropping >= opt.early_stop_tolerance:
                            logging.info(
                                'Have not increased for %d check points, early stop training'
                                % num_stop_dropping)
                            early_stop_flag = True
                            break
                        report_train_loss_statistics.clear()

    # export the training curve
    train_valid_curve_path = opt.exp_path + '/train_valid_curve'
    export_train_and_valid_loss(report_train_loss, report_valid_loss,
                                report_train_ppl, report_valid_ppl,
                                opt.checkpoint_interval,
                                train_valid_curve_path)
Example #8
0
def train_model(model, ntm_model, optimizer_ml, optimizer_ntm, optimizer_whole, train_data_loader, valid_data_loader,
                bow_dictionary, train_bow_loader, valid_bow_loader, opt):
    logging.info('======================  Start Training  =========================')

    if opt.only_train_ntm or (opt.use_topic_represent and not opt.load_pretrain_ntm):
        print("\nWarming up ntm for %d epochs" % opt.ntm_warm_up_epochs)
        for epoch in range(1, opt.ntm_warm_up_epochs + 1):
            sparsity = train_ntm_one_epoch(ntm_model, train_bow_loader, optimizer_ntm, opt, epoch)
            val_loss = test_ntm_one_epoch(ntm_model, valid_bow_loader, opt, epoch)
            if epoch % 10 == 0:
                ntm_model.print_topic_words(bow_dictionary, os.path.join(opt.model_path, 'topwords_e%d.txt' % epoch))
                best_ntm_model_path = os.path.join(opt.model_path, 'e%d.val_loss=%.3f.sparsity=%.3f.ntm_model' %
                                                   (epoch, val_loss, sparsity))
                logging.info("\nSaving warm up ntm model into %s" % best_ntm_model_path)
                torch.save(ntm_model.state_dict(), open(best_ntm_model_path, 'wb'))
    elif opt.use_topic_represent:
        print("Loading ntm model from %s" % opt.check_pt_ntm_model_path)
        ntm_model.load_state_dict(torch.load(opt.check_pt_ntm_model_path))

    if opt.only_train_ntm:
        return

    total_batch = 0
    total_train_loss_statistics = LossStatistics()
    report_train_loss_statistics = LossStatistics()
    report_train_ppl = []
    report_valid_ppl = []
    report_train_loss = []
    report_valid_loss = []
    best_valid_ppl = float('inf')
    best_valid_loss = float('inf')
    best_ntm_valid_loss = float('inf')
    joint_train_patience = 1
    ntm_train_patience = 1
    global_patience = 5
    num_stop_dropping = 0
    num_stop_dropping_ntm = 0
    num_stop_dropping_global = 0

    t0 = time.time()
    Train_Seq2seq = True
    begin_iterate_train_ntm = opt.iterate_train_ntm
    check_pt_model_path = ""
    print("\nEntering main training for %d epochs" % opt.epochs)
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        if Train_Seq2seq:
            if epoch <= opt.p_seq2seq_e or not opt.joint_train:
                optimizer = optimizer_ml
                model.train()
                ntm_model.eval()
                logging.info("\nTraining seq2seq epoch: {}/{}".format(epoch, opt.epochs))
            elif begin_iterate_train_ntm:
                optimizer = optimizer_ntm
                model.train()
                ntm_model.train()
                fix_model(model)
                logging.info("\nTraining ntm epoch: {}/{}".format(epoch, opt.epochs))
                begin_iterate_train_ntm = False
            else:
                optimizer = optimizer_whole
                unfix_model(model)
                model.train()
                ntm_model.train()
                logging.info("\nTraining seq2seq+ntm epoch: {}/{}".format(epoch, opt.epochs))
                if opt.iterate_train_ntm:
                    begin_iterate_train_ntm = True

            logging.info("The total num of batches: %d, current learning rate:%.6f" %
                         (len(train_data_loader), optimizer.param_groups[0]['lr']))

            for batch_i, batch in enumerate(train_data_loader):
                total_batch += 1
                batch_loss_stat, _ = train_one_batch(batch, model, ntm_model, optimizer, opt, batch_i)
                report_train_loss_statistics.update(batch_loss_stat)
                total_train_loss_statistics.update(batch_loss_stat)

                if (batch_i + 1) % (len(train_data_loader) // 10) == 0:
                    print("Train: %d/%d batches, current avg loss: %.3f" %
                          ((batch_i + 1), len(train_data_loader), batch_loss_stat.xent()))

            current_train_ppl = report_train_loss_statistics.ppl()
            current_train_loss = report_train_loss_statistics.xent()

            # test the model on the validation dataset for one epoch
            model.eval()
            valid_loss_stat = evaluate_loss(valid_data_loader, model, ntm_model, opt)
            current_valid_loss = valid_loss_stat.xent()
            current_valid_ppl = valid_loss_stat.ppl()

            # debug
            if math.isnan(current_valid_loss) or math.isnan(current_train_loss):
                logging.info(
                    "NaN valid loss. Epoch: %d; batch_i: %d, total_batch: %d" % (epoch, batch_i, total_batch))
                exit()

            if current_valid_loss < best_valid_loss:  # update the best valid loss and save the model parameters
                print("Valid loss drops")
                sys.stdout.flush()
                best_valid_loss = current_valid_loss
                best_valid_ppl = current_valid_ppl
                num_stop_dropping = 0
                num_stop_dropping_global = 0
                if epoch >= opt.start_checkpoint_at and epoch > opt.p_seq2seq_e and not opt.save_each_epoch:
                    check_pt_model_path = os.path.join(opt.model_path, 'e%d.val_loss=%.3f.model-%s' %
                                                       (epoch, current_valid_loss, convert_time2str(time.time() - t0)))
                    # save model parameters
                    torch.save(
                        model.state_dict(),
                        open(check_pt_model_path, 'wb')
                    )
                    logging.info('Saving seq2seq checkpoints to %s' % check_pt_model_path)

                    if opt.joint_train:
                        check_pt_ntm_model_path = check_pt_model_path.replace('.model-', '.model_ntm-')
                        # save model parameters
                        torch.save(
                            ntm_model.state_dict(),
                            open(check_pt_ntm_model_path, 'wb')
                        )
                        logging.info('Saving ntm checkpoints to %s' % check_pt_ntm_model_path)
            else:
                print("Valid loss does not drop")
                sys.stdout.flush()
                num_stop_dropping += 1
                num_stop_dropping_global += 1
                # decay the learning rate by a factor
                for i, param_group in enumerate(optimizer.param_groups):
                    old_lr = float(param_group['lr'])
                    new_lr = old_lr * opt.learning_rate_decay
                    if old_lr - new_lr > EPS:
                        param_group['lr'] = new_lr
                        print("The new learning rate for seq2seq is decayed to %.6f" % new_lr)

            if opt.save_each_epoch:
                check_pt_model_path = os.path.join(opt.model_path, 'e%d.train_loss=%.3f.val_loss=%.3f.model-%s' %
                                                   (epoch, current_train_loss, current_valid_loss,
                                                    convert_time2str(time.time() - t0)))
                torch.save(  # save model parameters
                    model.state_dict(),
                    open(check_pt_model_path, 'wb')
                )
                logging.info('Saving seq2seq checkpoints to %s' % check_pt_model_path)

                if opt.joint_train:
                    check_pt_ntm_model_path = check_pt_model_path.replace('.model-', '.model_ntm-')
                    torch.save(  # save model parameters
                        ntm_model.state_dict(),
                        open(check_pt_ntm_model_path, 'wb')
                    )
                    logging.info('Saving ntm checkpoints to %s' % check_pt_ntm_model_path)

            # log loss, ppl, and time

            logging.info('Epoch: %d; Time spent: %.2f' % (epoch, time.time() - t0))
            logging.info(
                'avg training ppl: %.3f; avg validation ppl: %.3f; best validation ppl: %.3f' % (
                    current_train_ppl, current_valid_ppl, best_valid_ppl))
            logging.info(
                'avg training loss: %.3f; avg validation loss: %.3f; best validation loss: %.3f' % (
                    current_train_loss, current_valid_loss, best_valid_loss))

            report_train_ppl.append(current_train_ppl)
            report_valid_ppl.append(current_valid_ppl)
            report_train_loss.append(current_train_loss)
            report_valid_loss.append(current_valid_loss)

            report_train_loss_statistics.clear()

            if not opt.save_each_epoch and num_stop_dropping >= opt.early_stop_tolerance:  # not opt.joint_train or
                logging.info('Have not increased for %d check points, early stop training' % num_stop_dropping)

                break

                # if num_stop_dropping_global >= global_patience and opt.joint_train:
                #     logging.info('Reach global stoping dropping patience: %d' % global_patience)
                #     break

                # if num_stop_dropping >= joint_train_patience and opt.joint_train:
                #     Train_Seq2seq = False
                # num_stop_dropping_ntm = 0
                # break

        # else:
        #     logging.info("\nTraining ntm epoch: {}/{}".format(epoch, opt.epochs))
        #     logging.info("The total num of batches: {}".format(len(train_bow_loader)))
        #     sparsity = train_ntm_one_epoch(ntm_model, train_bow_loader, optimizer_ntm, opt, epoch)
        #     val_loss = test_ntm_one_epoch(ntm_model, valid_bow_loader, opt, epoch)
        #     if val_loss < best_ntm_valid_loss:
        #         print('Ntm loss drops...')
        #         best_ntm_valid_loss = val_loss
        #         num_stop_dropping_ntm = 0
        #         num_stop_dropping_global = 0
        #     else:
        #         print('Ntm loss does not drop...')
        #         num_stop_dropping_ntm += 1
        #         num_stop_dropping_global += 1
        #
        #     if num_stop_dropping_global > global_patience:
        #         logging.info('Reach global stoping dropping patience: %d' % global_patience)
        #         break
        #
        #     if num_stop_dropping_ntm >= ntm_train_patience:
        #         Train_Seq2seq = True
        #         num_stop_dropping = 0
        #         # continue
        #
        #     if opt.joint_train:
        #         ntm_model.print_topic_words(bow_dictionary, os.path.join(opt.model_path, 'topwords_e%d.txt' % epoch))

    return check_pt_model_path
Example #9
0
def train_one_batch(batch, model, ntm_model, optimizer, opt, batch_i):
    # train for one batch
    src, src_lens, src_mask, trg, trg_lens, trg_mask, src_oov, trg_oov, oov_lists, src_bow = batch
    max_num_oov = max([len(oov) for oov in oov_lists])  # max number of oov for each batch

    # move data to GPU if available
    src = src.to(opt.device)
    src_mask = src_mask.to(opt.device)
    trg = trg.to(opt.device)
    trg_mask = trg_mask.to(opt.device)
    src_oov = src_oov.to(opt.device)
    trg_oov = trg_oov.to(opt.device)

    # model.train()
    optimizer.zero_grad()

    if opt.use_topic_represent:
        src_bow = src_bow.to(opt.device)
        src_bow_norm = F.normalize(src_bow)
        if opt.topic_type == 'z':
            topic_represent, _, recon_batch, mu, logvar = ntm_model(src_bow_norm)
        else:
            _, topic_represent, recon_batch, mu, logvar = ntm_model(src_bow_norm)

        if opt.add_two_loss:
            ntm_loss = loss_function(recon_batch, src_bow, mu, logvar)
    else:
        topic_represent = None

    start_time = time.time()

    # for one2one setting
    decoder_dist, h_t, attention_dist, encoder_final_state, coverage, _, _, _ \
        = model(src, src_lens, trg, src_oov, max_num_oov, src_mask, topic_represent)

    forward_time = time_since(start_time)

    start_time = time.time()
    if opt.copy_attention:  # Compute the loss using target with oov words
        loss = masked_cross_entropy(decoder_dist, trg_oov, trg_mask, trg_lens,
                                    opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage, opt.coverage_loss)
    else:  # Compute the loss using target without oov words
        loss = masked_cross_entropy(decoder_dist, trg, trg_mask, trg_lens,
                                    opt.coverage_attn, coverage, attention_dist, opt.lambda_coverage, opt.coverage_loss)

    loss_compute_time = time_since(start_time)

    total_trg_tokens = sum(trg_lens)

    if math.isnan(loss.item()):
        print("Batch i: %d" % batch_i)
        print("src")
        print(src)
        print(src_oov)
        print(src_lens)
        print(src_mask)
        print("trg")
        print(trg)
        print(trg_oov)
        print(trg_lens)
        print(trg_mask)
        print("oov list")
        print(oov_lists)
        print("Decoder")
        print(decoder_dist)
        print(h_t)
        print(attention_dist)
        raise ValueError("Loss is NaN")

    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = src.size(0)
    else:
        raise ValueError('The type of loss normalization is invalid.')

    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    if opt.add_two_loss:
        loss += ntm_loss
    # back propagation on the normalized loss
    loss.div(normalization).backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        grad_norm_before_clipping = nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)

    optimizer.step()

    # construct a statistic object for the loss
    stat = LossStatistics(loss.item(), total_trg_tokens, n_batch=1, forward_time=forward_time,
                          loss_compute_time=loss_compute_time, backward_time=backward_time)

    return stat, decoder_dist.detach()