Пример #1
0
class GANAgent(object):
    def __init__(self,
                 input_size,
                 output_size,
                 num_env,
                 num_step,
                 gamma,
                 lam=0.95,
                 learning_rate=1e-4,
                 ent_coef=0.01,
                 clip_grad_norm=0.5,
                 epoch=3,
                 batch_size=128,
                 ppo_eps=0.1,
                 update_proportion=0.25,
                 use_gae=True,
                 use_cuda=False,
                 use_noisy_net=False,
                 hidden_dim=512):
        self.model = CnnActorCriticNetwork(input_size, output_size,
                                           use_noisy_net)
        self.num_env = num_env
        self.output_size = output_size
        self.input_size = input_size
        self.num_step = num_step
        self.gamma = gamma
        self.lam = lam
        self.epoch = epoch
        self.batch_size = batch_size
        self.use_gae = use_gae
        self.ent_coef = ent_coef
        self.ppo_eps = ppo_eps
        self.clip_grad_norm = clip_grad_norm
        self.update_proportion = update_proportion
        self.device = torch.device('cuda' if use_cuda else 'cpu')

        self.netG = NetG(z_dim=hidden_dim)  #(input_size, z_dim=hidden_dim)
        self.netD = NetD(z_dim=1)
        self.netG.apply(weights_init)
        self.netD.apply(weights_init)

        self.optimizer_policy = optim.Adam(list(self.model.parameters()),
                                           lr=learning_rate)
        self.optimizer_G = optim.Adam(list(self.netG.parameters()),
                                      lr=learning_rate,
                                      betas=(0.5, 0.999))
        self.optimizer_D = optim.Adam(list(self.netD.parameters()),
                                      lr=learning_rate,
                                      betas=(0.5, 0.999))

        self.netG = self.netG.to(self.device)
        self.netD = self.netD.to(self.device)

        self.model = self.model.to(self.device)

    def reconstruct(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        reconstructed = self.vae(state.unsqueeze(0))[0].squeeze(0)
        return reconstructed.detach().cpu().numpy()

    def get_action(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        policy, value_ext, value_int = self.model(state)
        action_prob = F.softmax(policy, dim=-1).data.cpu().numpy()

        action = self.random_choice_prob_index(action_prob)

        return action, value_ext.data.cpu().numpy().squeeze(
        ), value_int.data.cpu().numpy().squeeze(), policy.detach()

    @staticmethod
    def random_choice_prob_index(p, axis=1):
        r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis)
        return (p.cumsum(axis=axis) > r).argmax(axis=axis)

    def compute_intrinsic_reward(self, obs):
        obs = torch.FloatTensor(obs).to(self.device)
        #embedding = self.vae.representation(obs)
        #reconstructed_embedding = self.vae.representation(self.vae(obs)[0]) # why use index[0]
        reconstructed_img, embedding, reconstructed_embedding = self.netG(obs)

        intrinsic_reward = (embedding - reconstructed_embedding
                            ).pow(2).sum(1) / 2  # Not use reconstructed loss

        return intrinsic_reward.detach().cpu().numpy()

    def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch,
                    adv_batch, next_obs_batch, old_policy):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device)
        target_int_batch = torch.FloatTensor(target_int_batch).to(self.device)
        y_batch = torch.LongTensor(y_batch).to(self.device)
        adv_batch = torch.FloatTensor(adv_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        #reconstruction_loss = nn.MSELoss(reduction='none')]
        l_adv = nn.MSELoss(reduction='none')
        l_con = nn.L1Loss(reduction='none')
        l_enc = nn.MSELoss(reduction='none')
        l_bce = nn.BCELoss(reduction='none')

        with torch.no_grad():
            policy_old_list = torch.stack(old_policy).permute(
                1, 0, 2).contiguous().view(-1,
                                           self.output_size).to(self.device)

            m_old = Categorical(F.softmax(policy_old_list, dim=-1))
            log_prob_old = m_old.log_prob(y_batch)
            # ------------------------------------------------------------

        #recon_losses = np.array([])
        #kld_losses = np.array([])
        mean_err_g_adv_per_batch = np.array([])
        mean_err_g_con_per_batch = np.array([])
        mean_err_g_enc_per_batch = np.array([])
        mean_err_d_per_batch = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for generative curiosity (GAN loss)
                #gen_next_state, mu, logvar = self.vae(next_obs_batch[sample_idx])
                ############### netG forward ##############################################
                gen_next_state, latent_i, latent_o = self.netG(
                    next_obs_batch[sample_idx])

                ############### netD forward ##############################################
                pred_real, feature_real = self.netD(next_obs_batch[sample_idx])
                pred_fake, feature_fake = self.netD(gen_next_state)

                #d = len(gen_next_state.shape)
                #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))
                ############### netG backward #############################################
                self.optimizer_G.zero_grad()

                err_g_adv_per_img = l_adv(
                    self.netD(next_obs_batch[sample_idx])[1],
                    self.netD(gen_next_state)[1]).mean(
                        axis=list(range(1, len(feature_real.shape))))
                err_g_con_per_img = l_con(
                    next_obs_batch[sample_idx], gen_next_state).mean(
                        axis=list(range(1, len(gen_next_state.shape))))
                err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1)

                #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                img_num = len(err_g_con_per_img)
                mask = torch.rand(img_num).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                # hyperparameter weights:
                w_adv = 1
                w_con = 50
                w_enc = 1

                mean_err_g = mean_err_g_adv * w_adv +\
                        mean_err_g_con * w_con +\
                        mean_err_g_enc * w_enc
                mean_err_g.backward(retain_graph=True)

                self.optimizer_G.step()

                mean_err_g_adv_per_batch = np.append(
                    mean_err_g_adv_per_batch,
                    mean_err_g_adv.detach().cpu().numpy())
                mean_err_g_con_per_batch = np.append(
                    mean_err_g_con_per_batch,
                    mean_err_g_con.detach().cpu().numpy())
                mean_err_g_enc_per_batch = np.append(
                    mean_err_g_enc_per_batch,
                    mean_err_g_enc.detach().cpu().numpy())

                ############## netD backward ##############################################
                self.optimizer_D.zero_grad()

                real_label = torch.ones_like(pred_real).to(self.device)
                fake_label = torch.zeros_like(pred_fake).to(self.device)

                err_d_real_per_img = l_bce(pred_real, real_label)
                err_d_fake_per_img = l_bce(pred_fake, fake_label)
                mean_err_d_real = (err_d_real_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))
                mean_err_d_fake = (err_d_fake_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))

                mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2
                mean_err_d.backward()
                self.optimizer_D.step()

                mean_err_d_per_batch = np.append(
                    mean_err_d_per_batch,
                    mean_err_d.detach().cpu().numpy())

                if mean_err_d.item() < 1e-5:
                    self.netD.apply(weights_init)
                    print('Reloading net d')
                ############# policy update ###############################################

                policy, value_ext, value_int = self.model(s_batch[sample_idx])
                m = Categorical(F.softmax(policy, dim=-1))
                log_prob = m.log_prob(y_batch[sample_idx])

                ratio = torch.exp(log_prob - log_prob_old[sample_idx])

                surr1 = ratio * adv_batch[sample_idx]
                surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps,
                                    1.0 + self.ppo_eps) * adv_batch[sample_idx]

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_ext_loss = F.mse_loss(value_ext.sum(1),
                                             target_ext_batch[sample_idx])
                critic_int_loss = F.mse_loss(value_int.sum(1),
                                             target_int_batch[sample_idx])

                critic_loss = critic_ext_loss + critic_int_loss

                entropy = m.entropy().mean()

                self.optimizer_policy.zero_grad()
                loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy
                loss.backward()
                #global_grad_norm_(list(self.model.parameters())+list(self.vae.parameters())) do we need this step
                #global_grad_norm_(list(self.model.parameter())) or just norm policy
                self.optimizer_poilicy.step()

        return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch

    def train_just_vae(self, s_batch, next_obs_batch):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))

        l_adv = nn.MSELoss(reduction='none')
        l_con = nn.L1Loss(reduction='none')
        l_enc = nn.MSELoss(reduction='none')
        l_bce = nn.BCELoss(reduction='none')

        mean_err_g_adv_per_batch = np.array([])
        mean_err_g_con_per_batch = np.array([])
        mean_err_g_enc_per_batch = np.array([])
        mean_err_d_per_batch = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                ############### netG forward ##############################################
                gen_next_state, latent_i, latent_o = self.netG(
                    next_obs_batch[sample_idx])

                ############### netD forward ##############################################
                pred_real, feature_real = self.netD(next_obs_batch[sample_idx])
                pred_fake, feature_fake = self.netD(gen_next_state)

                #d = len(gen_next_state.shape)
                #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))
                ############### netG backward #############################################
                self.optimizer_G.zero_grad()

                err_g_adv_per_img = l_adv(
                    self.netD(next_obs_batch[sample_idx])[1],
                    self.netD(gen_next_state)[1]).mean(
                        axis=list(range(1, len(feature_real.shape))))
                err_g_con_per_img = l_con(
                    next_obs_batch[sample_idx], gen_next_state).mean(
                        axis=list(range(1, len(gen_next_state.shape))))
                err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1)

                #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                img_num = len(err_g_con_per_img)
                mask = torch.rand(img_num).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                # hyperparameter weights:
                w_adv = 1
                w_con = 50
                w_enc = 1

                mean_err_g = mean_err_g_adv * w_adv +\
                        mean_err_g_con * w_con +\
                        mean_err_g_enc * w_enc
                mean_err_g.backward(retain_graph=True)

                self.optimizer_G.step()

                mean_err_g_adv_per_batch = np.append(
                    mean_err_g_adv_per_batch,
                    mean_err_g_adv.detach().cpu().numpy())
                mean_err_g_con_per_batch = np.append(
                    mean_err_g_con_per_batch,
                    mean_err_g_con.detach().cpu().numpy())
                mean_err_g_enc_per_batch = np.append(
                    mean_err_g_enc_per_batch,
                    mean_err_g_enc.detach().cpu().numpy())

                ############## netD backward ##############################################
                self.optimizer_D.zero_grad()

                real_label = torch.ones_like(pred_real).to(self.device)
                fake_label = torch.zeros_like(pred_fake).to(self.device)

                err_d_real_per_img = l_bce(pred_real, real_label)
                err_d_fake_per_img = l_bce(pred_fake, fake_label)
                mean_err_d_real = (err_d_real_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))
                mean_err_d_fake = (err_d_fake_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))

                mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2
                mean_err_d.backward()
                self.optimizer_D.step()

                mean_err_d_per_batch = np.append(
                    mean_err_d_per_batch,
                    mean_err_d.detach().cpu().numpy())

        return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch
Пример #2
0
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = dset.ImageFolder(dataroot, transform=transforms)

dataloader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

netG = NetG(ngf, nz).to(device)  # 生成器网络 Generator
netG.apply(weights_init)  # 生成器网络w初始化 Generator weights initialization
netD = NetD(ndf).to(device)  # 判别器网络 Discriminator
netD.apply(weights_init)  # 生成器网络w初始化 Discriminator weights initialization
# 打印网络模型 Print the model
print(netG)
print(netD)

criterion = nn.BCELoss()  # 初始化损失函数 Initialize BCELoss function

# Create batch of latent vectors that we will use to visualize the progression of the generator
fixed_noise = torch.randn(100, nz, 1, 1, device=device)

# 评估真假的标签 真为1 假为0 Establish convention for real and fake labels during training
real_label = 1
fake_label = 0
Пример #3
0
mean = (0.5, )
std = (0.5, )

train_dataset = Dataset(csv_file=CONFIG.train_csv_file,
                        transform=ImageTransform(mean, std))

train_dataloader = DataLoader(train_dataset,
                              batch_size=CONFIG.batch_size,
                              shuffle=True)

G = NetG(CONFIG)
D = NetD(CONFIG)
E = NetE(CONFIG)

G.apply(weights_init)
D.apply(weights_init)
E.apply(weights_init)

if not args.no_wandb:
    # Magic
    wandb.watch(G, log="all")
    wandb.watch(D, log="all")
    wandb.watch(E, log="all")

G_update, D_update, E_update = train(
    G,
    D,
    E,
    z_dim=CONFIG.z_dim,
    dataloader=train_dataloader,
Пример #4
0
def train_network():

    init_epoch = 0
    best_f1 = 0
    total_steps = 0
    train_dir = ct.TRAIN_TXT
    val_dir = ct.VAL_TXT
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

    train_data = OSCD_TRAIN(train_dir)
    train_dataloader = DataLoader(train_data,
                                  batch_size=ct.BATCH_SIZE,
                                  shuffle=True)
    val_data = OSCD_TEST(val_dir)
    val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)
    netg = NetG(ct.ISIZE, ct.NC * 2, ct.NZ, ct.NDF,
                ct.EXTRALAYERS).to(device=device)
    netd = NetD(ct.ISIZE, ct.GT_C, 1, ct.NGF, ct.EXTRALAYERS).to(device=device)
    netg.apply(weights_init)
    netd.apply(weights_init)
    if ct.RESUME:
        assert os.path.exists(os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth')) \
                and os.path.exists(os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth')), \
                'There is not found any saved weights'
        print("\nLoading pre-trained networks.")
        init_epoch = torch.load(
            os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth'))['epoch']
        netg.load_state_dict(
            torch.load(os.path.join(ct.WEIGHTS_SAVE_DIR,
                                    'current_netG.pth'))['model_state_dict'])
        netd.load_state_dict(
            torch.load(os.path.join(ct.WEIGHTS_SAVE_DIR,
                                    'current_netD.pth'))['model_state_dict'])
        with open(os.path.join(ct.OUTPUTS_DIR, 'f1_score.txt')) as f:
            lines = f.readlines()
            best_f1 = float(lines[-2].strip().split(':')[-1])
        print("\tDone.\n")

    l_adv = l2_loss
    l_con = nn.L1Loss()
    l_enc = l2_loss
    l_bce = nn.BCELoss()
    l_cos = cos_loss
    dice = DiceLoss()
    optimizer_d = optim.Adam(netd.parameters(), lr=ct.LR, betas=(0.5, 0.999))
    optimizer_g = optim.Adam(netg.parameters(), lr=ct.LR, betas=(0.5, 0.999))

    start_time = time.time()
    for epoch in range(init_epoch + 1, ct.EPOCH):
        loss_g = []
        loss_d = []
        netg.train()
        netd.train()
        epoch_iter = 0
        for i, data in enumerate(train_dataloader):
            INPUT_SIZE = [ct.ISIZE, ct.ISIZE]
            x1, x2, gt = data
            x1 = x1.to(device, dtype=torch.float)
            x2 = x2.to(device, dtype=torch.float)
            gt = gt.to(device, dtype=torch.float)
            gt = gt[:, 0, :, :].unsqueeze(1)
            x = torch.cat((x1, x2), 1)

            epoch_iter += ct.BATCH_SIZE
            total_steps += ct.BATCH_SIZE
            real_label = torch.ones(size=(x1.shape[0], ),
                                    dtype=torch.float32,
                                    device=device)
            fake_label = torch.zeros(size=(x1.shape[0], ),
                                     dtype=torch.float32,
                                     device=device)

            #forward

            fake = netg(x)
            pred_real = netd(gt)
            pred_fake = netd(fake).detach()
            err_d_fake = l_bce(pred_fake, fake_label)
            err_g = l_con(fake, gt)
            err_g_total = ct.G_WEIGHT * err_g + ct.D_WEIGHT * err_d_fake

            pred_fake_ = netd(fake.detach())
            err_d_real = l_bce(pred_real, real_label)
            err_d_fake_ = l_bce(pred_fake_, fake_label)
            err_d_total = (err_d_real + err_d_fake_) * 0.5

            #backward
            optimizer_g.zero_grad()
            err_g_total.backward(retain_graph=True)
            optimizer_g.step()
            optimizer_d.zero_grad()
            err_d_total.backward()
            optimizer_d.step()

            errors = utils.get_errors(err_d_total, err_g_total)
            loss_g.append(err_g_total.item())
            loss_d.append(err_d_total.item())

            counter_ratio = float(epoch_iter) / len(train_dataloader.dataset)
            if (i % ct.DISPOLAY_STEP == 0 and i > 0):
                print(
                    'epoch:', epoch, 'iteration:', i,
                    ' G|D loss is {}|{}'.format(np.mean(loss_g[-51:]),
                                                np.mean(loss_d[-51:])))
                if ct.DISPLAY:
                    utils.plot_current_errors(epoch, counter_ratio, errors,
                                              vis)
                    utils.display_current_images(gt.data, fake.data, vis)
        utils.save_current_images(epoch, gt.data, fake.data, ct.IM_SAVE_DIR,
                                  'training_output_images')

        with open(os.path.join(ct.OUTPUTS_DIR, 'train_loss.txt'), 'a') as f:
            f.write(
                'after %s epoch, loss is %g,loss1 is %g,loss2 is %g,loss3 is %g'
                % (epoch, np.mean(loss_g), np.mean(loss_d), np.mean(loss_g),
                   np.mean(loss_d)))
            f.write('\n')
        if not os.path.exists(ct.WEIGHTS_SAVE_DIR):
            os.makedirs(ct.WEIGHTS_SAVE_DIR)
        utils.save_weights(epoch, netg, optimizer_g, ct.WEIGHTS_SAVE_DIR,
                           'netG')
        utils.save_weights(epoch, netd, optimizer_d, ct.WEIGHTS_SAVE_DIR,
                           'netD')
        duration = time.time() - start_time
        print('training duration is %g' % duration)

        #val phase
        print('Validating.................')
        pretrained_dict = torch.load(
            os.path.join(ct.WEIGHTS_SAVE_DIR,
                         'current_netG.pth'))['model_state_dict']
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        net = NetG(ct.ISIZE, ct.NC * 2, ct.NZ, ct.NDF,
                   ct.EXTRALAYERS).to(device=device)
        net.load_state_dict(pretrained_dict, False)
        with net.eval() and torch.no_grad():
            TP = 0
            FN = 0
            FP = 0
            TN = 0
            for k, data in enumerate(val_dataloader):
                x1, x2, label = data
                x1 = x1.to(device, dtype=torch.float)
                x2 = x2.to(device, dtype=torch.float)
                label = label.to(device, dtype=torch.float)
                label = label[:, 0, :, :].unsqueeze(1)
                x = torch.cat((x1, x2), 1)
                time_i = time.time()
                v_fake = net(x)

                tp, fp, tn, fn = eva.f1(v_fake, label)
                TP += tp
                FN += fn
                TN += tn
                FP += fp

            precision = TP / (TP + FP + 1e-8)
            oa = (TP + TN) / (TP + FN + TN + FP + 1e-8)
            recall = TP / (TP + FN + 1e-8)
            f1 = 2 * precision * recall / (precision + recall + 1e-8)
            if not os.path.exists(ct.BEST_WEIGHT_SAVE_DIR):
                os.makedirs(ct.BEST_WEIGHT_SAVE_DIR)
            if f1 > best_f1:
                best_f1 = f1
                shutil.copy(
                    os.path.join(ct.WEIGHTS_SAVE_DIR, 'current_netG.pth'),
                    os.path.join(ct.BEST_WEIGHT_SAVE_DIR, 'netG.pth'))
            print('current F1: {}'.format(f1))
            print('best f1: {}'.format(best_f1))
            with open(os.path.join(ct.OUTPUTS_DIR, 'f1_score.txt'), 'a') as f:
                f.write('current epoch:{},current f1:{},best f1:{}'.format(
                    epoch, f1, best_f1))
                f.write('\n')