Пример #1
0
class Trainer():
    def __init__(self, opt):
        self.device = torch.device('cuda')
        self.opt = opt
        self.G = Generator(self.opt['network_G']).to(self.device)
        util.init_weights(self.G, init_type='kaiming', scale=0.1)
        if self.opt['path']['pretrain_G']:
            self.G.load_state_dict(torch.load(self.opt['path']['pretrain_G']),
                                   strict=True)
        self.D = Discriminator(self.opt['network_D']).to(self.device)
        util.init_weights(self.D, init_type='kaiming', scale=1)
        self.FE = VGGFeatureExtractor().to(self.device)
        self.G.train()
        self.D.train()
        self.FE.eval()

        self.log_dict = OrderedDict()

        self.optim_params = [
            v for k, v in self.G.named_parameters() if v.requires_grad
        ]
        self.opt_G = torch.optim.Adam(self.optim_params,
                                      lr=self.opt['train']['lr_G'],
                                      betas=(self.opt['train']['b1_G'],
                                             self.opt['train']['b2_G']))
        self.opt_D = torch.optim.Adam(self.D.parameters(),
                                      lr=self.opt['train']['lr_D'],
                                      betas=(self.opt['train']['b1_D'],
                                             self.opt['train']['b2_D']))

        self.optimizers = [self.opt_G, self.opt_D]
        self.schedulers = [
            lr_scheduler.MultiStepLR(optimizer, self.opt['train']['lr_steps'],
                                     self.opt['train']['lr_gamma'])
            for optimizer in self.optimizers
        ]

    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()

    def get_current_log(self):
        return self.log_dict

    def get_current_learning_rate(self):
        return self.schedulers[0].get_lr()[0]

    def load_model(self, step, strict=True):
        self.G.load_state_dict(torch.load(
            f"{self.opt['path']['checkpoints']['models']}/{step}_G.pth"),
                               strict=strict)
        self.D.load_state_dict(torch.load(
            f"{self.opt['path']['checkpoints']['models']}/{step}_D.pth"),
                               strict=strict)

    def resume_training(self, resume_state):
        '''Resume the optimizers and schedulers for training'''

        resume_optimizers = resume_state['optimizers']
        resume_schedulers = resume_state['schedulers']
        assert len(resume_optimizers) == len(
            self.optimizers), 'Wrong lengths of optimizers'
        assert len(resume_schedulers) == len(
            self.schedulers), 'Wrong lengths of schedulers'
        for i, o in enumerate(resume_optimizers):
            self.optimizers[i].load_state_dict(o)
        for i, s in enumerate(resume_schedulers):
            self.schedulers[i].load_state_dict(s)

    def save_network(self, network, network_label, iter_step):

        util.mkdir(self.opt['path']['checkpoints']['models'])
        save_filename = '{}_{}.pth'.format(iter_step, network_label)
        save_path = os.path.join(self.opt['path']['checkpoints']['models'],
                                 save_filename)

        if isinstance(network, nn.DataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, save_path)

    def save_model(self, epoch, current_step):
        self.save_network(self.G, 'G', current_step)
        self.save_network(self.D, 'D', current_step)
        self.save_training_state(epoch, current_step)

    def save_training_state(self, epoch, iter_step):
        '''Saves training state during training, which will be used for resuming'''
        state = {
            'epoch': epoch,
            'iter': iter_step,
            'schedulers': [],
            'optimizers': []
        }
        for s in self.schedulers:
            state['schedulers'].append(s.state_dict())
        for o in self.optimizers:
            state['optimizers'].append(o.state_dict())
        save_filename = '{}.state'.format(iter_step)
        util.mkdir(self.opt['path']['checkpoints']['states'])
        save_path = os.path.join(self.opt['path']['checkpoints']['states'],
                                 save_filename)
        torch.save(state, save_path)

    def train(self, train_batch, step):

        self.lr = train_batch['LR'].to(self.device)
        self.hr = train_batch['HR'].to(self.device)

        for p in self.D.parameters():
            p.requires_grad = False

        self.opt_G.zero_grad()

        self.sr = self.G(self.lr)

        l_g_total = 0
        # pixel loss
        l_g_pix = self.opt['train']['wt_pix'] * cri_pix(self.sr, self.hr)
        l_g_total += l_g_pix

        # feature loss
        real_fea = self.FE(self.hr).detach()
        fake_fea = self.FE(self.sr)
        l_g_fea = self.opt['train']['wt_fea'] * cri_fea(fake_fea, real_fea)
        l_g_total += l_g_fea

        # ragan loss

        pred_g_fake = self.D(self.sr)
        pred_d_real = self.D(self.hr).detach()

        l_g_gan = self.opt['train']['wt_gan'] * (
            cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
            cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
        l_g_total += l_g_gan

        l_g_total.backward()
        self.opt_G.step()

        # D
        for p in self.D.parameters():
            p.requires_grad = True

        self.opt_D.zero_grad()

        l_d_total = 0

        pred_d_real = self.D(self.hr)
        pred_d_fake = self.D(self.sr.detach())  # detach to avoid BP to G

        l_d_real = cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
        l_d_fake = cri_gan(pred_d_fake - torch.mean(pred_d_real), False)

        l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.opt_D.step()

        # set log
        # G
        self.log_dict['l_g_pix'] = l_g_pix.item()
        self.log_dict['l_g_fea'] = l_g_fea.item()
        self.log_dict['l_g_gan'] = l_g_gan.item()

        # D
        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()

        # D outputs
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def validate(self, val_batch, current_step):
        avg_psnr = 0.0
        avg_ssim = 0.0
        idx = 0
        for _, val_data in enumerate(val_batch):
            idx += 1
            img_name = os.path.splitext(
                os.path.basename(val_data['LR_path'][0]))[0]
            img_dir = os.path.join(
                self.opt['path']['checkpoints']['val_image_dir'], img_name)
            util.mkdir(img_dir)

            self.val_lr = val_data['LR'].to(self.device)
            self.val_hr = val_data['HR'].to(self.device)

            self.G.eval()
            with torch.no_grad():
                self.val_sr = self.G(self.val_lr)
            self.G.train()

            val_LR = self.val_lr.detach()[0].float().cpu()
            val_SR = self.val_sr.detach()[0].float().cpu()
            val_HR = self.val_hr.detach()[0].float().cpu()

            sr_img = util.tensor2img(val_SR)  # uint8
            gt_img = util.tensor2img(val_HR)  # uint8

            # Save SR images for reference
            save_img_path = os.path.join(
                img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
            cv2.imwrite(save_img_path, sr_img)

            # calculate PSNR
            crop_size = 4
            gt_img = gt_img / 255.
            sr_img = sr_img / 255.
            cropped_sr_img = sr_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            cropped_gt_img = gt_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            avg_psnr += PSNR(cropped_sr_img * 255, cropped_gt_img * 255)
            avg_ssim += SSIM(cropped_sr_img * 255, cropped_gt_img * 255)

        avg_psnr = avg_psnr / idx
        avg_ssim = avg_ssim / idx
        return avg_psnr, avg_ssim
Пример #2
0
def main(args):
    env = gym.make('CartPole-v0')
    env.seed(0)
    ob_space = env.observation_space
    Policy = Policy_net('policy', env)
    Old_Policy = Policy_net('old_policy', env)
    PPO = PPOTrain(Policy, Old_Policy, gamma=args.gamma)
    D = Discriminator(env)

    expert_observations = np.genfromtxt('trajectory/observations.csv')
    expert_actions = np.genfromtxt('trajectory/actions.csv', dtype=np.int32)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        writer = tf.summary.FileWriter(args.logdir, sess.graph)
        sess.run(tf.global_variables_initializer())

        obs = env.reset()
        success_num = 0

        for iteration in range(args.iteration):
            observations = []
            actions = []
            # do NOT use rewards to update policy
            rewards = []
            v_preds = []
            run_policy_steps = 0
            while True:
                run_policy_steps += 1
                obs = np.stack([obs]).astype(dtype=np.float32)  # prepare to feed placeholder Policy.obs
                act, v_pred = Policy.act(obs=obs, stochastic=True)

                act = np.asscalar(act)
                v_pred = np.asscalar(v_pred)
                next_obs, reward, done, info = env.step(act)

                observations.append(obs)
                actions.append(act)
                rewards.append(reward)
                v_preds.append(v_pred)

                if done:
                    next_obs = np.stack([next_obs]).astype(dtype=np.float32)  # prepare to feed placeholder Policy.obs
                    _, v_pred = Policy.act(obs=next_obs, stochastic=True)
                    v_preds_next = v_preds[1:] + [np.asscalar(v_pred)]
                    obs = env.reset()
                    break
                else:
                    obs = next_obs

            writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_length', simple_value=run_policy_steps)])
                               , iteration)
            writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(rewards))])
                               , iteration)

            if sum(rewards) >= 195:
                success_num += 1
                if success_num >= 100:
                    saver.save(sess, args.savedir + '/model.ckpt')
                    print('Clear!! Model saved.')
                    break
            else:
                success_num = 0

            # convert list to numpy array for feeding tf.placeholder
            observations = np.reshape(observations, newshape=[-1] + list(ob_space.shape))
            actions = np.array(actions).astype(dtype=np.int32)

            # train discriminator
            for i in range(2):
                D.train(expert_s=expert_observations,
                        expert_a=expert_actions,
                        agent_s=observations,
                        agent_a=actions)

            # output of this discriminator is reward
            d_rewards = D.get_rewards(agent_s=observations, agent_a=actions)
            d_rewards = np.reshape(d_rewards, newshape=[-1]).astype(dtype=np.float32)

            gaes = PPO.get_gaes(rewards=d_rewards, v_preds=v_preds, v_preds_next=v_preds_next)
            gaes = np.array(gaes).astype(dtype=np.float32)
            # gaes = (gaes - gaes.mean()) / gaes.std()
            v_preds_next = np.array(v_preds_next).astype(dtype=np.float32)

            # train policy
            inp = [observations, actions, gaes, d_rewards, v_preds_next]
            PPO.assign_policy_parameters()
            for epoch in range(6):
                sample_indices = np.random.randint(low=0, high=observations.shape[0],
                                                   size=32)  # indices are in [low, high)
                sampled_inp = [np.take(a=a, indices=sample_indices, axis=0) for a in inp]  # sample training data
                PPO.train(obs=sampled_inp[0],
                          actions=sampled_inp[1],
                          gaes=sampled_inp[2],
                          rewards=sampled_inp[3],
                          v_preds_next=sampled_inp[4])

            summary = PPO.get_summary(obs=inp[0],
                                      actions=inp[1],
                                      gaes=inp[2],
                                      rewards=inp[3],
                                      v_preds_next=inp[4])

            writer.add_summary(summary, iteration)
        writer.close()