def prepare_model(learning_rate, momentum, checkpoint_file):
    """Prepare a ResNet-34 model with CrossEntropyLoss and SGD.

    Args:
        learning_rate (float): The learning rate for SGD.
        momentum (float): The momentum for SGD.
        checkpoint_file (str or None): If not `None`, the path of the
            checkpoint file to load.

    Returns:
        model.Model: The prepared model object.
    """
    # Load model.
    resnet = torchvision.models.resnet34()
    resnet.conv1 = torch.nn.Conv2d(1,
                                   64,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   bias=False)
    resnet.avgpool = torch.nn.AvgPool2d(2)

    # Prepare loss function and optimizer.
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(resnet.parameters(),
                                lr=learning_rate,
                                momentum=momentum)

    # Wrap model object and load checkpoint file if provided.
    model = ModelWrapper(resnet, loss_function, optimizer)
    if checkpoint_file:
        model.load(checkpoint_file)

    return model
Esempio n. 2
0
model_list = os.listdir(model_dir)
print('\nFind model:')

for num, model in enumerate(model_list):
    print(num, model)

model_name = model_list[int(input('\nChoose a model (enter a number): '))]
print('Load Model {}'.format(model_name))

model_path = os.path.join(model_dir, model_name)

learning_rate = float(model_name.split('-')[2])
img_size = int(model_name.split('-')[3])

model = ModelWrapper(learning_rate, img_size).model
model.load(os.path.join(model_path, model_name))
print('Model loaded!')

test_set_processed_data_path = os.path.join(processed_data_dir,
                                            'test_{}.npy'.format(img_size))
if os.path.exists(test_set_processed_data_path):
    test_set_data = np.load(test_set_processed_data_path)
    print('Data loaded!')
else:
    test_set_data = process_data.process_test_set_data(
        img_size, test_set_raw_data_dir, test_set_processed_data_path)
    print('Data processed!')

fig = plt.figure()

for num, data in enumerate(test_set_data[:12]):