Ejemplo n.º 1
0
def main():
    global checkpoint, waiting, best_loss, start_epoch
    # Vocabulary
    if checkpoint is None:
        print("There isn't the checkpoint.")
        exit()
    language = Language(file_path_tokens_map=file_path_tokens_map, file_path_vectors_map=file_path_vectors_map)
    vocab_size = language.get_n_tokens()

    # Dataset
    korean_dataset = KoreanDataset(file_path_data=file_path_data, file_path_tokens_map=file_path_tokens_map,
                                   max_len_sentence=max_len_sentence, max_len_morpheme=max_len_morpheme,
                                   noise=noise, continuous=continuous)
    test_loader = torch.utils.data.DataLoader(korean_dataset, batch_size=batch_size,
                                              pin_memory=True, drop_last=True)

    # Load the checkpoint
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    waiting = checkpoint['waiting']
    model = checkpoint['model']
    model_optimizer = checkpoint['model_optimizer']

    model = model.to(device)

    # Loss function
    criterion_is_noise = nn.BCELoss().to(device)
    criterion_is_next = nn.BCELoss().to(device)
    with torch.no_grad():
        mean_loss = test(test_loader=test_loader,
                         model=model,
                         criterion_is_noise=criterion_is_noise,
                         criterion_is_next=criterion_is_next)

        best_loss = min(mean_loss, best_loss)
        print('BEST LOSS:', best_loss)
Ejemplo n.º 2
0
def main():
    global checkpoint, waiting, best_loss, start_epoch
    # Vocabulary
    if checkpoint is None:
        init_vectors_map()
    language = Language(file_path_tokens_map=file_path_tokens_map, file_path_vectors_map=file_path_vectors_map)
    vocab_size = language.get_n_tokens()
    print('total vocab_size:', vocab_size)

    # Dataset
    korean_dataset = KoreanDataset(file_path_data=file_path_data, file_path_tokens_map=file_path_tokens_map,
                                   max_len_sentence=max_len_sentence, max_len_morpheme=max_len_morpheme,
                                   noise=noise, continuous=continuous)
    dataset_size = len(korean_dataset)
    print('total dataset_size:', dataset_size)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))  # split for training and validation set

    # Model
    if checkpoint is None:
        model = AnomalyKoreanDetector(len_morpheme=max_len_morpheme,
                                      len_sentence=max_len_sentence,
                                      syllable_layer_type=syllable_layer_type,
                                      syllable_num_layers=syllable_num_layers,
                                      vocab_size=vocab_size,
                                      attention_num_layer=attention_num_layer,
                                      morpheme_layer_type=morpheme_layer_type,
                                      morpheme_num_layers=morpheme_num_layers,
                                      sentence_layer_type=sentence_layer_type,
                                      sentence_num_layers=sentence_num_layers,
                                      classifier_num_layer=classifier_num_layer,
                                      embedding_size=embedding_dim,
                                      phoneme_in_size=phoneme_in_size,
                                      phoneme_out_size=phoneme_out_size,
                                      morpheme_out_size=morpheme_out_size,
                                      sentence_out_size=sentence_out_size,
                                      attention_type=attention_type)

        # Optimizer
        model_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=model_lr)
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        waiting = checkpoint['waiting']
        model = checkpoint['model']
        model_optimizer = checkpoint['model_optimizer']

    model = model.to(device)

    # Loss function
    criterion_is_noise = nn.BCELoss().to(device)
    criterion_is_next = nn.BCELoss().to(device)

    for epoch in range(start_epoch, epochs):
        # Creating data indices for training and validation splits:
        if shuffle_dataset:
            np.random.seed(random_seed)
            np.random.shuffle(indices)
        train_indices, val_indices = indices[split:], indices[:split]

        # Creating data samplers and loaders:
        train_sampler = SubsetRandomSampler(train_indices)
        valid_sampler = SubsetRandomSampler(val_indices)
        train_loader = torch.utils.data.DataLoader(korean_dataset, batch_size=batch_size, sampler=train_sampler,
                                                   pin_memory=True, drop_last=True)
        validation_loader = torch.utils.data.DataLoader(korean_dataset, batch_size=batch_size, sampler=valid_sampler,
                                                        pin_memory=True, drop_last=True)

        if waiting >= patience:
            break
        if waiting > 0 and waiting % weight_decay_per_epoch == 0:
            adjust_learning_rate(optimizer=model_optimizer, shrink_factor=weight_decay_percentage)

        train(train_loader=train_loader,
              model=model,
              optimizer=model_optimizer,
              criterion_is_noise=criterion_is_noise,
              criterion_is_next=criterion_is_next,
              epoch=epoch)

        with torch.no_grad():
            mean_loss = validate(validation_loader=validation_loader,
                                 model=model,
                                 criterion_is_noise=criterion_is_noise,
                                 criterion_is_next=criterion_is_next)

        is_best = mean_loss < best_loss
        best_loss = min(mean_loss, best_loss)
        if not is_best:
            waiting += 1
        else:
            waiting = 0

        # Save checkpoint
        filepath = os.path.join(here, now, 'checkpoint.pth')
        save_checkpoint(filepath, epoch, waiting, model, model_optimizer, mean_loss, is_best)