Exemple #1
0
 def conditional(self, data, condition, level):
   base = self.base.log_prob(data).exp()
   delta = Normal(condition, self.sigma)
   delta = delta.log_prob(data)
   delta = delta.view(delta.size(0), -1).sum(dim=1, keepdim=True).exp()
   prob = (1 - level) * delta + level * base
   return prob.log()
Exemple #2
0
    def forward(
        self,
        source: torch.Tensor,  # (b, max_sou_seq_len)
        source_mask: torch.Tensor,  # (b, max_sou_seq_len)
        target: torch.Tensor,  # (b, max_tar_seq_len)
        target_mask: torch.Tensor,  # (b, max_tar_seq_len)
        label: torch.Tensor,  # (b, max_tar_seq_len)
        annealing: float
    ) -> Tuple[torch.Tensor, Tuple]:  # (b, max_tar_seq_len, d_emb)
        b = source.size(0)
        source_embedded = self.source_embed(
            source, source_mask)  # (b, max_sou_seq_len, d_s_emb)
        e_out, (hidden, _) = self.encoder(source_embedded, source_mask)

        h = self.transform(hidden, True)  # (n_e_lay * b, d_e_hid * n_dir)
        z_mu = self.z_mu(h)  # (n_e_lay * b, d_e_hid * n_dir)
        z_ln_var = self.z_ln_var(h)  # (n_e_lay * b, d_e_hid * n_dir)
        hidden = Gaussian(z_mu, z_ln_var).rsample()  # reparameterization trick
        # (n_e_lay * b, d_e_hid * n_dir) -> (b, d_e_hid * n_dir), initialize cell state
        states = (self.transform(hidden, False),
                  self.transform(hidden.new_zeros(hidden.size()), False))

        max_tar_seq_len = target.size(1)
        output = source_embedded.new_zeros(
            (b, max_tar_seq_len, self.target_vocab_size))
        target_embedded = self.target_embed(
            target, target_mask)  # (b, max_tar_seq_len, d_t_emb)
        target_embedded = target_embedded.transpose(
            1, 0)  # (max_tar_seq_len, b, d_t_emb)
        total_context_loss = 0
        # decode per word
        for i in range(max_tar_seq_len):
            d_out, states = self.decoder(target_embedded[i], target_mask[:, i],
                                         states)
            if self.attention:
                context, cs = self.calculate_context_vector(
                    e_out, states[0], source_mask, True)  # (b, d_d_hid)
                total_context_loss += self.calculate_context_loss(cs)
                d_out = torch.cat((d_out, context), dim=-1)  # (b, d_d_hid * 2)
            output[:, i, :] = self.w(self.maxout(
                d_out))  # (b, d_d_hid) -> (b, d_out) -> (b, tar_vocab_size)
        loss, details = self.calculate_loss(output, target_mask, label, z_mu,
                                            z_ln_var, total_context_loss,
                                            annealing)
        if torch.isnan(loss).any():
            raise ValueError('nan detected')
        return loss, details
Exemple #3
0
    def predict(
        self,
        source: torch.Tensor,  # (b, max_sou_seq_len)
        source_mask: torch.Tensor,  # (b, max_sou_seq_len)
        sampling: bool = True
    ) -> torch.Tensor:  # (b, max_seq_len)
        self.eval()
        with torch.no_grad():
            b = source.size(0)
            source_embedded = self.source_embed(
                source, source_mask)  # (b, max_seq_len, d_s_emb)
            e_out, (hidden, _) = self.encoder(source_embedded, source_mask)

            h = self.transform(hidden, True)
            z_mu = self.z_mu(h)
            z_ln_var = self.z_ln_var(h)
            hidden = Gaussian(z_mu, z_ln_var).sample() if sampling else z_mu
            states = (self.transform(hidden, False),
                      self.transform(hidden.new_zeros(hidden.size()), False))

            target_id = torch.full((b, 1), BOS,
                                   dtype=source.dtype).to(source.device)
            target_mask = torch.full(
                (b, 1), 1, dtype=source_mask.dtype).to(source_mask.device)
            predictions = source_embedded.new_zeros(b, self.max_seq_len, 1)
            for i in range(self.max_seq_len):
                target_embedded = self.target_embed(
                    target_id, target_mask).squeeze(1)  # (b, d_t_emb)
                d_out, states = self.decoder(target_embedded,
                                             target_mask[:, 0], states)
                if self.attention:
                    context, _ = self.calculate_context_vector(
                        e_out, states[0], source_mask, False)
                    d_out = torch.cat((d_out, context), dim=-1)

                output = self.w(self.maxout(d_out))  # (b, tar_vocab_size)
                output[:, UNK] -= 1e6  # mask <UNK>
                if i == 0:
                    output[:, EOS] -= 1e6  # avoid 0 length output
                prediction = torch.argmax(F.softmax(output, dim=1),
                                          dim=1).unsqueeze(1)  # (b, 1), greedy
                target_mask = target_mask * prediction.ne(EOS).type(
                    target_mask.dtype)
                target_id = prediction
                predictions[:, i, :] = prediction
        return predictions
    def test_normal(self):
        mean = Variable(torch.randn(5, 5), requires_grad=True)
        std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
        mean_1d = Variable(torch.randn(1), requires_grad=True)
        std_1d = Variable(torch.randn(1), requires_grad=True)
        mean_delta = torch.Tensor([1.0, 0.0])
        std_delta = torch.Tensor([1e-5, 1e-5])
        self.assertEqual(Normal(mean, std).sample().size(), (5, 5))
        self.assertEqual(Normal(mean, std).sample_n(7).size(), (7, 5, 5))
        self.assertEqual(Normal(mean_1d, std_1d).sample_n(1).size(), (1, 1))
        self.assertEqual(Normal(mean_1d, std_1d).sample().size(), (1, ))
        self.assertEqual(Normal(0.2, .6).sample_n(1).size(), (1, ))
        self.assertEqual(Normal(-0.7, 50.0).sample_n(1).size(), (1, ))

        # sample check for extreme value of mean, std
        self._set_rng_seed(1)
        self.assertEqual(Normal(mean_delta,
                                std_delta).sample(sample_shape=(1, 2)),
                         torch.Tensor([[[1.0, 0.0], [1.0, 0.0]]]),
                         prec=1e-4)

        self._gradcheck_log_prob(Normal, (mean, std))
        self._gradcheck_log_prob(Normal, (mean, 1.0))
        self._gradcheck_log_prob(Normal, (0.0, std))

        state = torch.get_rng_state()
        eps = torch.normal(torch.zeros_like(mean), torch.ones_like(std))
        torch.set_rng_state(state)
        z = Normal(mean, std).rsample()
        z.backward(torch.ones_like(z))
        self.assertEqual(mean.grad, torch.ones_like(mean))
        self.assertEqual(std.grad, eps)
        mean.grad.zero_()
        std.grad.zero_()
        self.assertEqual(z.size(), (5, 5))

        def ref_log_prob(idx, x, log_prob):
            m = mean.data.view(-1)[idx]
            s = std.data.view(-1)[idx]
            expected = (math.exp(-(x - m)**2 / (2 * s**2)) /
                        math.sqrt(2 * math.pi * s**2))
            self.assertAlmostEqual(log_prob, math.log(expected), places=3)

        self._check_log_prob(Normal(mean, std), ref_log_prob)
    def test_normal(self):
        mean = Variable(torch.randn(5, 5), requires_grad=True)
        std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
        mean_1d = Variable(torch.randn(1), requires_grad=True)
        std_1d = Variable(torch.randn(1), requires_grad=True)
        self.assertEqual(Normal(mean, std).sample().size(), (5, 5))
        self.assertEqual(Normal(mean, std).sample_n(7).size(), (7, 5, 5))
        self.assertEqual(Normal(mean_1d, std_1d).sample_n(1).size(), (1, 1))
        self.assertEqual(Normal(mean_1d, std_1d).sample().size(), (1, ))
        self.assertEqual(Normal(0.2, .6).sample_n(1).size(), (1, 1))
        self.assertEqual(Normal(-0.7, 50.0).sample_n(1).size(), (1, 1))

        self._gradcheck_log_prob(Normal, (mean, std))
        self._gradcheck_log_prob(Normal, (mean, 1.0))
        self._gradcheck_log_prob(Normal, (0.0, std))

        state = torch.get_rng_state()
        eps = torch.normal(torch.zeros_like(mean), torch.ones_like(std))
        torch.set_rng_state(state)
        z = Normal(mean, std).rsample()
        z.backward(torch.ones_like(z))
        self.assertEqual(mean.grad, torch.ones_like(mean))
        self.assertEqual(std.grad, eps)
        mean.grad.zero_()
        std.grad.zero_()
        self.assertEqual(z.size(), (5, 5))

        def ref_log_prob(idx, x, log_prob):
            m = mean.data.view(-1)[idx]
            s = std.data.view(-1)[idx]
            expected = (math.exp(-(x - m)**2 / (2 * s**2)) /
                        math.sqrt(2 * math.pi * s**2))
            self.assertAlmostEqual(log_prob, math.log(expected), places=3)

        self._check_log_prob(Normal(mean, std), ref_log_prob)

        def call_sample_wshape_gt_2():
            return Normal(mean, std).sample((1, 2))

        self.assertRaises(NotImplementedError, call_sample_wshape_gt_2)
    def __getitem__(self, index):
        data = self.data[index]

        if 'image' in data:
            image = data['image'].copy()
        elif 'original_image' in data:
            image = data['original_image'].copy()
        else:
            raise RuntimeError(data.keys())
        image = torch.tensor(image.transpose([2, 0, 1]) / 255.0,
                             dtype=torch.float32)

        boxes_xyxy = np.array(self.proposals[index]['boxes'], dtype=np.float32)

        # # visualization
        # import matplotlib.pyplot as plt
        # from space.utils.plt_utils import draw_boxes
        # draw_boxes(data['image'], boxes_xyxy)
        # plt.show()

        if len(boxes_xyxy) == 0:
            patches = torch.empty((0, 3) + self.size, dtype=torch.float32)
            boxes_xywh_tensor = torch.empty((0, 4), dtype=torch.float32)
            edge_index = torch.empty((2, 0), dtype=torch.int64)
            # print(f'Data point {index} has empty detection.')
        else:
            # Normalize boxes
            boxes_xyxy_tensor = torch.tensor(boxes_xyxy, dtype=torch.float32)
            boxes_xyxy_tensor[:, [0, 2]] /= image.shape[2]
            boxes_xyxy_tensor[:, [1, 3]] /= image.shape[1]
            boxes_xywh_tensor = boxes_xyxy2xywh(boxes_xyxy_tensor)

            if self.fixed_crop:
                boxes_xywh_tensor[:, 2] = self.size[1] / image.shape[2]
                boxes_xywh_tensor[:, 3] = self.size[0] / image.shape[1]

            # add augmentation
            if self.std:
                std_tensor = boxes_xywh_tensor.new_tensor(self.std)
                boxes_xywh_tensor = Normal(boxes_xywh_tensor,
                                           std_tensor).sample()
            patches = image_to_glimpse(image.unsqueeze(0),
                                       boxes_xywh_tensor.unsqueeze(0),
                                       self.size)

            n = boxes_xywh_tensor.size(0)
            edge_index = torch.tensor([[i, j] for i in range(n)
                                       for j in range(n)],
                                      dtype=torch.int64).transpose(0, 1)

        # get target
        if 'action' in data:
            action = data['action']  # scalar
        else:
            action = data['q'].argmax()

        out = Data(
            x=patches,
            action=torch.tensor([action], dtype=torch.int64),
            edge_index=edge_index.long(),
            pos=boxes_xywh_tensor.float(),
            idx=torch.tensor([index],
                             dtype=torch.int64),  # for visualization and dp
            size=torch.tensor([1], dtype=torch.int64),  # indicate batch size
        )
        return out