Exemplo n.º 1
0
def assert_params_changed(model, input, exclude=[]):
    """
    Check if all model-parameters are updated when training.

    Args:
        model (torch.nn.Module): model to test
        data (torch.utils.data.Dataset): input dataset
        exclude (list): layers that are not necessarily updated
    """
    # save state-dict
    torch.save(model.state_dict(), 'before')
    # do one training step
    optimizer = Adam(model.parameters())
    loss_fn = MSELoss()
    pred = model(*input)
    loss = loss_fn(pred, torch.rand(pred.shape))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # check if all trainable parameters have changed
    after = model.state_dict()
    before = torch.load('before')
    for key in before.keys():
        if np.array([key.startswith(exclude_layer) for exclude_layer in exclude]).any():
            continue
        assert (before[key] != after[key]).any(), '{} layer has not been updated!'.format(key)
Exemplo n.º 2
0
    def __init__(self,
                 state_space,
                 channels,
                 action_space,
                 epsilon=0.99,
                 epsilon_min=0.01,
                 epsilon_decay=0.99,
                 gamma=0.9,
                 learning_rate=0.01):
        super(Agent, self).__init__()
        self.action_space = action_space
        self.state_space = state_space
        self.channels = channels
        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.gamma = gamma

        self.conv1 = Conv2d(self.channels, 32, 8)
        self.conv2 = Conv2d(32, 64, 4)
        self.conv3 = Conv2d(64, 128, 3)
        self.fc1 = Linear(128 * 52 * 52, 64)
        self.fc2 = Linear(64, 32)
        self.output = Linear(32, action_space)

        self.loss_fn = MSELoss()
        self.optimizer = Adam(self.parameters(), lr=self.learning_rate)
Exemplo n.º 3
0
    def __init__(self, state_space, action_space, **kwargs):
        super(Agent, self).__init__()

        self.state_space = state_space
        self.action_space = action_space

        self.epsilon = HyperParams.EPSILON.value
        self.epsilon_min = HyperParams.EPSILON_MIN.value
        self.epsilon_decay = HyperParams.EPSILON_DECAY.value
        self.gamma = HyperParams.GAMMA.value
        self.learning_rate = HyperParams.LEARNING_RATE.value

        self.override_hyper_params(kwargs)

        self.in_layer = Linear(state_space, 128)
        self.hidden_layer = Linear(128, 64)
        self.out_layer = Linear(64, action_space)

        self.loss_fn = MSELoss()
        self.optimizer = Adam(self.parameters(), lr=self.learning_rate)
Exemplo n.º 4
0
 def __init__(self, threshold=0.5):
     super(BoxLoss, self).__init__()
     self.threshold = threshold
     self.mse = MSELoss()
Exemplo n.º 5
0
 def __init__(self, size_average=None, reduce=None, reduction='mean'):
     super(Wav2EdgeLoss, self).__init__()
     self.loss = MSELoss(size_average, reduce, reduction)
Exemplo n.º 6
0
# Loading the data into batches
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

####################################################################################################
#                                   MAIN CODE
####################################################################################################

# Defining the Model
model = Autoencoder()
model.to(device)
print("Model loaded on GPU")

# Defining the Optimziers and Loss Functions
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
loss = MSELoss()
train_per_epoch = int(dataset.__len__() / 128)
target = int(train_per_epoch * (1 - split_ratio))

# Training and Validation Loop
print("Starting the Training")
for k in range(epochs):

    # EPOCH START
    print("\nEpoch = ", k + 1)
    kbar = pkbar.Kbar(target=target + 1, width=12)

    # Training and Validation of Epoch
    j, final_val_loss = 0, torch.tensor(0)
    for i, (input_data, output_data) in enumerate(dataloader):