parser.add_argument('-w', '--weights', default=None,
                        help="The path of the saved weights. Should be specified when testing")

    args = parser.parse_args()

    print(args)
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # load data
    train_list,test_list = read(1)

    # train_list = torch.LongTensor(train_list)
    # test_list = torch.LongTensor(test_list)
    train_loader, test_loader = load_mnist(train_list,test_list, batch_size=args.batch_size)
    print('Data loading finish')
    # define model
    model = SpeechResModel(input_size=[1, 98, 60], classes=10, routings=3)
    model.cuda()
    print(model)

    # train or test
    if args.weights is not None:  # init the model weights with provided one
        model.load_state_dict(torch.load(args.weights))

    if not args.testing:
        train(model, train_loader, test_loader, args)

    else:  # testing
        if args.weights is None:
Пример #2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 100)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=40,
                        metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save_dir', default='./result')

    parser.add_argument('-t',
                        '--testing',
                        action='store_true',
                        help="Test the trained model on testing dataset")

    parser.add_argument(
        '-w',
        '--weights',
        default=None,
        help="The path of the saved weights. Should be specified when testing")
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    # train_loader = torch.utils.data.DataLoader(
    #     datasets.MNIST('./data', train=True, download=True,
    #                    transform=transforms.Compose([
    #                        transforms.ToTensor(),
    #                        transforms.Normalize((0.1307,), (0.3081,))
    #                    ])),
    #     batch_size=args.batch_size, shuffle=True, **kwargs)
    # test_loader = torch.utils.data.DataLoader(
    #     datasets.MNIST('./data', train=False, transform=transforms.Compose([
    #                        transforms.ToTensor(),
    #                        transforms.Normalize((0.1307,), (0.3081,))
    #                    ])),
    #     batch_size=args.test_batch_size, shuffle=True, **kwargs)
    train_list, test_list = read(1)
    train_loader, test_loader = load_mnist(train_list,
                                           test_list,
                                           batch_size=args.batch_size)

    # model = Net().to(device)
    model = SpeechResModel().to(device)
    print(model)
    optimizer = Adam(model.parameters(), lr=args.lr)

    if args.weights is not None:  # init the model weights with provided one
        model.load_state_dict(torch.load(args.weights))

    if not args.testing:
        print('training start')
        for epoch in range(1, args.epochs + 1):
            ti = time()

            train(args, model, device, train_loader, test_loader, optimizer,
                  epoch, ti)
        # train(model, train_loader, test_loader, args)

    else:  # testing
        if args.weights is None:
            print(
                'No weights are provided. Will test using random initialized weights.'
            )

        train_loss = 0
        epoch = 0
        ti = 0

        test(args, model, device, train_loss, test_loader, epoch, ti)