예제 #1
0
def main(args):
    #Initialise config vars
    NGRAMS = args.ngrams
    BATCH_SIZE = args.batch
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if (args.backend == 'gpu' and not torch.cuda.is_available()):
        logger.error('Backend device: %s not available', args.backend)
    if args.backend != 'auto':
        device = torch.device('cpu' if args.backend == 'cpu' else 'cuda')
    EMBED_DIM = args.model_dim
    N_EPOCHS = args.epochs
    MAX_SEQ_LEN = args.max_seq_len
    TRAIN_SPLIT = args.train_split
    EVOLVED = args.evolved
    #logging config vars
    logger.info(
        'Device:%s|Batch size:%s|EmbedDim:%s|Epochs:%s|Ngrams:%s|MAX_LEN:%s|Split:%s',
        device.type, BATCH_SIZE, EMBED_DIM, N_EPOCHS, NGRAMS, MAX_SEQ_LEN,
        TRAIN_SPLIT)

    import os
    #download data
    logger.info('Loading Data')
    # train_dataset,test_dataset,x_vocab,y_vocab=read_data('./.data/ag_news_csv',ngrams=NGRAMS)
    data_reader = DataReader('./.data/ag_news_csv')
    #train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](root='./.data', ngrams=NGRAMS, vocab=None)
    logger.info('Data Loaded')

    VOCAB_SIZE = data_reader.get_vocab_size()
    NUM_CLASS = data_reader.get_num_classes()

    train_dataset = data_reader.get_training_data()
    test_dataset = data_reader.get_testing_data()

    model = ClassificationTransformer(EMBED_DIM,
                                      VOCAB_SIZE,
                                      NUM_CLASS,
                                      max_seq_len=MAX_SEQ_LEN,
                                      evolved=EVOLVED)
    model.to(device)

    from torch.utils.data.dataset import random_split
    min_valid_loss = float('inf')

    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

    train_len = int(len(train_dataset) * TRAIN_SPLIT)
    sub_train_, sub_valid_ = random_split(
        train_dataset, [train_len, len(train_dataset) - train_len])

    for epoch in range(N_EPOCHS):
        train_loss, train_acc = train_func(sub_train_, data_reader, model,
                                           BATCH_SIZE, device, optimizer,
                                           scheduler, criterion, MAX_SEQ_LEN)
        logger.info('Trained for epoch %s', str(epoch))
        valid_loss, valid_acc = test(sub_valid_, data_reader, model,
                                     BATCH_SIZE, device, optimizer, criterion,
                                     MAX_SEQ_LEN)

        TrainingLossStr = f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)'
        ValidationLossStr = f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)'
        logger.info(TrainingLossStr)
        logger.info(ValidationLossStr)