def dis_pre_train_step():
    discriminator.train()
    lab_batch = next(labeled_train_loader)
    lab_token_seqs = lab_batch.content[0]
    lab_seq_lengths = np.array([len(seq) for seq in lab_token_seqs])
    labels = lab_batch.label
    lab_token_seqs = torch.from_numpy(np.transpose(lab_token_seqs.numpy()))
    labels = torch.from_numpy(np.transpose(labels.numpy()))
    num_lab_sample = lab_token_seqs.shape[1]
    lab_hidden = discriminator.init_hidden(num_lab_sample)
    lab_output = discriminator(lab_token_seqs, lab_hidden, lab_seq_lengths)
    lab_element_loss = criterion(lab_output, labels)
    lab_loss = torch.mean(lab_element_loss)
    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called.
    dis_optimizer.zero_grad()

    lab_loss.backward()

    # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), args.clip)
    dis_optimizer.step()

    return lab_loss
def evaluate(test=False):
    # Turn on evaluate mode which disables dropout.
    correct = 0
    total = 0
    discriminator.eval()
    current_loader = valid_loader
    current_length = valid_length
    if test:
        current_loader = test_loader
        current_length = test_length
    with torch.no_grad():
        for i_batch in range(current_length):
            sample_batched = next(current_loader)
            token_seqs = sample_batched.content[0]
            seq_lengths = np.array([len(seq) for seq in token_seqs])
            labels = sample_batched.label
            token_seqs = torch.from_numpy(np.transpose(
                token_seqs.numpy())).cuda(env_settings.CUDA_DEVICE)
            labels = torch.from_numpy(np.transpose(labels.numpy())).cuda(
                env_settings.CUDA_DEVICE)
            hidden = discriminator.init_hidden(token_seqs.shape[1])
            output = discriminator(token_seqs, hidden, seq_lengths)
            _, predict_class = torch.max(output, 1)
            total += labels.size(0)
            correct += (predict_class == labels).sum().item()

            for i_metric in range(list(predict_class.size())[0]):
                metrics_handler.metricsHandler.update(
                    (predict_class.data)[i_metric].item(),
                    (labels.data)[i_metric].item())
        test_acc = 100 * correct / total
        print(
            'Accuracy of the classifier on the test data is : {:5.4f}'.format(
                test_acc))

        if test:
            output_handler.outputFileHandler.write(
                f'Test Acc: {test_acc:.2f}%\n')
            output_handler.outputFileHandler.write(
                f'Test recall: {metrics_handler.metricsHandler.getRecall():.3f}%\n'
            )
            output_handler.outputFileHandler.write(
                f'Test precision: {metrics_handler.metricsHandler.getPrecision():.3f}%\n'
            )
        else:
            output_handler.outputFileHandler.write(
                f'Valid Acc: {test_acc:.2f}%\n')
            output_handler.outputFileHandler.write(
                f'Valid recall: {metrics_handler.metricsHandler.getRecall():.3f}%\n'
            )
            output_handler.outputFileHandler.write(
                f'Valid precision: {metrics_handler.metricsHandler.getPrecision():.3f}%\n'
            )
        return correct / total
def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0.
    start_time = time.time()
    for i_batch, sample_batched in enumerate(train_loader):
        # the sample batched has the following information
        # {token_seqs, next_token_seqs, importance_seqs, labels, seq_lengths, pad_length}
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        token_seqs = torch.from_numpy(np.transpose(
            sample_batched[0])).to(device)
        labels = torch.from_numpy(np.transpose(sample_batched[3])).to(device)
        seq_lengths = np.transpose(sample_batched[4])
        hidden = model.init_hidden(token_seqs.shape[1])
        output = model(token_seqs, hidden, seq_lengths)
        element_loss = criterion(output, labels)
        loss = torch.mean(element_loss)
        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable
        # weights of the model). This is because by default, gradients are
        # accumulated in buffers( i.e, not overwritten) whenever .backward()
        # is called.
        optimizer.zero_grad()

        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        total_loss += loss.item()

        if i_batch % args.log_interval == 0 and i_batch > 0:
            cur_loss = total_loss / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  'loss {:5.4f} | ppl {:8.2f}'.format(
                      epoch, i_batch,
                      len(train_data) // args.batch_size,
                      elapsed * 1000 / args.log_interval, cur_loss,
                      math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
def evaluate():
    # Turn on evaluate mode which disables dropout.
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for i_batch, sample_batched in enumerate(test_loader):
            token_seqs = torch.from_numpy(np.transpose(
                sample_batched[0])).to(device)
            labels = torch.from_numpy(np.transpose(
                sample_batched[3])).to(device)
            seq_lengths = np.transpose(sample_batched[4])
            hidden = model.init_hidden(token_seqs.shape[1])
            output = model(token_seqs, hidden, seq_lengths)
            _, predict_class = torch.max(output, 1)
            total += labels.size(0)
            correct += (predict_class == labels).sum().item()
        print(
            'Accuracy of the classifier on the test data is : {:5.4f}'.format(
                100 * correct / total))
        return correct / total
def adv_train_step(judge_only=True):
    discriminator.train()
    judger.train()

    # {token_seqs, next_token_seqs, importance_seqs, labels, seq_lengths, pad_length}
    # Sample m labeled instances from DL
    lab_batch = next(labeled_train_loader)
    lab_token_seqs = lab_batch.content[0]
    lab_seq_lengths = np.array([len(seq) for seq in lab_token_seqs])
    labels = lab_batch.label
    lab_token_seqs = torch.from_numpy(np.transpose(lab_token_seqs.numpy()))
    labels = torch.from_numpy(np.transpose(labels.numpy()))
    num_lab_sample = lab_token_seqs.shape[1]

    # Sample m labeled instances from DU and predict their corresponding label
    unl_batch = next(unlabeled_train_loader)
    unl_token_seqs = unl_batch.content[0]
    unl_seq_lengths = [len(seq) for seq in unl_token_seqs]
    unl_token_seqs = torch.from_numpy(np.transpose(unl_token_seqs.numpy()))
    num_unl_sample = unl_token_seqs.shape[1]
    unl_hidden = discriminator.init_hidden(num_unl_sample)
    unl_output = discriminator(unl_token_seqs, unl_hidden, unl_seq_lengths)
    _, fake_labels = torch.max(unl_output, 1)

    if judge_only:
        k = 1
    else:
        k = 3

    for _k in range(k):
        # Update the judge model
        ###############################################################################
        lab_judge_hidden = judger.init_hidden(num_lab_sample)
        one_hot_label = one_hot_embedding(labels,
                                          args.nclass)  # one hot encoder
        lab_judge_prob = judger(lab_token_seqs, lab_judge_hidden,
                                lab_seq_lengths, one_hot_label)
        lab_labeled = torch.ones(num_lab_sample)

        unl_judge_hidden = judger.init_hidden(num_unl_sample)
        one_hot_unl = one_hot_embedding(fake_labels,
                                        args.nclass)  # one hot encoder
        unl_judge_prob = judger(unl_token_seqs, unl_judge_hidden,
                                unl_seq_lengths, one_hot_unl)
        unl_labeled = torch.zeros(num_unl_sample)

        if_labeled = torch.cat((lab_labeled, unl_labeled))
        all_judge_prob = torch.cat((lab_judge_prob, unl_judge_prob))
        all_judge_prob = all_judge_prob.view(-1)
        judge_loss = criterion_judge(all_judge_prob, if_labeled)
        judge_optimizer.zero_grad()

        judge_loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(judger.parameters(), args.clip)
        judge_optimizer.step()

        unl_loss_value = 0.0
        lab_loss_value = 0.0
        fake_labels = repackage_hidden(fake_labels)
        unl_judge_prob = repackage_hidden(unl_judge_prob)
        if not judge_only:
            # Update the predictor
            ###############################################################################
            lab_hidden = discriminator.init_hidden(num_lab_sample)
            lab_output = discriminator(lab_token_seqs, lab_hidden,
                                       lab_seq_lengths)
            lab_element_loss = criterion(lab_output, labels)
            lab_loss = torch.mean(lab_element_loss)

            # calculate loss for unlabeled instances
            unl_hidden = discriminator.init_hidden(num_unl_sample)
            unl_output = discriminator(unl_token_seqs, unl_hidden,
                                       unl_seq_lengths)
            unl_element_loss = criterion(unl_output, fake_labels)
            unl_loss = unl_element_loss.dot(
                unl_judge_prob.view(-1)) / num_unl_sample
            # do not include this in version 1
            if _k < int(k / 2):
                lab_unl_loss = lab_loss + unl_loss
            else:
                lab_unl_loss = unl_loss
            dis_optimizer.zero_grad()
            lab_unl_loss.backward()
            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(),
                                           args.clip)
            dis_optimizer.step()

            unl_loss_value = unl_loss.item()
            lab_loss_value = lab_loss.item()

    return judge_loss, unl_loss_value, lab_loss_value