示例#1
0
def main():
    parser = argparse.ArgumentParser(description='Data')
    parser.add_argument('--data',
                        type=int,
                        default=0,
                        metavar='N',
                        help='data 0 - 7')
    parser.add_argument('--charlength',
                        type=int,
                        default=1014,
                        metavar='N',
                        help='length: default 1014')
    parser.add_argument('--wordlength',
                        type=int,
                        default=500,
                        metavar='N',
                        help='length: default 500')
    parser.add_argument('--model',
                        type=str,
                        default='simplernn',
                        metavar='N',
                        help='model type: LSTM as default')
    parser.add_argument('--space',
                        type=bool,
                        default=False,
                        metavar='B',
                        help='Whether including space in the alphabet')
    parser.add_argument(
        '--trans',
        type=bool,
        default=False,
        metavar='B',
        help='Not implemented yet, add thesaurus transformation')
    parser.add_argument('--backward',
                        type=int,
                        default=-1,
                        metavar='B',
                        help='Backward direction')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='B',
                        help='Number of epochs')
    parser.add_argument('--power',
                        type=int,
                        default=25,
                        metavar='N',
                        help='Attack power')
    parser.add_argument('--batchsize',
                        type=int,
                        default=200,
                        metavar='B',
                        help='batch size')
    parser.add_argument('--maxbatches',
                        type=int,
                        default=None,
                        metavar='B',
                        help='maximum batches of adv samples generated')
    parser.add_argument('--scoring',
                        type=str,
                        default='replaceone',
                        metavar='N',
                        help='Scoring function.')
    parser.add_argument('--transformer',
                        type=str,
                        default='homoglyph',
                        metavar='N',
                        help='Transformer function.')
    parser.add_argument('--dictionarysize',
                        type=int,
                        default=20000,
                        metavar='B',
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=5e-4,
                        metavar='B',
                        help='learning rate')
    parser.add_argument('--maxnorm',
                        type=float,
                        default=400,
                        metavar='B',
                        help='learning rate')
    parser.add_argument('--adv_train',
                        type=bool,
                        default=True,
                        help='is adversarial training?')
    parser.add_argument('--hidden_loss',
                        type=bool,
                        default=True,
                        help='add loss on hidden')
    parser.add_argument('--temperature',
                        default=1,
                        type=float,
                        help='temperature for smoothing the soft target')
    args = parser.parse_args()

    torch.manual_seed(9527)
    torch.cuda.manual_seed_all(9527)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if args.model == "charcnn":
        args.datatype = "char"
    elif args.model == "simplernn":
        args.datatype = "word"
    elif args.model == "bilstm":
        args.datatype = "word"
    elif args.model == "smallcharrnn":
        args.datatype = "char"
        args.charlength = 300
    elif args.model == "wordcnn":
        args.datatype = "word"

    if args.datatype == "char":
        (train, test, numclass) = loaddata(args.data)
        trainchar = Chardata(train, getidx=True)
        testchar = Chardata(test, getidx=True)
        train_loader = DataLoader(trainchar,
                                  batch_size=args.batchsize,
                                  num_workers=4,
                                  shuffle=True)
        test_loader = DataLoader(testchar,
                                 batch_size=args.batchsize,
                                 num_workers=4,
                                 shuffle=False)
        alphabet = trainchar.alphabet
        maxlength = args.charlength
        word_index = None
    elif args.datatype == "word":
        (train, test, tokenizer, numclass, rawtrain,
         rawtest) = loaddatawithtokenize(args.data,
                                         nb_words=args.dictionarysize,
                                         datalen=args.wordlength,
                                         withraw=True)
        word_index = tokenizer.word_index
        trainword = Worddata(train, getidx=True, rawdata=rawtrain)
        testword = Worddata(test, getidx=True, rawdata=rawtest)
        train_loader = DataLoader(trainword,
                                  batch_size=args.batchsize,
                                  num_workers=4,
                                  shuffle=True)
        test_loader = DataLoader(testword,
                                 batch_size=args.batchsize,
                                 num_workers=4,
                                 shuffle=False)
        maxlength = args.wordlength
        alphabet = None

    if args.model == "charcnn":
        model = CharCNN(classes=numclass)
    elif args.model == "simplernn":
        model = SmallRNN(classes=numclass)
    elif args.model == "bilstm":
        model = SmallRNN(classes=numclass, bidirection=True)
    elif args.model == "smallcharrnn":
        model = SmallCharRNN(classes=numclass)
    elif args.model == "wordcnn":
        model = WordCNN(classes=numclass)

    model = model.to(device)
    print(model)
    print(args)
    iterator = tqdm(train_loader, ncols=0, leave=False)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    train_test(alphabet, args, device, iterator, model, numclass, optimizer,
               test_loader, word_index)
示例#2
0
        trainchar = dataloader.Chardata(train, getidx=True)
        testchar = dataloader.Chardata(test, getidx=True)
        train_loader = DataLoader(trainchar,
                                  batch_size=args.batchsize,
                                  num_workers=4,
                                  shuffle=True)
        test_loader = DataLoader(testchar,
                                 batch_size=args.batchsize,
                                 num_workers=4,
                                 shuffle=True)
        alphabet = trainchar.alphabet
        maxlength = args.charlength
    elif args.datatype == "word":
        (train, test, tokenizer, numclass, rawtrain,
         rawtest) = loaddata.loaddatawithtokenize(args.data,
                                                  nb_words=args.dictionarysize,
                                                  datalen=args.wordlength,
                                                  withraw=True)
        word_index = tokenizer.word_index
        trainword = dataloader.Worddata(train, getidx=True, rawdata=rawtrain)
        testword = dataloader.Worddata(test, getidx=True, rawdata=rawtest)
        train_loader = DataLoader(trainword,
                                  batch_size=args.batchsize,
                                  num_workers=4,
                                  shuffle=True)
        test_loader = DataLoader(testword,
                                 batch_size=args.batchsize,
                                 num_workers=4,
                                 shuffle=True)
        maxlength = args.wordlength
if args.model == "charcnn":
    model = model.CharCNN(classes=numclass)
示例#3
0
def main():
    np.random.seed(9527)

    parser = argparse.ArgumentParser(description='Data')
    parser.add_argument('--data', type=int, default=0, metavar='N',
                        help='data: can be 0,1,2,3,5,6,7 which specify a textdata file')
    parser.add_argument('--externaldata', type=str, default='', metavar='S',
                        help='External database file. Default: Empty string')
    parser.add_argument('--model', type=str, default='charcnn', metavar='S',
                        help='model type(simplernn, charcnn, bilstm). LSTM as default.')
    parser.add_argument('--modelpath', type=str, default='outputs/charcnn_0_bestmodel.dat', metavar='S',
                        help='model file path')
    parser.add_argument('--power', type=int, default=30, metavar='N',
                        help='Attack power')
    parser.add_argument('--batchsize', type=int, default=30, metavar='N',
                        help='batch size')
    parser.add_argument('--scoring', type=str, default='replaceone', metavar='N',
                        help='Scoring function.')
    parser.add_argument('--transformer', type=str, default='homoglyph', metavar='N',
                        help='Transformer function.')
    parser.add_argument('--maxbatches', type=int, default=None, metavar='B',
                        help='maximum batches of adv samples generated')
    parser.add_argument('--advsamplepath', type=str, default=None, metavar='B',
                        help='advsamplepath: If default, will generate one according to parameters')
    parser.add_argument('--dictionarysize', type=int, default=20000, metavar='B',
                        help='Size of the dictionary used in RNN model')
    parser.add_argument('--charlength', type=int, default=1014, metavar='N',
                        help='length: default 1014')
    parser.add_argument('--wordlength', type=int, default=500, metavar='N',
                        help='word length: default 500')
    args = parser.parse_args()

    torch.manual_seed(8)
    torch.cuda.manual_seed(8)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.model == "charcnn":
        args.datatype = "char"
    elif args.model == "simplernn":
        args.datatype = "word"
    elif args.model == "bilstm":
        args.datatype = "word"

    if args.externaldata != '':
        if args.datatype == 'char':
            (data, numclass) = pickle.load(open(args.externaldata, 'rb'))
            testchar = Chardata(data, getidx=True)
            test_loader = DataLoader(testchar, batch_size=args.batchsize, num_workers=4, shuffle=False)
            alphabet = testchar.alphabet
        elif args.datatype == 'word':
            (data, word_index, numclass) = pickle.load(open(args.externaldata, 'rb'))
            testword = Worddata(data, getidx=True)
            test_loader = DataLoader(testword, batch_size=args.batchsize, num_workers=4, shuffle=False)
    else:
        if args.datatype == "char":
            (train, test, numclass) = loaddata(args.data)
            trainchar = Chardata(train, getidx=True)
            testchar = Chardata(test, getidx=True)
            train_loader = DataLoader(trainchar, batch_size=args.batchsize, num_workers=4, shuffle=True)
            test_loader = DataLoader(testchar, batch_size=args.batchsize, num_workers=4, shuffle=True)
            alphabet = trainchar.alphabet
            maxlength = args.charlength
        elif args.datatype == "word":
            (train, test, tokenizer,
             numclass, rawtrain, rawtest) = loaddatawithtokenize(args.data,
                                                                 nb_words=args.dictionarysize,
                                                                 datalen=args.wordlength,
                                                                 withraw=True)
            word_index = tokenizer.word_index
            trainword = Worddata(train, getidx=True, rawdata=rawtrain)
            testword = Worddata(test, getidx=True, rawdata=rawtest)
            train_loader = DataLoader(trainword, batch_size=args.batchsize, num_workers=4, shuffle=True)
            test_loader = DataLoader(testword, batch_size=args.batchsize, num_workers=4, shuffle=True)
            maxlength = args.wordlength
    if args.model == "charcnn":
        model = CharCNN(classes=numclass)
    elif args.model == "simplernn":
        model = SmallRNN(classes=numclass)
    elif args.model == "bilstm":
        model = SmallRNN(classes=numclass, bidirection=True)

    print(model)

    state = torch.load(args.modelpath)
    model = model.to(device)
    try:
        model.load_state_dict(state['state_dict'])
    except:
        model = torch.nn.DataParallel(model).to(device)
        model.load_state_dict(state['state_dict'])
        model = model.module

    alltimebest = 0
    bestfeature = []

    if args.datatype == "char":
        attackchar(model, args, numclass, alphabet, test_loader, device, maxbatch=args.maxbatches)
    elif args.datatype == "word":
        index2word = {0: '[PADDING]', 1: '[START]', 2: '[UNKNOWN]', 3: ''}
        if args.dictionarysize == 20000:
            for i in word_index:
                if word_index[i] + 3 < args.dictionarysize:
                    index2word[word_index[i] + 3] = i
        else:
            for i in word_index:
                if word_index[i] + 3 < args.dictionarysize:
                    index2word[word_index[i] + 3] = i
        attackword(model, args, numclass, test_loader, device, index2word, word_index, maxbatch=args.maxbatches)
示例#4
0
elif args.model == "simplernn":
    args.datatype = "word"
elif args.model == "bilstm":
    args.datatype = "word"

if args.datatype == "char":
    (train, test, numclass) = loaddata.loaddata(args.data)
    traintext = dataloader.Chardata(train,
                                    backward=args.backward,
                                    length=args.length)
    testtext = dataloader.Chardata(test,
                                   backward=args.backward,
                                   length=args.length)
elif args.datatype == "word":
    (train, test, tokenizer,
     numclass) = loaddata.loaddatawithtokenize(args.data,
                                               nb_words=args.dictionarysize)
    traintext = dataloader.Worddata(train)
    testtext = dataloader.Worddata(test)

train_loader = DataLoader(traintext,
                          batch_size=args.batchsize,
                          num_workers=4,
                          shuffle=True)
test_loader = DataLoader(testtext,
                         batch_size=args.batchsize,
                         num_workers=4,
                         shuffle=True)

if args.model == "charcnn":
    model = model.CharCNN(classes=numclass)
elif args.model == "simplernn":
示例#5
0
        if rawdata:
            self.raw = rawdata

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        x = self.inputs[idx]
        y = self.labels[idx]
        if self.getidx:
            if self.raw:
                return x, y, idx, self.raw[idx]
            else:
                return x, y, idx
        else:
            return x, y


if __name__ == '__main__':
    # Example for generating external dataset
    import pickle
    import loaddata

    (train, test, tokenizer, numclass) = loaddata.loaddatawithtokenize(0)
    test.content = test.content[:100]
    test.output = test.output[:100]
    word_index = tokenizer.word_index

    pickle.dump((test, word_index, numclass),
                open('textdata/ag_news_small_word.pickle', 'wb'))
        alphabet = trainchar.alphabet
    elif args.datatype == 'word':
        (data,word_index,numclass) = pickle.load(open(args.externaldata,'rb'))
        testword = dataloader.Worddata(data,backward = args.backward, getidx = True)
        test_loader = DataLoader(testword,batch_size=args.batchsize, num_workers=4,shuffle=False)  
else:
    if args.datatype == "char":
        (train,test,numclass) = loaddata.loaddata(args.data)
        trainchar = dataloader.Chardata(train,backward = args.backward, getidx = True)
        testchar = dataloader.Chardata(test,backward = args.backward, getidx = True)
        train_loader = DataLoader(trainchar,batch_size=args.batchsize, num_workers=4, shuffle = True)
        test_loader = DataLoader(testchar,batch_size=args.batchsize, num_workers=4, shuffle=True)
        alphabet = trainchar.alphabet
        maxlength = args.charlength
    elif args.datatype == "word":
        (train,test,tokenizer,numclass) = loaddata.loaddatawithtokenize(args.data, nb_words = args.dictionarysize, datalen = args.wordlength)
        word_index = tokenizer.word_index
        trainword = dataloader.Worddata(train,backward = args.backward, getidx = True)
        testword = dataloader.Worddata(test,backward = args.backward, getidx = True)
        train_loader = DataLoader(trainword,batch_size=args.batchsize, num_workers=4, shuffle = True)
        test_loader = DataLoader(testword,batch_size=args.batchsize, num_workers=4,shuffle=True)
        maxlength =  args.wordlength
if args.model == "charcnn":
    model = model.CharCNN(classes = numclass)
elif args.model == "simplernn":
    model = model.smallRNN(classes = numclass)
elif args.model == "bilstm":
    model = model.smallRNN(classes = numclass, bidirection = True)

print(model)
示例#7
0
# print(model)

if args.datatype == "char":
    (train, test, numclass) = loaddata.loaddata(args.data)
    trainchar = dataloader.Chardata(train, backward=args.backward)
    testchar = dataloader.Chardata(test, backward=args.backward)
    train_loader = DataLoader(trainchar,
                              batch_size=args.batchsize,
                              num_workers=4,
                              shuffle=True)
    test_loader = DataLoader(testchar,
                             batch_size=args.batchsize,
                             num_workers=4)
elif args.datatype == "word":
    (train, test, tokenizer,
     numclass) = loaddata.loaddatawithtokenize(args.data)
    trainword = dataloader.Worddata(train, backward=args.backward)
    testword = dataloader.Worddata(test, backward=args.backward)
    train_loader = DataLoader(trainword,
                              batch_size=args.batchsize,
                              num_workers=4,
                              shuffle=True)
    test_loader = DataLoader(testword,
                             batch_size=args.batchsize,
                             num_workers=4)

if args.model == "charcnn":
    model = model.CharCNN(classes=numclass)
elif args.model == "simplernn":
    model = model.smallRNN(classes=numclass)
elif args.model == "bilstm":
示例#8
0
        alphabet = trainchar.alphabet
    elif args.datatype == 'word':
        (data,word_index,numclass) = pickle.load(open(args.externaldata,'rb'))
        testword = dataloader.Worddata(data, getidx = True)
        test_loader = DataLoader(testword,batch_size=args.batchsize, num_workers=4,shuffle=False)  
else:
    if args.datatype == "char":
        (train,test,numclass) = loaddata.loaddata(args.data)
        trainchar = dataloader.Chardata(train, getidx = True)
        testchar = dataloader.Chardata(test, getidx = True)
        train_loader = DataLoader(trainchar,batch_size=args.batchsize, num_workers=4, shuffle = True)
        test_loader = DataLoader(testchar,batch_size=args.batchsize, num_workers=4, shuffle=True)
        alphabet = trainchar.alphabet
        maxlength = args.charlength
    elif args.datatype == "word":
        (train,test,tokenizer,numclass, rawtrain, rawtest) = loaddata.loaddatawithtokenize(args.data, nb_words = args.dictionarysize, datalen = args.wordlength, withraw=True)
        word_index = tokenizer.word_index
        trainword = dataloader.Worddata(train, getidx = True, rawdata = rawtrain)
        testword = dataloader.Worddata(test, getidx = True, rawdata = rawtest)
        train_loader = DataLoader(trainword,batch_size=args.batchsize, num_workers=4, shuffle = True)
        test_loader = DataLoader(testword,batch_size=args.batchsize, num_workers=4,shuffle=True)
        maxlength =  args.wordlength
if args.model == "charcnn":
    model = model.CharCNN(classes = numclass)
elif args.model == "simplernn":
    model = model.smallRNN(classes = numclass)
elif args.model == "bilstm":
    model = model.smallRNN(classes = numclass, bidirection = True)

print(model)
示例#9
0
def main():
    model_path = './outputs/simplernn_0_clean.dat'
    parser = argparse.ArgumentParser(description='Data')
    parser.add_argument('--data',
                        type=int,
                        default=0,
                        metavar='N',
                        help='data 0 - 7')
    parser.add_argument('--charlength',
                        type=int,
                        default=1014,
                        metavar='N',
                        help='length: default 1014')
    parser.add_argument('--wordlength',
                        type=int,
                        default=500,
                        metavar='N',
                        help='length: default 500')
    parser.add_argument('--model',
                        type=str,
                        default='simplernn',
                        metavar='N',
                        help='model type: LSTM as default')
    parser.add_argument('--space',
                        type=bool,
                        default=False,
                        metavar='B',
                        help='Whether including space in the alphabet')
    parser.add_argument(
        '--trans',
        type=bool,
        default=False,
        metavar='B',
        help='Not implemented yet, add thesaurus transformation')
    parser.add_argument('--backward',
                        type=int,
                        default=-1,
                        metavar='B',
                        help='Backward direction')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='B',
                        help='Number of epochs')
    parser.add_argument('--power',
                        type=int,
                        default=25,
                        metavar='N',
                        help='Attack power')
    parser.add_argument('--batchsize',
                        type=int,
                        default=100,
                        metavar='B',
                        help='batch size')
    parser.add_argument('--maxbatches',
                        type=int,
                        default=None,
                        metavar='B',
                        help='maximum batches of adv samples generated')
    parser.add_argument('--scoring',
                        type=str,
                        default='replaceone',
                        metavar='N',
                        help='Scoring function.')
    parser.add_argument('--transformer',
                        type=str,
                        default='homoglyph',
                        metavar='N',
                        help='Transformer function.')
    parser.add_argument('--dictionarysize',
                        type=int,
                        default=20000,
                        metavar='B',
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=5e-4,
                        metavar='B',
                        help='learning rate')
    parser.add_argument('--maxnorm',
                        type=float,
                        default=400,
                        metavar='B',
                        help='learning rate')
    parser.add_argument('--adv_train',
                        type=bool,
                        default=False,
                        help='is adversarial training?')
    parser.add_argument('--hidden_loss',
                        type=bool,
                        default=False,
                        help='add loss on hidden')
    args = parser.parse_args()

    torch.manual_seed(9527)
    torch.cuda.manual_seed_all(9527)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if args.model == "charcnn":
        args.datatype = "char"
    elif args.model == "simplernn":
        args.datatype = "word"
    elif args.model == "bilstm":
        args.datatype = "word"
    elif args.model == "smallcharrnn":
        args.datatype = "char"
        args.charlength = 300
    elif args.model == "wordcnn":
        args.datatype = "word"

    if args.datatype == "char":
        (train, test, numclass) = loaddata(args.data)
        trainchar = Chardata(train, getidx=True)
        testchar = Chardata(test, getidx=True)
        train_loader = DataLoader(trainchar,
                                  batch_size=args.batchsize,
                                  num_workers=4,
                                  shuffle=True)
        test_loader = DataLoader(testchar,
                                 batch_size=args.batchsize,
                                 num_workers=4,
                                 shuffle=False)
        alphabet = trainchar.alphabet
        maxlength = args.charlength
        word2index = None
    elif args.datatype == "word":
        (train, test, tokenizer, numclass, rawtrain,
         rawtest) = loaddatawithtokenize(args.data,
                                         nb_words=args.dictionarysize,
                                         datalen=args.wordlength,
                                         withraw=True)
        word2index = tokenizer.word_index
        index2word = tokenizer.index_word
        trainword = Worddata(train, getidx=True, rawdata=rawtrain)
        testword = Worddata(test, getidx=True, rawdata=rawtest)
        train_loader = DataLoader(trainword,
                                  batch_size=args.batchsize,
                                  num_workers=4,
                                  shuffle=True)
        test_loader = DataLoader(testword,
                                 batch_size=args.batchsize,
                                 num_workers=4,
                                 shuffle=False)
        maxlength = args.wordlength
        alphabet = None

    if args.model == "charcnn":
        model = CharCNN(classes=numclass)
    elif args.model == "simplernn":
        model = SmallRNN(classes=numclass)
    elif args.model == "bilstm":
        model = SmallRNN(classes=numclass, bidirection=True)
    elif args.model == "smallcharrnn":
        model = SmallCharRNN(classes=numclass)
    elif args.model == "wordcnn":
        model = WordCNN(classes=numclass)

    model = model.to(device)
    print(model)
    print(args)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    with torch.no_grad():
        for batch_idx, data in enumerate(test_loader):
            inputs, targets, idx, raw = data
            inputs, targets, idx = inputs.to(device), targets.to(
                device), idx.to(device)
            indices = torch.tensor([58, 3, 26, 34]).to(device)
            print(targets)
            inputs = torch.index_select(inputs, 0, indices)
            targets = torch.index_select(targets, 0, indices)
            h = model(inputs)
            h_orig = h.view(h.size()[0], -1)
            y_adv, x_adv = get_adv(args, data, device, model, numclass,
                                   word2index, alphabet)
            x_adv = torch.index_select(x_adv, 0, indices)
            y_adv = torch.index_select(y_adv, 0, indices)
            ah = model(x_adv)
            ah = ah.view(ah.size()[0], -1)
            h = torch.cat((h_orig, ah), 0)
            y = torch.cat((targets, targets), 0)
            draw_pca_plot(h.cpu().detach().numpy(), y)
示例#10
0
def main():
    parser = argparse.ArgumentParser(description='Data')
    parser.add_argument('--data',
                        type=int,
                        default=0,
                        metavar='N',
                        help='data 0 - 6')
    parser.add_argument('--charlength',
                        type=int,
                        default=1014,
                        metavar='N',
                        help='length: default 1014')
    parser.add_argument('--wordlength',
                        type=int,
                        default=500,
                        metavar='N',
                        help='length: default 500')
    parser.add_argument('--model',
                        type=str,
                        default='charcnn',
                        metavar='N',
                        help='model type: LSTM as default')
    parser.add_argument('--space',
                        type=bool,
                        default=False,
                        metavar='B',
                        help='Whether including space in the alphabet')
    parser.add_argument(
        '--trans',
        type=bool,
        default=False,
        metavar='B',
        help='Not implemented yet, add thesausus transformation')
    parser.add_argument('--backward',
                        type=int,
                        default=-1,
                        metavar='B',
                        help='Backward direction')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='B',
                        help='Number of epochs')
    parser.add_argument('--batchsize',
                        type=int,
                        default=100,
                        metavar='B',
                        help='batch size')
    parser.add_argument('--dictionarysize',
                        type=int,
                        default=20000,
                        metavar='B',
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        metavar='B',
                        help='learning rate')
    parser.add_argument('--maxnorm',
                        type=float,
                        default=400,
                        metavar='B',
                        help='learning rate')
    args = parser.parse_args()

    torch.manual_seed(9527)
    torch.cuda.manual_seed_all(9527)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.model == "charcnn":
        args.datatype = "char"
    elif args.model == "simplernn":
        args.datatype = "word"
    elif args.model == "bilstm":
        args.datatype = "word"
    elif args.model == "smallcharrnn":
        args.datatype = "char"
        args.charlength = 300
    elif args.model == "wordcnn":
        args.datatype = "word"

    print("Loading data...")
    if args.datatype == "char":
        (train, test, numclass) = loaddata.loaddata(args.data)
        trainchar = dataloader.Chardata(train,
                                        backward=args.backward,
                                        length=args.charlength)
        testchar = dataloader.Chardata(test,
                                       backward=args.backward,
                                       length=args.charlength)
        train_loader = DataLoader(trainchar,
                                  batch_size=args.batchsize,
                                  num_workers=4,
                                  shuffle=True)
        test_loader = DataLoader(testchar,
                                 batch_size=args.batchsize,
                                 num_workers=4)
    elif args.datatype == "word":
        (train, test, tokenizer, numclass) = loaddata.loaddatawithtokenize(
            args.data, nb_words=args.dictionarysize, datalen=args.wordlength)
        trainword = dataloader.Worddata(train, backward=args.backward)
        testword = dataloader.Worddata(test, backward=args.backward)
        train_loader = DataLoader(trainword,
                                  batch_size=args.batchsize,
                                  num_workers=4,
                                  shuffle=True)
        test_loader = DataLoader(testword,
                                 batch_size=args.batchsize,
                                 num_workers=4)

    if args.model == "charcnn":
        model = CharCNN(classes=numclass)
    elif args.model == "simplernn":
        model = SmallRNN(classes=numclass)
    elif args.model == "bilstm":
        model = SmallRNN(classes=numclass, bidirection=True)
    elif args.model == "smallcharrnn":
        model = SmallCharRNN(classes=numclass)
    elif args.model == "wordcnn":
        model = WordCNN(classes=numclass)

    model = model.to(device)
    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    bestacc = 0
    for epoch in range(1, args.epochs + 1):
        print('Start epoch %d' % epoch)
        model.train()
        for dataid, data in enumerate(train_loader):
            inputs, target = data
            inputs, target = Variable(inputs), Variable(target)
            inputs, target = inputs.to(device), target.to(device)
            output = model(inputs)

            loss = F.nll_loss(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        correct = .0
        total_loss = 0
        model.eval()
        for dataid, data in enumerate(test_loader):
            inputs, target = data
            inputs, target = inputs.to(device), target.to(device)
            output = model(inputs)
            loss = F.nll_loss(output, target)
            total_loss += loss.item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

        acc = correct / len(test_loader.dataset)
        avg_loss = total_loss / len(test_loader.dataset)
        print('Epoch %d : Loss %.4f Accuracy %.5f' % (epoch, avg_loss, acc))
        is_best = acc > bestacc
        if is_best:
            bestacc = acc
        if args.dictionarysize != 20000:
            fname = "outputs/" + args.model + str(
                args.dictionarysize) + "_" + str(args.data)
        else:
            fname = "outputs/" + args.model + "_" + str(args.data)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'bestacc': bestacc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            filename=fname)