示例#1
0
def main():
    # parse arguments
    args = parser.parse_args()

    # load training data
    print("\nLoading training data...")
    train_dataset = AGNEWs(label_data_path=args.train_path, alphabet_path=args.alphabet_path)
    print("Transferring training data into iterator...")
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True, shuffle=True)
    # feature length
    args.num_features = len(train_dataset.alphabet)

    # load developing data
    print("\nLoading developing data...")
    dev_dataset = AGNEWs(label_data_path=args.val_path, alphabet_path=args.alphabet_path)
    print("Transferring developing data into iterator...")
    dev_loader = DataLoader(dev_dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True)

    class_weight, num_class_train = train_dataset.get_class_weight()
    _, num_class_dev = dev_dataset.get_class_weight()
    
    # when you have an unbalanced training set
    if args.class_weight!=None:
        args.class_weight = torch.FloatTensor(class_weight).sqrt_()
        if args.cuda:
            args.class_weight = args.class_weight.cuda()

    print('\nNumber of training samples: '+str(train_dataset.__len__()))
    for i, c in enumerate(num_class_train):
        print("\tLabel {:d}:".format(i).ljust(15)+"{:d}".format(c).rjust(8))
    print('\nNumber of developing samples: '+str(dev_dataset.__len__()))
    for i, c in enumerate(num_class_dev):
        print("\tLabel {:d}:".format(i).ljust(15)+"{:d}".format(c).rjust(8))


    # make save folder
    try:
        os.makedirs(args.save_folder)
    except OSError as e:
        if e.errno == errno.EEXIST:
            print('Directory already exists.')
        else:
            raise
    # args.save_folder = os.path.join(args.save_folder, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))

    # configuration
    print("\nConfiguration:")
    for attr, value in sorted(args.__dict__.items()):
        print("\t{}:".format(attr.capitalize().replace('_', ' ')).ljust(25)+"{}".format(value))

    # log result
    if args.log_result:
        with open(os.path.join(args.save_folder,'result.csv'), 'w') as r:
            r.write('{:s},{:s},{:s},{:s},{:s}'.format('epoch', 'batch', 'loss', 'acc', 'lr'))
    # model
    model = CharCNN(args)
    print(model)
            
    # train 
    train(train_loader, dev_loader, model, args)
示例#2
0
def make_data_loader(dataset_path, alphabet_path, l0, batch_size, num_workers):
    print("\nLoading data from {}".format(dataset_path))
    dataset = AGNEWs(label_data_path=dataset_path,
                     alphabet_path=alphabet_path,
                     l0=l0)
    dataset_loader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                drop_last=True,
                                shuffle=True)
    return dataset, dataset_loader
示例#3
0
                    help='Number of workers used in data-loading')
parser.add_argument('--cuda',
                    action='store_true',
                    default=True,
                    help='enable the gpu')
# logging options
parser.add_argument('--save-folder',
                    default='Results/',
                    help='Location to save epoch models')
args = parser.parse_args()

if __name__ == '__main__':

    # load testing data
    print("\nLoading testing data...")
    test_dataset = AGNEWs(label_data_path=args.test_path,
                          alphabet_path=args.alphabet_path)
    print("Transferring testing data to iterator...")
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             drop_last=True)

    _, num_class_test = test_dataset.get_class_weight()
    print('\nNumber of testing samples: ' + str(test_dataset.__len__()))
    for i, c in enumerate(num_class_test):
        print("\tLabel {:d}:".format(i).ljust(15) + "{:d}".format(c).rjust(8))

    args.num_features = len(test_dataset.alphabet)
    model = CharCNN(args)
    print("=> loading weights from '{}'".format(args.model_path))
    assert os.path.isfile(