/
train.py
74 lines (53 loc) · 2.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
Author: Justin Seymour
Date: 12 March 2020
"""
#Declare libraries for import
import argparse
import torch
from collections import OrderedDict
from os.path import isdir
from torch import nn
from torch import optim
from torchvision import datasets, transforms, models
from get_input_args import arg_parser
from data_preparation import create_transforms, image_datasets, data_loaders
from classifier import create_model, create_classifier, check_gpu, train, test_model, validate
from checkpoint import initial_checkpoint
def main():
# Instantiate the console arguments function
args = arg_parser()
print("GPU setting: {}".format(args.gpu))
# Define normalization for transforms
normalize = transforms.Normalize(
mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225],
)
# Define transformations for training, validation and test sets
data_transforms = create_transforms(30, 224 , 256, normalize)
# Load the datasets from the image folders
datasets = image_datasets(data_transforms)
# Define the dataloaders using the image datasets
loaders = data_loaders(datasets, 32)
# Instantiate a new model
model = create_model(arch=args.arch)
output_units = len(datasets['training'].classes)
# Create new classifier
model.classifier = create_classifier(model, args.hidden_layers, output_units, args.dropout)
device = check_gpu(args.gpu)
print(device)
model.to(device)
learning_rate = args.learning_rate
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr = learning_rate)
epochs = args.epochs
print_every = args.print_every
steps = 0
trainloader = loaders['training']
validloader = loaders['validation']
# trained_model = train(model, epochs, learning_rate, criterion, optimizer, loaders['training'], loaders['validation'], device)
trained_model = train(model, trainloader, validloader, device, criterion, optimizer, epochs, print_every, steps)
print("Training has completed")
test_model(trained_model, loaders['testing'], device)
initial_checkpoint(trained_model, args.checkpoint_dir, datasets['training'])
if __name__ == '__main__': main()