示例#1
0
if torch.cuda.is_available():
	args.device = torch.device('cuda')
	torch.cuda.manual_seed(np.random.randint(1, 10000))
	torch.backends.cudnn.enabled = True 
args.classes = ["desert", "rainforest", "grassland", "tundra", "ocean"]
if args.embedding_type == 'linear':
	model = EmbeddingModel(len(args.classes))
elif args.embedding_type == 'conv':
	model = ConvolutionalEmbeddingModel(len(args.classes))
else:
	print("Model type [{0}] not supported".format(args.embedding_type))
	exit(1)
if torch.cuda.is_available():
	model = model.cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
model = model.train()
starting_epoch = 0
running_loss = 0.0
if len(args.model_checkpoint) > 0:
	checkpoint = torch.load(args.model_checkpoint)
	model.load_state_dict(checkpoint['model_state_dict'])
	optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
	starting_epoch = checkpoint['epoch']
	running_loss = checkpoint['running_loss']
train_set, train_labels, test_set, test_labels, classes = generateData(args.corpus_file, args.classes, args.train_split_percentage, args.load_embedding_from_file, args.save_embedding_dict)
print(len(train_set))
print('Training model...')
print('Starting from epoch %d' % (starting_epoch + 1))
for epoch in progressbar(range(starting_epoch, args.epochs)):
	for i in range(len(train_set)):
示例#2
0
vocab_count['<UNK>'] = 1

idx2word = [word for word in vocab_count.keys()]
word2idx = {word: i for i, word in enumerate(idx2word)}

nc = np.array([count for count in vocab_count.values()],
              dtype=np.float32)**(3. / 4.)
word_freqs = nc / np.sum(nc)

dataset = WordEmbeddingDataset(subsampling, word2idx, word_freqs)
dataloader = tud.DataLoader(dataset, BATCH_SIZE, shuffle=True)

model = EmbeddingModel(len(idx2word), EMBEDDING_SIZE)
model.to(device)
model.train()
optimizer = optim.Adam(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    pbar = tqdm(dataloader)
    pbar.set_description("[Epoch {}]".format(epoch))
    for i, (input_labels, pos_labels, neg_labels) in enumerate(pbar):
        input_labels = input_labels.to(device)
        pos_labels = pos_labels.to(device)
        neg_labels = neg_labels.to(device)
        model.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels).mean()
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item())

model.save_embedding(OUT_DIR, idx2word)