def debug_z_by_group_matrix(t): fig, ax = plt.subplots() W_col_norms = torch.sqrt( torch.sum(torch.pow(generative_net.Ws[t].data, 2), dim=2)) ax.imshow(W_col_norms, aspect='equal') ax.set_xlabel('z') ax.set_ylabel('group') ax.xaxis.tick_top() ax.xaxis.set_label_position('top') lr = 1e-3 optimizer = torch.optim.Adam([ { 'params': inference_net.parameters(), 'lr': lr }, # {'params': [inference_net_log_stddev], 'lr': lr}, { 'params': generative_net.group_generators_parameters(), 'lr': lr }, { 'params': [ gen.sigma_net.extra_args[0] for gen in generative_net.group_generators ], 'lr': lr }
lr = 1e-4 betas = (0.9,0.999) lr_inferencenet = 1e-4 betas_inferencenet = (0.9,0.999) lr_generativenet = 1e-4 betas_generativenet = (0.9,0.999) optimizer = torch.optim.Adam([ {'params': inference_net.parameters(), 'lr': lr_inferencenet, 'betas':betas_inferencenet}, # {'params': [inference_net_log_stddev], 'lr': lr}, {'params': generative_net.group_generators_parameters(), 'lr': lr_generativenet,'betas': betas_generativenet}, {'params': [gen.sigma_net.extra_args[0] for gen in generative_net.group_generators], 'lr': lr, 'betas':betas} ]) Ws_lr = 1e-4 optimizer_Ws = torch.optim.SGD([ {'params': [generative_net.Ws], 'lr': Ws_lr, 'momentum': 0} ]) vae = OIVAE( inference_model=inference_net, generative_model=generative_net, #prior_z=prior_z,