def train(loaders, save_path):
    """returns trained model"""
    # Initialize custom defined cnn
    model = Net()
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        model.cuda()

    # cross entropy loss for classification task
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.lr)

    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf

    n_epochs = config.n_epochs
    for epoch in range(1, n_epochs + 1):
        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0

        model.train()
        for batch_idx, (data, target) in enumerate(loaders['train']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            # average training loss
            train_loss += (1 / (batch_idx + 1)) * (loss.data - train_loss)

        # vaidation
        model.eval()
        for batch_idx, (data, target) in enumerate(loaders['valid']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## update the average validation loss
            output = model(data)
            loss = criterion(output, target)
            valid_loss += (1 / (batch_idx + 1)) * (loss.data - valid_loss)

        # print training/validation statistics
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.
              format(epoch, train_loss, valid_loss))

        # save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            torch.save(model.state_dict(), save_path)

            # Updating the validation loss minimum
            valid_loss_min = valid_loss

    # return trained model
    return model
Esempio n. 2
0
def main():
    net = Net().to(setting.device)
    net.eval()
    net.load_state_dict(torch.load('model.pt'))
    print('model has been loaded.')
    correct = 0
    valid_dataloader = data_preprocess.get_valid_dataloader()
    total = len(os.listdir(setting.valid_folder_path))
    with torch.no_grad():
        
        miss_character = {}
        for (imgs, labels) in tqdm((valid_dataloader)):
            imgs, labels = imgs.to(setting.device), labels.to(setting.device)
            labels_ohe_predict = net(imgs)
            # for each img in one batch
            for single in range(labels_ohe_predict.shape[0]):              
                single_labels_ohe_predict = labels_ohe_predict[single, :]
                predict_label = ''
                # get predict_label
                for slice in range(setting.char_num):
                    char = ohe.num2char[np.argmax(
                        single_labels_ohe_predict[slice*setting.pool_length:(slice+1)*setting.pool_length].cpu().data.numpy())]
                    predict_label += char
                # get true label
                true_label = ohe.decode(labels[single, :].cpu().numpy())
                # print('true label:', true_label, '   predict label:', predict_label)
                if predict_label == true_label:
                    correct += 1
                else:
                    
                    for i in range(setting.char_num):
                        if predict_label[i] != true_label[i]:
                            error_info = '{} -> {}'.format(true_label[i], predict_label[i])
                            if error_info in miss_character:
                                miss_character[error_info] +=1
                            else:
                                miss_character[error_info] =1
    sorted_miss = sorted(miss_character.items(), key=lambda kv:kv[1], reverse=True)
    sorted_miss=collections.OrderedDict(sorted_miss)            
    with open('miss_character.txt','w') as f:
        for i in sorted_miss:
            f.write('{} : {}\n'.format(i, sorted_miss[i]))
    print('accuracy: {}/{} -- {:.4f}'.format(correct, total, correct/total))
test_dataset = AnimalsDataset(filename='meta-data_new/meta-data/test.csv', root_dir='./test_new/test', train=False,
                              transform=data_transform1)

test_loader = DataLoader(test_dataset, batch_size=60)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

net = Net().to(device)
print(net)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

net.load_state_dict(torch.load("./models/model_ep50.net"))

net.eval()

#######################
# Testing an image
data_iter = iter(test_dataset)
sample = next(data_iter)
data, labels = sample['image'], sample['labels']

img = data
img = torch.unsqueeze(img, 0)
img = img.float().to(device)

out = net(img)
prediction = torch.nn.functional.softmax(out)
predicted_label = animals_dataset.labels_to_idx[torch.nn.functional.softmax(out).argmax().item() ]
print("Predicted Label:", predicted_label)