def train_main(): model = Net().to(device) optimizer = optim.SGD(model.parameters(), lr=0.03) criterion = nn.CrossEntropyLoss() print(model) batch_size = 50 train_loader = get_train_loader(batch_size) validation_loader = get_validation_loader(batch_size) log = get_tensorboard('simple') epochs = 50 start_time = datetime.now() for epoch in range(1, epochs + 1): train(model, train_loader, criterion, optimizer, epoch, log) with torch.no_grad(): print('\nValidation:') evaluate(model, validation_loader, criterion, epoch, log) end_time = datetime.now() print('Total training time: {}.'.format(end_time - start_time)) torch.save(model.state_dict(), model_file) print('Wrote model to', model_file)
def train_main(): # Learning 1: New layers model = PretrainedNet().to(device) params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.SGD(params, lr=0.01) criterion = nn.CrossEntropyLoss() print(model) batch_size = 50 train_loader = get_train_loader(batch_size) validation_loader = get_validation_loader(batch_size) log = get_tensorboard('pretrained') epochs = 20 start_time = datetime.now() for epoch in range(1, epochs + 1): train(model, train_loader, criterion, optimizer, epoch, log) with torch.no_grad(): print('\nValidation:') evaluate(model, validation_loader, criterion, epoch, log) end_time = datetime.now() print('Total training time: {}.'.format(end_time - start_time)) torch.save(model.state_dict(), model_file) print('Wrote model to', model_file) # Learning 2: Fine-tuning log = get_tensorboard('finetuned') for name, layer in model.vgg_features.named_children(): note = ' ' for param in layer.parameters(): note = '-' if int(name) >= 24: param.requires_grad = True note = '+' print(name, note, layer, len(param)) params = filter(lambda p: p.requires_grad, model.parameters()) # optimizer = optim.SGD(model.parameters(), lr=1e-3) optimizer = optim.RMSprop(params, lr=1e-5) criterion = nn.CrossEntropyLoss() print(model) prev_epochs = epoch epochs = 20 start_time = datetime.now() for epoch in range(1, epochs + 1): train(model, train_loader, criterion, optimizer, prev_epochs + epoch, log) with torch.no_grad(): print('\nValidation:') evaluate(model, validation_loader, criterion, prev_epochs + epoch, log) end_time = datetime.now() print('Total training time: {}.'.format(end_time - start_time)) torch.save(model.state_dict(), model_file_ft) print('Wrote finetuned model to', model_file_ft)