Esempio n. 1
0
def main():
    arguments = parse_arguments()
    initialize_distributed_backend(arguments)
    download_data(arguments)
    model = allocate_model()
    model = torch.nn.parallel.DistributedDataParallelCPU(model)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=arguments.lr,
                                momentum=arguments.momentum)
    worker_procedure(arguments, model, optimizer)
def main():
    arguments = parse_arguments()
    initialize_distributed_backend(arguments)
    download_data(arguments)
    model = allocate_model()
    optimizer = GEM(model.parameters(),
                    lr=arguments.lr,
                    momentum=arguments.momentum)
    #optimizer = DOWNPOUR(model.parameters(), lr=arguments.lr)
    if is_master():
        master_procedure(arguments, model, optimizer)
    else:
        worker_procedure(arguments, model, optimizer)
    optimizer.ready()
Esempio n. 3
0
            def make_dataset() -> tf.data.TFRecordDataset:
                download_directory = util.download_data("/tmp/data")

                files = get_filenames(os.path.join(download_directory,
                                                   "train"))
                ds = tf.data.TFRecordDataset(files)
                return ds
Esempio n. 4
0
def main():
    parser = argparse.ArgumentParser(description='Spoken Language Idenfication')
    parser.add_argument('--dropout', type=float, default=0.2, help='dropout rate in training (default: 0.2')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size in training (default: 32')
    parser.add_argument('--workers', type=int, default=4, help='number of workers in dataset loader (default: 4)')
    parser.add_argument('--max_epochs', type=int, default=5, help='number of max epochs in training (default: 10)')
    parser.add_argument('--lr', type=float, default=1e-04, help='learning rate (default: 0.0001)')
    parser.add_argument('--n_class', type=int, default=2, help='number of class')
    parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--nn_type', type=str, default='crnn', help='type of neural networks')
    parser.add_argument('--save_name', type=str, default='model', help='the name of model')
    parser.add_argument('--mode', type=str, default='train')

    args = parser.parse_args()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if args.cuda else 'cpu')

    model = DNN.DNN()

    model = nn.DataParallel(model).to(device)

    optimizer = optim.Adam(model.module.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss(reduction='sum').to(device)

    if args.mode != 'train':
        return

    download_data()

    kor_db_list = []
    search('dataset/train/train_data', kor_db_list)

    train_wav_paths = np.loadtxt("dataset/TRAIN_list.csv", delimiter=',', dtype=np.unicode)
    valid_wav_paths = np.loadtxt("dataset/TEST_developmentset_list.csv", delimiter=',', dtype=np.unicode)
    test_wav_paths = np.loadtxt("dataset/TEST_coreset_list.csv", delimiter=',', dtype=np.unicode)

    train_wav_paths = list(map(lambda x: "dataset/TIMIT/{}.WAV".format(x), train_wav_paths))
    valid_wav_paths = list(map(lambda x: "dataset/TIMIT/{}.WAV".format(x), valid_wav_paths))
    test_wav_paths = list(map(lambda x: "dataset/TIMIT/{}.WAV".format(x), test_wav_paths))

    min_loss = 100000
    begin_epoch = 0

    loss_acc = [[], [], [], []]

    train_batch_num, train_dataset_list, valid_dataset, test_dataset = \
        split_dataset(args, train_wav_paths, valid_wav_paths, test_wav_paths, kor_db_list)

    logger.info('start')

    train_begin = time.time()

    for epoch in range(begin_epoch, args.max_epochs):

        train_queue = queue.Queue(args.workers * 2)

        train_loader = MultiLoader(train_dataset_list, train_queue, args.batch_size, args.workers, args.nn_type)
        train_loader.start()

        train_loss, train_acc = train(model, train_batch_num, train_queue, criterion, optimizer, device, train_begin, args.workers, 10)
        logger.info('Epoch %d (Training) Loss %0.4f Acc %0.4f' % (epoch, train_loss,  train_acc))

        train_loader.join()

        loss_acc[0].append(train_loss)
        loss_acc[1].append(train_acc)

        valid_queue = queue.Queue(args.workers * 2)

        valid_loader = BaseDataLoader(valid_dataset, valid_queue, args.batch_size, 0, args.nn_type)
        valid_loader.start()

        eval_loss, eval_acc = evaluate(model, valid_loader, valid_queue, criterion, device)
        logger.info('Epoch %d (Evaluate) Loss %0.4f Acc %0.4f' % (epoch, eval_loss, eval_acc))

        valid_loader.join()

        loss_acc[2].append(eval_loss)
        loss_acc[3].append(eval_acc)

        best_model = (eval_loss < min_loss)

        if best_model:
            min_loss = eval_loss
            torch.save(model.state_dict(), './save_model/best_model.pt')
            save_epoch = epoch

    model.load_state_dict(torch.load('./save_model/best_model.pt'))

    test_queue = queue.Queue(args.workers * 2)

    test_loader = BaseDataLoader(test_dataset, test_queue, args.batch_size, 0, args.nn_type)
    test_loader.start()

    confusion_matrix = torch.zeros((args.n_class, args.n_class))
    test_loss, test_acc = evaluate(model, test_loader, test_queue, criterion, device, confusion_matrix)
    logger.info('Epoch %d (Test) Loss %0.4f Acc %0.4f' % (save_epoch, test_loss, test_acc))

    test_loader.join()

    save_data(loss_acc, test_loss, test_acc, confusion_matrix.to('cpu').numpy())
    plot_data(loss_acc, test_loss, test_acc)

    return 0
Esempio n. 5
0
#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path as osp
import sys

from util import download_data


download_data(
    url='https://drive.google.com/uc?id=0B9P1L--7Wd2vMVNlR1JNV1RXLVE',
    md5='2da610302072f99ba7aa34ab145c4b0d',
    path='all.tgz',
    extract=True
)