Ejemplo n.º 1
0
    def make_environment(images, labels, e):
        '''
            https://github.com/facebookresearch/InvariantRiskMinimization
        '''
        # different from the IRM repo, here the labels are already binarized
        images = images.reshape((-1, 28, 28))

        # change label with prob 0.25
        prob_label = torch.ones((10, 10)).float() * (0.25 / 9)
        for i in range(10):
            prob_label[i, i] = 0.75

        labels_prob = torch.index_select(prob_label, dim=0, index=labels)
        labels = Categorical(probs=labels_prob).sample()

        # assign the color variable
        prob_color = torch.ones((10, 10)).float() * (e / 9.0)
        for i in range(10):
            prob_color[i, i] = 1 - e

        color_prob = torch.index_select(prob_color, dim=0, index=labels)
        color = Categorical(probs=color_prob).sample()

        # Apply the color to the image by zeroing out the other color channel
        output_images = torch.zeros((len(images), 10, 28, 28))

        idx_dict = defaultdict(list)
        for i in range(len(images)):
            idx_dict[int(labels[i])].append(i)
            output_images[i, color[i], :, :] = images[i]

        cor = color.float()

        idx_list = list(range(len(images)))

        return {
            'images': (output_images.float() / 255.),
            'labels': labels.long(),
            'idx_dict': idx_dict,
            'idx_list': idx_list,
            'cor': cor,
        }
Ejemplo n.º 2
0
    def forward(self, decoder_output, mode='train'):
        ret = self.fc(decoder_output)
        if self.hard or mode != 'train':
            # apply temperature, do softmax
            command = self.identity(
                ret[..., :self.command_len]) / self.mix_temperature
            command_max = torch.max(command, dim=-1, keepdim=True)[0]
            command = torch.exp(command - command_max)
            command = command / torch.sum(command, dim=-1, keepdim=True)

            # sample from the given probs, this is the same as get_pi_idx
            # and already returns not soft prob
            # [seq_len, batch, command_len]
            command = Categorical(probs=command).sample()
            # [seq_len, batch]
            command = F.one_hot(command, self.command_len).to(
                decoder_output.device).float()
            # print(command.size())

            arguments = ret[..., self.command_len:]
            # args are [seq_len, batch, 6*3*num_mix], and get [seq_len*batch*6, 3*num_mix]
            arguments = arguments.reshape([-1, 3 * self.num_mix])
            mdn_coef = self.get_mdn_coef(arguments)
            out_logmix, out_mean, out_logstd = mdn_coef['logmix'], mdn_coef[
                'mean'], mdn_coef['logstd']
            # these are [seq_len*batch*6, num_mix]

            # apply temp to logmix
            out_logmix = self.identity(out_logmix) / self.mix_temperature
            out_logmix_max = torch.max(out_logmix, dim=-1, keepdim=True)[0]
            out_logmix = torch.exp(out_logmix - out_logmix_max)
            out_logmix = out_logmix / torch.sum(
                out_logmix, dim=-1, keepdim=True)
            # get_pi_idx
            out_logmix = Categorical(probs=out_logmix).sample()
            # [seq_len*batch*arg_len]
            out_logmix = out_logmix.long().unsqueeze(1)
            out_logmix = torch.cat([
                torch.arange(out_logmix.size(0),
                             device=decoder_output.device).unsqueeze(1),
                out_logmix
            ],
                                   dim=-1)
            # [seq_len*batch*arg_len, 2]
            chosen_mean = [out_mean[i[0], i[1]] for i in out_logmix]
            chosen_logstd = [out_logstd[i[0], i[1]] for i in out_logmix]
            chosen_mean = torch.tensor(chosen_mean,
                                       device=decoder_output.device)
            chosen_logstd = torch.tensor(chosen_logstd,
                                         device=decoder_output.device)

            rand_gaussian = (
                torch.randn(chosen_mean.size(), device=decoder_output.device) *
                math.sqrt(self.gauss_temperature))
            arguments = chosen_mean + torch.exp(chosen_logstd) * rand_gaussian

            batch_size = command.size(1)
            arguments = arguments.reshape(
                -1, batch_size, self.arg_len)  # [seg_len, batch, arg_len]

            # concat with the command we picked
            ret = torch.cat([command, arguments], dim=-1)

        return ret