def main(): """Train a PyTorch CNN analog model with the MNIST dataset.""" # Make sure the directory where to save the results exist. # Results include: Loss vs Epoch graph, Accuracy vs Epoch graph and vector data. os.makedirs(RESULTS, exist_ok=True) manual_seed(SEED) # Load datasets. train_data, validation_data = load_images() # Load the pytorch model model = create_model() # Convert the model to its analog version model = convert_to_analog(model, RPU_CONFIG, weight_scaling_omega=0.6) # Load saved weights if previously saved # model.load_state_dict(load(WEIGHT_PATH)) if USE_CUDA: model.cuda() optimizer = create_sgd_optimizer(model, LEARNING_RATE) criterion = nn.CrossEntropyLoss() print(f'\n{datetime.now().time().replace(microsecond=0)} --- ' f'Started ResNet Training') model, optimizer = training_loop(model, criterion, optimizer, train_data, validation_data, N_EPOCHS) print(f'{datetime.now().time().replace(microsecond=0)} --- ' f'Completed ResNet Training') save(model.state_dict(), WEIGHT_PATH)
def test_conversion_torchvision_alexnet(self): """Test converting resnet model from torchvision""" model = alexnet() if self.use_cuda: model = model.cuda() analog_model = convert_to_analog(model, FloatingPointRPUConfig()) self.assertEqual(analog_model.features[0].__class__, AnalogConv2d) self.assertEqual(analog_model.classifier[6].__class__, AnalogLinear) self.assertEqual(analog_model.features.__class__, AnalogSequential)
def test_conversion_torchvision_resnet(self): """Test converting resnet model from torchvision.""" model = resnet18() if self.use_cuda: model = model.cuda() analog_model = convert_to_analog(model, FloatingPointRPUConfig()) self.assertEqual(analog_model.conv1.__class__, AnalogConv2d) self.assertEqual(analog_model.layer1.__class__, AnalogSequential) self.assertEqual(analog_model.layer1[0].conv1.__class__, AnalogConv2d)
def main(): """Load a predefined model from pytorch library and convert to its analog version.""" # Load the pytorch model. model = resnet34() print(model) # Convert the model to its analog version. model = convert_to_analog(model, RPU_CONFIG, weight_scaling_omega=0.6) print(model)
def test_load_state_dict_conversion(self): """Test loading and setting conversion with alpha.""" # Create the device and the array. x_b = Tensor([[0.1, 0.2, 0.3, 0.4], [0.2, 0.4, 0.3, 0.1]]) y_b = Tensor([[0.3], [0.6]]) if self.use_cuda: x_b = x_b.cuda() y_b = y_b.cuda() model = self.get_torch_model(self.use_cuda) self.train_model_torch(model, mse_loss, x_b, y_b) analog_model = convert_to_analog(model, self.get_rpu_config()) analog_loss = mse_loss(analog_model(x_b), y_b) with TemporaryFile() as file: save(analog_model.state_dict(), file) # Load the model. file.seek(0) model = self.get_torch_model(self.use_cuda) new_analog_model = convert_to_analog(model, self.get_rpu_config()) state_dict = load(file) new_analog_model.load_state_dict(state_dict, load_rpu_config=True) new_state_dict = new_analog_model.state_dict() for key in new_state_dict.keys(): if not key.endswith('analog_tile_state'): continue state1 = new_state_dict[key] state2 = state_dict[key] assert_array_almost_equal(state1['analog_tile_weights'], state2['analog_tile_weights']) assert_array_almost_equal(state1['analog_alpha_scale'], state2['analog_alpha_scale']) new_analog_loss = mse_loss(new_analog_model(x_b), y_b) self.assertTensorAlmostEqual(new_analog_loss, analog_loss)
def test_conversion_linear_sequential(self): """Test converting sequential and linear.""" loss_func = mse_loss x_b = Tensor([[0.1, 0.2, 0.3, 0.4], [0.2, 0.4, 0.3, 0.1]]) y_b = Tensor([[0.3], [0.6]]) manual_seed(4321) model = Sequential(Linear(4, 3), Linear(3, 3), Sequential(Linear(3, 1), Linear(1, 1))) if self.use_cuda: x_b = x_b.cuda() y_b = y_b.cuda() model = model.cuda() self.train_model_torch(model, loss_func, x_b, y_b) digital_loss = loss_func(model(x_b), y_b) analog_model = convert_to_analog(model, FloatingPointRPUConfig()) self.assertEqual(analog_model[0].__class__, AnalogLinear) self.assertEqual(analog_model.__class__, AnalogSequential) self.assertTensorAlmostEqual(loss_func(analog_model(x_b), y_b), digital_loss)