class GANAgent(object): def __init__(self, input_size, output_size, num_env, num_step, gamma, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, update_proportion=0.25, use_gae=True, use_cuda=False, use_noisy_net=False, hidden_dim=512): self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net) self.num_env = num_env self.output_size = output_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.update_proportion = update_proportion self.device = torch.device('cuda' if use_cuda else 'cpu') self.netG = NetG(z_dim=hidden_dim) #(input_size, z_dim=hidden_dim) self.netD = NetD(z_dim=1) self.netG.apply(weights_init) self.netD.apply(weights_init) self.optimizer_policy = optim.Adam(list(self.model.parameters()), lr=learning_rate) self.optimizer_G = optim.Adam(list(self.netG.parameters()), lr=learning_rate, betas=(0.5, 0.999)) self.optimizer_D = optim.Adam(list(self.netD.parameters()), lr=learning_rate, betas=(0.5, 0.999)) self.netG = self.netG.to(self.device) self.netD = self.netD.to(self.device) self.model = self.model.to(self.device) def reconstruct(self, state): state = torch.Tensor(state).to(self.device) state = state.float() reconstructed = self.vae(state.unsqueeze(0))[0].squeeze(0) return reconstructed.detach().cpu().numpy() def get_action(self, state): state = torch.Tensor(state).to(self.device) state = state.float() policy, value_ext, value_int = self.model(state) action_prob = F.softmax(policy, dim=-1).data.cpu().numpy() action = self.random_choice_prob_index(action_prob) return action, value_ext.data.cpu().numpy().squeeze( ), value_int.data.cpu().numpy().squeeze(), policy.detach() @staticmethod def random_choice_prob_index(p, axis=1): r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis) return (p.cumsum(axis=axis) > r).argmax(axis=axis) def compute_intrinsic_reward(self, obs): obs = torch.FloatTensor(obs).to(self.device) #embedding = self.vae.representation(obs) #reconstructed_embedding = self.vae.representation(self.vae(obs)[0]) # why use index[0] reconstructed_img, embedding, reconstructed_embedding = self.netG(obs) intrinsic_reward = (embedding - reconstructed_embedding ).pow(2).sum(1) / 2 # Not use reconstructed loss return intrinsic_reward.detach().cpu().numpy() def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch, adv_batch, next_obs_batch, old_policy): s_batch = torch.FloatTensor(s_batch).to(self.device) target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device) target_int_batch = torch.FloatTensor(target_int_batch).to(self.device) y_batch = torch.LongTensor(y_batch).to(self.device) adv_batch = torch.FloatTensor(adv_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) #reconstruction_loss = nn.MSELoss(reduction='none')] l_adv = nn.MSELoss(reduction='none') l_con = nn.L1Loss(reduction='none') l_enc = nn.MSELoss(reduction='none') l_bce = nn.BCELoss(reduction='none') with torch.no_grad(): policy_old_list = torch.stack(old_policy).permute( 1, 0, 2).contiguous().view(-1, self.output_size).to(self.device) m_old = Categorical(F.softmax(policy_old_list, dim=-1)) log_prob_old = m_old.log_prob(y_batch) # ------------------------------------------------------------ #recon_losses = np.array([]) #kld_losses = np.array([]) mean_err_g_adv_per_batch = np.array([]) mean_err_g_con_per_batch = np.array([]) mean_err_g_enc_per_batch = np.array([]) mean_err_d_per_batch = np.array([]) for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for generative curiosity (GAN loss) #gen_next_state, mu, logvar = self.vae(next_obs_batch[sample_idx]) ############### netG forward ############################################## gen_next_state, latent_i, latent_o = self.netG( next_obs_batch[sample_idx]) ############### netD forward ############################################## pred_real, feature_real = self.netD(next_obs_batch[sample_idx]) pred_fake, feature_fake = self.netD(gen_next_state) #d = len(gen_next_state.shape) #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d))) ############### netG backward ############################################# self.optimizer_G.zero_grad() err_g_adv_per_img = l_adv( self.netD(next_obs_batch[sample_idx])[1], self.netD(gen_next_state)[1]).mean( axis=list(range(1, len(feature_real.shape)))) err_g_con_per_img = l_con( next_obs_batch[sample_idx], gen_next_state).mean( axis=list(range(1, len(gen_next_state.shape)))) err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1) #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1) # TODO: keep this proportion of experience used for VAE update? # Proportion of experience used for VAE update img_num = len(err_g_con_per_img) mask = torch.rand(img_num).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) # hyperparameter weights: w_adv = 1 w_con = 50 w_enc = 1 mean_err_g = mean_err_g_adv * w_adv +\ mean_err_g_con * w_con +\ mean_err_g_enc * w_enc mean_err_g.backward(retain_graph=True) self.optimizer_G.step() mean_err_g_adv_per_batch = np.append( mean_err_g_adv_per_batch, mean_err_g_adv.detach().cpu().numpy()) mean_err_g_con_per_batch = np.append( mean_err_g_con_per_batch, mean_err_g_con.detach().cpu().numpy()) mean_err_g_enc_per_batch = np.append( mean_err_g_enc_per_batch, mean_err_g_enc.detach().cpu().numpy()) ############## netD backward ############################################## self.optimizer_D.zero_grad() real_label = torch.ones_like(pred_real).to(self.device) fake_label = torch.zeros_like(pred_fake).to(self.device) err_d_real_per_img = l_bce(pred_real, real_label) err_d_fake_per_img = l_bce(pred_fake, fake_label) mean_err_d_real = (err_d_real_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_d_fake = (err_d_fake_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2 mean_err_d.backward() self.optimizer_D.step() mean_err_d_per_batch = np.append( mean_err_d_per_batch, mean_err_d.detach().cpu().numpy()) if mean_err_d.item() < 1e-5: self.netD.apply(weights_init) print('Reloading net d') ############# policy update ############################################### policy, value_ext, value_int = self.model(s_batch[sample_idx]) m = Categorical(F.softmax(policy, dim=-1)) log_prob = m.log_prob(y_batch[sample_idx]) ratio = torch.exp(log_prob - log_prob_old[sample_idx]) surr1 = ratio * adv_batch[sample_idx] surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) * adv_batch[sample_idx] actor_loss = -torch.min(surr1, surr2).mean() critic_ext_loss = F.mse_loss(value_ext.sum(1), target_ext_batch[sample_idx]) critic_int_loss = F.mse_loss(value_int.sum(1), target_int_batch[sample_idx]) critic_loss = critic_ext_loss + critic_int_loss entropy = m.entropy().mean() self.optimizer_policy.zero_grad() loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy loss.backward() #global_grad_norm_(list(self.model.parameters())+list(self.vae.parameters())) do we need this step #global_grad_norm_(list(self.model.parameter())) or just norm policy self.optimizer_poilicy.step() return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch def train_just_vae(self, s_batch, next_obs_batch): s_batch = torch.FloatTensor(s_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) l_adv = nn.MSELoss(reduction='none') l_con = nn.L1Loss(reduction='none') l_enc = nn.MSELoss(reduction='none') l_bce = nn.BCELoss(reduction='none') mean_err_g_adv_per_batch = np.array([]) mean_err_g_con_per_batch = np.array([]) mean_err_g_enc_per_batch = np.array([]) mean_err_d_per_batch = np.array([]) for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] ############### netG forward ############################################## gen_next_state, latent_i, latent_o = self.netG( next_obs_batch[sample_idx]) ############### netD forward ############################################## pred_real, feature_real = self.netD(next_obs_batch[sample_idx]) pred_fake, feature_fake = self.netD(gen_next_state) #d = len(gen_next_state.shape) #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d))) ############### netG backward ############################################# self.optimizer_G.zero_grad() err_g_adv_per_img = l_adv( self.netD(next_obs_batch[sample_idx])[1], self.netD(gen_next_state)[1]).mean( axis=list(range(1, len(feature_real.shape)))) err_g_con_per_img = l_con( next_obs_batch[sample_idx], gen_next_state).mean( axis=list(range(1, len(gen_next_state.shape)))) err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1) #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1) # TODO: keep this proportion of experience used for VAE update? # Proportion of experience used for VAE update img_num = len(err_g_con_per_img) mask = torch.rand(img_num).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) # hyperparameter weights: w_adv = 1 w_con = 50 w_enc = 1 mean_err_g = mean_err_g_adv * w_adv +\ mean_err_g_con * w_con +\ mean_err_g_enc * w_enc mean_err_g.backward(retain_graph=True) self.optimizer_G.step() mean_err_g_adv_per_batch = np.append( mean_err_g_adv_per_batch, mean_err_g_adv.detach().cpu().numpy()) mean_err_g_con_per_batch = np.append( mean_err_g_con_per_batch, mean_err_g_con.detach().cpu().numpy()) mean_err_g_enc_per_batch = np.append( mean_err_g_enc_per_batch, mean_err_g_enc.detach().cpu().numpy()) ############## netD backward ############################################## self.optimizer_D.zero_grad() real_label = torch.ones_like(pred_real).to(self.device) fake_label = torch.zeros_like(pred_fake).to(self.device) err_d_real_per_img = l_bce(pred_real, real_label) err_d_fake_per_img = l_bce(pred_fake, fake_label) mean_err_d_real = (err_d_real_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_d_fake = (err_d_fake_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2 mean_err_d.backward() self.optimizer_D.step() mean_err_d_per_batch = np.append( mean_err_d_per_batch, mean_err_d.detach().cpu().numpy()) return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) dataset = dset.ImageFolder(dataroot, transform=transforms) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True, ) netG = NetG(ngf, nz).to(device) # 生成器网络 Generator netG.apply(weights_init) # 生成器网络w初始化 Generator weights initialization netD = NetD(ndf).to(device) # 判别器网络 Discriminator netD.apply(weights_init) # 生成器网络w初始化 Discriminator weights initialization # 打印网络模型 Print the model print(netG) print(netD) criterion = nn.BCELoss() # 初始化损失函数 Initialize BCELoss function # Create batch of latent vectors that we will use to visualize the progression of the generator fixed_noise = torch.randn(100, nz, 1, 1, device=device) # 评估真假的标签 真为1 假为0 Establish convention for real and fake labels during training real_label = 1 fake_label = 0 # Adam优化 Setup Adam optimizers for both G and D optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
mean = (0.5, ) std = (0.5, ) train_dataset = Dataset(csv_file=CONFIG.train_csv_file, transform=ImageTransform(mean, std)) train_dataloader = DataLoader(train_dataset, batch_size=CONFIG.batch_size, shuffle=True) G = NetG(CONFIG) D = NetD(CONFIG) E = NetE(CONFIG) G.apply(weights_init) D.apply(weights_init) E.apply(weights_init) if not args.no_wandb: # Magic wandb.watch(G, log="all") wandb.watch(D, log="all") wandb.watch(E, log="all") G_update, D_update, E_update = train( G, D, E, z_dim=CONFIG.z_dim, dataloader=train_dataloader, CONFIG=CONFIG,
def train_network(): init_epoch = 0 best_f1 = 0 total_steps = 0 train_dir = ct.TRAIN_TXT val_dir = ct.VAL_TXT device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True train_data = OSCD_TRAIN(train_dir) train_dataloader = DataLoader(train_data, batch_size=ct.BATCH_SIZE, shuffle=True) val_data = OSCD_TEST(val_dir) val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False) netg = NetG(ct.ISIZE, ct.NC * 2, ct.NZ, ct.NDF, ct.EXTRALAYERS).to(device=device) netd = NetD(ct.ISIZE, ct.GT_C, 1, ct.NGF, ct.EXTRALAYERS).to(device=device) netg.apply(weights_init) netd.apply(weights_init) if ct.RESUME: assert os.path.exists(os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth')) \ and os.path.exists(os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth')), \ 'There is not found any saved weights' print("\nLoading pre-trained networks.") init_epoch = torch.load( os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth'))['epoch'] netg.load_state_dict( torch.load(os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth'))['model_state_dict']) netd.load_state_dict( torch.load(os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netD.pth'))['model_state_dict']) with open(os.path.join(ct.OUTPUTS_DIR, 'f1_score.txt')) as f: lines = f.readlines() best_f1 = float(lines[-2].strip().split(':')[-1]) print("\tDone.\n") l_adv = l2_loss l_con = nn.L1Loss() l_enc = l2_loss l_bce = nn.BCELoss() l_cos = cos_loss dice = DiceLoss() optimizer_d = optim.Adam(netd.parameters(), lr=ct.LR, betas=(0.5, 0.999)) optimizer_g = optim.Adam(netg.parameters(), lr=ct.LR, betas=(0.5, 0.999)) start_time = time.time() for epoch in range(init_epoch + 1, ct.EPOCH): loss_g = [] loss_d = [] netg.train() netd.train() epoch_iter = 0 for i, data in enumerate(train_dataloader): INPUT_SIZE = [ct.ISIZE, ct.ISIZE] x1, x2, gt = data x1 = x1.to(device, dtype=torch.float) x2 = x2.to(device, dtype=torch.float) gt = gt.to(device, dtype=torch.float) gt = gt[:, 0, :, :].unsqueeze(1) x = torch.cat((x1, x2), 1) epoch_iter += ct.BATCH_SIZE total_steps += ct.BATCH_SIZE real_label = torch.ones(size=(x1.shape[0], ), dtype=torch.float32, device=device) fake_label = torch.zeros(size=(x1.shape[0], ), dtype=torch.float32, device=device) #forward fake = netg(x) pred_real = netd(gt) pred_fake = netd(fake).detach() err_d_fake = l_bce(pred_fake, fake_label) err_g = l_con(fake, gt) err_g_total = ct.G_WEIGHT * err_g + ct.D_WEIGHT * err_d_fake pred_fake_ = netd(fake.detach()) err_d_real = l_bce(pred_real, real_label) err_d_fake_ = l_bce(pred_fake_, fake_label) err_d_total = (err_d_real + err_d_fake_) * 0.5 #backward optimizer_g.zero_grad() err_g_total.backward(retain_graph=True) optimizer_g.step() optimizer_d.zero_grad() err_d_total.backward() optimizer_d.step() errors = utils.get_errors(err_d_total, err_g_total) loss_g.append(err_g_total.item()) loss_d.append(err_d_total.item()) counter_ratio = float(epoch_iter) / len(train_dataloader.dataset) if (i % ct.DISPOLAY_STEP == 0 and i > 0): print( 'epoch:', epoch, 'iteration:', i, ' G|D loss is {}|{}'.format(np.mean(loss_g[-51:]), np.mean(loss_d[-51:]))) if ct.DISPLAY: utils.plot_current_errors(epoch, counter_ratio, errors, vis) utils.display_current_images(gt.data, fake.data, vis) utils.save_current_images(epoch, gt.data, fake.data, ct.IM_SAVE_DIR, 'training_output_images') with open(os.path.join(ct.OUTPUTS_DIR, 'train_loss.txt'), 'a') as f: f.write( 'after %s epoch, loss is %g,loss1 is %g,loss2 is %g,loss3 is %g' % (epoch, np.mean(loss_g), np.mean(loss_d), np.mean(loss_g), np.mean(loss_d))) f.write('\n') if not os.path.exists(ct.WEIGHTS_SAVE_DIR): os.makedirs(ct.WEIGHTS_SAVE_DIR) utils.save_weights(epoch, netg, optimizer_g, ct.WEIGHTS_SAVE_DIR, 'netG') utils.save_weights(epoch, netd, optimizer_d, ct.WEIGHTS_SAVE_DIR, 'netD') duration = time.time() - start_time print('training duration is %g' % duration) #val phase print('Validating.................') pretrained_dict = torch.load( os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth'))['model_state_dict'] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net = NetG(ct.ISIZE, ct.NC * 2, ct.NZ, ct.NDF, ct.EXTRALAYERS).to(device=device) net.load_state_dict(pretrained_dict, False) with net.eval() and torch.no_grad(): TP = 0 FN = 0 FP = 0 TN = 0 for k, data in enumerate(val_dataloader): x1, x2, label = data x1 = x1.to(device, dtype=torch.float) x2 = x2.to(device, dtype=torch.float) label = label.to(device, dtype=torch.float) label = label[:, 0, :, :].unsqueeze(1) x = torch.cat((x1, x2), 1) time_i = time.time() v_fake = net(x) tp, fp, tn, fn = eva.f1(v_fake, label) TP += tp FN += fn TN += tn FP += fp precision = TP / (TP + FP + 1e-8) oa = (TP + TN) / (TP + FN + TN + FP + 1e-8) recall = TP / (TP + FN + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) if not os.path.exists(ct.BEST_WEIGHT_SAVE_DIR): os.makedirs(ct.BEST_WEIGHT_SAVE_DIR) if f1 > best_f1: best_f1 = f1 shutil.copy( os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth'), os.path.join(ct.BEST_WEIGHT_SAVE_DIR, 'netG.pth')) print('current F1: {}'.format(f1)) print('best f1: {}'.format(best_f1)) with open(os.path.join(ct.OUTPUTS_DIR, 'f1_score.txt'), 'a') as f: f.write('current epoch:{},current f1:{},best f1:{}'.format( epoch, f1, best_f1)) f.write('\n')