示例#1
0
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# Copyright Nils Schaetti <*****@*****.**>

# Imports
import torch.utils.data
from echotorch import datasets
from echotorch.transforms import text

# Reuters C50 dataset
reutersloader = torch.utils.data.DataLoader(datasets.ReutersC50Dataset(
    root="../../data/reutersc50/",
    download=True,
    n_authors=2,
    transform=text.Token(),
    dataset_size=2,
    dataset_start=20),
                                            batch_size=1,
                                            shuffle=True)

# For each batch
for k in range(10):
    # Set fold and training mode
    reutersloader.dataset.set_fold(k)
    reutersloader.dataset.set_train(True)

    # Get training data for this fold
    for i, data in enumerate(reutersloader):
        # Inputs and labels
import echotorch.nn as etnn
import echotorch.utils
import os

# Settings
n_epoch = 1
embedding_dim = 10
n_authors = 15
use_cuda = True
voc_size = 15723

# Word embedding
transform = text.Character3Gram()

# Reuters C50 dataset
reutersloader = torch.utils.data.DataLoader(datasets.ReutersC50Dataset(
    download=True, n_authors=15, transform=transform),
                                            batch_size=1,
                                            shuffle=False)

# Model
model = CNNCharacterEmbedding(voc_size=voc_size, embedding_dim=embedding_dim)

# Optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Loss function
# loss_function = nn.NLLLoss()
loss_function = nn.CrossEntropyLoss()

# Set fold and training mode
reutersloader.dataset.set_fold(0)
# Experiment
xp = nsNLP.tools.ResultManager\
(
    args.output,
    args.name,
    args.description,
    args.get_space(),
    args.n_samples,
    args.k,
    verbose=args.verbose
)

# Reuters C50 dataset
reutersloader = torch.utils.data.DataLoader(datasets.ReutersC50Dataset(
    root=args.dataset,
    download=True,
    n_authors=args.n_authors,
    dataset_size=args.dataset_size,
    dataset_start=0),
                                            batch_size=1,
                                            shuffle=True)

# Print authors
xp.write(u"Authors : {}".format(reutersloader.dataset.authors), log_level=0)

# First params
rc_size = int(args.get_space()['reservoir_size'][0])
rc_w_sparsity = args.get_space()['w_sparsity'][0]
last_rc_size = 0

# W index
w_index = 0