def test_beta_shape_tensor_params(self): dist = Beta(torch.Tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), torch.Tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])) self.assertEqual(dist._batch_shape, torch.Size((3, 2))) self.assertEqual(dist._event_shape, torch.Size(())) self.assertEqual(dist.sample().size(), torch.Size((3, 2))) self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
def test_beta_shape_scalar_params(self): dist = Beta(0.1, 0.1) self.assertEqual(dist._batch_shape, torch.Size()) self.assertEqual(dist._event_shape, torch.Size()) self.assertEqual(dist.sample().size(), torch.Size((1,))) self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2))) self.assertRaises(ValueError, dist.log_prob, self.scalar_sample) self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_e_log_stick(): """ This test DOES NOT PASS, and maybe should not """ model = InfiniteIBP(4., 10, 0.1, 0.5, 36) model.init_z(10) K = model.K # take a lot of samples to get something working dist = Beta(model.tau.detach()[:, 0], model.tau.detach()[:, 1]) samples = dist.sample((100000, )) f = (1. - samples.cumprod(1)).log().mean(0) log_stick, q = model._E_log_stick(model.tau, model.K) jeffrey_q = np.zeros((K, K)) jeffrey_log_stick = np.zeros((K, )) for k in range(K): a, b = compute_q_Elogstick(model.tau.detach().numpy().T, k) jeffrey_q[k, :k + 1] = a jeffrey_log_stick[k] = b print("old: {}".format(jeffrey_log_stick)) print("new: {}".format(log_stick.detach().numpy())) print("samples: {}".format(f.detach().numpy())) import ipdb ipdb.set_trace()
def sample(self, datas): alpha, beta = datas distribution = Beta(alpha, beta) action = distribution.sample().float().to(set_device(self.use_gpu)) return action
class MixUp(Callback): run_valid = False def __init__(self, alpha=0.4, onehot=False): self.alpha = alpha self.distrib = Beta(alpha, alpha) self.onehot = onehot def before_batch(self): bs = self.xb[0].shape[0] device = self.xb[0].device lambd = self.distrib.sample( (self.y.size(0), )).squeeze().to(self.x.device) lambd = torch.stack([lambd, 1 - lambd], 1).max(1)[0] shuffle = torch.randperm(bs).to(device) xb1, yb1 = self.xb[0][shuffle], self.yb[0][shuffle] a = tensor(lambd).float().view(-1, 1, 1, 1).to(device) self.learn.xb = tuple([a * self.xb[0] + (1 - a) * xb1]) a = a.view(-1) if self.onehot: while len(a.shape) < len(yb1.shape): a = a[..., None] self.learn.yb = tuple([a * self.learn.yb[0] + (1 - a) * yb1]) else: self.learn.yb = tuple([{ 'yb': self.learn.yb[0], 'yb1': yb1, 'a': a }])
def chooseActionTrain(self, state): """ Choose an action during training mode Parameters ------- state: The current state of the car. Returns ------- action : np.ndarray The actions to run on the track coefficient : float The logarithmic probability for an action Notes ------- This function is only called when the --train flag IS provided. """ state = torch.from_numpy(state).double().to( self.hardwareDevice).unsqueeze(0) with torch.no_grad(): alpha, beta = self.nn(state)[0] dist = Beta(alpha, beta) action = dist.sample() coefficient = dist.log_prob(action).sum(dim=1) action = action.squeeze().cpu().numpy() coefficient = coefficient.item() return action, coefficient
def test_beta_likelihood(concentration1: float, concentration0: float) -> None: """ Test to check that maximizing the likelihood recovers the parameters """ # generate samples concentration1s = torch.zeros((NUM_SAMPLES, )) + concentration1 concentration0s = torch.zeros((NUM_SAMPLES, )) + concentration0 distr = Beta(concentration1s, concentration0s) samples = distr.sample() init_biases = [ inv_softplus(concentration1 - START_TOL_MULTIPLE * TOL * concentration1), inv_softplus(concentration0 - START_TOL_MULTIPLE * TOL * concentration0), ] concentration1_hat, concentration0_hat = maximum_likelihood_estimate_sgd( BetaOutput(), samples, init_biases=init_biases, learning_rate=PositiveFloat(0.05), num_epochs=PositiveInt(10), ) assert ( np.abs(concentration1_hat - concentration1) < TOL * concentration1 ), f"concentration1 did not match: concentration1 = {concentration1}, concentration1_hat = {concentration1_hat}" assert ( np.abs(concentration0_hat - concentration0) < TOL * concentration0 ), f"concentration0 did not match: concentration0 = {concentration0}, concentration0_hat = {concentration0_hat}"
class CutMix(Callback): """ Cutmix callback which replaces a random patch of image data with the corresponding patch from another image. This callback also converts labels to one hot before combining them according to the lambda parameters, sampled from a beta distribution as is done in the paper. Example: :: >>> from torchbearer import Trial >>> from torchbearer.callbacks import CutMix # Example Trial which does CutMix regularisation >>> cutmix = CutMix(1, classes=10) >>> trial = Trial(None, callbacks=[cutmix], metrics=['acc']) Args: alpha (float): The alpha value for the beta distribution. classes (int): The number of classes for conversion to one hot. State Requirements: - :attr:`torchbearer.state.X`: State should have the current data stored - :attr:`torchbearer.state.Y_TRUE`: State should have the current data stored """ def __init__(self, alpha, classes=-1): super(CutMix, self).__init__() self.classes = classes self.dist = Beta(torch.tensor([float(alpha)]), torch.tensor([float(alpha)])) def _to_one_hot(self, target): if target.dim() == 1: target = target.unsqueeze(1) one_hot = torch.zeros_like(target).repeat(1, self.classes) one_hot.scatter_(1, target, 1) return one_hot return target def on_sample(self, state): super(CutMix, self).on_sample(state) lam = self.dist.sample().to(state[torchbearer.DEVICE]) length = (1 - lam).sqrt() cutter = BatchCutout( 1, (length * state[torchbearer.X].size(-1)).round().item(), (length * state[torchbearer.X].size(-2)).round().item()) mask = cutter(state[torchbearer.X]) erase_locations = mask == 0 permutation = torch.randperm(state[torchbearer.X].size(0)) state[torchbearer.X][erase_locations] = state[ torchbearer.X][permutation][erase_locations] target = self._to_one_hot(state[torchbearer.TARGET]).float() state[torchbearer. TARGET] = lam * target + (1 - lam) * target[permutation] def on_sample_validation(self, state): super(CutMix, self).on_sample_validation(state) state[torchbearer.TARGET] = self._to_one_hot( state[torchbearer.TARGET]).float()
def kl_bernoulli(pi, step, args): cap = min(args.h_cap, step * args.h_cap / args.total_steps) beta_dist = Beta(torch.ones_like(pi) * args.alpha_0, torch.ones_like(pi)) pi_prior = Bernoulli(torch.cumprod(beta_dist.sample(), dim=-1)) pi_posterior = Bernoulli(pi) klh_loss = kl_divergence(pi_posterior, pi_prior).sum(dim=1).mean() cap_klh_loss = args.gamma_h * (klh_loss - cap).abs() return cap_klh_loss
def test_beta_log_prob(self): for _ in range(100): alpha = np.exp(np.random.normal()) beta = np.exp(np.random.normal()) dist = Beta(alpha, beta) x = dist.sample() actual_log_prob = dist.log_prob(x).sum() expected_log_prob = scipy.stats.beta.logpdf(x, alpha, beta)[0] self.assertAlmostEqual(actual_log_prob, expected_log_prob, places=3, allow_inf=True)
def sample(self, device, epoch, num=64): sample = torch.randn(num, self.latent_dim).to(device) x_alpha, x_beta = self.decode(sample) beta = Beta(x_alpha, x_beta) p = beta.sample() binomial = Binomial(255, p) x_sample = binomial.sample() x_sample = x_sample.float() / 255. save_image(x_sample.view(num, 1, 28, 28), 'results/epoch_{}_samples.png'.format(epoch))
def select_action(self, state): state = torch.from_numpy(state).double().to(device).unsqueeze(0) with torch.no_grad(): alpha, beta = self.net(state)[0] dist = Beta(alpha, beta) action = dist.sample() # 3 values in [0,1] a_logp = dist.log_prob(action).sum(dim=1) # For PPO action = action.squeeze().cpu().numpy() a_logp = a_logp.item() return action, a_logp
def select_action(self, state): state = torch.from_numpy(state).double().to(device).unsqueeze(0) with torch.no_grad(): (alpha, beta), _, rcrc_s = self.net(state) dist = Beta(alpha, beta) action = dist.sample() a_logp = dist.log_prob(action).sum(dim=1) action = action.squeeze().cpu().numpy() a_logp = a_logp.item() return action, a_logp, rcrc_s
def select_action(self, state): # deal with datatype of state and transform it state = torch.from_numpy(state).double().unsqueeze(0) with torch.no_grad(): alpha, beta = self.net(state)[0] dist = Beta(alpha, beta) action = dist.sample() # sampled action in interval (0, 1) a_logp = dist.log_prob(action).sum( dim=1) # add the log probability densities of the 3-stack action = action.squeeze().numpy() a_logp = a_logp.item() return action, a_logp
def mixup(x, y, num_classes, gamma=0.2, smooth_eps=0.1): if gamma == 0 and smooth_eps == 0: return x, y m = Beta(torch.tensor([gamma]), torch.tensor([gamma])) lambdas = m.sample([x.size(0), 1, 1]).to(x) my = onehot(y, num_classes).to(x) true_class, false_class = 1. - smooth_eps * num_classes / (num_classes - 1), smooth_eps / (num_classes - 1) my = my * true_class + torch.ones_like(my) * false_class perm = torch.randperm(x.size(0)) x2 = x[perm] y2 = my[perm] return x * (1 - lambdas) + x2 * lambdas, my * (1 - lambdas) + y2 * lambdas
def reconstruct(self, x, device, epoch): x = x.view(-1, 784).float().to(device) z_mu, z_logvar = self.encode(x) z = self.reparameterize(z_mu, z_logvar) # sample zs x_alpha, x_beta = self.decode(z) beta = Beta(x_alpha, x_beta) p = beta.sample() binomial = Binomial(255, p) x_recon = binomial.sample() x_recon = x_recon.float() / 255. x_with_recon = torch.cat((x, x_recon)) save_image(x_with_recon.view(64, 1, 28, 28), 'results/epoch_{}_recon.png'.format(epoch))
def forward(self, x): with torch.no_grad(): features = self.main(x) actor_features = self.actor(features) alpha = self.alpha_head(actor_features)+1 beta = self.beta_head(actor_features)+1 dist = Beta(alpha, beta) if not self.deterministic_sample: action = dist.sample().squeeze().numpy() else: action = dist.mean.squeeze().numpy() action[0] = action[0]*2-1 return action
def kl_categorical_dp(eta, step, args): cap = min(args.m_cap, step * args.m_cap / args.total_steps) # cap = min(cap, np.log(args.disc)) beta_dist = Beta(torch.ones_like(eta), torch.ones_like(eta) * args.beta_0) beta_sample = beta_dist.sample() neg_prod = torch.cumprod(1.0-beta_sample, dim=-1) beta_sample[:, 1:] = beta_sample[:, :-1] * neg_prod[:, :-1] beta_sample = F.softmax(beta_sample, dim=-1) cat_prior = Categorical(probs=beta_sample) cat_posterior = Categorical(probs=eta) klm_loss = kl_divergence(cat_posterior, cat_prior).mean() cap_klm_loss = args.gamma_m * (klm_loss - cap).abs() return cap_klm_loss
def select_action(self, state, hidden): with torch.no_grad(): _, latent_mu, _ = self.vae(state) alpha, beta = self.net(latent_mu, hidden[0])[0] dist = Beta(alpha, beta) action = dist.sample() a_logp = dist.log_prob(action).sum(dim=1) a_logp = a_logp.item() _, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden) return action.squeeze().cpu().numpy(), a_logp, latent_mu, next_hidden
def sample_from_beta_dist(y_hat): """ y_hat (batch_size x seq_len x 2): """ # take exponentional to ensure positive loc_y = y_hat.exp() alpha = loc_y[:, :, 0].unsqueeze(-1) beta = loc_y[:, :, 1].unsqueeze(-1) dist = Beta(alpha, beta) sample = dist.sample() # rescale sample from [0,1] to [-1, 1] sample = 2.0 * sample - 1.0 return sample
class CarlaImgPolicy(nn.Module): def __init__(self, input_dim, action_dim, hidden_layer=[400, 300]): super(CarlaImgPolicy, self).__init__() self.main_actor = CarlaSimpleEncoder(latent_size=input_dim - 1) self.main_critic = CarlaSimpleEncoder(latent_size=input_dim - 1) actor_layer_size = [input_dim] + hidden_layer actor_feature_layers = nn.ModuleList([]) for i in range(len(actor_layer_size) - 1): actor_feature_layers.append( nn.Linear(actor_layer_size[i], actor_layer_size[i + 1])) actor_feature_layers.append(nn.ReLU()) self.actor = nn.Sequential(*actor_feature_layers) self.alpha_head = nn.Sequential( nn.Linear(hidden_layer[-1], action_dim), nn.Softplus()) self.beta_head = nn.Sequential(nn.Linear(hidden_layer[-1], action_dim), nn.Softplus()) critic_layer_size = [input_dim] + hidden_layer critic_layers = nn.ModuleList([]) for i in range(len(critic_layer_size) - 1): critic_layers.append( nn.Linear(critic_layer_size[i], critic_layer_size[i + 1])) critic_layers.append(nn.ReLU()) critic_layers.append(layer_init(nn.Linear(hidden_layer[-1], 1), gain=1)) self.critic = nn.Sequential(*critic_layers) def forward(self, x, action=None): speed = x[:, -1:] x = x[:, :-1].view(-1, 3, 128, 128) # image size in carla driving task is 128x128 x1 = self.main_actor(x) x1 = torch.cat([x1, speed], dim=1) x2 = self.main_critic(x) x2 = torch.cat([x2, speed], dim=1) actor_features = self.actor(x1) alpha = self.alpha_head(actor_features) + 1 beta = self.beta_head(actor_features) + 1 self.dist = Beta(alpha, beta) if action is None: action = self.dist.sample() else: action = (action + 1) / 2 action_log_prob = self.dist.log_prob(action).sum(-1) entropy = self.dist.entropy().sum(-1) value = self.critic(x2) return action * 2 - 1, action_log_prob, value.squeeze(-1), entropy
def sq_log_posterior_predictive_eval(x_new, kappa, tau_0, tau_1, S): T = kappa.shape[0] + 1 q_beta = Beta(torch.ones(T - 1), kappa) q_lambda = Gamma(tau_0, tau_1) beta_mc = q_beta.sample([S]) lambda_mc = q_lambda.sample([S]) log_prob = 0 for s in range(S): post_pred_weights = mix_weights(beta_mc[s]) post_pred_clusters = lambda_mc[s] for t in range(post_pred_clusters.shape[0]): log_prob -= post_pred_weights[t] * torch.exp( Poisson(post_pred_clusters[t]).log_prob(x_new))**2 log_prob /= S return log_prob
def experience(self, steps): total_obs = np.zeros((steps, ) + self.last_ob.shape + (self.stack_size, )) total_rewards = np.zeros((steps, 1)) total_actions = np.zeros((steps, 3)) total_values = np.zeros((steps + 1, 1)) masks = np.zeros((steps, 1)) for step in range(steps): total_obs[step] = np.roll(total_obs[step], shift=-1, axis=-1) total_obs[step, :, :, -1] = self.last_ob alpha, beta, values = self.network( torch.from_numpy(total_obs[step]).type( torch.FloatTensor).unsqueeze(0)) total_values[step] = values.view(-1).detach().numpy() m = Beta(alpha, beta) actions = m.sample() total_actions[step] = actions.numpy() actions = actions.numpy() * np.array([2., 1., 1.]) - np.array( [1., 0., 0.]) actions = actions.reshape((-1)) self.last_ob, rews, dones_, _ = self.env.step(actions) self.env.render() self.last_ob = rgb2gris(self.last_ob) dones = np.logical_not(dones_) * 1 total_rewards[step] = rews masks[step] = dones if dones_: self.env.reset() temp_ob = np.roll(total_obs[step], shift=-1, axis=-1) temp_ob[..., -1] = self.last_ob _, _, values = self.network( torch.from_numpy(temp_ob).type(torch.FloatTensor).unsqueeze(0)) total_values[steps] = values.view(-1).detach().numpy() advantage, real_values = gae(total_rewards, masks, total_values) advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-5) return (total_obs, total_values, total_rewards, total_actions, masks, advantage, real_values)
def posterior_predictive_sample(kappa, tau_0, tau_1, S, M): T = kappa.shape[0] + 1 q_beta = Beta(torch.ones(T - 1), kappa) q_lambda = Gamma(tau_0, tau_1) beta_mc = q_beta.sample([S]) lambda_mc = q_lambda.sample([S]) hallucinated_samples = torch.zeros(S, M) for s in range(S): post_pred_weights = mix_weights(beta_mc[s]) post_pred_clusters = lambda_mc[s] hallucinated_samples[s, :] = MixtureSameFamily( Categorical(post_pred_weights), Poisson(post_pred_clusters)).sample([M]) return hallucinated_samples
def select_action(self, state, deterministic=False): """ Compute an action or vector of actions given a state or vector of states :param state: the input state(s) :param deterministic: whether the policy should be considered deterministic or not :return: the resulting action(s) """ with torch.no_grad(): alpha, beta = self.forward(state) if deterministic: return alpha.data.numpy() / (alpha.data.numpy() + beta.data.numpy()).astype(float) # return np.clip(alpha.data.numpy().astype(float),-2,2) else: n = Beta(alpha, beta) action = n.sample() return action.data.numpy().astype(float)
def select_action(self, state): if args.action_vec > 0: state = (torch.from_numpy( state[0]).float().to(device).unsqueeze(0), torch.from_numpy( state[1]).float().to(device).unsqueeze(0)) else: state = torch.from_numpy(state).float().to(device).unsqueeze(0) #TODO CHANGE FOR VECTOR ACTIONS with torch.no_grad(): alpha, beta = self.net(state)[0] dist = Beta(alpha, beta) action = dist.sample() a_logp = dist.log_prob(action).sum(dim=1) action = action.squeeze().cpu().numpy() a_logp = a_logp.item() return action, a_logp
def traverse_grid(self, cont_dim=0, nrow=8, ncol=8, traverse=True, use_prior=True, set_zero=True, file_name=None): if traverse and self.args.disc != 0: nrow = self.args.disc if set_zero: cont_samples = torch.zeros(nrow * ncol, self.args.cont).cuda() else: cont_samples = torch.randn(nrow * ncol, self.args.cont).cuda() fixed_value = torch.linspace(-2, 2, ncol).cuda() for row in range(nrow): for i in range(ncol): cont_samples[i + row * ncol, cont_dim] = fixed_value[i] if use_prior: v_prior = Beta( torch.ones_like(cont_samples) * self.args.alpha_0, torch.ones_like(cont_samples)) mask_prob = torch.cumprod(v_prior.sample(), dim=1) mask = Bernoulli(mask_prob).sample() cont_samples = cont_samples * mask if self.args.model_type != 'ibp': disc_samples = torch.zeros(nrow * ncol, self.args.disc).cuda() for i in range(nrow): for j in range(ncol): disc_samples[j + i * ncol, i] = 1.0 samples = torch.cat([cont_samples, disc_samples], dim=-1) else: samples = cont_samples with torch.no_grad(): x = self.model.decoder(samples).view(-1, self.args.img_channel, self.args.img_size, self.args.img_size) if self.save_img: save_image(x.data, file_name, nrow=ncol, padding=0, pad_value=0.0) else: return make_grid(x.data, nrow=ncol, padding=0, pad_value=0.0)
class BetaSeparatedPolicy(nn.Module): def __init__(self, input_dim, action_dim, hidden_layer=[64, 64]): super(BetaSeparatedPolicy, self).__init__() actor_layer_size = [input_dim] + hidden_layer alpha_feature_layers = nn.ModuleList([]) beta_feature_layers = nn.ModuleList([]) for i in range(len(actor_layer_size) - 1): alpha_feature_layers.append( nn.Linear(actor_layer_size[i], actor_layer_size[i + 1])) alpha_feature_layers.append(nn.ReLU()) beta_feature_layers.append( nn.Linear(actor_layer_size[i], actor_layer_size[i + 1])) beta_feature_layers.append(nn.ReLU()) self.alpha_body = nn.Sequential(*alpha_feature_layers) self.beta_body = nn.Sequential(*beta_feature_layers) self.alpha_head = nn.Sequential( nn.Linear(hidden_layer[-1], action_dim), nn.Softplus()) self.beta_head = nn.Sequential(nn.Linear(hidden_layer[-1], action_dim), nn.Softplus()) critic_layer_size = [input_dim] + hidden_layer critic_layers = nn.ModuleList([]) for i in range(len(critic_layer_size) - 1): critic_layers.append( nn.Linear(critic_layer_size[i], critic_layer_size[i + 1])) critic_layers.append(nn.ReLU()) critic_layers.append(nn.Linear(hidden_layer[-1], 1)) self.critic = nn.Sequential(*critic_layers) def forward(self, x, action=None): alpha = self.alpha_head(self.alpha_body(x)) + 1 beta = self.beta_head(self.beta_body(x)) + 1 self.dist = Beta(alpha, beta) if action is None: action = self.dist.sample() else: action = (action + 1) / 2 action_log_prob = self.dist.log_prob(action).sum(-1) entropy = self.dist.entropy().sum(-1) value = self.critic(x) return action * 2 - 1, action_log_prob, value.squeeze(-1), entropy
class MyDist(ActionDistribution): @staticmethod def required_model_output_shape(action_space, model_config): return 6 def __init__(self, inputs, model): super(MyDist, self).__init__(inputs, model) self.dist = Beta(inputs[:, :3], inputs[:, 3:]) def sample(self): self.sampled_action = self.dist.sample() return self.sampled_action def deterministic_sample(self): return self.dist.mean def sampled_action_logp(self): return self.logp(self.sampled_action) def logp(self, actions): return self.dist.log_prob(actions).sum(-1) # refered from https://github.com/pytorch/pytorch/blob/master/torch/distributions/kl.py def kl(self, other): p, q = self.dist, other.dist sum_params_p = p.concentration1 + p.concentration0 sum_params_q = q.concentration1 + q.concentration0 t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + ( sum_params_p).lgamma() t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + ( sum_params_q).lgamma() t3 = (p.concentration1 - q.concentration1) * torch.digamma( p.concentration1) t4 = (p.concentration0 - q.concentration0) * torch.digamma( p.concentration0) t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p) return (t1 - t2 + t3 + t4 + t5).sum(-1) def entropy(self): return self.dist.entropy().sum(-1)
def traverse_line(self, cont_dim=0, disc_dim=None, size=10, use_prior=True, set_zero=True, traverse=False, file_name=None): if set_zero: cont_samples = torch.zeros(size, self.args.cont).cuda() else: cont_samples = torch.randn(size, self.args.cont).cuda() fixed_value = torch.linspace(-2, 2, size).cuda() cont_samples[:, cont_dim] = fixed_value if use_prior: v_prior = Beta( torch.ones_like(cont_samples) * self.args.alpha_0, torch.ones_like(cont_samples)) mask_prob = torch.cumprod(v_prior.sample(), dim=1) mask = Bernoulli(mask_prob).sample() cont_samples = cont_samples * mask if self.args.model_type != 'ibp': disc_samples = torch.zeros(size, self.args.disc).cuda() if traverse: for i in range(size): disc_samples[i, i % self.args.disc] = 1.0 else: disc_samples[:, disc_dim] = 1.0 samples = torch.cat([cont_samples, disc_samples], dim=-1) else: samples = cont_samples with torch.no_grad(): x = self.model.decoder(samples).view(-1, self.args.img_channel, self.args.img_size, self.args.img_size) if self.save_img: save_image(x.data, file_name, nrow=size, padding=0, pad_value=0.0) else: return make_grid(x.data, nrow=size, padding=0, pad_value=0.0)