def make_environment(images, labels, e): ''' https://github.com/facebookresearch/InvariantRiskMinimization ''' # different from the IRM repo, here the labels are already binarized images = images.reshape((-1, 28, 28)) # change label with prob 0.25 prob_label = torch.ones((10, 10)).float() * (0.25 / 9) for i in range(10): prob_label[i, i] = 0.75 labels_prob = torch.index_select(prob_label, dim=0, index=labels) labels = Categorical(probs=labels_prob).sample() # assign the color variable prob_color = torch.ones((10, 10)).float() * (e / 9.0) for i in range(10): prob_color[i, i] = 1 - e color_prob = torch.index_select(prob_color, dim=0, index=labels) color = Categorical(probs=color_prob).sample() # Apply the color to the image by zeroing out the other color channel output_images = torch.zeros((len(images), 10, 28, 28)) idx_dict = defaultdict(list) for i in range(len(images)): idx_dict[int(labels[i])].append(i) output_images[i, color[i], :, :] = images[i] cor = color.float() idx_list = list(range(len(images))) return { 'images': (output_images.float() / 255.), 'labels': labels.long(), 'idx_dict': idx_dict, 'idx_list': idx_list, 'cor': cor, }
def forward(self, decoder_output, mode='train'): ret = self.fc(decoder_output) if self.hard or mode != 'train': # apply temperature, do softmax command = self.identity( ret[..., :self.command_len]) / self.mix_temperature command_max = torch.max(command, dim=-1, keepdim=True)[0] command = torch.exp(command - command_max) command = command / torch.sum(command, dim=-1, keepdim=True) # sample from the given probs, this is the same as get_pi_idx # and already returns not soft prob # [seq_len, batch, command_len] command = Categorical(probs=command).sample() # [seq_len, batch] command = F.one_hot(command, self.command_len).to( decoder_output.device).float() # print(command.size()) arguments = ret[..., self.command_len:] # args are [seq_len, batch, 6*3*num_mix], and get [seq_len*batch*6, 3*num_mix] arguments = arguments.reshape([-1, 3 * self.num_mix]) mdn_coef = self.get_mdn_coef(arguments) out_logmix, out_mean, out_logstd = mdn_coef['logmix'], mdn_coef[ 'mean'], mdn_coef['logstd'] # these are [seq_len*batch*6, num_mix] # apply temp to logmix out_logmix = self.identity(out_logmix) / self.mix_temperature out_logmix_max = torch.max(out_logmix, dim=-1, keepdim=True)[0] out_logmix = torch.exp(out_logmix - out_logmix_max) out_logmix = out_logmix / torch.sum( out_logmix, dim=-1, keepdim=True) # get_pi_idx out_logmix = Categorical(probs=out_logmix).sample() # [seq_len*batch*arg_len] out_logmix = out_logmix.long().unsqueeze(1) out_logmix = torch.cat([ torch.arange(out_logmix.size(0), device=decoder_output.device).unsqueeze(1), out_logmix ], dim=-1) # [seq_len*batch*arg_len, 2] chosen_mean = [out_mean[i[0], i[1]] for i in out_logmix] chosen_logstd = [out_logstd[i[0], i[1]] for i in out_logmix] chosen_mean = torch.tensor(chosen_mean, device=decoder_output.device) chosen_logstd = torch.tensor(chosen_logstd, device=decoder_output.device) rand_gaussian = ( torch.randn(chosen_mean.size(), device=decoder_output.device) * math.sqrt(self.gauss_temperature)) arguments = chosen_mean + torch.exp(chosen_logstd) * rand_gaussian batch_size = command.size(1) arguments = arguments.reshape( -1, batch_size, self.arg_len) # [seg_len, batch, arg_len] # concat with the command we picked ret = torch.cat([command, arguments], dim=-1) return ret