def dis_pre_train_step():
    discriminator.train()
    lab_batch = next(labeled_train_loader)
    lab_token_seqs = lab_batch.content[0]
    lab_seq_lengths = np.array([len(seq) for seq in lab_token_seqs])
    labels = lab_batch.label
    lab_token_seqs = torch.from_numpy(np.transpose(lab_token_seqs.numpy()))
    labels = torch.from_numpy(np.transpose(labels.numpy()))
    num_lab_sample = lab_token_seqs.shape[1]
    lab_hidden = discriminator.init_hidden(num_lab_sample)
    lab_output = discriminator(lab_token_seqs, lab_hidden, lab_seq_lengths)
    lab_element_loss = criterion(lab_output, labels)
    lab_loss = torch.mean(lab_element_loss)
    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called.
    dis_optimizer.zero_grad()

    lab_loss.backward()

    # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), args.clip)
    dis_optimizer.step()

    return lab_loss
def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0.
    start_time = time.time()
    for i_batch, sample_batched in enumerate(train_loader):
        # the sample batched has the following information
        # {token_seqs, next_token_seqs, importance_seqs, labels, seq_lengths, pad_length}
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        token_seqs = torch.from_numpy(np.transpose(
            sample_batched[0])).to(device)
        labels = torch.from_numpy(np.transpose(sample_batched[3])).to(device)
        seq_lengths = np.transpose(sample_batched[4])
        hidden = model.init_hidden(token_seqs.shape[1])
        output = model(token_seqs, hidden, seq_lengths)
        element_loss = criterion(output, labels)
        loss = torch.mean(element_loss)
        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable
        # weights of the model). This is because by default, gradients are
        # accumulated in buffers( i.e, not overwritten) whenever .backward()
        # is called.
        optimizer.zero_grad()

        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        total_loss += loss.item()

        if i_batch % args.log_interval == 0 and i_batch > 0:
            cur_loss = total_loss / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  'loss {:5.4f} | ppl {:8.2f}'.format(
                      epoch, i_batch,
                      len(train_data) // args.batch_size,
                      elapsed * 1000 / args.log_interval, cur_loss,
                      math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
                                          collate_fn=data.collate_fn)

print('The size of the dictionary is', len(Corpus_Dic))

###############################################################################
# Build the model
###############################################################################
learning_rate = args.lr

ntokens = len(Corpus_Dic)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
                       args.nlayers, args.nclass, args.dropout_em,
                       args.dropout_rnn, args.dropout_cl, args.tied).to(device)

criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            step_size=10,
                                            gamma=args.reduce_rate)

###############################################################################
# Training code
###############################################################################


def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0.
    start_time = time.time()
    for i_batch, sample_batched in enumerate(train_loader):
def adv_train_step(judge_only=True):
    discriminator.train()
    judger.train()

    # {token_seqs, next_token_seqs, importance_seqs, labels, seq_lengths, pad_length}
    # Sample m labeled instances from DL
    lab_batch = next(labeled_train_loader)
    lab_token_seqs = lab_batch.content[0]
    lab_seq_lengths = np.array([len(seq) for seq in lab_token_seqs])
    labels = lab_batch.label
    lab_token_seqs = torch.from_numpy(np.transpose(lab_token_seqs.numpy()))
    labels = torch.from_numpy(np.transpose(labels.numpy()))
    num_lab_sample = lab_token_seqs.shape[1]

    # Sample m labeled instances from DU and predict their corresponding label
    unl_batch = next(unlabeled_train_loader)
    unl_token_seqs = unl_batch.content[0]
    unl_seq_lengths = [len(seq) for seq in unl_token_seqs]
    unl_token_seqs = torch.from_numpy(np.transpose(unl_token_seqs.numpy()))
    num_unl_sample = unl_token_seqs.shape[1]
    unl_hidden = discriminator.init_hidden(num_unl_sample)
    unl_output = discriminator(unl_token_seqs, unl_hidden, unl_seq_lengths)
    _, fake_labels = torch.max(unl_output, 1)

    if judge_only:
        k = 1
    else:
        k = 3

    for _k in range(k):
        # Update the judge model
        ###############################################################################
        lab_judge_hidden = judger.init_hidden(num_lab_sample)
        one_hot_label = one_hot_embedding(labels,
                                          args.nclass)  # one hot encoder
        lab_judge_prob = judger(lab_token_seqs, lab_judge_hidden,
                                lab_seq_lengths, one_hot_label)
        lab_labeled = torch.ones(num_lab_sample)

        unl_judge_hidden = judger.init_hidden(num_unl_sample)
        one_hot_unl = one_hot_embedding(fake_labels,
                                        args.nclass)  # one hot encoder
        unl_judge_prob = judger(unl_token_seqs, unl_judge_hidden,
                                unl_seq_lengths, one_hot_unl)
        unl_labeled = torch.zeros(num_unl_sample)

        if_labeled = torch.cat((lab_labeled, unl_labeled))
        all_judge_prob = torch.cat((lab_judge_prob, unl_judge_prob))
        all_judge_prob = all_judge_prob.view(-1)
        judge_loss = criterion_judge(all_judge_prob, if_labeled)
        judge_optimizer.zero_grad()

        judge_loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(judger.parameters(), args.clip)
        judge_optimizer.step()

        unl_loss_value = 0.0
        lab_loss_value = 0.0
        fake_labels = repackage_hidden(fake_labels)
        unl_judge_prob = repackage_hidden(unl_judge_prob)
        if not judge_only:
            # Update the predictor
            ###############################################################################
            lab_hidden = discriminator.init_hidden(num_lab_sample)
            lab_output = discriminator(lab_token_seqs, lab_hidden,
                                       lab_seq_lengths)
            lab_element_loss = criterion(lab_output, labels)
            lab_loss = torch.mean(lab_element_loss)

            # calculate loss for unlabeled instances
            unl_hidden = discriminator.init_hidden(num_unl_sample)
            unl_output = discriminator(unl_token_seqs, unl_hidden,
                                       unl_seq_lengths)
            unl_element_loss = criterion(unl_output, fake_labels)
            unl_loss = unl_element_loss.dot(
                unl_judge_prob.view(-1)) / num_unl_sample
            # do not include this in version 1
            if _k < int(k / 2):
                lab_unl_loss = lab_loss + unl_loss
            else:
                lab_unl_loss = unl_loss
            dis_optimizer.zero_grad()
            lab_unl_loss.backward()
            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(),
                                           args.clip)
            dis_optimizer.step()

            unl_loss_value = unl_loss.item()
            lab_loss_value = lab_loss.item()

    return judge_loss, unl_loss_value, lab_loss_value
ntokens, embedding_vectors, labeled_train_loader, unlabeled_train_loader, valid_loader, test_loader, labeled_data_length, unlabeled_data_length = dataset.load(
    args.embedding)

discriminator = discriminator.RNNModel(args.model, ntokens, args.emsize,
                                       args.nhid, args.nlayers, args.nclass,
                                       embedding_vectors, args.dropout_em,
                                       args.dropout_rnn, args.dropout_cl,
                                       args.tied)
judger = judge.RNNModel(args.model, ntokens, args.emsize, args.nhid,
                        args.nlayers, args.nclass, embedding_vectors,
                        args.dropout_em, args.dropout_rnn, args.dropout_cl,
                        args.tied)

criterion = nn.CrossEntropyLoss(reduction='none')
criterion_judge = nn.CrossEntropyLoss(reduction='none')
dis_optimizer = torch.optim.Adam(discriminator.parameters(),
                                 lr=dis_learning_rate,
                                 weight_decay=0.005)
dis_scheduler = torch.optim.lr_scheduler.StepLR(dis_optimizer,
                                                step_size=10,
                                                gamma=args.reduce_rate)

judge_optimizer = torch.optim.Adam(judger.parameters(),
                                   lr=judge_learning_rate,
                                   weight_decay=0.005)
judge_scheduler = torch.optim.lr_scheduler.StepLR(judge_optimizer,
                                                  step_size=5,
                                                  gamma=args.reduce_rate)

###############################################################################
# Training code