def get_data_iter(mnist_train, mnist_test, batch_size=32):
    """ 获取数据集迭代器.

        @params:
            mnist_train - 训练数据.
            mnist_test - 测试数据.
            batch_size - 批次大小.

        @return:
            On success - train与test数据迭代器.
            On failure - 错误信息.
    """
    if sys.platform.startswith('win'):
        num_workers = 0  # 0表示不用额外的进程来加速读取数据
    else:
        num_workers = 4
    train_iter = torch.utils.data.DataLoader(mnist_train,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(mnist_test,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=num_workers)
    logger.info('train_iter len:{}'.format(len(train_iter)))
    logger.info('test_iter len:{}'.format(len(test_iter)))
    return train_iter, test_iter
Exemple #2
0
def train_net(net, train_iter, dev_iter, max_epoch, optimizer, loss_func):
    """ 训练神经网络.

        @params:
            net - 神经网络.
            train_iter - 训练数据迭代器.
            dev_iter - 开发数据迭代器.
            max_epoch - 最大epoch.
            optimizer- 优化器.
            loss_func - 损失函数.
    """
    for epoch in range(max_epoch):
        logger.info('epoch {} begin to train'.format(epoch + 1))
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in tqdm(train_iter):
            y_hat = net(X)
            loss = loss_func(y_hat, y).sum()
            # 梯度清零
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()  # “softmax回归的简洁实现”一节将用到
            train_l_sum += loss.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(dev_iter, net)
        logger.info('epoch %d, loss %.4f, train acc %.3f, test acc %.3f' %
                    (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))
def get_parameter_number(net):
    """ 统计神经网络参数个数.

        @params:
            net - 神经网络.

        @return:
            On success - 字典,例如 {'total': total_num, 'trainable': trainable_num}.
            On failure - 错误信息.
    """
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    logger.info(' total:{}, trainable:{}'.format(_convert_num_to_English(total_num), _convert_num_to_English(trainable_num)))
    return {'total': total_num, 'trainable': trainable_num}
Exemple #4
0
def get_dataset(data_path, augmentation_funcs=list()):
    """ 获取数据集.

        @params:
            data_path - 数据保存路径.
            augmentation_funcs - 训练数据增强方法.

        @return:
            On success - train与test数据.
            On failure - 错误信息.
    """
    train_augmentation_funcs = [torchvision.transforms.ToTensor()]
    train_augmentation_funcs.extend(augmentation_funcs)
    augmentation_func = torchvision.transforms.Compose(
        train_augmentation_funcs)
    mnist_train = torchvision.datasets.CIFAR10(root=data_path,
                                               train=True,
                                               download=True,
                                               transform=augmentation_func)
    mnist_test = torchvision.datasets.CIFAR10(
        root=data_path,
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor())
    logger.info('dataset is :{}'.format(type(mnist_train)))
    logger.info('train data len :{}'.format(len(mnist_train)))
    logger.info('test data len :{}'.format(len(mnist_test)))
    return mnist_train, mnist_test
Exemple #5
0
def train_machine_translation_net(encoder,
                                  decoder,
                                  dataset,
                                  out_vocab,
                                  lr=0.01,
                                  batch_size=8,
                                  max_epoch=5):
    """ 训练machine_translation神经网络.

        @params:
            encoder - 编码器神经网络.
            decoder - 解码器神经网络.
            dataset - 数据集.
            out_vocab - 解码器vocab.
            lr - 学习率.
            batch_size - 批次大小.
            max_epoch - 最大epoch.
    """
    enc_optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)
    dec_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)

    loss = nn.CrossEntropyLoss(reduction='none')
    data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)
    for epoch in range(max_epoch):
        l_sum = 0.0
        for X, Y in data_iter:
            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()
            loss_func = get_batch_loss(encoder,
                                       decoder,
                                       X,
                                       Y,
                                       loss,
                                       out_vocab=out_vocab)
            loss_func.backward()
            enc_optimizer.step()
            dec_optimizer.step()
            l_sum += loss_func.item()
        logger.info("epoch %d, loss %.3f" %
                    (epoch + 1, l_sum / len(data_iter)))
def get_dataset(data_path):
    """ 获取数据集.

        @params:
            data_path - 数据保存路径.

        @return:
            On success - train与test数据.
            On failure - 错误信息.
    """
    mnist_train = torchvision.datasets.FashionMNIST(
        root=data_path,
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor())
    mnist_test = torchvision.datasets.FashionMNIST(
        root=data_path,
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor())
    logger.info('dataset is :{}'.format(type(mnist_train)))
    logger.info('train data len :{}'.format(len(mnist_train)))
    logger.info('test data len :{}'.format(len(mnist_test)))
    return mnist_train, mnist_test
Exemple #7
0
 def test_train_net(self):
     """ 训练神经网络.
     """
     print('{} test_train_net {}'.format('-' * 15, '-' * 15))
     data_path = './data/FashionMNIST'
     batch_size = 64
     num_inputs = 1 * 28 * 28
     num_outputs = 10
     max_epoch = 5
     logger.info('加载数据')
     mnist_train, mnist_test = get_dataset(data_path=data_path)
     train_iter, test_iter = get_data_iter(mnist_train,
                                           mnist_test,
                                           batch_size=batch_size)
     logger.info('定义网络')
     net = nn.Sequential()
     net.add_module('flatten', FlattenLayer())
     net.add_module('linear', nn.Linear(num_inputs, num_outputs))
     logger.info(net)
     logger.info('参数初始化')
     torch.nn.init.normal_(net.linear.weight, mean=0, std=0.01)
     torch.nn.init.constant_(net.linear.bias, val=0)
     logger.info('定义损失函数')
     loss_func = nn.CrossEntropyLoss()
     logger.info('定义优化器')
     # optimizer = torch.optim.SGD(net.parameters(), lr=0.03)
     optimizer = torch.optim.Adam(net.parameters())
     logger.info('模型训练')
     train_net(net, train_iter, test_iter, max_epoch, optimizer, loss_func)
     """