def train_method(path, lr=1e-3, epochs=50, dropout=.2, batch_size=1024):
    print('TRAINING MODEL')
    dataset = MyDataset(bucket='test', path=path)
    data_loader = DataLoader(dataset)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = MDN(dataset.shape[1] - 1, dropout=dropout).to(device=device)
    optimizer = optim.Adam(net.parameters(), lr=lr)

    for i in range(1, epochs + 1):
        for file_data in data_loader:
            file_data = file_data.squeeze().float()
            x_file, y_file = file_data[:, :-1], file_data[:, -1]
            for j in range(0, x_file.shape[0], batch_size):
                x_batch = x_file[j:j + batch_size].to(device=device)
                y_batch = y_file[j:j + batch_size].to(device=device)
                y_train_pred = net(x_batch).float()

                loss = Normal(y_train_pred[:, 0], y_train_pred[:, 1])
                loss = -loss.log_prob(y_batch).mean()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        print(i, loss.item())
Ejemplo n.º 2
0
    def test_normal(self):
        mean = Variable(torch.randn(5, 5), requires_grad=True)
        std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
        mean_1d = Variable(torch.randn(1), requires_grad=True)
        std_1d = Variable(torch.randn(1), requires_grad=True)
        mean_delta = torch.Tensor([1.0, 0.0])
        std_delta = torch.Tensor([1e-5, 1e-5])
        self.assertEqual(Normal(mean, std).sample().size(), (5, 5))
        self.assertEqual(Normal(mean, std).sample_n(7).size(), (7, 5, 5))
        self.assertEqual(Normal(mean_1d, std_1d).sample_n(1).size(), (1, 1))
        self.assertEqual(Normal(mean_1d, std_1d).sample().size(), (1, ))
        self.assertEqual(Normal(0.2, .6).sample_n(1).size(), (1, ))
        self.assertEqual(Normal(-0.7, 50.0).sample_n(1).size(), (1, ))

        # sample check for extreme value of mean, std
        self._set_rng_seed(1)
        self.assertEqual(Normal(mean_delta,
                                std_delta).sample(sample_shape=(1, 2)),
                         torch.Tensor([[[1.0, 0.0], [1.0, 0.0]]]),
                         prec=1e-4)

        self._gradcheck_log_prob(Normal, (mean, std))
        self._gradcheck_log_prob(Normal, (mean, 1.0))
        self._gradcheck_log_prob(Normal, (0.0, std))

        state = torch.get_rng_state()
        eps = torch.normal(torch.zeros_like(mean), torch.ones_like(std))
        torch.set_rng_state(state)
        z = Normal(mean, std).rsample()
        z.backward(torch.ones_like(z))
        self.assertEqual(mean.grad, torch.ones_like(mean))
        self.assertEqual(std.grad, eps)
        mean.grad.zero_()
        std.grad.zero_()
        self.assertEqual(z.size(), (5, 5))

        def ref_log_prob(idx, x, log_prob):
            m = mean.data.view(-1)[idx]
            s = std.data.view(-1)[idx]
            expected = (math.exp(-(x - m)**2 / (2 * s**2)) /
                        math.sqrt(2 * math.pi * s**2))
            self.assertAlmostEqual(log_prob, math.log(expected), places=3)

        self._check_log_prob(Normal(mean, std), ref_log_prob)
Ejemplo n.º 3
0
    def test_normal(self):
        mean = Variable(torch.randn(5, 5), requires_grad=True)
        std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
        mean_1d = Variable(torch.randn(1), requires_grad=True)
        std_1d = Variable(torch.randn(1), requires_grad=True)
        self.assertEqual(Normal(mean, std).sample().size(), (5, 5))
        self.assertEqual(Normal(mean, std).sample_n(7).size(), (7, 5, 5))
        self.assertEqual(Normal(mean_1d, std_1d).sample_n(1).size(), (1, 1))
        self.assertEqual(Normal(mean_1d, std_1d).sample().size(), (1, ))
        self.assertEqual(Normal(0.2, .6).sample_n(1).size(), (1, 1))
        self.assertEqual(Normal(-0.7, 50.0).sample_n(1).size(), (1, 1))

        self._gradcheck_log_prob(Normal, (mean, std))
        self._gradcheck_log_prob(Normal, (mean, 1.0))
        self._gradcheck_log_prob(Normal, (0.0, std))

        state = torch.get_rng_state()
        eps = torch.normal(torch.zeros_like(mean), torch.ones_like(std))
        torch.set_rng_state(state)
        z = Normal(mean, std).rsample()
        z.backward(torch.ones_like(z))
        self.assertEqual(mean.grad, torch.ones_like(mean))
        self.assertEqual(std.grad, eps)
        mean.grad.zero_()
        std.grad.zero_()
        self.assertEqual(z.size(), (5, 5))

        def ref_log_prob(idx, x, log_prob):
            m = mean.data.view(-1)[idx]
            s = std.data.view(-1)[idx]
            expected = (math.exp(-(x - m)**2 / (2 * s**2)) /
                        math.sqrt(2 * math.pi * s**2))
            self.assertAlmostEqual(log_prob, math.log(expected), places=3)

        self._check_log_prob(Normal(mean, std), ref_log_prob)

        def call_sample_wshape_gt_2():
            return Normal(mean, std).sample((1, 2))

        self.assertRaises(NotImplementedError, call_sample_wshape_gt_2)