def test(self, session, step, summary_writer=None, print_rate=1, sample_dir=None, meta=None): feed_dict, original_sequence = self.get_feed_dict_and_orig(session) g_loss_pure, g_reg, d_loss_val, d_pen, rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo, summary = session.run( [ self.g_cost_pure, self.gen_reg, self.d_cost, self.d_penalty, self.rmse_temp, self.rmse_cc, self.rmse_sh, self.rmse_sp, self.rmse_geo, self.summary_op ], feed_dict=feed_dict) summary_writer.add_summary(summary, step) original_sequence = original_sequence.reshape([ 1, self.frame_size, self.crop_size, self.crop_size, self.channels ]) # print(original_sequence.shape) # images = zero state of weather # generate forecast from state zero forecast = session.run(self.sample, feed_dict=feed_dict) # return all rmse-s denorm_original_sequence = denormalize(original_sequence, self.wvars, self.crop_size, self.frame_size, self.channels, meta) denorm_forecast = denormalize(forecast, self.wvars, self.crop_size, self.frame_size, self.channels, meta) diff = [] for orig, gen in zip(denorm_original_sequence, denorm_forecast): dif = orig - gen diff.append(dif[:, 1, :, :, :]) if step % print_rate == 0: print( "Step: %d, generator loss: (%g + %g), discriminator_loss: (%g + %g)" % (step, g_loss_pure, g_reg, d_loss_val, d_pen)) print("RMSE - Temp: %g, CC: %g, SH: %g, SP: %g, Geo: %g" % (rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo)) print('saving original') save_image(denorm_original_sequence, sample_dir, 'init_%d_image' % step) print('saving forecast / fakes') save_image(denorm_forecast, sample_dir, 'gen_%d_future' % step) rmse_all = [rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo] costs = [g_loss_pure, g_reg, d_loss_val, d_pen] return rmse_all, costs, diff
def train(self, session, step, summary_writer=None, log_summary=False, sample_dir=None, generate_sample=False, meta=None): if log_summary: start_time = time.time() critic_itrs = self.critic_iterations for critic_itr in range(critic_itrs): session.run(self.d_adam, feed_dict=self.get_feed_dict(session)) feed_dict = self.get_feed_dict(session) session.run(self.g_adam_gan, feed_dict=feed_dict) session.run(self.g_adam_first, feed_dict=feed_dict) if log_summary: g_loss_pure, g_reg, d_loss_val, d_pen, rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo, fake_min, fake_max, summary = session.run( [self.g_cost_pure, self.gen_reg, self.d_cost, self.d_penalty, self.rmse_temp, self.rmse_cc, self.rmse_sh, self.rmse_sp, self.rmse_geo, self.fake_min, self.fake_max, self.summary_op], feed_dict=feed_dict) summary_writer.add_summary(summary, step) print("Time: %g/itr, Step: %d, generator loss: (%g + %g), discriminator_loss: (%g + %g)" % ( time.time() - start_time, step, g_loss_pure, g_reg, d_loss_val, d_pen)) print("RMSE - Temp: %g, CC: %g, SH: %g, SP: %g, Geo: %g" % (rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo)) print("Fake_vid min: %g, max: %g" % (fake_min, fake_max)) if generate_sample: original_sequence = session.run(self.videos) original_sequence = original_sequence.reshape([self.batch_size, self.frame_size, self.crop_size, self.crop_size, self.channels]) print(original_sequence.shape) # images = zero state of weather images = original_sequence[:,0,:,:,:] # generate forecast from state zero forecast = session.run(self.sample, feed_dict={self.input_images: images}) original_sequence = denormalize(original_sequence, self.wvars, self.crop_size, self.frame_size, self.channels, meta) print('saving original') save_image(original_sequence, sample_dir, 'init_%d_image' % step) forecast = denormalize(forecast, self.wvars, self.crop_size, self.frame_size, self.channels, meta) print('saving forecast / fakes') save_image(forecast, sample_dir, 'gen_%d_future' % step)
def train(self): device = self.device self.net.to(device) opts = { "title": 'train', "xlabel": 'times', "ylabel": 'loss', "legend": ['dloss', 'gloss'] } count = 0 for step in range(100): for img in self.data_loader: img = img.to(device) count += 1 self.net.reset_grad() self.net.g_forward(img) self.net.d_forward() dloss = self.net.calc_d_loss() dloss.backward(retain_graph=True) gloss = self.net.calc_g_loss() gloss.backward() self.net.d_opt.step() self.net.D.frozen(False) self.net.g_opt.step() self.net.D.frozen(True) if not count % self.out_inv: self.vis.line( X=[self.out_inv * count], Y=[[dloss.detach().cpu(), gloss.detach().cpu()]], update='append', opts=opts, win='training loss') self.vis.images([ denormalize(self.net.raw_img[0]), denormalize( random.choice(self.net.add_watermark_img_trans)[0]) ], win='img')
def main(plot_dir, epoch): # read in pickle files glimpses = pickle.load(open(plot_dir + "g_{}.p".format(epoch), "rb")) locations = pickle.load(open(plot_dir + "l_{}.p".format(epoch), "rb")) glimpses = np.concatenate(glimpses) # grab useful params size = int(plot_dir.split('_')[2][0]) num_anims = len(locations) num_cols = glimpses.shape[0] img_shape = glimpses.shape[1] # denormalize coordinates coords = [denormalize(img_shape, l) for l in locations] fig, axs = plt.subplots(nrows=1, ncols=num_cols) # fig.set_dpi(100) # plot base image for j, ax in enumerate(axs.flat): ax.imshow(glimpses[j], cmap="Greys_r") ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) def updateData(i): color = 'r' co = coords[i] for j, ax in enumerate(axs.flat): for p in ax.patches: p.remove() c = co[j] rect = bounding_box(c[0], c[1], size, color) ax.add_patch(rect) # animate anim = animation.FuncAnimation(fig, updateData, frames=num_anims, interval=500, repeat=True) # save as mp4 name = plot_dir + 'epoch_{}.mp4'.format(epoch) anim.save(name, extra_args=['-vcodec', 'h264', '-pix_fmt', 'yuv420p'])
def get_images(self, net_student=None, targets=None): print("get_images call") net_teacher = self.net_teacher use_fp16 = self.use_fp16 save_every = self.save_every kl_loss = nn.KLDivLoss(reduction='batchmean').cuda() local_rank = torch.cuda.current_device() best_cost = 1e4 criterion = self.criterion # setup target labels if targets is None: #only works for classification now, for other tasks need to provide target vector targets = torch.LongTensor( [random.randint(0, 999) for _ in range(self.bs)]).to('cuda') if not self.random_label: # preselected classes, good for ResNet50v1.5 targets = [ 1, 933, 946, 980, 25, 63, 92, 94, 107, 985, 151, 154, 207, 250, 270, 277, 283, 292, 294, 309, 311, 325, 340, 360, 386, 402, 403, 409, 530, 440, 468, 417, 590, 670, 817, 762, 920, 949, 963, 967, 574, 487 ] targets = torch.LongTensor( targets * (int(self.bs / len(targets)))).to('cuda') img_original = self.image_resolution data_type = torch.half if use_fp16 else torch.float inputs = torch.randn((self.bs, 3, img_original, img_original), requires_grad=True, device='cuda', dtype=data_type) pooling_function = nn.modules.pooling.AvgPool2d(kernel_size=2) if self.setting_id == 0: skipfirst = False else: skipfirst = True iteration = 0 for lr_it, lower_res in enumerate([2, 1]): if lr_it == 0: iterations_per_layer = 2000 else: iterations_per_layer = 1000 if not skipfirst else 2000 if self.setting_id == 2: iterations_per_layer = 20000 if lr_it == 0 and skipfirst: continue lim_0, lim_1 = self.jitter // lower_res, self.jitter // lower_res if self.setting_id == 0: #multi resolution, 2k iterations with low resolution, 1k at normal, ResNet50v1.5 works the best, ResNet50 is ok optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.5, 0.9], eps=1e-8) do_clip = True elif self.setting_id == 1: #2k normal resolultion, for ResNet50v1.5; Resnet50 works as well optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.5, 0.9], eps=1e-8) do_clip = True elif self.setting_id == 2: #20k normal resolution the closes to the paper experiments for ResNet50 optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.9, 0.999], eps=1e-8) do_clip = False if use_fp16: static_loss_scale = 256 static_loss_scale = "dynamic" _, optimizer = amp.initialize([], optimizer, opt_level="O2", loss_scale=static_loss_scale) lr_scheduler = lr_cosine_policy(self.lr, 100, iterations_per_layer) for iteration_loc in range(iterations_per_layer): iteration += 1 # learning rate scheduling lr_scheduler(optimizer, iteration_loc, iteration_loc) # perform downsampling if needed if lower_res != 1: inputs_jit = pooling_function(inputs) else: inputs_jit = inputs # apply random jitter offsets off1 = random.randint(-lim_0, lim_0) off2 = random.randint(-lim_1, lim_1) inputs_jit = torch.roll(inputs_jit, shifts=(off1, off2), dims=(2, 3)) # Flipping flip = random.random() > 0.5 if flip and self.do_flip: inputs_jit = torch.flip(inputs_jit, dims=(3, )) # forward pass optimizer.zero_grad() net_teacher.zero_grad() outputs = net_teacher(inputs_jit) outputs = self.network_output_function(outputs) # R_cross classification loss loss = criterion(outputs, targets) # R_prior losses loss_var_l1, loss_var_l2 = get_image_prior_losses(inputs_jit) # R_feature loss loss_r_feature = sum( [mod.r_feature for mod in self.loss_r_feature_layers]) # R_ADI loss_verifier_cig = torch.zeros(1) if self.adi_scale != 0.0: if self.detach_student: outputs_student = net_student(inputs_jit).detach() else: outputs_student = net_student(inputs_jit) T = 3.0 if 1: T = 3.0 # Jensen Shanon divergence: # another way to force KL between negative probabilities P = nn.functional.softmax(outputs_student / T, dim=1) Q = nn.functional.softmax(outputs / T, dim=1) M = 0.5 * (P + Q) P = torch.clamp(P, 0.01, 0.99) Q = torch.clamp(Q, 0.01, 0.99) M = torch.clamp(M, 0.01, 0.99) eps = 0.0 loss_verifier_cig = 0.5 * kl_loss( torch.log(P + eps), M) + 0.5 * kl_loss( torch.log(Q + eps), M) # JS criteria - 0 means full correlation, 1 - means completely different loss_verifier_cig = 1.0 - torch.clamp( loss_verifier_cig, 0.0, 1.0) if local_rank == 0: if iteration % save_every == 0: print('loss_verifier_cig', loss_verifier_cig.item()) # l2 loss on images loss_l2 = torch.norm(inputs_jit.view(self.bs, -1), dim=1).mean() # combining losses loss_aux = self.var_scale_l2 * loss_var_l2 + \ self.var_scale_l1 * loss_var_l1 + \ self.bn_reg_scale * loss_r_feature + \ self.l2_scale * loss_l2 if self.adi_scale != 0.0: loss_aux += self.adi_scale * loss_verifier_cig loss = self.main_loss_multiplier * loss + loss_aux if local_rank == 0: if iteration % save_every == 0: print("------------iteration {}----------".format( iteration)) print("total loss", loss.item()) print("loss_r_feature", loss_r_feature.item()) print("main criterion", criterion(outputs, targets).item()) if self.hook_for_display is not None: self.hook_for_display(inputs, targets) # do image update if use_fp16: # optimizer.backward(loss) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() # clip color outlayers if do_clip: inputs.data = clip(inputs.data, use_fp16=use_fp16) if best_cost > loss.item() or iteration == 1: best_inputs = inputs.data.clone() if iteration % save_every == 0 and (save_every > 0): if local_rank == 0: vutils.save_image( inputs, '{}/best_images/output_{:05d}_gpu_{}.png'.format( self.prefix, iteration // save_every, local_rank), normalize=True, scale_each=True, nrow=int(10)) if self.store_best_images: best_inputs = denormalize(best_inputs) self.save_images(best_inputs, targets) # to reduce memory consumption by states of the optimizer we deallocate memory optimizer.state = collections.defaultdict(dict)
def over_sampling(X, T, Y, params_over): gen_lr = params_over['gen_lr'] dis_lr = params_over['dis_lr'] batch_size = params_over['batch_size'] epochs = params_over['epochs'] latent_dim = params_over['noise_size'] major_multiple = params_over['major_multiple'] minor_ratio = params_over['minor_ratio'] loss_type = params_over['loss_type'] seed = 1234 max_loop = 10 fake_multiple = 5 init_tf(seed) train_df = pd.concat([X, T, Y], axis=1).copy() out_df = train_df.copy() n_samples = pd.Series(num_class(train_df, 'Y', 'T')) print('Initial samples:', n_samples.tolist()) num_major = n_samples.max() * major_multiple idx_major = n_samples.argmax() num_minor = num_major * minor_ratio n_rest_samples = [num_minor] * len(n_samples) n_rest_samples[idx_major] = num_major n_rest_samples = pd.Series(n_rest_samples).round().astype('int32') n_rest_samples -= n_samples n_rest_samples[n_rest_samples < 0] = 0 num_fake_data = n_rest_samples.sum() * fake_multiple print('Initial rest samples:', n_rest_samples.tolist()) train_df, normalize_vars = normalize(train_df) data_dim = train_df.shape[1] generator, discriminator, combined = \ build_gan_network(gen_lr, dis_lr, data_dim, latent_dim, loss_type) train(train_df, epochs, batch_size, latent_dim, generator, discriminator, combined) global stored_discriminator stored_discriminator = discriminator for _ in range(max_loop): if n_rest_samples.sum() == 0: break noise = np.random.normal(0, 1, (num_fake_data, latent_dim)) gen_data = generator.predict(noise) gen_df = pd.DataFrame(gen_data, columns=train_df.columns) gen_df = denormalize(gen_df, normalize_vars) gen_df = gen_df.round() tr, tn, cr, cn = num_class(gen_df, 'Y', 'T') print('Generated data (tr, tn, cr, cn):', tr, tn, cr, cn) gen_df_list = split_class(gen_df, 'Y', 'T') for idx, df in enumerate(gen_df_list): n_sel_samples = df.shape[0] if df.shape[0] < n_rest_samples[ idx] else n_rest_samples[idx] n_rest_samples[idx] -= n_sel_samples sel_df = gen_df.iloc[:n_sel_samples] out_df = pd.concat([out_df, sel_df]) print('Rest samples:', n_rest_samples.tolist()) out_df = out_df.reset_index(drop=True).sample(frac=1) X = out_df.drop(['T', 'Y'], axis=1) T = out_df['T'] Y = out_df['Y'] return X, T, Y