Esempio n. 1
0
def train(num_epochs, model, dataloader, optim, criterion, val_loader, output_path):
    logger = utils.TrainLogger()
    best_val_accuracy = 0
    for ep in range(num_epochs):
        model.train()
        for i, (sentences, labels) in enumerate(dataloader):
            if utils.is_cuda(model):
                labels = labels.cuda()
            optim.zero_grad()
            preds = model(sentences)
            loss = criterion(preds, labels)
            loss.backward()
            optim.step()

            l2_norm_constraint(model, s=3)

            logger.update_iter(
                total=labels.size(0),
                correct=(labels == preds.argmax(dim=1)).sum().item(),
                loss=loss.item()
            )

        accuracy, ep_loss = logger.get_epoch()
        val_accuracy = eval(model, val_loader)
        if val_accuracy >= best_val_accuracy:
            best_val_accuracy = val_accuracy
            utils.save_model(output_path, model, optim, ep)
        if plotter:
            total_iter = (ep + 1) * len(dataloader)
            plotter.plot('loss', 'epoch', 'Loss', total_iter, ep_loss)
            plotter.plot('accuracy', 'epoch', 'Accuracy', total_iter, accuracy)
            plotter.plot('accuracy', 'val', 'Accuracy', total_iter, val_accuracy)
        else:
            print('epoch: %d\t loss: %f\t accuracy: %3.2f\t test accuracy: %3.2f'
                  % (ep, ep_loss, accuracy, val_accuracy))
Esempio n. 2
0
    def forward(self, input_):
        sentence_tensor = self.generate_sentence_tensor(input_)
        if is_cuda(self):
            sentence_tensor = sentence_tensor.cuda()

        x = self.cnn(sentence_tensor)
        return x
Esempio n. 3
0
 def forward(self, sentence):
     indexes = []
     for word in sentence:
         indexes.append(self.word2index[word])
     indexes = torch.LongTensor(indexes)
     if is_cuda(self):
         indexes = indexes.cuda()
     return self.index2vec(indexes)
Esempio n. 4
0
def eval(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for sentences, labels in dataloader:
            if utils.is_cuda(model):
                labels = labels.cuda()
            preds = model(sentences)
            total += labels.size(0)
            correct += (labels == preds.argmax(dim=1)).sum().item()
    accuracy = 100 * correct / total
    return accuracy
Esempio n. 5
0
    def forward(self, input_):
        sentences = []
        max_len = 0
        for words in input_:
            if len(words) > max_len:
                max_len = len(words)
            sentence_matrix_static = self.sentence2mat_static(words)
            sentence_matrix_non_static = self.sentence2mat_non_static(words)
            sentence_matrix = torch.stack((sentence_matrix_static, sentence_matrix_non_static), dim=-1)
            sentences.append(sentence_matrix)

        # zero padding
        max_len = max(max_len, 5)
        batch_size = len(input_)
        sentence_tensor = torch.zeros(batch_size, max_len, self.word_vec_size, 2)
        for i, sentence_matrix in enumerate(sentences):
            sentence_tensor[i].narrow(0, 0, sentence_matrix.size(0)).copy_(sentence_matrix)

        if is_cuda(self):
            sentence_tensor = sentence_tensor.cuda()

        x = self.cnn(sentence_tensor)
        return x
Esempio n. 6
0
                    type=str,
                    default='cifar10',
                    help='dataset name')
parser.add_argument('--logging', type=bool, default=False, help='log or not')
parser.add_argument('--log_port',
                    type=int,
                    default=8080,
                    help='visdom log panel port')
parser.add_argument('--use_cpu',
                    type=bool,
                    default=False,
                    help='if testing on cpu')
opt = parser.parse_args()
print(opt)

cuda = is_cuda(opt.use_cpu)


# Create z noise for generator input
def create_noise(batch_size, latent_dim):
    return Variable(
        Tensor(batch_size, latent_dim).normal_().view(-1, latent_dim, 1, 1))


# Logging
if opt.logging:
    d_real_loss_logger = helper.get_logger(opt.log_port, 'd_loss_real')
    d_fake_loss_logger = helper.get_logger(opt.log_port, 'd_loss_fake')
    d_total_loss_logger = helper.get_logger(opt.log_port, 'd_loss_total')
    g_loss_logger = helper.get_logger(opt.log_port, 'g_loss')
    viz_image_logger = Visdom(port=opt.log_port, env="images")